In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Demo for 10X large data. In large datasets (ie >= 100k SCRaPL parameters) is common to run out of memory or observe ineffective space exploration. NUTS parameters are tunned to achieve prespecified acceptance probability. In high dimensions this might translate to tiny step sizes, making the state exploration ineffective. To overcome that, we exploit the feature independence to reduce the dimensionality of the problem. More precisely, we infer posterior in batches (without replacement).

In [None]:
from IPython import display
import pandas as pd
import numpy as np
import scipy
import scipy.stats

from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from tensorflow import keras

from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import pickle
from timeit import default_timer as timer

In [None]:
tfd = tfp.distributions
tfb = tfp.bijectors
Folder = '/content/drive/MyDrive/SCRaPL/'

In [None]:
a1 = 1000 #Number of genes
a2 = 800 #Number of cells
prts = 20

In [None]:
#Define bijectors variables transformation to sample in an unbounded support.
aff = tfb.Chain([tfb.Shift(-1.),tfb.Scale(scale=2.)])
aff_inv = tfb.Invert(aff)

exp = tfb.Exp()
log = tfb.Invert(exp)

tanh = tfb.Tanh()
tanh_inv = tfb.Invert(tanh)

sigm = tfb.Sigmoid()
sigm_inv = tfb.Invert(sigm)

cor_trsf = tfb.Chain([aff_inv,tanh,tfb.Scale(scale=0.5)])
cor_trsf_inv = tfb.Invert(cor_trsf)

eps=0.001
bin_bij = tfb.Chain([tfb.Shift(eps/2.0),tfb.Scale(scale=1.0-eps),tfb.NormalCDF()])

cor_bij = tfb.Chain([tanh,tfb.Scale(scale=0.5)])
std_bij = tfb.Chain([exp,tfb.Scale(scale=-1.0)])
sqr_bij = tfb.Square()

In [None]:
#Define SCRaPL's graphical model.

def SCRaPL(N_genes,N_cells,Nrm_acc,Nrm_rna):
    def model():
        cor_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 = 15.0*tf.ones([N_genes,1]), concentration1=15.0*tf.ones([N_genes,1])), bijector= cor_trsf_inv, name = "cor_lt" )
        m_acc_lt = yield tfd.Normal(loc=3*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_acc_lt")
        m_exp_lt = yield tfd.Normal(loc=4*tf.ones([N_genes,1]),scale=tf.ones([N_genes,1]), name = "m_exp_lt")
        s_acc_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log, name = "s_acc_lt" )
        s_exp_lt = yield tfd.TransformedDistribution( distribution = tfd.InverseGamma(concentration=2.5*tf.ones([N_genes,1]),scale=4.5*tf.ones([N_genes,1])),bijector= log , name = "s_exp_lt")
        infl_acc_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 =8.0*tf.ones([N_genes,1]), concentration1=2.0*tf.ones([N_genes,1])), bijector= sigm_inv, name = "infl_acc_lt" )
        infl_rna_lt = yield tfd.TransformedDistribution( distribution = tfd.Beta( concentration0 =8.0*tf.ones([N_genes,1]), concentration1=2.0*tf.ones([N_genes,1])), bijector= sigm_inv, name = "infl_rna_lt" )

        cor = cor_bij.forward(cor_lt)
        s_acc = std_bij.forward(s_acc_lt)
        s_exp = std_bij.forward(s_exp_lt)
        infl_acc = sigm.forward(infl_acc_lt)
        infl_rna = sigm.forward(infl_rna_lt)
        
        mm_acc = tf.math.multiply( m_acc_lt,tf.ones([N_genes,N_cells]))
        mm_exp = tf.math.multiply( m_exp_lt,tf.ones([N_genes,N_cells]))
        ss_acc = tf.math.multiply( s_acc   ,tf.ones([N_genes,N_cells]))
        ss_exp = tf.math.multiply( s_exp   ,tf.ones([N_genes,N_cells]))
        ccor =   tf.math.multiply( cor     ,tf.ones([N_genes,N_cells]))
        p_acc =  tf.math.multiply( infl_acc,tf.ones([N_genes,N_cells]))
        p_rna =  tf.math.multiply( infl_rna,tf.ones([N_genes,N_cells]))  

        nrm_acc = log.forward(Nrm_acc)
        nrm_rna = log.forward(Nrm_rna)#tf.math.multiply( ,tf.ones([N_genes,1]))

        x_acc = yield tfd.Normal(loc = mm_acc, scale = ss_acc,name="x_acc")
        m_cnd_exp = mm_exp+tf.math.multiply(tf.math.divide(tf.math.multiply(ss_exp,x_acc-mm_acc),ss_acc),ccor)
        s_cnd_exp = tf.math.sqrt(tf.math.multiply(1-tf.math.square(ccor),tf.math.square(ss_exp)))

        x_exp = yield tfd.Normal(loc = m_cnd_exp, scale = s_cnd_exp,name="x_exp")

        pp_acc = tf.stack([p_acc,1-p_acc],axis=-1)
        x_acc_lt = tf.stack([-20*tf.ones_like(x_acc),x_acc+nrm_acc],axis=-1)

        pp_rna = tf.stack([p_rna,1-p_rna],axis=-1)
        x_exp_lt = tf.stack([-20*tf.ones_like(x_exp),x_exp+nrm_rna],axis=-1)

        y_acc = yield tfd.MixtureSameFamily(
                                                          mixture_distribution = tfd.Categorical(probs=pp_acc),
                                                          components_distribution = tfd.Poisson(log_rate=x_acc_lt),
                                                          name="y_acc")
        y_exp = yield tfd.MixtureSameFamily(
                                                          mixture_distribution = tfd.Categorical(probs=pp_rna),
                                                          components_distribution = tfd.Poisson(log_rate=x_exp_lt),
                                                          name="y_exp")

    comp_var_coroutine = tfd.JointDistributionCoroutineAutoBatched(model)
    return comp_var_coroutine

