In [99]:
#toy model bayesian signature inference

import torch
import pyro
import pyro.distributions as dist 
from torch.distributions import constraints
import numpy as np
from pyro.optim import Adam


#define the model in terms of M = phylogenetic matrix ,K = number of signatures besides the aging signatures SBS1,SBS5.
# The aging signatures are described by the fixed matrix Beta_aging. Beta_aging has dim (2,96).
# In this toy model the rows of M are rank 96 vectors M_truncal,M_plus,M_mignus. M_truncal contains shared mutations,
# M_plus,M_mignus contain the private mutations of two distinct subpopulations. We call t_0 the split time between the two 
# subpopulations, t_1 and t_2 the MRCA times of the two branches. Times are measured in terms of number of mutations.
# the overlaps between the two branches starting at t_0 is overlap = min(t_2,t_1). 
# The overap of the truncal branch with the others is clearly 0. So we define overlaps as the 3X3 overlap symmetric matrix

def model(M, K, Beta_aging, overlaps):
    
    num_sumples = 3
    
    # construct covariance matrix for alpha sampling with RBF-like kernel
    
    cov = torch.zeros(num_samples*(K+2),num_samples*(K+2))
    
    for i in range(num_samples*(K+2)):
        
        for j in range(num_samples*(K+2)):
            
            if i==j : cov[i,j] = 1     
                
                # correlation structure : we correlate activities of the same signature 
                # according to phylogenetic tree
                
            elif  (i != j) and ((j-i)%(K + 2) == 0) : 
                
                m = int(i/(K+2))
                n = int(j/(K+2))
                
                cov[i,j] = np.heaviside(overlaps[m,n], 0)*np.exp(-1/( 1 + overlaps[m,n]))
                
            else: cov[i,j] = 0
    
      #sample phylogeny activities as a 3*(K+2) vector from multivariate Lognormal
    
    log_alpha = pyro.sample("activities", dist.MultivariateNormal(torch.zeros(num_samples*(K+2)),cov))
    
    # sample the extra signature profiles
    
    with pyro.plate("signal samplig", K):
    
        Beta_signals = pyro.sample("extra_signal", dist.Dirichlet(torch.ones(96)))    
        
        
   # write the likelihood


    with pyro.plate("context",96):
        
        with pyro.plate("sample",num_samples):
    
            pyro.sample("obs", dist.Poisson(torch.matmul(torch.exp(log_alpha.reshape(num_samples,K+2)),
                                                         torch.cat((Beta_aging,Beta_signals),0))),obs=M)
    

In [127]:
#define the guide. We perform a MAP estimate

from pyro.infer.autoguide import AutoDelta
from pyro.infer.autoguide.initialization import init_to_sample

guide = AutoDelta(model,init_loc_fn = init_to_sample)
    
    
    
# SVI

from pyro.infer import SVI, Trace_ELBO



def inference(model,guide,M,K,overlaps,Beta_aging,lr=0.05,num_steps=500):

    pyro.clear_param_store()  # always clear the store before the inference

    # learning global parameters

    adam_params = {"lr": lr}
    optimizer = Adam(adam_params)
    elbo = Trace_ELBO()

    svi = SVI(model, guide, optimizer, loss=elbo)

#    inference

#   do gradient steps

    for step in range(num_steps):
        loss = svi.step(M, K,Beta_aging,overlaps)  # get the loss function after a gradient step
        if step % 10 == 0:
            print("loss=", loss)  # check the progress

    print("final loss=", svi.evaluate_loss(M, K,Beta_aging,overlaps)) 

       #get MAP estimates
 
    

    

In [49]:
# example inference

import pandas as pd

X_1 = pd.read_csv("/Users/riccardobergamin/sample_1.csv")


In [50]:
X_1

Unnamed: 0.1,Unnamed: 0,A[C>A]A,A[C>A]C,A[C>A]G,A[C>A]T,A[C>G]A,A[C>G]C,A[C>G]G,A[C>G]T,A[C>T]A,...,T[T>A]G,T[T>A]T,T[T>C]A,T[T>C]C,T[T>C]G,T[T>C]T,T[T>G]A,T[T>G]C,T[T>G]G,T[T>G]T
0,BRMA_1 clonal,11,10,4,11,16,5,2,5,52,...,2,15,10,3,9,17,6,3,3,9
1,BRMA_2_subclonal,369,19,3,26,17,6,7,9,25,...,17,25,5,6,14,15,6,3,5,8
2,BRMA_3 shared,407,27,3,44,11,1,2,14,15,...,5,11,6,8,2,8,3,0,4,6


