In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Generate expression data from the fitted VAE.

In [None]:
from IPython import display
import pandas as pd
import numpy as np
import datetime
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from matplotlib import cm
from matplotlib.colors import Normalize 
from tensorflow import keras
import tensorflow as tf
import scipy.stats as stats
from tensorflow.keras import layers
import tensorflow_probability as tfp
from scipy.stats import gaussian_kde
import pickle

Folder = '/content/drive/MyDrive/'

Rna_train=pd.read_csv(Folder+'SCRaPL/Synth/VAE/Data/Rna_train_gastr_new.csv',",")
Rna_test=pd.read_csv(Folder+'SCRaPL/Synth/VAE/Data/Rna_test_gastr_new.csv',",")

In [None]:
class Sampling(layers.Layer):

    def __init__(self, name=None, **kwargs):
        super(Sampling, self).__init__(name=name)
        super(Sampling, self).__init__(**kwargs)

    def get_config(self):
        config = super(Sampling, self).get_config()
        return config

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


vae=keras.models.load_model(Folder+"SCRaPL/Synth/VAE/scVI_gastr.pb",custom_objects={'sampling': Sampling})

In [None]:
MRna_train=Rna_train.drop(columns=Rna_train.columns[0])
MRna_test=Rna_test.drop(columns=Rna_test.columns[0])

Rna_tr=tf.convert_to_tensor(MRna_train,dtype=tf.float32)
Rna_tst=tf.convert_to_tensor(MRna_test,dtype=tf.float32)

In [None]:
Rna = tf.concat([Rna_tr,Rna_tst],axis=0)

In [None]:
def gen_lat(mdl,data):

    rna = data
    _, _, z_rna = mdl.encoder_rna.predict(rna)
    _, _, lib_rna = mdl.lib.predict(rna)

    return z_rna,lib_rna

def gen_rts(mdl,data):
     
     z_rna, lib_rna = gen_lat(mdl,data)
     mu_rna, theta_rna, infl_rna =mdl.decoder_rna(z_rna)
     
     mu_rna_raw = mu_rna

     mu_rna = tf.math.exp(lib_rna+mu_rna)
     theta_rna = tf.math.exp(theta_rna)
     infl_rna = tf.keras.activations.sigmoid(infl_rna)
     p_rna = tf.math.divide(mu_rna,mu_rna+theta_rna)

     return mu_rna, theta_rna, infl_rna, p_rna, mu_rna_raw, lib_rna



In [None]:
def density_scatter_plot(x, y, **kwargs):
    """
    :param x: data positions on the x axis
    :param y: data positions on the y axis
    :return: matplotlib.collections.PathCollection object
    """
    # Kernel Density Estimate (KDE)
    values = np.vstack((x, y))
    kernel = gaussian_kde(values)
    kde = kernel.evaluate(values)

    # create array with colors for each data point
    norm = Normalize(vmin=kde.min(), vmax=kde.max())
    colors = cm.ScalarMappable(norm=norm, cmap='viridis').to_rgba(kde)

    # override original color argument
    kwargs['color'] = colors

    mm=tf.math.maximum(tf.reduce_max(x),tf.reduce_max(y))
    xx=tf.range(0, mm, 1)

    plt.scatter(x, y, **kwargs)
    plt.plot(xx,xx,color='red')
    plt.colorbar(orientation='vertical')
    return plt

In [None]:
mu,r,pi,p_rna,mu_nrm,lib =  gen_rts(vae,Rna) 

NBin_rna=tfp.distributions.NegativeBinomial(total_count=r,probs=p_rna)
RNA_pred_tst=NBin_rna.sample()
Infl = tfp.distributions.Bernoulli(probs=pi)
iden=Infl.sample()
RNA_pred_tst = tf.where(tf.math.greater(iden,0), tf.zeros_like(RNA_pred_tst), RNA_pred_tst)
llk = tf.reduce_sum(NBin_rna.log_prob(Rna),axis=1)



In [None]:
num_genes = 300
num_cells = -800
llk_sort_ind = np.argsort(llk)
gene_ind = tf.random.shuffle(tf.range(mu.shape[1]))[:num_genes]
x_exp_lt = tf.gather(mu_nrm,llk_sort_ind[num_cells:],axis=0)
x_exp_lt = tf.gather(x_exp_lt,gene_ind,axis=1)
lib_lt = tf.gather(lib,llk_sort_ind[num_cells:])

with open(Folder+'SCRaPL/Synth/VAE/x_exp_scVI_'+str(num_genes)+'_'+str(-num_cells)+'.pickle', 'wb') as handle:
    pickle.dump(x_exp_lt, handle)
with open(Folder+'SCRaPL/Synth/VAE/lib_scVI_'+str(num_genes)+'_'+str(-num_cells)+'.pickle', 'wb') as handle:
    pickle.dump(lib_lt, handle)