In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


This script is used to perform inference without inflation on mEBC data.

In [None]:
from IPython import display
import pandas as pd
import numpy as np
import scipy
import scipy.stats

from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from tensorflow import keras

from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import pickle
from timeit import default_timer as timer

In [None]:
tfd = tfp.distributions
tfb = tfp.bijectors
Folder = '/content/drive/MyDrive/'

In [None]:
#Load accessibility, expression data and cell normalization constants
yy_acc_pd = pd.read_csv(Folder+'SCRaPL/Real/Data/Acc_atac.csv',',',header=[0],index_col=[0])
yy_exp_pd = pd.read_csv(Folder+'SCRaPL/Real/Data/Rna_atac.csv',',',header=[0],index_col=[0])
Norm = pd.read_csv(Folder+'SCRaPL/Real/Data/Nrm_atac.csv',',',index_col=[1])
Norm = Norm.drop(columns=Norm.columns[0])

In [None]:
yy_acc = tf.convert_to_tensor(yy_acc_pd,dtype=tf.float32)
yy_exp = tf.convert_to_tensor(yy_exp_pd,dtype=tf.float32)
Norm_acc = tf.convert_to_tensor(Norm.iloc[:,0],dtype=tf.float32)
Norm_exp = tf.convert_to_tensor(Norm.iloc[:,1],dtype=tf.float32)

In [None]:
x_genes,x_cells = tf.shape(yy_acc)

In [None]:
aff = tfb.Chain([tfb.Shift(-1.),tfb.Scale(scale=2.)])
aff_inv = tfb.Invert(aff)

exp = tfb.Exp()
log = tfb.Invert(exp)

tanh = tfb.Tanh()
tanh_inv = tfb.Invert(tanh)

sigm = tfb.Sigmoid()
sigm_inv = tfb.Invert(sigm)

cor_trsf = tfb.Chain([aff_inv,tanh,tfb.Scale(scale=0.5)])
cor_trsf_inv = tfb.Invert(cor_trsf)

eps=0.001
bin_bij = tfb.Chain([tfb.Shift(eps/2.0),tfb.Scale(scale=1.0-eps),tfb.NormalCDF()])

cor_bij = tfb.Chain([tanh,tfb.Scale(scale=0.5)])
std_bij = tfb.Chain([exp,tfb.Scale(scale=-1.0)])
sqr_bij = tfb.Square()


In [None]:
#SCRaPL's graphical model
def SCRaPL(N_genes,N_cells,Nrm_acc,Nrm_rna):
    def prior():
        cor_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 = 15.0*tf.ones([N_genes,1]), concentration1=15.0*tf.ones([N_genes,1])), bijector= cor_trsf_inv, name = "cor_lt" )
        m_acc_lt = yield tfd.Normal(loc=3*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_acc_lt")
        m_exp_lt = yield tfd.Normal(loc=4*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_exp_lt")
        s_acc_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log, name = "s_acc_lt" )
        s_exp_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log , name = "s_exp_lt")

        cor = cor_bij.forward(cor_lt)
        s_acc = std_bij.forward(s_acc_lt)
        s_exp = std_bij.forward(s_exp_lt)
        
        mm_acc = tf.math.multiply( m_acc_lt,tf.ones([N_genes,N_cells]))
        mm_exp = tf.math.multiply( m_exp_lt,tf.ones([N_genes,N_cells]))
        ss_acc = tf.math.multiply( s_acc   ,tf.ones([N_genes,N_cells]))
        ss_exp = tf.math.multiply( s_exp   ,tf.ones([N_genes,N_cells]))
        ccor =   tf.math.multiply( cor     ,tf.ones([N_genes,N_cells]))  

        nrm_acc = tf.math.multiply( log.forward(Nrm_acc),tf.ones([N_genes,1]))
        nrm_rna = tf.math.multiply( log.forward(Nrm_rna),tf.ones([N_genes,1]))

        x_acc = yield tfd.Normal(loc = mm_acc, scale = ss_acc,name="x_acc")
        m_cnd_exp = mm_exp+tf.math.multiply(tf.math.divide(tf.math.multiply(ss_exp,x_acc-mm_acc),ss_acc),ccor)
        s_cnd_exp = tf.math.sqrt(tf.math.multiply(1-tf.math.square(ccor),tf.math.square(ss_exp)))

        x_exp = yield tfd.Normal(loc = m_cnd_exp, scale = s_cnd_exp,name="x_exp")

        y_acc = yield tfd.Poisson(log_rate=x_acc+nrm_acc, name="y_acc")
        y_exp = yield tfd.Poisson(log_rate=x_exp+nrm_exp, name="y_exp")

    comp_var_coroutine = tfd.JointDistributionCoroutineAutoBatched(prior)
    return comp_var_coroutine


In [None]:
#As the dataset is too large to fit in memory we utilize genomic region independence assumption to perform sample chunks of 10 in each turn.
#This also makes it easier to tune parameters to target acceptance rate as the sampling space is reduced. Depending computational resources
#chunk size (www) has to be tuned  

