# A first attempt at generating some data

Based on this model. Some notation

M: # regions  
N: # of bins per region  
L: # of signals  

alpha: params of dirichlet prior over reference epigenomes  
beta: ref->sample state categorical distribution  
p: state->signal bernoulli distribution
r: reference state at each region/bin

```
for each region d do
  draw ref distr pi_d ~ Dir(alpha)
  for each genomic position i in d do
    sample ref z_di ~ Categorical(p_d)
    sample state s_di ~ Categorical(beta_{r_{z_{di}}})
    for each signal j do
      sample m_{dij} ~ Bernoulli(p_{s_{di}})
```


In [67]:
# set some parameters
num_regions = 12
num_bins_per_region = 5
num_signals = 3
num_states = 5
num_references = 10

In [68]:
import pyro
import pyro.distributions as dist
import torch
import numpy as np

alpha = torch.from_numpy(np.arange(num_references) + 1).float()
beta = torch.zeros((num_states, num_states))
for i in range(num_states):
    w = torch.ones(num_states)
    w[i] = 10
    beta[i,:] = w / torch.sum(w)

print('alpha:')
print(alpha)

print('beta:')
print(beta)

alpha:
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
beta:
tensor([[0.7143, 0.0714, 0.0714, 0.0714, 0.0714],
        [0.0714, 0.7143, 0.0714, 0.0714, 0.0714],
        [0.0714, 0.0714, 0.7143, 0.0714, 0.0714],
        [0.0714, 0.0714, 0.0714, 0.7143, 0.0714],
        [0.0714, 0.0714, 0.0714, 0.0714, 0.7143]])


In [69]:
p = torch.zeros((num_states, num_signals))
for i in range(num_states):
    w = torch.ones(num_signals)
    w[i % num_signals] = 10
    p[i,:] = w / torch.sum(w)

print('p')
print(p)

p
tensor([[0.8333, 0.0833, 0.0833],
        [0.0833, 0.8333, 0.0833],
        [0.0833, 0.0833, 0.8333],
        [0.8333, 0.0833, 0.0833],
        [0.0833, 0.8333, 0.0833]])


In [78]:
ref_states = torch.zeros((num_regions, num_bins_per_region, num_references))
for i in range(num_references):
    ref_states[:,:,i] = i % num_states
ref_states.shape  

torch.Size([12, 5, 10])

In [128]:
ref_states_indicator = torch.zeros((num_regions, num_bins_per_region, num_references, num_states))
for i in range(num_regions):
    for j in range(num_bins_per_region):
        for k in range(num_references):
            ref_states_indicator[i, j, k, ref_states[i,j,k].long()] = 1.
print(ref_states_indicator.shape)
ref_states_indicator[i,j]

torch.Size([12, 5, 10, 5])


tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

In [133]:
pi = dist.Dirichlet(alpha).sample()
collapsed_pi = torch.matmul(ref_states_indicator[i,j,:,:].T, pi)
assert torch.sum(collapsed_pi) == 1

In [1361]:
toy_parms = {
    'num_bins': 25,
    'num_references': 10,
    'num_signals': 3,
    'num_states': 4
}