In [None]:
#Use NUTS to sample from the model conditioned on observations.

#Inference
@tf.function(autograph=False, jit_compile=True) 
def sample_nuts(model,acc,exp):

      unconstrained_bijectors = [tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity(),tfb.Identity()]

      init = model.sample()
      vrr = [init[0],init[1],init[2],init[3],init[4],init[5],init[6],init[7],init[8]]

      init_x = vrr
      num_burnin_iter = 3000
      num_warmup_iter = int(0.8*num_burnin_iter)
      num_chain_iter = 2000

      target_accept_rate =0.651

      log_post = lambda x0,x1,x2,x3,x4,x5,x6,x7,x8: model.log_prob(x0,x1,x2,x3,x4,x5,x6,x7,x8,acc,exp)

      def trace_fn(_, pkr):
            return (
                pkr.inner_results.inner_results.target_log_prob,
                pkr.inner_results.inner_results.leapfrogs_taken,
                pkr.inner_results.inner_results.has_divergence,
                pkr.inner_results.inner_results.energy,
                pkr.inner_results.inner_results.log_accept_ratio,
                pkr.inner_results.inner_results.step_size
                  )
      nuts=tfp.mcmc.NoUTurnSampler(
                                target_log_prob_fn=log_post,
                                step_size=0.05,
                                max_tree_depth=6
                                    )
      ttk = tfp.mcmc.TransformedTransitionKernel(
                                inner_kernel=nuts,
                                bijector=unconstrained_bijectors
                                                    )
      
      adapted_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
                                          inner_kernel=ttk,
                                          num_adaptation_steps=num_warmup_iter,
                                          target_accept_prob= target_accept_rate)
      
      states , sampler_stat =tfp.mcmc.sample_chain(
                        num_results=num_chain_iter,
                        num_burnin_steps=num_burnin_iter,
                        current_state=init_x,
                        kernel=adapted_kernel,
                        trace_fn=trace_fn)
      return states, sampler_stat

