In [2]:
import pyro
import pyro.distributions as dist
import torch
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import torch.nn.functional as F
from pyro.infer import SVI, TraceMeanField_ELBO
from tqdm import trange
import math

## Class to generate data

In [4]:
class ToyGenerator:
    def __init__(self,  
                 num_bins=5, 
                 num_references=3, 
                 num_signals=3,
                 num_states=3,
                 high_w=100):
        self.num_bins = num_bins
        self.num_references = num_references
        self.num_signals = num_signals
        self.num_states = num_states
        self.high_w = high_w
        self.sample = None
        self.params = self.set_params()
    
        
    # parameter of state->signal distributions
    # shape is (num_states, num_signals)
    def generate_param_p(self):
        p = torch.zeros((self.num_states, self.num_signals))
        for i in range(self.num_states):
            w = -self.high_w * torch.ones(self.num_signals)
            w[i % self.num_signals] = self.high_w
            p[i,:] = w
        return p
    
    # generate a state assignment tensor
    # shape is (num_bins, num_references)
    # helper function for set_params
    def generate_ref_states(self):
        ref_states = torch.zeros(
            (self.num_bins, 
             self.num_references))
        
        for i in range(self.num_references):
            ref_states[:,i] = i % self.num_states
        return ref_states.long()
    
    # set parameters of the data generator
    def set_params(self):
        # parameters of the dirichlet over references
        # same one for every region
        # very high probability that generated sample looks like
        # reference 0
        # shape is (num_references,)
        alpha = torch.ones(self.num_references)
        # alpha[0] = self.high_w
        
        # parameters of bernoulli distribution for each signal
        # for each state
        # shape is (num_states, num_signals)
        p = self.generate_param_p()
        
        # an indicator matrix along genome of the state for 
        # each refenrece
        # shape is (num_regions, num_bins_per_region, num_states, num_references)
        ref_states_indicator = F.one_hot(self.generate_ref_states(), self.num_states)
        params = {
            'alpha': alpha,
            'p': p,
            'ref_states_indicator': ref_states_indicator
        }
        self.params = params
        return params
        
    # collapse a prob vector over references to a prob vector over states
    # takes the cross product of prob vector theta and reference state indicator matrix r
    # shapes:
    #  theta: (None, num_references)
    #  r: (None, num_references, num_states)
    #  out: (None, num_states)
    def collapse_theta(self, theta, r=None):
        if r is None:
            assert self.params is not None
            r = self.params['ref_states_indicator']
            
        r = r.float()
        collapsed_theta = torch.zeros(theta.shape[0], r.shape[2])
        for i in range(theta.shape[0]):
            collapsed_theta[i,:] = torch.matmul(r[i,:,:].T, theta[i,:])
        return collapsed_theta
    
    def generate_sample(self):
        if self.params is None:
            self.set_params()
            
        r = self.params['ref_states_indicator']
                
        # generate reference distribution for each region
        with pyro.plate('bins', self.num_bins):
            # theta is shape (num_regions, num_references)
            theta = pyro.sample('theta', dist.Dirichlet(self.params['alpha']))
            # collapse the reference distribution for each bin to a 
            # state distribution 
            collapsed_theta = self.collapse_theta(theta, r)

            signal_params = torch.sigmoid(torch.matmul(collapsed_theta, self.params['p']))
            m = pyro.sample('m', dist.Bernoulli(signal_params).to_event(1))

        result = {
            'theta': theta,
            'm': m,
        }
        self.sample = result
        return self.sample
    
    def get_sampled_collapsed_theta(self):
        if self.sample is None:
            self.generate_sample()
        theta = self.sample['theta']
        return self.collapse_theta(theta)
    
    def get_sampled_signals(self):
        if self.sample is None:
            self.generate_sample()
        return self.sample['m']
    
    def get_sampled_theta(self):
        if self.sample is None:
            self.generate_sample()
        return self.sample['theta']
    
    def get_signal_parms(self):
        collapsed_theta = self.get_sampled_collapsed_theta()
        return torch.sigmoid(torch.matmul(collapsed_theta, self.params['p']))
    
    def get_ref_state_indicators(self):
        if self.params is None:
            self.set_params()
        return self.params['ref_states_indicator']