In [51]:
X_2 = pd.read_csv("/Users/riccardobergamin/sample_2.csv")
X_2

Unnamed: 0.1,Unnamed: 0,A[C>A]A,A[C>A]C,A[C>A]G,A[C>A]T,A[C>G]A,A[C>G]C,A[C>G]G,A[C>G]T,A[C>T]A,...,T[T>A]G,T[T>A]T,T[T>C]A,T[T>C]C,T[T>C]G,T[T>C]T,T[T>G]A,T[T>G]C,T[T>G]G,T[T>G]T
0,BRMA_1 clonal,11,10,4,11,16,5,2,5,52,...,2,15,10,3,9,17,6,3,3,9
1,BRMA_2 private_lowfreq,295,35,11,50,53,24,17,26,120,...,26,88,19,37,38,138,15,17,26,86
2,BRMA_3 private,21,10,2,11,2,1,5,2,12,...,10,24,9,3,3,8,0,3,6,7


In [61]:
c_1 = X_1.values[:,1:]
c_2 = X_2.values[:,1:]

In [62]:
M_1 = torch.tensor(np.array(c_1,dtype=float))
M_2 = torch.tensor(np.array(c_2,dtype=float))

In [63]:
M_1.size()

torch.Size([3, 96])

In [64]:
M_2.size()

torch.Size([3, 96])

In [105]:
M_1= M_1.float()

In [106]:
M_2 = M_2.float()

In [147]:
# construct overlap matrix


# example 1 

num_samples = 3
    
t_0 = torch.sum(M_1[0,])
    
t_1 = torch.sum(M_1[1,])
    
t_2 = torch.sum(M_1[2,])
    
overlaps_0 = torch.tensor([t_0,0,0])
    
overlaps_1 = torch.tensor([0,t_1,min(t_1,t_2)])
    
overlaps_3 = torch.tensor([0,min(t_1,t_2),t_2])
    
overlaps = torch.cat((overlaps_0,overlaps_1,overlaps_3),0).reshape(3,3)
    
overlaps

tensor([[1334.,    0.,    0.],
        [   0., 4008., 4008.],
        [   0., 4008., 4401.]])

In [148]:
#look for chemo
K = 1

#aging signature

aging = pd.read_csv("/Users/riccardobergamin/beta_aging.csv")

In [149]:
b = aging.values[:,1:]
beta_aging = torch.tensor(np.array(b,dtype=float))
beta_aging = beta_aging.float()

In [150]:
beta_aging.size()

torch.Size([2, 96])

In [151]:
inference(model,guide,M_1,K,overlaps,beta_aging,lr=0.05,num_steps=500)

loss= 5078.93505859375
loss= 1625.1857147216797
loss= 751.9563446044922
loss= 651.204704284668
loss= 613.8875198364258
loss= 596.5420074462891
loss= 581.9448852539062
loss= 580.1666641235352
loss= 578.6744689941406
loss= 578.0636215209961
loss= 577.892692565918
loss= 577.7688140869141
loss= 577.7285232543945
loss= 577.710807800293
loss= 577.6980667114258
loss= 577.695686340332
loss= 577.6925506591797
loss= 577.6919937133789
loss= 577.6896057128906
loss= 577.6900024414062
loss= 577.6881866455078
loss= 577.6864929199219
loss= 577.6855010986328
loss= 577.6854934692383
loss= 577.6844177246094
loss= 577.6836547851562
loss= 577.6813049316406
loss= 577.6808013916016
loss= 577.6819534301758
loss= 577.6793365478516
loss= 577.6801834106445
loss= 577.6797027587891
loss= 577.6789245605469
loss= 577.6774444580078
loss= 577.6763458251953
loss= 577.6747741699219
loss= 577.676513671875
loss= 577.6734771728516
loss= 577.6742248535156
loss= 577.6744842529297
loss= 577.6720199584961
loss= 577.67283630371

In [152]:
parameters={}
        
for key in pyro.get_param_store().get_all_param_names() :
    
             parameters.update({key : torch.tensor(pyro.param(key))})
        
parameters  

  parameters.update({key : torch.tensor(pyro.param(key))})