prt =  1
time = []
skp_ind = []
www = 10
prts_tot = tf.math.ceil(x_genes/www)
jj = 0
while prt < prts_tot+1:
          aa =  (prt-1)*www
          aa1 = prt*www
          yy_acc_prt = yy_acc[aa:aa1,:]
          yy_exp_prt = yy_exp[aa:aa1,:]
          batch_ft = tf.shape(yy_exp_prt)[0]

          mdl_tr = SCRaPL(batch_ft,x_cells,Norm_acc,Norm_exp)
          init = mdl_tr.sample()
          vrr = [init[0],init[1],init[2],init[3],init[4],init[5],init[6]]

          unconstrained_bijectors = [tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity()]
          @tf.function(autograph=False, jit_compile=True) 

          def sample_nuts():

                  init_x = vrr
                  num_burnin_iter = 3000
                  num_warmup_iter = int(0.8*num_burnin_iter)
                  num_chain_iter = 2000

                  target_accept_rate = 0.65 

                  log_post = lambda x0,x1,x2,x3,x4,x5,x6: mdl_tr.log_prob(x0,x1,x2,x3,x4,x5,x6,yy_acc_prt,yy_exp_prt)

                  def trace_fn(_, pkr):
                        return (
                            pkr.inner_results.inner_results.target_log_prob,
                            pkr.inner_results.inner_results.leapfrogs_taken,
                            pkr.inner_results.inner_results.has_divergence,
                            pkr.inner_results.inner_results.energy,
                            pkr.inner_results.inner_results.log_accept_ratio,
                            pkr.inner_results.inner_results.step_size
                              )
                  nuts= tfp.mcmc.NoUTurnSampler(
                                            target_log_prob_fn=log_post,
                                            step_size=0.05,
                                            max_tree_depth=6
                                                ) 
                  ttk = tfp.mcmc.TransformedTransitionKernel(
                                            inner_kernel=nuts,
                                            bijector=unconstrained_bijectors
                                                                )
                  adapted_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
                                            inner_kernel=ttk,
                                            num_adaptation_steps=num_warmup_iter,
                                            target_accept_prob= target_accept_rate)
                  
                  states , sampler_stat =tfp.mcmc.sample_chain(
                                    num_results=num_chain_iter,
                                    num_burnin_steps=num_burnin_iter,
                                    current_state=init_x,
                                    kernel=adapted_kernel,
                                    trace_fn=trace_fn) 

                  return states, sampler_stat

          start = timer()
          samples, sampler_stat = sample_nuts() 
          end = timer()
          ttime = end-start

          p_accept = tf.math.exp(tfp.math.reduce_logmeanexp(tf.minimum(sampler_stat[4], 0.)))
          stp_sz = sampler_stat[5][0]
          hmc_cor,hmc_m_acc,hmc_m_exp,hmc_s_acc,hmc_s_exp,hmc_x_acc,hmc_x_exp = samples

          crr_nuts = tf.squeeze(hmc_cor)
          s_exp_nuts = tf.squeeze(hmc_s_exp)
          s_acc_nuts = tf.squeeze(hmc_s_acc)
          m_acc_nuts = tf.squeeze(hmc_m_acc)
          m_exp_nuts = tf.squeeze(hmc_m_exp)

          x_acc_mn = tf.reduce_mean(hmc_x_acc,axis=2)
          x_exp_mn = tf.reduce_mean(hmc_x_exp,axis=2)

          qc_acc_rt = tf.math.logical_and(p_accept<0.9,p_accept>0.4)
          qc_stp_sz = stp_sz>0.00001

          if tf.math.logical_and(qc_acc_rt,qc_stp_sz)== True:
                time.append(ttime)
                with open(Folder+'SCRaPL/Real/Results_atac/nuts_cor_atac_no_inf_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(crr_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_atac/nuts_m_acc_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(m_acc_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_atac/nuts_m_exp_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(m_exp_nuts, handle)

                with open(Folder+'SCRaPL/Real/Results_atac/nuts_s_acc_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(s_acc_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_atac/nuts_s_exp_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(s_exp_nuts, handle)
                with open(Folder+'SCRaPL/Real/Results_atac/log_prob_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(sampler_stat[0], handle)
                with open(Folder+'SCRaPL/Real/Results_atac/avg_exp_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(x_exp_mn, handle)
                with open(Folder+'SCRaPL/Real/Results_atac/avg_acc_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                    pickle.dump(x_acc_mn, handle)                   

                #with open(Folder+'SCRaPL/Real/Results_atac/nuts_x_acc_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                #    pickle.dump(tf.squeeze(hmc_x_acc), handle)
                #with open(Folder+'SCRaPL/Real/Results_atac/nuts_x_exp_no_inf_atac_'+str(prt)+'.pickle', 'wb') as handle:
                #    pickle.dump(tf.squeeze(hmc_x_exp), handle)
                
                avg_time = tf.reduce_mean(tf.stack(time))
                rem_time = (tf.cast(prts_tot,dtype=tf.float32)-tf.cast(prt,dtype=tf.float32)) *avg_time
                print(tf.squeeze([rem_time,prt,stp_sz,p_accept]),tf.math.logical_and(qc_acc_rt,qc_stp_sz))
                prt+=1
                jj=0.0
          else:
                time.append(ttime)
                
                if jj>3.0:
                    skp_ind.append(prt)
                    prt+=1
                    jj=0.0

                avg_time = tf.reduce_mean(tf.stack(time))
                rem_time = (tf.cast(prts_tot,dtype=tf.float32)-tf.cast(prt,dtype=tf.float32)) *avg_time
                print(tf.squeeze([rem_time,prt,stp_sz,p_accept]),tf.math.logical_and(qc_acc_rt,qc_stp_sz))
                jj+=1
     

tf.Tensor([1.9509584e+04 3.0000000e+02 4.0732618e-04 4.7486001e-01], shape=(4,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)


Exception ignored in: <function ScopedTFGraph.__del__ at 0x7f48cd005710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/c_api_util.py", line 57, in __del__
    def __del__(self):
KeyboardInterrupt
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-4bfded097d14>", line 65, in <module>
    samples, sampler_stat = sample_nuts()
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py", line 760, in _initialize
    *args, **kwds))
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py", line 3066, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/usr/local/lib/pytho

KeyboardInterrupt: ignored