class ToyGenerator:
    def __init__(self,  
                 num_bins=5, 
                 num_references=10, 
                 num_signals=3,
                 num_states=5,
                 high_w=100):
        self.num_bins = num_bins
        self.num_references = num_references
        self.num_signals = num_signals
        self.num_states = num_states
        self.high_w = high_w
        self.sample = None
        self.params = self.set_params()
    
        
    # parameter of state->signal distributions
    # shape is (num_states, num_signals)
    def generate_param_p(self):
        p = torch.zeros((self.num_states, self.num_signals))
        for i in range(self.num_states):
            w = -self.high_w * torch.ones(self.num_signals)
            w[i % self.num_signals] = self.high_w
            p[i,:] = w
        return p
    
    # generate a state assignment tensor
    # shape is (num_regions, num_bins_per_region, num_references)
    def generate_ref_states(self):
        ref_states = torch.zeros(
            (self.num_bins, 
             self.num_references))
        
        for i in range(self.num_references):
            ref_states[:,i] = i % self.num_states
        return ref_states.long()
    
    # set parameters of the data generator
    def set_params(self):
        # parameters of the dirichlet over references
        # same one for every region
        # very high probability that generated sample looks like
        # reference 0
        # shape is (num_references,)
        alpha = torch.ones(self.num_references)
        alpha[0] = self.high_w
        
        # parameters of bernoulli distribution for each signal
        # for each state
        # shape is (num_states, num_signals)
        p = self.generate_param_p()
        
        # an indicator matrix along genome of the state for 
        # each refenrece
        # shape is (num_regions, num_bins_per_region, num_states, num_references)
        ref_states_indicator = F.one_hot(self.generate_ref_states(), self.num_states)
        params = {
            'alpha': alpha,
            'p': p,
            'ref_states_indicator': ref_states_indicator
        }
        self.params = params
        return params
        
    # collapse a prob vector over references to a prob vector over states
    # takes the cross product of prob vector pi and reference state indicator matrix r
    # shapes:
    #  pi: (None, num_references)
    #  r: (None, num_references, num_states)
    #  out: (None, num_states)
    def collapse_pi(self, pi, r=None):
        if r is None:
            assert self.params is not None
            r = self.params['ref_states_indicator']
            
        r = r.float()
        collapsed_pi = torch.zeros(pi.shape[0], r.shape[2])
        for i in range(pi.shape[0]):
            collapsed_pi[i,:] = torch.matmul(r[i,:,:].T, pi[i,:])
        return collapsed_pi
    
    def generate_sample(self):
        if self.params is None:
            self.set_params()
            
        r = self.params['ref_states_indicator']
                
        # generate reference distribution for each region
        with pyro.plate('bins', self.num_bins):
            # pi is shape (num_regions, num_references)
            pi = pyro.sample('pi', dist.Dirichlet(self.params['alpha']))
            # collapse the reference distribution for each bin to a 
            # state distribution 
            collapsed_pi = self.collapse_pi(pi, r)

            signal_params = torch.sigmoid(torch.matmul(collapsed_pi, self.params['p']))
            m = pyro.sample('m', dist.Bernoulli(signal_params).to_event(1))

        result = {
            'pi': pi,
            'm': m
        }
        self.sample = result
        return self.sample
    
    def get_sampled_collapsed_pi(self):
        if self.sample is None:
            self.generate_sample()
        pi = self.sample['pi']
        return self.collapse_pi(pi)
    
    def get_sampled_signals(self):
        if self.sample is None:
            self.generate_sample()
        return self.sample['m']
    
    def get_sampled_pi(self):
        if self.sample is None:
            self.generate_sample()
        return self.sample['pi']
    
    def get_signal_parms(self):
        collapsed_pi = self.get_sampled_collapsed_pi()
        return torch.sigmoid(torch.matmul(collapsed_pi, self.params['p']))
    
    def get_ref_state_indicators(self):
        if self.params is None:
            self.set_params()
        return self.params['ref_states_indicator']
    

In [1364]:
generator = ToyGenerator(**toy_parms, high_w=1000)
m = generator.get_sampled_signals()
r = generator.get_ref_state_indicators()
print(m.shape)
print(m[0,:])

torch.Size([25, 3])
tensor([1., 0., 0.])


In [1368]:
print(generator.get_sampled_collapsed_pi()[0:3,:,])
print(generator.params['p'])
print(generator.get_signal_parms()[0:3,:])

tensor([[9.9728e-01, 2.4500e-03, 3.9277e-05, 2.3388e-04],
        [9.9439e-01, 1.7245e-03, 2.6350e-03, 1.2457e-03],
        [9.9264e-01, 2.3736e-03, 1.0253e-03, 3.9576e-03]])
tensor([[ 1000., -1000., -1000.],
        [-1000.,  1000., -1000.],
        [-1000., -1000.,  1000.],
        [ 1000., -1000., -1000.]])
tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]])


In [1369]:
import math
import torch.nn as nn
import torch.nn.functional as F