In [None]:
jj = 0
prt = 1
while prt < prts+1:

                #Define hyper-parameters like number of genes and cells. Since we do not have raw coverage data and cell specific normalization constants, we generate them.
                x_genes = tf.constant(tf.cast(a1/prts, dtype=tf.int32),dtype=tf.int32) 
                x_cells = tf.constant(a2,dtype=tf.int32) 

                #######################################################################################################
                #In case of real data Norm_acc and Norm_exp should be replaced with the relevant chunk of data.       #
                Norm_acc = tf.random.uniform(shape=(1,x_cells),minval=0.5,maxval=1.5,dtype=tf.float32)                #
                Norm_exp = tf.random.uniform(shape=(1,x_cells),minval=0.5,maxval=1.5,dtype=tf.float32)                #
                #######################################################################################################

                mdl_tr = SCRaPL(x_genes,x_cells,Norm_acc,Norm_exp)

                #########################################################################################################################################
                #Sample data from the generative model. If expression (yy_exp) and methylation (yy_met) readings are available this step can be ommited.#
                cor,mm_acc,mm_exp,ss_acc,ss_exp,infl_acc,infl_exp,xx_acc,xx_exp,yy_acc,yy_exp = mdl_tr.sample()                                         #
                #########################################################################################################################################

                start = timer()
                samples, sampler_stat =  sample_nuts(mdl_tr,yy_acc,yy_exp)
                end = timer()
                ttime = end-start

                #Estimate acceptance probability and step size.
                p_accept = tf.math.exp(tfp.math.reduce_logmeanexp(tf.minimum(sampler_stat[4], 0.)))
                stp_sz = sampler_stat[5][0]
                hmc_cor,hmc_m_acc,hmc_m_exp,hmc_s_acc,hmc_s_exp,hmc_inf_acc,hmc_inf_exp,hmc_x_acc,hmc_x_exp = samples

                crr_nuts = tf.squeeze(hmc_cor)
                m_acc_nuts = tf.squeeze(hmc_m_acc)
                m_exp_nuts = tf.squeeze(hmc_m_exp)
                s_acc_nuts = tf.squeeze(hmc_s_acc)
                s_exp_nuts = tf.squeeze(hmc_s_exp)
                infl_acc_nuts  = tf.squeeze(hmc_inf_acc)
                infl_exp_nuts  = tf.squeeze(hmc_inf_exp)
                x_acc_nuts = tf.squeeze(hmc_x_acc)
                x_exp_nuts = tf.squeeze(hmc_x_exp)

                qc_acc_rt = tf.math.logical_and(p_accept<0.9,p_accept>0.4)
                qc_stp_sz = stp_sz>0.00001


                #Save generating parameters and posterior samples.

                if tf.math.logical_and(qc_acc_rt,qc_stp_sz)== True:
                      with open(Folder+'Demo/Results/nuts_cor_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(crr_nuts, handle)

                      with open(Folder+'Demo/Results/nuts_m_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(m_acc_nuts, handle)
                      with open(Folder+'Demo/Results/nuts_m_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(m_exp_nuts, handle)

                      with open(Folder+'Demo/Results/nuts_s_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(s_acc_nuts, handle)
                      with open(Folder+'Demo/Results/nuts_s_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(s_exp_nuts, handle)

                      with open(Folder+'Demo/Results/nuts_inf_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(infl_acc_nuts, handle)
                      with open(Folder+'Demo/Results/nuts_inf_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(infl_exp_nuts, handle)

                      print(tf.squeeze([ttime,stp_sz,p_accept]),tf.math.logical_and(qc_acc_rt,qc_stp_sz))

                      with open(Folder+'Demo/Data/cor_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(cor, handle)
                      with open(Folder+'Demo/Data/mm_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(mm_acc, handle)
                      with open(Folder+'Demo/Data/mm_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(mm_exp, handle)

                      with open(Folder+'Demo/Data/ss_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(ss_acc, handle)
                      with open(Folder+'Demo/Data/ss_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(ss_exp, handle)

                      with open(Folder+'Demo/Data/infl_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(infl_acc, handle)
                      with open(Folder+'Demo/Data/infl_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(infl_exp, handle)
                      
                      with open(Folder+'Demo/Data/xx_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(xx_acc, handle)
                      with open(Folder+'Demo/Data/xx_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(xx_exp, handle)

                      with open(Folder+'Demo/Data/yy_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(yy_acc, handle)
                      with open(Folder+'Demo/Data/yy_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(yy_exp, handle)

                      with open(Folder+'Demo/Data/Norm_acc_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(Norm_acc, handle)  
                      with open(Folder+'Demo/Data/Norm_exp_'+str(a1)+'_'+str(a2)+'_10X_'+str(prt)+'.pickle', 'wb') as handle:
                          pickle.dump(Norm_exp, handle) 

                      prt+=1
                      jj=0
                else:
                      #If acceptance probability or step size are outside reasonable values we retry up to three times
                      if jj>3.0:
                          prt+=1
                          jj=0.0
                      print("Attempt "+str(jj+1)+" failed for batch "+str(prt))
                      print("Retrying...")
                      jj+=1

tf.Tensor([3.3146463e+02 1.4890161e-03 7.3972487e-01], shape=(3,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor([3.3832516e+02 2.4875011e-03 6.2752193e-01], shape=(3,), dtype=float32) tf.Tensor(True, shape=(), dtype=bool)