In [5]:
'''
M: # regions
N: # bins per region
L: # signals (marks)
alpha: params of dirichlet prior over reference epigenomics
beta: ref --> sample state categorical distribution
p: state --> signal bernoulli distribution 
r: reference state at each bin. one-hot encoding, matrix size : #bins * #ref * #states
theta: the mixture probabilities of reference ethetagenome
'''

class CircularStateGenerator:
    # Within the number of references, there is a group of references that will be similar to the 
    # sample of interests in terms of state assignments
    def __init__(self,  
                 num_bins=5, 
                 num_references=10, 
                 num_groups=3,
                 state_vary_rate=0.01, 
                 # fraction of the genome where the state assignments among references of the same group are diff
                 num_signals=3,
                 num_states=5,
                 high_w=100):
        self.num_bins = num_bins
        self.num_references = num_references
        self.num_groups = num_groups
        self.state_vary_rate = state_vary_rate
        self.num_signals = num_signals
        self.num_states = num_states
        self.high_w = high_w
        self.sample = None
        self.params = self.set_params()
    
        
    # parameter of state->signal distributions
    # shape is (num_states, num_signals)
    def generate_param_p(self):
        p = torch.zeros((self.num_states, self.num_signals))
        for i in range(self.num_states):
            w = -self.high_w * torch.ones(self.num_signals)
            w[i % self.num_signals] = self.high_w
            p[i,:] = w
        return p
    
    # generate a state assignment tensor
    # shape is (num_regions, num_bins_per_region, num_references)
    def generate_ref_states(self):
        # this is code for the case where we want varied state patterns from each reference
        # and that there are actually groups of references that are similar to each other
        num_ref_per_groups = np.ceil(self.num_references/self.num_groups).astype(int)
        sample_r = torch.zeros(self.num_states, self.num_groups)
        for i in range(self.num_groups):
            sample_r[:,i] = torch.arange(self.num_states).roll(i)
            # each group has a circular permutation of states that are characteristics to that group
        sample_r = sample_r.repeat(np.ceil(self.num_bins / self.num_states).astype(int), 1)
        # now r is just a repeated sequence of sample_r
        r = torch.zeros(sample_r.shape[0], self.num_references)
        for i in range(self.num_references):
            r[:,i] = sample_r[:, i % self.num_groups]
        # now we will start to introduce some random changes to the state assignments among references from
        # the same groups
        num_change = int(self.state_vary_rate * self.num_bins)
        for i in range(self.num_states, self.num_references): 
            # for the first num_states columns, keep all the state assignments
            # if num_references < num_states, this loop will not be called
            org_r = r[:,i]
            indices_to_change = np.random.choice(self.num_bins, num_change)
            indices_to_change = torch.tensor(indices_to_change).type(torch.LongTensor)
            states_to_change = torch.tensor(np.random.choice(self.num_states, num_change)).float()
            r[indices_to_change,i] = states_to_change
        r = r[:self.num_bins,:self.num_references]
        return r.long() # num_bins, num_references --> values: state-0-based 
    
    # set parameters of the data generator
    def set_params(self):
        # parameters of the dirichlet over references
        # same one for every region
        # very high probability that generated sample looks like
        # reference 0
        # shape is (num_references,)
        alpha = torch.ones(self.num_references)
        num_ref_per_groups = np.ceil(self.num_references/self.num_groups).astype(int)
        for i in range(self.num_references):
            if i % self.num_groups == 0:
                alpha[i] = self.high_w # all refs in group 1 will be more similar to sample of interest
        
        # parameters of bernoulli distribution for each signal
        # for each state
        # shape is (num_states, num_signals)
        p = self.generate_param_p()
        
        # an indicator matrix along genome of the state for 
        # each refenrece
        # shape is (num_regions, num_bins_per_region, num_states, num_references)
        ref_states_indicator = F.one_hot(self.generate_ref_states(), self.num_states)
        params = {
            'alpha': alpha,
            'p': p,
            'ref_states_indicator': ref_states_indicator
        }
        self.params = params
        return params
        
    # collapse a prob vector over references to a prob vector over states
    # takes the cross product of prob vector theta and reference state indicator matrix r
    # shapes:
    #  theta: (None, num_references)
    #  r: (None, num_references, num_states)
    #  out: (None, num_states)
    def collapse_theta(self, theta, r=None):
        if r is None:
            assert self.params is not None
            r = self.params['ref_states_indicator']
            
        r = r.float()
        collapsed_theta = torch.zeros(theta.shape[0], r.shape[2])
        for i in range(theta.shape[0]):
            collapsed_theta[i,:] = torch.matmul(r[i,:,:].T, theta[i,:])
        return collapsed_theta
    
    def generate_sample(self):
        if self.params is None:
            self.set_params()
            
        r = self.params['ref_states_indicator']
                
        # generate reference distribution for each region
        with pyro.plate('bins', self.num_bins):
            # theta is shape (num_regions, num_references)
            theta = pyro.sample('theta', dist.Dirichlet(self.params['alpha']))
            # collapse the reference distribution for each bin to a 
            # state distribution 
            collapsed_theta = self.collapse_theta(theta, r)

            signal_params = torch.sigmoid(torch.matmul(collapsed_theta, self.params['p']))
            m = pyro.sample('m', dist.Bernoulli(signal_params).to_event(1))

        result = {
            'theta': theta,
            'm': m
        }
        self.sample = result
        return self.sample
    

    def get_sampled_collapsed_theta(self):
        if self.sample is None:
            self.generate_sample()
        theta = self.sample['theta']
        return self.collapse_theta(theta)
    
    def get_sampled_signals(self):
        if self.sample is None:
            self.generate_sample()
        return self.sample['m']
    
    def get_sampled_theta(self):
        if self.sample is None:
            self.generate_sample()
        return self.sample['theta']
    
    def get_signal_parms(self):
        collapsed_theta = self.get_sampled_collapsed_theta()
        return torch.sigmoid(torch.matmul(collapsed_theta, self.params['p']))
    
    def get_ref_state_indicators(self):
        if self.params is None:
            self.set_params()
        return self.params['ref_states_indicator']