{'AutoDelta.activities': tensor([ 5.7040,  6.9274,  0.2099, -0.0809, -0.0950,  8.3182, -0.0810, -0.0951,
          8.3662]),
 'AutoDelta.extra_signal': tensor([[9.2316e-02, 5.4700e-03, 7.1349e-04, 8.3268e-03, 3.3288e-03, 8.3123e-04,
          1.0703e-03, 2.7337e-03, 4.7463e-03, 1.3034e-03, 8.7221e-03, 2.7296e-03,
          5.4711e-03, 4.4006e-03, 4.6377e-03, 6.3054e-03, 3.4397e-03, 1.1860e-03,
          2.8468e-03, 2.9645e-03, 7.1298e-04, 1.1834e-04, 4.7410e-04, 4.7481e-04,
          3.3860e-01, 3.0915e-03, 2.8541e-03, 9.0417e-03, 1.4261e-03, 1.1876e-03,
          4.7523e-04, 2.2582e-03, 1.6611e-03, 1.8994e-03, 6.9719e-03, 3.3267e-03,
          2.8546e-03, 1.6644e-03, 2.2592e-03, 3.8060e-03, 1.7815e-03, 2.3765e-03,
          2.1367e-03, 1.6622e-03, 4.7526e-04, 5.9371e-04, 1.1882e-03, 8.3113e-04,
          1.5562e-01, 1.2014e-02, 9.7540e-03, 3.3551e-02, 3.4489e-03, 1.3074e-03,
          3.5655e-04, 2.4969e-03, 5.1109e-03, 1.6609e-03, 5.6598e-03, 2.6137e-03,
          6.8997e-03, 7.7312e

In [153]:
alpha = torch.exp(parameters["AutoDelta.activities"]).reshape(num_samples,K+2) 
beta_chemo = parameters["AutoDelta.extra_signal"]
print(alpha)
print(beta_chemo)

tensor([[3.0007e+02, 1.0199e+03, 1.2335e+00],
        [9.2226e-01, 9.0934e-01, 4.0978e+03],
        [9.2217e-01, 9.0925e-01, 4.2992e+03]])
tensor([[9.2316e-02, 5.4700e-03, 7.1349e-04, 8.3268e-03, 3.3288e-03, 8.3123e-04,
         1.0703e-03, 2.7337e-03, 4.7463e-03, 1.3034e-03, 8.7221e-03, 2.7296e-03,
         5.4711e-03, 4.4006e-03, 4.6377e-03, 6.3054e-03, 3.4397e-03, 1.1860e-03,
         2.8468e-03, 2.9645e-03, 7.1298e-04, 1.1834e-04, 4.7410e-04, 4.7481e-04,
         3.3860e-01, 3.0915e-03, 2.8541e-03, 9.0417e-03, 1.4261e-03, 1.1876e-03,
         4.7523e-04, 2.2582e-03, 1.6611e-03, 1.8994e-03, 6.9719e-03, 3.3267e-03,
         2.8546e-03, 1.6644e-03, 2.2592e-03, 3.8060e-03, 1.7815e-03, 2.3765e-03,
         2.1367e-03, 1.6622e-03, 4.7526e-04, 5.9371e-04, 1.1882e-03, 8.3113e-04,
         1.5562e-01, 1.2014e-02, 9.7540e-03, 3.3551e-02, 3.4489e-03, 1.3074e-03,
         3.5655e-04, 2.4969e-03, 5.1109e-03, 1.6609e-03, 5.6598e-03, 2.6137e-03,
         6.8997e-03, 7.7312e-03, 8.3263e-03, 3.3318

In [154]:
# example 2 

num_samples = 3
    
t_0 = torch.sum(M_2[0,])
    
t_1 = torch.sum(M_2[1,])
    
t_2 = torch.sum(M_2[2,])
    
overlaps_0 = torch.tensor([t_0,0,0])
    
overlaps_1 = torch.tensor([0,t_1,min(t_1,t_2)])
    
overlaps_3 = torch.tensor([0,min(t_1,t_2),t_2])
    
overlaps = torch.cat((overlaps_0,overlaps_1,overlaps_3),0).reshape(3,3)
    
overlaps

tensor([[1334.,    0.,    0.],
        [   0., 5510.,  857.],
        [   0.,  857.,  857.]])

In [155]:
inference(model,guide,M_2,K,overlaps,beta_aging,lr=0.05,num_steps=800)

loss= 5967.743637084961
loss= 2672.9450073242188
loss= 1461.6221313476562
loss= 1242.6842651367188
loss= 1215.05322265625
loss= 1199.2042846679688
loss= 1187.131103515625
loss= 1182.2289428710938
loss= 1181.5432739257812
loss= 1181.3078002929688
loss= 1181.0877075195312
loss= 1181.0213623046875
loss= 1181.01708984375
loss= 1181.0081787109375
loss= 1181.0043334960938
loss= 1181.0032348632812
loss= 1181.002685546875
loss= 1181.0025634765625
loss= 1181.001953125
loss= 1181.0023803710938
loss= 1181.001708984375
loss= 1181.0014038085938
loss= 1181.0013427734375
loss= 1181.00146484375
loss= 1181.00048828125
loss= 1181.0008544921875
loss= 1181.0
loss= 1181.000244140625
loss= 1181.0006713867188
loss= 1181.0003662109375
loss= 1181.0001831054688
loss= 1180.9999389648438
loss= 1181.0001220703125
loss= 1181.0001831054688
loss= 1180.9998168945312
loss= 1181.0001831054688
loss= 1180.9999389648438
loss= 1180.9998168945312
loss= 1180.9999389648438
loss= 1180.9996948242188
loss= 1180.99951171875
loss= 

In [156]:
parameters={}
        
for key in pyro.get_param_store().get_all_param_names() :
    
             parameters.update({key : torch.tensor(pyro.param(key))})
        
parameters 

  parameters.update({key : torch.tensor(pyro.param(key))})


{'AutoDelta.activities': tensor([ 5.7038,  6.9276,  0.1446,  0.1300, -0.1729,  8.5091,  0.1302, -0.1733,
          7.2410]),
 'AutoDelta.extra_signal': tensor([[0.0497, 0.0071, 0.0020, 0.0096, 0.0086, 0.0039, 0.0035, 0.0044, 0.0207,
          0.0068, 0.0225, 0.0102, 0.0165, 0.0072, 0.0071, 0.0209, 0.0231, 0.0046,
          0.0129, 0.0096, 0.0046, 0.0014, 0.0030, 0.0025, 0.1460, 0.0050, 0.0030,
          0.0086, 0.0022, 0.0027, 0.0022, 0.0028, 0.0071, 0.0061, 0.0180, 0.0113,
          0.0101, 0.0086, 0.0116, 0.0099, 0.0055, 0.0055, 0.0077, 0.0085, 0.0009,
          0.0020, 0.0022, 0.0025, 0.0621, 0.0086, 0.0052, 0.0157, 0.0038, 0.0011,
          0.0016, 0.0027, 0.0074, 0.0060, 0.0170, 0.0083, 0.0110, 0.0093, 0.0135,
          0.0072, 0.0093, 0.0046, 0.0116, 0.0071, 0.0013, 0.0013, 0.0036, 0.0020,
          0.0361, 0.0072, 0.0025, 0.0145, 0.0064, 0.0058, 0.0020, 0.0124, 0.0082,
          0.0091, 0.0100, 0.0118, 0.0091, 0.0028, 0.0057, 0.0176, 0.0044, 0.0063,
          0.0064, 0.0229, 0.0

In [157]:
alpha = torch.exp(parameters["AutoDelta.activities"]).reshape(num_samples,K+2) 
beta_chemo = parameters["AutoDelta.extra_signal"]
print(alpha)
print(beta_chemo)

tensor([[3.0000e+02, 1.0201e+03, 1.1556e+00],
        [1.1388e+00, 8.4121e-01, 4.9597e+03],
        [1.1391e+00, 8.4090e-01, 1.3955e+03]])
tensor([[0.0497, 0.0071, 0.0020, 0.0096, 0.0086, 0.0039, 0.0035, 0.0044, 0.0207,
         0.0068, 0.0225, 0.0102, 0.0165, 0.0072, 0.0071, 0.0209, 0.0231, 0.0046,
         0.0129, 0.0096, 0.0046, 0.0014, 0.0030, 0.0025, 0.1460, 0.0050, 0.0030,
         0.0086, 0.0022, 0.0027, 0.0022, 0.0028, 0.0071, 0.0061, 0.0180, 0.0113,
         0.0101, 0.0086, 0.0116, 0.0099, 0.0055, 0.0055, 0.0077, 0.0085, 0.0009,
         0.0020, 0.0022, 0.0025, 0.0621, 0.0086, 0.0052, 0.0157, 0.0038, 0.0011,
         0.0016, 0.0027, 0.0074, 0.0060, 0.0170, 0.0083, 0.0110, 0.0093, 0.0135,
         0.0072, 0.0093, 0.0046, 0.0116, 0.0071, 0.0013, 0.0013, 0.0036, 0.0020,
         0.0361, 0.0072, 0.0025, 0.0145, 0.0064, 0.0058, 0.0020, 0.0124, 0.0082,
         0.0091, 0.0100, 0.0118, 0.0091, 0.0028, 0.0057, 0.0176, 0.0044, 0.0063,
         0.0064, 0.0229, 0.0024, 0.0031, 0.0050, 0.