In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


This script is used to  get samples which are later used to estimate DIC of SCRaPL without inflation on mESC data.The only difference from standard inference scripts is that it stores some extra parameters like log likelihood tied to a particular sample and average of posterior latent methylation and expression.

In [None]:
from IPython import display
import pandas as pd
import numpy as np
import scipy
import scipy.stats

from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from tensorflow import keras

from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import pickle
from timeit import default_timer as timer

In [None]:
tfd = tfp.distributions
tfb = tfp.bijectors
Folder = '/content/drive/MyDrive/'

In [None]:
#Load methylation, expression data and cell normalization constants
yy_met = pd.read_csv(Folder+'SCRaPL/Real/Data/Met.csv',sep=',',index_col=[0])
yy_exp = pd.read_csv(Folder+'SCRaPL/Real/Data/Rna.csv',sep=',',index_col=[0])
CpG = pd.read_csv(Folder+'SCRaPL/Real/Data/CpG.csv',sep=',',index_col=[0])
nrm = pd.read_csv(Folder+'SCRaPL/Real/Data/nrm.csv',sep=',',index_col=[1])


In [None]:
Norm = tf.transpose(tf.convert_to_tensor(nrm,dtype=tf.float32))
yy_met = tf.convert_to_tensor(yy_met,dtype=tf.float32)
CpG = tf.convert_to_tensor(CpG,dtype=tf.float32)
yy_exp = tf.convert_to_tensor(yy_exp,dtype=tf.float32)

In [None]:
x_genes,x_cells = tf.shape(CpG)

In [None]:
aff = tfb.Chain([tfb.Shift(-1.),tfb.Scale(scale=2.)])
aff_inv = tfb.Invert(aff)

exp = tfb.Exp()
log = tfb.Invert(exp)

tanh = tfb.Tanh()
tanh_inv = tfb.Invert(tanh)

sigm = tfb.Sigmoid()
sigm_inv = tfb.Invert(sigm)

cor_trsf = tfb.Chain([aff_inv,tanh,tfb.Scale(scale=0.5)])
cor_trsf_inv = tfb.Invert(cor_trsf)

eps=0.001
bin_bij = tfb.Chain([tfb.Shift(eps/2.0),tfb.Scale(scale=1.0-eps),tfb.NormalCDF()])

cor_bij = tfb.Chain([tanh,tfb.Scale(scale=0.5)])
std_bij = tfb.Chain([exp,tfb.Scale(scale=-1.0)])
sqr_bij = tfb.Square()


In [None]:
#SCRaPL's graphical model
def SCRaPL(N_genes,N_cells,Cover,Nrm):
    def prior():
        cor_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 = 15.0*tf.ones([N_genes,1]), concentration1=15.0*tf.ones([N_genes,1])), bijector= cor_trsf_inv, name = "cor_lt" )
        m_met_lt = yield tfd.Normal(loc=0*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_met_lt")
        m_exp_lt = yield tfd.Normal(loc=4*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_exp_lt")
        s_met_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log, name = "s_met_lt" )
        s_exp_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log , name = "s_exp_lt")

        cor = cor_bij.forward(cor_lt)
        s_met = std_bij.forward(s_met_lt)
        s_exp = std_bij.forward(s_exp_lt)
        
        mm_met = tf.math.multiply( m_met_lt,tf.ones([N_genes,N_cells]))
        mm_exp = tf.math.multiply( m_exp_lt,tf.ones([N_genes,N_cells]))
        ss_met = tf.math.multiply( s_met   ,tf.ones([N_genes,N_cells]))
        ss_exp = tf.math.multiply( s_exp   ,tf.ones([N_genes,N_cells]))
        ccor =   tf.math.multiply( cor     ,tf.ones([N_genes,N_cells])) 
        nrm = tf.math.multiply( log.forward(Nrm),tf.ones([N_genes,1]))

        x_met = yield tfd.Normal(loc = mm_met, scale = ss_met,name="x_met")
        m_cnd_exp = mm_exp+tf.math.multiply(tf.math.divide(tf.math.multiply(ss_exp,x_met-mm_met),ss_met),ccor)
        s_cnd_exp = tf.math.sqrt(tf.math.multiply(1-tf.math.square(ccor),tf.math.square(ss_exp)))

        x_exp = yield tfd.Normal(loc = m_cnd_exp, scale = s_cnd_exp,name="x_exp")

        rt_bin = bin_bij.forward(x_met)

        y_met = yield tfd.Binomial(total_count=Cover,probs=rt_bin,name="y_met")
        y_exp = yield tfd.Poisson(log_rate=x_exp+nrm,name="y_exp")

    comp_var_coroutine = tfd.JointDistributionCoroutineAutoBatched(prior)
    return comp_var_coroutine