In [1518]:
class Encoder(nn.Module):
    def __init__(self, num_signals, num_states, hidden, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        self.fc1 = nn.Linear(num_signals, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fcmu = nn.Linear(hidden, num_states)
        self.fclv = nn.Linear(hidden, num_states)
        self.bnmu = nn.BatchNorm1d(num_states, affine=True)
        self.bnlv = nn.BatchNorm1d(num_states, affine=True)
        
    def forward(self, x, r):
        #inputs = torch.cat((x, r.reshape(r.shape[0], -1)), 1)
        inputs = x
        h = F.softplus(self.fc1(inputs))
        h = F.softplus(self.fc2(h))
        h = self.drop(h)
        logpi_loc = self.bnmu(self.fcmu(h))
        logpi_logvar = self.bnmu(self.fclv(h))
        logpi_scale = (0.5 * logpi_logvar).exp()
        return logpi_loc, logpi_scale

class Decoder(nn.Module):
    def __init__(self, num_states, num_signals, hidden, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        self.beta = nn.Linear(num_states, num_signals, bias=False)
        self.bn = nn.BatchNorm1d(num_signals, affine=True)
        
    def forward(self, inputs):
        inputs = self.drop(inputs)
        beta = self.beta(inputs)
        return torch.sigmoid(self.bn(beta))
    
class TransferStateModel(nn.Module):
    def __init__(self, 
                 num_signals, 
                 num_references,
                 num_states,
                 hidden, 
                 dropout):
        super().__init__()
        self.num_signals = num_signals
        self.num_references = num_references,
        self.num_states = num_states,
        self.hidden = hidden
        self.dropout = dropout
        self.encoder = Encoder(num_signals, num_states, hidden, dropout)
        self.decoder = Decoder(num_states, num_signals, hidden, dropout)
                
    # shapes: 
    #  m: (bins x signals) signal matrix
    #  r: (bins x reference x state) indicator matrix
    def model(self, m, r):
        # flatten out the r indicator matrix
        pyro.module("decoder", self.decoder)
        with pyro.plate('bins', m.shape[0]):
            logpi_loc = m.new_zeros((torch.Size(m.shape[0], self.num_states[0])))
            logpi_scale = m.new_ones((torch.Size(m.shape[0], self.num_states[0])))
            logpi = pyro.sample(
                'logpi', dist.Normal(logpi_loc, logpi_scale).to_event(1))
            pi = F.softmax(logpi, -1)
            signal_param = self.decoder(pi)          
            pyro.sample('m', dist.Bernoulli(signal_param).to_event(1), obs=m)
                
    def guide(self, m, r):
        pyro.module("encoder", self.encoder)
        with pyro.plate('regions', m.shape[0]):
            logpi_loc, logpi_scale = self.encoder(m, r)
            logpi = pyro.sample(
                'logpi', dist.Normal(logpi_loc, logpi_scale).to_event(1))
   # def p(self):
   #     return self.decoder.p.weight.cpu().detach().T
        

In [1519]:
serious_parms = {
    'num_bins': 100,
    'num_references': 12,
    'num_signals': 3,
    'num_states': 8
}

serious_generator = ToyGenerator(**serious_parms, high_w=100)

In [1520]:
seed = 0
torch.manual_seed(seed)
pyro.set_rng_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

m = serious_generator.get_sampled_signals()
r = serious_generator.get_ref_state_indicators()

print(m.shape)
print(r.shape)

torch.Size([100, 3])
torch.Size([100, 12, 8])


In [1521]:
pyro.clear_param_store()
state_model = TransferStateModel(
    num_signals = serious_generator.num_signals,
    num_references = serious_generator.num_references,
    num_states = serious_generator.num_states,
    hidden = 32,
    dropout = 0.2)

In [1522]:
from pyro.infer import SVI, TraceMeanField_ELBO
from tqdm import trange

In [1523]:
batch_size = 20
learning_rate = 1e-3
num_epochs = 1000

state_model.to(device)
optimizer = pyro.optim.Adam({"lr": learning_rate})
svi = SVI(state_model.model, state_model.guide, optimizer, loss=TraceMeanField_ELBO())
num_batches = int(math.ceil(m.shape[0] / batch_size))

bar = trange(num_epochs)
for epoch in bar:
    running_loss = 0.0
    for i in range(num_batches):
        batch_m = m[i * batch_size:(i+1) * batch_size, :]
        batch_r = r[i * batch_size:(i+1) * batch_size, :, :]
        loss = svi.step(batch_m, batch_r)
        running_loss += loss / batch_m.size(0)
        
    bar.set_postfix(epoch_loss='{:.2e}'.format(running_loss))

  0%|                                                                                                 | 0/1000 [00:00<?, ?it/s]


TypeError: tuple expected at most 1 argument, got 2

In [1505]:
class TransferStateModelPostHoc:
    def __init__(self, model):
        self.model = model
        
    def sample_one_posterior_pi(self, loc, scale):
        logpi = dist.Normal(loc, scale).sample()
        return F.softmax(logpi, -1)
    
    def get_posterior_pi(self, m, r, num_samples = 1000):
        running_sum = torch.zeros(m.shape[0], self.model.num_states[0])
        running_sum_squares = torch.zeros(m.shape[0], self.model.num_states[0])
        for _ in range(num_samples):
            loc, scale = self.model.encoder(m, r)
            logpi = dist.Normal(loc, scale).sample()
            pi = F.softmax(logpi, -1)
            running_sum += pi
            running_sum_squares += (pi ** 2)
            
        return {
            'posterior_mean': running_sum / num_samples, 
            'posterior_std': torch.sqrt(running_sum_squares / num_samples),
            'hacky_pi': F.softmax(running_sum, dim=1)
        }
    
    def do_posterior_stuff(self, m, r, num_samples=100):
        pi = self.get_posterior_pi(m, r, num_samples=num_samples)
        signal_param = self.model.decoder(pi['hacky_pi'])
        m = (signal_param > 0.5).float()
        return {
            'pi': pi,
            'signal_param': signal_param,
            'm': m
        }

In [1506]:
posterior_stuff=TransferStateModelPostHoc(state_model).do_posterior_stuff(m, r)

In [1507]:
posterior_stuff['pi']['hacky_pi'][0:5,:]

tensor([[0.5032, 0.0578, 0.0120, 0.0403, 0.3583, 0.0105, 0.0065, 0.0114],
        [0.0216, 0.0794, 0.1330, 0.0898, 0.1676, 0.2295, 0.1995, 0.0796],
        [0.1161, 0.0175, 0.0290, 0.0998, 0.0320, 0.3084, 0.3318, 0.0654],
        [0.1868, 0.1690, 0.1083, 0.0958, 0.0332, 0.0401, 0.0897, 0.2772],
        [0.1260, 0.0141, 0.3549, 0.0319, 0.2706, 0.0131, 0.0246, 0.1647]])

In [1508]:
posterior_stuff['signal_param'][0:5,:]

tensor([[0.5006, 0.4954, 0.5011],
        [0.5001, 0.4963, 0.5003],
        [0.4999, 0.4941, 0.5000],
        [0.5000, 0.5066, 0.4995],
        [0.5006, 0.4978, 0.5007]], grad_fn=<SliceBackward>)

In [1462]:
pi_posterior['pi']['hacky_pi'][0:5,:]

tensor([[0.0209, 0.0678, 0.5084, 0.1500, 0.0091, 0.0412, 0.0963, 0.1063],
        [0.0107, 0.3958, 0.1768, 0.0582, 0.0687, 0.0120, 0.1908, 0.0870],
        [0.0418, 0.0814, 0.0569, 0.2400, 0.4487, 0.0435, 0.0118, 0.0760],
        [0.0871, 0.0377, 0.0634, 0.0825, 0.2295, 0.0326, 0.0515, 0.4157],
        [0.0241, 0.7570, 0.0145, 0.0014, 0.0792, 0.0856, 0.0302, 0.0081]])

In [1443]:
pi_posterior['signal_param'][0:5,:]

tensor([[0.5001, 0.5001, 0.4995],
        [0.5002, 0.4997, 0.5001],
        [0.4997, 0.5002, 0.4995],
        [0.5000, 0.5008, 0.4993],
        [0.5002, 0.5001, 0.5012]], grad_fn=<SliceBackward>)

In [1444]:
pi_posterior['m'][0:5,:]

tensor([[1., 1., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [1445]:
m[0:5,:]

tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]])

In [1350]:
a = torch.zeros(10, 12)
for i in range(10):
    for j in range(12):
        a[i,j] = j % 4
        
print(a.long())
b = F.one_hot(a.long(), num_classes=4)
print(b.shape)
c = torch.rand(10,12)
print(c.shape)
torch.matmul(torch.transpose(b,2,1).float(),c.T).shape

tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]])
torch.Size([10, 12, 4])
torch.Size([10, 12])


