In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


Experiment 2 with synthetic data.

In [None]:
from IPython import display
import pandas as pd
import numpy as np
import scipy
import scipy.stats

from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from tensorflow import keras

from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import pickle
from timeit import default_timer as timer

In [None]:
tfd = tfp.distributions
tfb = tfp.bijectors
Folder = '/content/drive/MyDrive/'

In [None]:
a1 = 1

In [None]:
x_genes = tf.constant(300,dtype=tf.int32) #Number of genes
x_cells = tf.constant(60,dtype=tf.int32)  #Number of cells

CpG_vls = tf.math.round(tf.random.uniform(shape=(x_genes,x_cells),minval=1,maxval=3,dtype=tf.float32))
CpG = tf.math.round(tf.random.uniform(shape=(x_genes,x_cells),minval=50,maxval=500,dtype=tf.float32))
CpG = tf.where(tf.equal(CpG_vls, 0), tf.zeros_like(CpG),CpG)
Norm = tf.random.uniform(shape=(1,x_cells),minval=0.5,maxval=1.5,dtype=tf.float32)

In [None]:
aff = tfb.Chain([tfb.Shift(-1.),tfb.Scale(scale=2.)])
aff_inv = tfb.Invert(aff)

exp = tfb.Exp()
log = tfb.Invert(exp)

tanh = tfb.Tanh()
tanh_inv = tfb.Invert(tanh)

sigm = tfb.Sigmoid()
sigm_inv = tfb.Invert(sigm)

cor_trsf = tfb.Chain([aff_inv,tanh,tfb.Scale(scale=0.5)])
cor_trsf_inv = tfb.Invert(cor_trsf)

eps=0.001
bin_bij = tfb.Chain([tfb.Shift(eps/2.0),tfb.Scale(scale=1.0-eps),tfb.NormalCDF()])

cor_bij = tfb.Chain([tanh,tfb.Scale(scale=0.5)])
std_bij = tfb.Chain([exp,tfb.Scale(scale=-1.0)])
sqr_bij = tfb.Square()

In [None]:
#Graphical Model
def SCRaPL(N_genes,N_cells,Cover,Nrm):
    def prior():
        cor_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 = 15.0*tf.ones([N_genes,1]), concentration1=15.0*tf.ones([N_genes,1])), bijector= cor_trsf_inv, name = "cor_lt" )
        m_met_lt = yield tfd.Normal(loc=tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_met_lt")
        m_exp_lt = yield tfd.Normal(loc=4*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_exp_lt")
        s_met_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log, name = "s_met_lt" )
        s_exp_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log , name = "s_exp_lt")
        infl_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 =2.0*tf.ones([N_genes,1]), concentration1=8.0*tf.ones([N_genes,1])), bijector= sigm_inv, name = "infl_lt" )

        cor = cor_bij.forward(cor_lt)
        s_met = std_bij.forward(s_met_lt)
        s_exp = std_bij.forward(s_exp_lt)
        infl = sigm.forward(infl_lt)
        
        mm_met = tf.math.multiply( m_met_lt,tf.ones([N_genes,N_cells]))
        mm_exp = tf.math.multiply( m_exp_lt,tf.ones([N_genes,N_cells]))
        ss_met = tf.math.multiply( s_met   ,tf.ones([N_genes,N_cells]))
        ss_exp = tf.math.multiply( s_exp   ,tf.ones([N_genes,N_cells]))
        ccor =   tf.math.multiply( cor     ,tf.ones([N_genes,N_cells]))
        p =      tf.math.multiply( infl    ,tf.ones([N_genes,N_cells]))  
        nrm = tf.math.multiply( log.forward(Nrm),tf.ones([N_genes,1]))

        x_met = yield tfd.Normal(loc = mm_met, scale = ss_met,name="x_met")
        m_cnd_exp = mm_exp+tf.math.multiply(tf.math.divide(tf.math.multiply(ss_exp,x_met-mm_met),ss_met),ccor)
        s_cnd_exp = tf.math.sqrt(tf.math.multiply(1-tf.math.square(ccor),tf.math.square(ss_exp)))

        x_exp = yield tfd.Normal(loc = m_cnd_exp, scale = s_cnd_exp,name="x_exp")

        rt_bin = bin_bij.forward(x_met)

        pp = tf.stack([p,1-p],axis=-1)
        x_exp_msk= tf.where(tf.equal(Cover, 0), -20*tf.ones_like(x_exp),x_exp)
        x_exp_lt = tf.stack([-20*tf.ones_like(x_exp),x_exp_msk+nrm],axis=-1)

        y_met = yield tfd.Binomial(total_count=Cover,probs=rt_bin,name="y_met")
        y_exp = yield tfd.MixtureSameFamily(
                                                          mixture_distribution = tfd.Categorical(probs=pp),
                                                          components_distribution = tfd.Poisson(log_rate=x_exp_lt) ,
                                                          name="y_exp")

    comp_var_coroutine = tfd.JointDistributionCoroutineAutoBatched(prior)
    return comp_var_coroutine
mdl_tr = SCRaPL(x_genes,x_cells,CpG,Norm)

In [None]:
cor,mm_met,mm_exp,ss_met,ss_exp,infl,xx_met,xx_exp,yy_met,yy_exp = mdl_tr.sample()



In [None]:
#Inference
unconstrained_bijectors = [tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity()]
init = mdl_tr.sample()
vrr = [init[0],init[1],init[2],init[3],init[4],init[5],init[6],init[7]]