In [None]:
#SCRaPL without priors used in DIC estimation.
def SCRaPL_dic(N_genes,N_cells,Cover,Nrm,param):
    cor_lt,m_met_lt,m_exp_lt,s_met_lt,s_exp_lt = param
    def prior():
        Root = tfd.JointDistributionCoroutine.Root
        cor = cor_bij.forward(cor_lt)
        s_met = std_bij.forward(s_met_lt)
        s_exp = std_bij.forward(s_exp_lt)
        
        mm_met = tf.math.multiply( m_met_lt,tf.ones([1,N_genes,N_cells]))
        mm_exp = tf.math.multiply( m_exp_lt,tf.ones([1,N_genes,N_cells]))
        ss_met = tf.math.multiply( s_met   ,tf.ones([1,N_genes,N_cells]))
        ss_exp = tf.math.multiply( s_exp   ,tf.ones([1,N_genes,N_cells]))
        ccor =   tf.math.multiply( cor     ,tf.ones([1,N_genes,N_cells]))
        nrm = tf.math.multiply( log.forward(Nrm),tf.ones([1,N_genes,1]))

        x_met = yield Root(tfd.Independent(tfd.Normal(loc = mm_met, scale = ss_met,name="x_met"),reinterpreted_batch_ndims=1))
        m_cnd_exp = mm_exp+tf.math.multiply(tf.math.divide(tf.math.multiply(ss_exp,x_met-mm_met),ss_met),ccor)
        s_cnd_exp = tf.math.sqrt(tf.math.multiply(1-tf.math.square(ccor),tf.math.square(ss_exp)))

        x_exp = yield Root(tfd.Independent(tfd.Normal(loc = m_cnd_exp, scale = s_cnd_exp,name="x_exp"),reinterpreted_batch_ndims=1))

        rt_bin = bin_bij.forward(x_met)

        y_met = yield Root(tfd.Independent(tfd.Binomial(total_count=Cover,probs=rt_bin,name="y_met"),reinterpreted_batch_ndims=1))
        y_exp = yield Root(tfd.Independent(tfd.Poisson(log_rate=x_exp+nrm, name="y_exp") , reinterpreted_batch_ndims=1)   )

    comp_var_coroutine = tfd.JointDistributionCoroutine(prior)
    return comp_var_coroutine

In [None]:
#As the dataset is too large to fit in memory we utilize genomic region independence assumption to perform sample chunks of 30 in each turn.
#As this step tends to be computationally demanding we halve the number of genomic regions to avoid the risk of running out of memory. Depending computational resources
#chunk size (www) has to be tuned 

prt = 1
time = []
www = 30
prts_tot = tf.math.ceil(x_genes/www)
jj = 0