In [7]:
serious_parms = {
    'num_bins': 10000,
    'num_references': 10,
    'num_groups': 3,
    'state_vary_rate': 0.003,
    'num_signals': 3,
    'num_states': 3,
    'high_w': 100
}

seed = 0
torch.manual_seed(seed)
pyro.set_rng_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

generator = CircularStateGenerator(**serious_parms)

m = generator.get_sampled_signals()
r = generator.get_ref_state_indicators()
collapsed_theta = generator.get_sampled_collapsed_theta()
theta = generator.get_sampled_theta()
signal_params = generator.get_signal_parms()

In [8]:
print(serious_parms)
print('m: obs. signals at each position')
print(m.shape)
print(m[:10,:])
print('r: reference epigenome state indicator at each position')
print(r.shape)
print(r[:10,2,:])
print('collapsed_theta: state assignment at each position')
print(collapsed_theta.shape)
print(collapsed_theta)
print('theta: the reference mixture at each position')
print(theta.shape)
print(theta)
print('signal_params: bernoulli dist. params generating signal at each position')
print(signal_params.shape)
print('p')
p = generator.params['p']
print (p)

{'num_bins': 10000, 'num_references': 10, 'num_groups': 3, 'state_vary_rate': 0.003, 'num_signals': 3, 'num_states': 3, 'high_w': 100}
m: obs. signals at each position
torch.Size([10000, 3])
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.]])
r: reference epigenome state indicator at each position
torch.Size([10000, 10, 3])
tensor([[0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
collapsed_theta: state assignment at each position
torch.Size([10000, 3])
tensor([[0.9912, 0.0020, 0.0068],
        [0.0109, 0.9862, 0.0029],
        [0.0024, 0.0276, 0.9700],
        ...,
        [0.0027, 0.9882, 0.0091],
        [0.0047, 0.0035, 0.9917],
        [0.9931, 0.0049, 0.0019]])
theta: the reference m

## Model with data of signals and ref states

In [9]:
class Encoder(nn.Module):
	def __init__(self, num_signals, num_states, num_references, hidden, dropout):
		super().__init__()
		self.drop = nn.Dropout(dropout)
		input_dim = num_signals + num_states * num_references
		self.fc1 = nn.Linear(input_dim, hidden)
		self.fc2 = nn.Linear(hidden, hidden)
		self.fcmu = nn.Linear(hidden, num_states)
		self.fclv = nn.Linear(hidden, num_states)

	def forward(self, m, r):
		inputs = torch.cat((m, r.reshape(r.shape[0], -1)), 1)
		h = F.softplus(self.fc1(inputs))
		h = F.softplus(self.fc2(h))
		h = self.drop(h)
		logpi_loc = (self.fcmu(h))
		logpi_logvar = self.fclv(h)
		logpi_scale = (0.5 * logpi_logvar).exp()
		return logpi_loc, logpi_scale

class Decoder(nn.Module):
	def __init__(self, num_states, num_signals, num_references, hidden, dropout):
		super().__init__()
		self.num_states = num_states
		self.num_signals = num_signals
		self.num_references = num_references
		self.drop = nn.Dropout(dropout)
		self.fcih = nn.Linear(num_states, hidden) # input (state probabilities) --> hidden
		self.fchh = nn.Linear(hidden, hidden) # hiddent --> hidden
		self.fchs = nn.Linear(hidden, self.num_signals) # hidden --> signals
		self.fchr = nn.Linear(hidden, self.num_references * self.num_states) # hidden --> states in references



	def forward(self, inputs):
		# takes in the values of collapsed pi: probabilities of state 
		# assignments at each positions, and then apply a linear trans
		# to get the probabilities of observing signals at each position
		# --> vector size #signals
		# used as parameters for bernoulli dist. to get obs. signals
		# create multiple layers
		# inputs: bins, state probabilities
		# h: bins, hidden 
		# signal_param: bins, signals
		# ref_param: bins, num_references, num_states
		h = F.softplus(self.fcih(inputs)) # --> hidden element vector
		h = F.softplus(self.fchh(h)) # --> hidden element vector
		h = self.drop(h)
		signal_param = torch.sigmoid(self.fchs(h)) # hidden --> marks
		ref_param = torch.sigmoid(self.fchr(h)).reshape((h.shape[0], self.num_references, self.num_states)) 
		# hidden --> num_ref*num_states
		ref_param = F.normalize(ref_param, p = 1.0, dim = 2) # row normalize, sum over states per ref is 1
		return signal_param, ref_param

class Model_signals_refStates(nn.Module):
	def __init__(self, num_signals, num_references, num_states, hidden, dropout):
		super().__init__()
		self.num_signals = num_signals
		self.num_references = num_references
		self.num_states = num_states
		self.hidden = hidden
		self.dropout = dropout
		self.encoder = Encoder(num_signals, num_states, num_references, hidden, dropout)
		self.decoder = Decoder(num_states, num_signals, num_references, hidden, dropout)

	# shapes: 
	#  m: (bins x signals) signal matrix
	#  r: (bins x reference x state) indicator matrix
	def model(self, m, r):
		# flatten out the r indicator matrix
		pyro.module("decoder", self.decoder)
		with pyro.plate('bins', m.shape[0]):
			logCpi_loc = m.new_zeros((m.shape[0], self.num_states))
			logCpi_scale = m.new_ones((m.shape[0], self.num_states))
			logCpi = pyro.sample('log_collapsedPi', dist.Normal(logCpi_loc, logCpi_scale).to_event(1))
			Cpi = F.softmax(logCpi, -1) 
			# the softmax function should be used here because Cpi is lognormal
			signal_param, ref_param = self.decoder(Cpi) # vector of probabilities. 
			# signal_param: bins, signals
			# ref_param: bins, references, states
			# first num_signals elements: bernoulli params
			# each of the following num_states elements: multinomial params
			# for the state segmentation in a reference    
			pyro.sample('m', dist.Bernoulli(signal_param).to_event(1), obs=m)
			# plate across references
			with pyro.plate('refs', self.num_references):
				pyro.sample('r', dist.Multinomial(1, ref_param).to_event(1), obs = r)

	def guide(self, m, r):
		pyro.module("encoder", self.encoder)
		with pyro.plate('bins', m.shape[0]):
			logpi_loc, logpi_scale = self.encoder(m, r)
			logpi = pyro.sample('log_collapsedPi', dist.Normal(logpi_loc, logpi_scale).to_event(1))

	def predict_state_assignment(self, m, r):
		logpi_loc, logpi_scale = self.encoder(m, r)
		Cpi = F.softmax(logpi_loc, -1)
		return(Cpi)


	def generate_reconstructed_data(self, m, r):
		logpi_loc, logpi_scale = self.encoder(m, r)
		Cpi = F.softmax(logpi_loc, -1)
		signal_param, ref_param = self.decoder(Cpi) # vector of probabilities. 
		re_m = pyro.sample('re_m', dist.Bernoulli(signal_param).to_event(1))
		re_r = pyro.sample('re_r', dist.Multinomial(1, ref_param).to_event(1))
		return(re_m, re_r)


	def get_percentage_correct_reconstruct(self, m, r):
		# m and r can be different from the m and r used in training
		re_m, re_r = self.generate_reconstructed_data(m,r)
		total_m_entries = re_m.shape[0] * re_m.shape[1]
		signals_CR = (re_m==m).sum() # correct reconstruct entries of signals
		total_r_entries = re_r.shape[0] * self.num_references
		# for each reference at each position, if the state assignment is different between re_r and r, there are 2 out of num_states entries that are different between re_r and r
		wrong_r = ((re_r.shape[0] * re_r.shape[1] * re_r.shape[2]) - (re_r==r).sum()) / 2 # wrong reconstruct entries of reference states
		r_CR = total_r_entries - wrong_r
		ratio_m_CR = (signals_CR / total_m_entries).item()
		ratio_r_CR = (r_CR / total_r_entries).item()
		return ratio_m_CR, ratio_r_CR


In [11]:
batch_size = 200
learning_rate = 1e-3
num_epochs = 1000
pyro.clear_param_store()
state_model = Model_signals_refStates(
    num_signals = generator.num_signals,
    num_references = generator.num_references,
    num_states = generator.num_states,
    hidden = 32,
    dropout = 0.2)
state_model.to(device)
optimizer = pyro.optim.Adam({"lr": learning_rate})
svi = SVI(state_model.model, state_model.guide, optimizer, loss=TraceMeanField_ELBO())
num_batches = int(math.ceil(m.shape[0] / batch_size))

bar = trange(num_epochs)
for epoch in bar:
    running_loss = 0.0
    for i in range(num_batches):
        batch_m = m[i * batch_size:(i+1) * batch_size, :]
        batch_r = r[i * batch_size:(i+1) * batch_size, :, :]
        loss = svi.step(batch_m, batch_r)
        running_loss += loss / batch_m.size(0)
        
    bar.set_postfix(epoch_loss='{:.2e}'.format(running_loss))

100%|██████████████████| 1000/1000 [07:56<00:00,  2.10it/s, epoch_loss=1.79e+02]


In [12]:
# ratio_m_CR, ratio_r_CR = state_model.get_percentage_correct_reconstruct(m,r)
re_m, re_r = state_model.generate_reconstructed_data(m,r)
wrong_r = ((re_r.shape[0] * re_r.shape[1] * re_r.shape[2]) - (re_r==r).sum()) / 2
print(wrong_r)

tensor(208.)


In [13]:

Cpi = state_model.predict_state_assignment(m, r)
print(Cpi.shape)

torch.Size([10000, 3])


In [14]:
logpi_loc, logpi_scale = state_model.encoder(m, r)
Cpi = F.softmax(logpi_loc, -1)
signal_param, ref_param = state_model.decoder(Cpi) # vector of probabilities. 
re_m = pyro.sample('re_m', dist.Bernoulli(signal_param).to_event(1))
re_r = pyro.sample('re_r', dist.Multinomial(1, ref_param).to_event(1))

In [16]:
print(pd.DataFrame(Cpi.detach().numpy()))

             0         1         2
0     0.743172  0.135145  0.121683
1     0.121452  0.112941  0.765607
2     0.132993  0.740218  0.126790
3     0.723228  0.128154  0.148618
4     0.117799  0.110641  0.771560
...        ...       ...       ...
9995  0.096648  0.807010  0.096343
9996  0.773625  0.115915  0.110460
9997  0.077738  0.071258  0.851004
9998  0.072391  0.831101  0.096508
9999  0.816641  0.091635  0.091724

[10000 rows x 3 columns]


## Model with signals and ref states, fixed beta values

In [17]:
class Encoder(nn.Module):
	def __init__(self, num_signals, num_states, num_references, hidden, dropout):
		super().__init__()
		self.drop = nn.Dropout(dropout)
		input_dim = num_signals + num_states * num_references
		self.fc1 = nn.Linear(input_dim, hidden)
		self.fc2 = nn.Linear(hidden, hidden)
		self.fcmu = nn.Linear(hidden, num_states)
		self.fclv = nn.Linear(hidden, num_states)

	def forward(self, m, r):
		inputs = torch.cat((m, r.reshape(r.shape[0], -1)), 1)
		h = F.softplus(self.fc1(inputs))
		h = F.softplus(self.fc2(h))
		h = self.drop(h)
		logpi_loc = (self.fcmu(h))
		logpi_logvar = self.fclv(h)
		logpi_scale = (0.5 * logpi_logvar).exp()
		return logpi_loc, logpi_scale

class Decoder(nn.Module):
	def __init__(self, num_states, num_signals, num_references, hidden, dropout, fixed_signalP):
		super().__init__()
		self.num_states = num_states
		self.num_signals = num_signals
		self.num_references = num_references
		self.fixed_signalP = fixed_signalP
		self.drop = nn.Dropout(dropout)
		self.fcih = nn.Linear(num_states, hidden) # input (state probabilities) --> hidden
		self.fchh = nn.Linear(hidden, hidden) # hiddent --> hidden
		self.fchs = nn.Linear(hidden, self.num_signals) # hidden --> signals
		self.fchr = nn.Linear(hidden, self.num_references * self.num_states) # hidden --> states in references



	def forward(self, inputs):
		# takes in the values of collapsed pi: probabilities of state 
		# assignments at each positions, and then apply a linear trans
		# to get the probabilities of observing signals at each position
		# --> vector size #signals
		# used as parameters for bernoulli dist. to get obs. signals
		# create multiple layers
		# inputs: bins, state probabilities
		# h: bins, hidden 
		# signal_param: bins, signals
		# ref_param: bins, num_references, num_states
		h = F.softplus(self.fcih(inputs)) # --> hidden element vector
		h = F.softplus(self.fchh(h)) # --> hidden element vector
		h = self.drop(h)
		signal_param = torch.sigmoid(torch.matmul(inputs, self.fixed_signalP)) # hidden --> marks
		ref_param = torch.sigmoid(self.fchr(h)).reshape((h.shape[0], self.num_references, self.num_states)) 
		# hidden --> num_ref*num_states
		ref_param = F.normalize(ref_param, p = 1.0, dim = 2) # row normalize, sum over states per ref is 1
		return signal_param, ref_param



    
class Model_signals_refStates_fixedBeta(nn.Module):
	def __init__(self, num_signals, num_references, num_states, hidden, dropout, fixed_signalP):
		super().__init__()
		self.num_signals = num_signals
		self.num_references = num_references
		self.num_states = num_states
		self.hidden = hidden
		self.dropout = dropout
		self.fixed_signalP = fixed_signalP
		self.encoder = Encoder(num_signals, num_states, num_references, hidden, dropout)
		self.decoder = Decoder(num_states, num_signals, num_references, hidden, dropout, fixed_signalP)

	# shapes: 
	#  m: (bins x signals) signal matrix
	#  r: (bins x reference x state) indicator matrix
	def model(self, m, r):
		# flatten out the r indicator matrix
		pyro.module("decoder", self.decoder)
		with pyro.plate('bins', m.shape[0]):
			logCpi_loc = m.new_zeros((m.shape[0], self.num_states))
			logCpi_scale = m.new_ones((m.shape[0], self.num_states))
			logCpi = pyro.sample('log_collapsedPi', dist.Normal(logCpi_loc, logCpi_scale).to_event(1))
			Cpi = F.softmax(logCpi, -1) 
			# the softmax function should be used here because Cpi is lognormal
			signal_param, ref_param = self.decoder(Cpi) # vector of probabilities. 
			# signal_param: bins, signals
			# ref_param: bins, references, states
			# first num_signals elements: bernoulli params
			# each of the following num_states elements: multinomial params
			# for the state segmentation in a reference    
			t1 = pyro.sample('m', dist.Bernoulli(signal_param).to_event(1), obs=m)
			# plate across references
			with pyro.plate('refs', self.num_references):
				t2 = pyro.sample('r', dist.Multinomial(1, ref_param).to_event(1), obs = r)

	def guide(self, m, r):
		pyro.module("encoder", self.encoder)
		with pyro.plate('bins', m.shape[0]):
			logpi_loc, logpi_scale = self.encoder(m, r)
			logpi = pyro.sample('log_collapsedPi', dist.Normal(logpi_loc, logpi_scale).to_event(1))

	def predict_state_assignment(self, m, r):
		logpi_loc, logpi_scale = self.encoder(m, r)
		Cpi = F.softmax(logpi_loc, -1)
		return(Cpi)


	def generate_reconstructed_data(self, m, r):
		logpi_loc, logpi_scale = self.encoder(m, r)
		Cpi = F.softmax(logpi_loc, -1)
		signal_param, ref_param = self.decoder(Cpi) # vector of probabilities. 
		re_m = pyro.sample('re_m', dist.Bernoulli(signal_param).to_event(1))
		re_r = pyro.sample('re_r', dist.Multinomial(1, ref_param).to_event(1))
		return(re_m, re_r)


	def get_percentage_correct_reconstruct(self, m, r):
		# m and r can be different from the m and r used in training
		re_m, re_r = self.generate_reconstructed_data(m,r)
		total_m_entries = re_m.shape[0] * re_m.shape[1]
		signals_CR = (re_m==m).sum() # correct reconstruct entries of signals
		total_r_entries = re_r.shape[0] * self.num_references
		# for each reference at each position, if the state assignment is different between re_r and r, there are 2 out of num_states entries that are different between re_r and r
		wrong_r = ((re_r.shape[0] * re_r.shape[1] * re_r.shape[2]) - (re_r==r).sum()) / 2 # wrong reconstruct entries of reference states
		r_CR = total_r_entries - wrong_r
		ratio_m_CR = (signals_CR / total_m_entries).item()
		ratio_r_CR = (r_CR / total_r_entries).item()
		return ratio_m_CR, ratio_r_CR


In [None]:
batch_size = 200
learning_rate = 1e-3
num_epochs = 1000
pyro.clear_param_store()
state_model = Model_signals_refStates_fixedBeta(
    num_signals = generator.num_signals,
    num_references = generator.num_references,
    num_states = generator.num_states,
    hidden = 32,
    dropout = 0.2,
    fixed_signalP  = p)
state_model.to(device)
optimizer = pyro.optim.Adam({"lr": learning_rate})
svi = SVI(state_model.model, state_model.guide, optimizer, loss=TraceMeanField_ELBO())
num_batches = int(math.ceil(m.shape[0] / batch_size))

bar = trange(num_epochs)
for epoch in bar:
    running_loss = 0.0
    for i in range(num_batches):
        batch_m = m[i * batch_size:(i+1) * batch_size, :]
        batch_r = r[i * batch_size:(i+1) * batch_size, :, :]
        loss = svi.step(batch_m, batch_r)
        running_loss += loss / batch_m.size(0)
        
    bar.set_postfix(epoch_loss='{:.2e}'.format(running_loss))

  guide_vars - model_vars
  4%|▊                   | 42/1000 [00:16<06:21,  2.51it/s, epoch_loss=2.58e+02]

In [13]:
logpi_loc, logpi_scale = state_model.encoder(m, r)
Cpi = F.softmax(logpi_loc, -1)
signal_param, ref_param = state_model.decoder(Cpi) # vector of probabilities. 
re_m = pyro.sample('re_m', dist.Bernoulli(signal_param).to_event(1))
re_r = pyro.sample('re_r', dist.Multinomial(1, ref_param).to_event(1))

In [14]:
ratio_m_CR, ratio_r_CR = state_model.get_percentage_correct_reconstruct(m,r)

tensor(29997)
tensor(960000)


In [5]:
(1000000 - torch.tensor(960000) /2).item()

520000.0