@tf.function(autograph=False, jit_compile=True) 
def sample_nuts():

      init_x = vrr
      num_burnin_iter = 3000
      num_warmup_iter = int(0.8*num_burnin_iter)
      num_chain_iter = 2000

      target_accept_rate =0.651

      log_post = lambda x0,x1,x2,x3,x4,x5,x6,x7: mdl_tr.log_prob(x0,x1,x2,x3,x4,x5,x6,x7,yy_met,yy_exp)

      def trace_fn(_, pkr):
            return (
                pkr.inner_results.inner_results.target_log_prob,
                pkr.inner_results.inner_results.leapfrogs_taken,
                pkr.inner_results.inner_results.has_divergence,
                pkr.inner_results.inner_results.energy,
                pkr.inner_results.inner_results.log_accept_ratio,
                pkr.inner_results.inner_results.step_size
                  )
      nuts=tfp.mcmc.NoUTurnSampler(
                                target_log_prob_fn=log_post,
                                step_size=0.05,
                                max_tree_depth=6
                                    )
      ttk = tfp.mcmc.TransformedTransitionKernel(
                                inner_kernel=nuts,
                                bijector=unconstrained_bijectors
                                                    )
      
      adapted_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
                                          inner_kernel=ttk,
                                          num_adaptation_steps=num_warmup_iter,
                                          target_accept_prob= target_accept_rate)
      
      states , sampler_stat =tfp.mcmc.sample_chain(
                        num_results=num_chain_iter,
                        num_burnin_steps=num_burnin_iter,
                        current_state=init_x,
                        kernel=adapted_kernel,
                        trace_fn=trace_fn)
      return states, sampler_stat

start = timer()
samples, sampler_stat =  sample_nuts()
end = timer()
ttime = end-start



In [None]:
p_accept = tf.math.exp(tfp.math.reduce_logmeanexp(tf.minimum(sampler_stat[4], 0.)))
stp_sz = sampler_stat[5][0]
hmc_cor,hmc_m_met,hmc_m_exp,hmc_s_met,hmc_s_exp,hmc_inf,hmc_x_met,hmc_x_exp = samples

In [None]:
crr_nuts = tf.squeeze(hmc_cor)
m_met_nuts = tf.squeeze(hmc_m_met)
m_exp_nuts = tf.squeeze(hmc_m_exp)
s_met_nuts = tf.squeeze(hmc_s_met)
s_exp_nuts = tf.squeeze(hmc_s_exp)
infl_nuts  = tf.squeeze(hmc_inf)
x_met_nuts = tf.squeeze(hmc_x_met)
x_exp_nuts = tf.squeeze(hmc_x_exp)

In [None]:
qc_acc_rt = tf.math.logical_and(p_accept<0.9,p_accept>0.4)
qc_stp_sz = stp_sz>0.00001

In [None]:
#Save samples
if tf.math.logical_and(qc_acc_rt,qc_stp_sz)== True:
        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_cor_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(crr_nuts, handle)
        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_m_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(m_met_nuts, handle)
        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_m_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(m_exp_nuts, handle)

        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_s_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(s_met_nuts, handle)
        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_s_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(s_exp_nuts, handle)

        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_inf_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(infl_nuts, handle)

        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_x_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(x_met_nuts, handle)
        with open(Folder+'SCRaPL/Synth/Results/Inflation/Beta/nuts_x_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(x_exp_nuts, handle)

        print(tf.squeeze([ttime,stp_sz,p_accept]),tf.math.logical_and(qc_acc_rt,qc_stp_sz))

        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/cor_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(cor, handle)
        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/mm_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(mm_met, handle)
        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/mm_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(mm_exp, handle)

        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/ss_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(ss_met, handle)
        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/ss_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(ss_exp, handle)

        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/infl_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(infl, handle)

        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/xx_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(xx_met, handle)
        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/xx_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(xx_exp, handle)

        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/yy_met_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(yy_met, handle)
        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/yy_exp_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(yy_exp, handle)
        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/yy_cpg_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(CpG, handle)

        with open(Folder+'SCRaPL/Synth/Data/Inflation/Beta/Norm_'+str(a1)+'.pickle', 'wb') as handle:
            pickle.dump(Norm, handle)     

"\n        with open('/content/drive/MyDrive/SCRaPL/Synth/Results/Inflation/Beta/nuts_m_met_'+str(a1)+'.pickle', 'wb') as handle:\n            pickle.dump(m_met_nuts, handle)\n        with open('/content/drive/MyDrive/SCRaPL/Synth/Results/Inflation/Beta/nuts_m_exp_'+str(a1)+'.pickle', 'wb') as handle:\n            pickle.dump(m_exp_nuts, handle)\n\n        with open('/content/drive/MyDrive/SCRaPL/Synth/Results/Inflation/Beta/nuts_s_met_'+str(a1)+'.pickle', 'wb') as handle:\n            pickle.dump(s_met_nuts, handle)\n        with open('/content/drive/MyDrive/SCRaPL/Synth/Results/Inflation/Beta/nuts_s_exp_'+str(a1)+'.pickle', 'wb') as handle:\n            pickle.dump(s_exp_nuts, handle)\n\n        with open('/content/drive/MyDrive/SCRaPL/Synth/Results/Inflation/Beta/nuts_inf_'+str(a1)+'.pickle', 'wb') as handle:\n            pickle.dump(infl_nuts, handle)\n\n        with open('/content/drive/MyDrive/SCRaPL/Synth/Results/Inflation/Beta/nuts_x_met_'+str(a1)+'.pickle', 'wb') as handle:\n 