while prt < prts_tot+1:
          aa =  (prt-1)*www
          aa1 = prt*www
          yy_met_prt = yy_met[aa:aa1,:]
          yy_exp_prt = yy_exp[aa:aa1,:]
          CpG_prt = CpG[aa:aa1,:]

          batch_num = tf.shape(CpG_prt)[0]

          mdl_tr = SCRaPL(batch_num,x_cells,CpG_prt,Norm)
          init = mdl_tr.sample()
          vrr = [init[0],init[1],init[2],init[3],init[4],init[5],init[6]]

          unconstrained_bijectors = [tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity()]
          @tf.function(autograph=False, jit_compile=True) 

          def sample_nuts():

            init_x = vrr
            num_burnin_iter = 3000
            num_warmup_iter = int(0.8*num_burnin_iter)
            num_chain_iter = 2000

            target_accept_rate = 0.65 

            log_post = lambda x0,x1,x2,x3,x4,x5,x6: mdl_tr.log_prob(x0,x1,x2,x3,x4,x5,x6,yy_met_prt,yy_exp_prt)

            def trace_fn(_, pkr):
                  return (
                      pkr.inner_results.inner_results.target_log_prob,
                      pkr.inner_results.inner_results.leapfrogs_taken,
                      pkr.inner_results.inner_results.has_divergence,
                      pkr.inner_results.inner_results.energy,
                      pkr.inner_results.inner_results.log_accept_ratio,
                      pkr.inner_results.inner_results.step_size
                        )
            nuts= tfp.mcmc.NoUTurnSampler(
                                      target_log_prob_fn=log_post,
                                      step_size=0.05,
                                      max_tree_depth=6
                                          ) 
            ttk = tfp.mcmc.TransformedTransitionKernel(
                                      inner_kernel=nuts,
                                      bijector=unconstrained_bijectors
                                                          )
            adapted_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
                                      inner_kernel=ttk,
                                      num_adaptation_steps=num_warmup_iter,
                                      target_accept_prob= target_accept_rate)
            
            states , sampler_stat =tfp.mcmc.sample_chain(
                              num_results=num_chain_iter,
                              num_burnin_steps=num_burnin_iter,
                              current_state=init_x,
                              kernel=adapted_kernel,
                              trace_fn=trace_fn) 

            return states, sampler_stat

          start = timer()
          samples, sampler_stat = sample_nuts() 
          end = timer()
          ttime = end-start

          p_accept = tf.math.exp(tfp.math.reduce_logmeanexp(tf.minimum(sampler_stat[4], 0.)))
          stp_sz = sampler_stat[5][0][-1]

          hmc_cor,hmc_m_met,hmc_m_exp,hmc_s_met,hmc_s_exp,hmc_x_met,hmc_x_exp = samples

          crr_nuts = tf.squeeze(hmc_cor)
          m_met_nuts = tf.squeeze(hmc_m_met)
          m_exp_nuts = tf.squeeze(hmc_m_exp)
          s_met_nuts = tf.squeeze(hmc_s_met)
          s_exp_nuts = tf.squeeze(hmc_s_exp)


          mdl_dic = SCRaPL_dic(batch_num,x_cells,CpG_prt,Norm,[hmc_cor,hmc_m_met,hmc_m_exp,hmc_s_met,hmc_s_exp])
          llk_gastr = mdl_dic.log_prob(hmc_x_met,hmc_x_exp,yy_met_prt,yy_exp_prt)

          x_met_mn = tf.reduce_mean(hmc_x_met,axis=0)
          x_exp_mn = tf.reduce_mean(hmc_x_exp,axis=0)

          qc_acc_rt = tf.math.logical_and(p_accept<0.9,p_accept>0.4)
          qc_stp_sz = stp_sz>0.00001

          if tf.math.logical_and(qc_acc_rt,qc_stp_sz)== True:
                time.append(ttime)
                with open(Folder+'SCRaPL/Real/Results_DIC/nuts_cor_gastr_ninf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(crr_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_DIC/nuts_m_met_gastr_ninf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(m_met_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_DIC/nuts_m_exp_gastr_ninf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(m_exp_nuts, handle)

                with open(Folder+'SCRaPL/Real/Results_DIC/nuts_s_met_gastr_ninf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(s_met_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_DIC/nuts_s_exp_gastr_ninf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(s_exp_nuts, handle)

                with open(Folder+'SCRaPL/Real/Results_DIC/log_prob_gastr_ninf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(llk_gastr, handle)
                with open(Folder+'SCRaPL/Real/Results_DIC/avg_exp_ninf_gastr_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(x_exp_mn, handle)
                with open(Folder+'SCRaPL/Real/Results_DIC/avg_met_ninf_gastr_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(x_met_mn, handle)     
                
                avg_time = tf.reduce_mean(tf.stack(time))
                rem_time = (tf.cast(prts_tot,dtype=tf.float32)-tf.cast(prt,dtype=tf.float32)) *avg_time
                print(tf.squeeze([rem_time,prt,stp_sz,p_accept]),tf.math.logical_and(qc_acc_rt,qc_stp_sz))
                prt+=1
                jj=0.0
          else:
                time.append(ttime)
                
                if jj>3.0:
                    skp_ind.append(prt)
                    prt+=1
                    jj=0.0

                avg_time = tf.reduce_mean(tf.stack(time))
                rem_time = (tf.cast(prts_tot,dtype=tf.float32)-tf.cast(prt,dtype=tf.float32)) *avg_time
                print(tf.squeeze([rem_time,prt,stp_sz,p_accept]),tf.math.logical_and(qc_acc_rt,qc_stp_sz))
                jj+=1

tf.Tensor([2.5467537e+03 3.0100000e+02 1.0642152e-02 6.5197062e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([2.3759587e+03 3.0200000e+02 9.7204298e-03 6.4480865e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([2.2052170e+03 3.0300000e+02 9.2363283e-03 6.5441215e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([2.0393785e+03 3.0400000e+02 9.1340207e-03 6.5191185e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([1.8678354e+03 3.0500000e+02 8.7258648e-03 7.3664391e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([1.6987128e+03 3.0600000e+02 1.4317469e-02 6.7490411e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([1.5276816e+03 3.0700000e+02 1.1304198e-02 6.2265599e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([1.3569080e+03 3.0800000e+02 1.1461093e-02 7.1654934