torch.Size([10, 4, 10])

In [1356]:
torch.matmul(b[0,:,:].T.float(), c[0,:])

tensor([1.0481, 0.7089, 1.9728, 1.4895])

In [1414]:
torch.eye(3,3)

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

In [1423]:
a=torch.arange(2*3*4).reshape(2,3,4)
print(a)
print(a.reshape(a.shape[0], -1))
c=torch.arange(2*3).reshape(2,3)
print(c)
torch.cat((c, a.reshape(a.shape[0], -1)), 1)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
tensor([[0, 1, 2],
        [3, 4, 5]])


tensor([[ 0,  1,  2,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 3,  4,  5, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])

In [1446]:
m[0:5,:], r[0:5,:]

(tensor([[1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.]]),
 tensor([[[1, 0, 0, 0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0, 0, 0],
          [0, 0, 1, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0],
          [0, 0, 0, 0, 1, 0, 0, 0],
          [0, 0, 0, 0, 0, 1, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 0],
          [0, 0, 0, 0, 0, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0, 0, 0],
          [0, 0, 1, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0]],
 
         [[1, 0, 0, 0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0, 0, 0],
          [0, 0, 1, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0],
          [0, 0, 0, 0, 1, 0, 0, 0],
          [0, 0, 0, 0, 0, 1, 0, 0],
          [0, 0, 0, 0, 0, 0, 1, 0],
          [0, 0, 0, 0, 0, 0, 0, 1],
          [1, 0, 0, 0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0, 0, 0],
          [0, 0, 1, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0]],
 
         [[1,