# 0. Imports

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('WebAgg')
import numpy as np
import pandas as pd
import copy

In [2]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi/external


In [3]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi


***import ete3 Tree***

In [4]:
from ete3 import Tree

tree_name = "/home/eecs/khalil.ouardini/cas_scvi_topologies/newick_objects/500cells/high_fitness/topology1.nwk"
tree = Tree(tree_name, 1)

for i, n in enumerate(tree.traverse('levelorder')):
    n.add_features(index=i)
    n.name = str(i)

eps = 1e-3
branch_length = {}
for node in tree.traverse('levelorder'):
    if node.is_root():
        branch_length[node.name] = 0.0
    else:
        branch_length[node.name] = node.dist
branch_length['prior_root'] = 1.0

In [5]:
# Data
from anndata import AnnData
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from external.dataset.tree import TreeDataset, GeneExpressionDataset
from external.dataset.poisson_glm import Poisson_GLM
from external.dataset.anndataset import AnnDatasetFromAnnData

# Models
from models.vae import VAE
import scanpy as sc
from external.inference.tree_inference import TreeTrainer
from inference.inference import UnsupervisedTrainer
from scvi.inference import posterior
from external.models.treevae import TreeVAE

# Utils
from external.utils.data_util import get_leaves, get_internal
from external.utils.metrics import ks_pvalue, accuracy_imputation, correlations, knn_purity, knn_purity_stratified
from external.utils.plots_util import plot_histograms, plot_scatter_mean, plot_ecdf_ks, plot_density
from external.utils.plots_util import plot_losses, plot_elbo, plot_common_ancestor, plot_one_gene, training_dashboard
from external.utils.baselines import avg_weighted_baseline, scvi_baseline, scvi_baseline_z, cascvi_baseline_z, avg_baseline_z, construct_latent



In [6]:
import torch
    
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f0e8fb0f0b0>

# 1. Simulations (Poisson GLM)

In [20]:
d = 10
g = 1000
vis = False
leaves_only = False
var = 1.0
alpha = 1.0

glm = Poisson_GLM(tree, g, d, vis, leaves_only, branch_length, alpha)

glm.simulate_latent()

***Generate gene expression count data***

In [21]:
glm.simulate_ge(negative_binomial=False)
# Quality Control (i.e Gene Filtering)
#glm.gene_qc()

glm.X.shape, glm.W.shape, glm.beta.shape

((1000, 1000), (1000, 10), (1000,))

***Binomial thinning***

In [22]:
print("Proportion of dropouts: {}".format(np.mean(glm.X == 0)))
glm.binomial_thinning(p=0.1)
print("Proportion of dropouts after Binomial thinning: {}".format(np.mean(glm.X == 0)))

Proportion of dropouts: 0.40074
Proportion of dropouts after Binomial thinning: 0.74143


***Get the data and the indexes at the leaves***

In [23]:
# Latent vectors
leaves_z, _, _ = get_leaves(glm.z, glm.mu, tree)

#FIXED training set
leaves_X, leaves_idx, mu = get_leaves(glm.X, glm.mu, tree)

# internal nodes data (for imputation)
internal_X, internal_idx, internal_mu = get_internal(glm.X, glm.mu, tree)

# Additional data for clade samling
n_leaves_X = glm.generate_ge(n_samples=100, leaves_idx=leaves_idx)

leaves_X.shape, mu.shape, internal_X.shape, internal_mu.shape, leaves_z.shape

((500, 1000), (500, 1000), (500, 1000), (500, 1000), (500, 10))

# 2. Fitting CascVI

In [24]:
gene_dataset = GeneExpressionDataset()

leaves = [n for n in tree.traverse('levelorder') if n.is_leaf()]
cell_names = np.array([[n.name for n in leaves] * len(n_leaves_X)]).flatten()
gene_dataset.populate_from_data(X=np.vstack(n_leaves_X),
                               gene_names=[str(i) for i in range(glm.X.shape[1])])
gene_dataset.initialize_cell_attribute('barcodes', cell_names)

***Create a TreeDataset object***

In [25]:
# treeVAE
import copy

tree_bis = copy.deepcopy(tree)
cas_dataset = TreeDataset(gene_dataset, tree=tree_bis, filtering=False)

# No batches beacause of the message passing
use_cuda = True
use_MP = True
ldvae = False

cas_dataset

GeneExpressionDataset object with n_cells x nb_genes = 50000 x 1000
    gene_attribute_names: 'gene_names'
    cell_attribute_names: 'local_means', 'local_vars', 'labels', 'batch_indices', 'barcodes'
    cell_categorical_attribute_names: 'batch_indices', 'labels'

***Initialize model***

In [26]:
treevae = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=glm.latent,
              n_hidden=128,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=use_MP,
              use_clades=False
             )

In [27]:
import torch

freeze = False
if freeze:
    new_weight = torch.from_numpy(glm.W).float()
    new_bias = torch.from_numpy(glm.beta).float()

    with torch.no_grad():
        treevae.decoder.factor_regressor.fc_layers[0][0].weight = torch.nn.Parameter(new_weight)
        treevae.decoder.factor_regressor.fc_layers[0][0].bias = torch.nn.Parameter(new_bias)
        
    for param in treevae.decoder.factor_regressor.fc_layers[0][0].parameters():
        param.requires_grad = False

In [28]:
#assert(treevae.decoder.factor_regressor.fc_layers[0][0].weight.numpy().all() == glm.W.T.all())
#assert(treevae.decoder.factor_regressor.fc_layers[0][0].bias.numpy().all() == glm.beta.all())

***Are we able to generate the gene expression data by decoding the simulated latent space?***

In [29]:
px_scale, px_rate, raw_px_scale = treevae.decoder(treevae.dispersion,
                                        torch.from_numpy(leaves_z).float(),
                                        torch.from_numpy(np.array([np.log(10000)])).float()
                                       )

from sklearn.metrics import mean_squared_error

if ldvae:
    foo = np.clip(a=np.exp(raw_px_scale.detach().cpu().numpy()),
            a_min=0,
            a_max=1e8
    )
    mse = mean_squared_error(mu, foo)
else:
    mse = mean_squared_error(mu, px_rate.detach().numpy())

print("the distance between the Poisson and the NB means is {}".format(mse))

the distance between the Poisson and the NB means is 17083.17867126643


***Hyperparameters***

In [30]:
n_epochs = 1000
lr = 1e-3
lambda_ = 1.0

***trainer***

In [34]:
freq = 100
trainer = TreeTrainer(
    model = treevae,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

KeyboardInterrupt: 

***Start training***

In [None]:
trainer.train(n_epochs=n_epochs,
              lr=lr
              )

***Loss Functions***

In [85]:
training_dashboard(trainer, treevae.encoder_variance)

### 3. Posterior and MV imputation

In [32]:
from sklearn.metrics import mean_squared_error

full_posterior = trainer.create_posterior(trainer.model, cas_dataset, trainer.clades,
                                indices=np.arange(len(cas_dataset))
                                         )
error = mean_squared_error(full_posterior.get_latent(), leaves_z)
print("the distance is {}".format(error))

the distance is 0.8535115741934923


***Missing Value imputation By Posterior Predictive sampling***

Encoding test set

In [33]:
input = torch.from_numpy(leaves_X).float().to('cuda:0')
with torch.no_grad():
    known_latent = treevae.z_encoder(input)[2]

mean_squared_error(known_latent.cpu().numpy(), full_posterior.get_latent()), mean_squared_error(known_latent.cpu().numpy(), leaves_z), mean_squared_error(leaves_z, full_posterior.get_latent())

(266.05035, 276.98887707457567, 0.8525989193607199)

PP sampling

In [34]:
empirical_l = np.mean(np.sum(glm.X, axis=1))

# CascVI impitations
imputed = {}
imputed_z = {}
imputed_gt = {}

for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        imputed[n.name], imputed_z[n.name] = full_posterior.imputation_internal(n,
                                                            give_mean=False,
                                                            library_size=empirical_l)#,
                                                            #known_latent=known_latent
                                                           #)
        imputed_gt[n.name] = glm.X[n.index]

In [38]:
imputed_X = [x for x in imputed.values()]
imputed_X = np.array(imputed_X).reshape(-1, cas_dataset.X.shape[1])
#plot_histograms(imputed_X, "Histogram of CasscVI imputed gene expression data")

***CascVI Baseline 1 (MP Oracle)***

In [39]:
# CascVI impitations
imputed_cascvi_1 = {}
imputed_cascvi_1_z ={}

for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        _, imputed_cascvi_1_z[n.name] = full_posterior.imputation_internal(n,
                                                                    give_mean=False,
                                                                    library_size=empirical_l,
                                                                    known_latent=leaves_z
        )
        mu_z = np.clip(a=np.exp(glm.W @ imputed_cascvi_1_z[n.name].cpu().numpy() + glm.beta),
                        a_min=0,
                        a_max=1e8
                        )
        samples = np.array([np.random.poisson(mu_z) for i in range(100)])
        imputed_cascvi_1[n.name] = np.clip(a=np.mean(samples, axis=0),
                                           a_min=0,
                                           a_max=1e8
                                           )


***CascVI Baseline 2 (Reconstruction of Averaged latent space)***

In [40]:
imputed_cascvi_2, imputed_cascvi_2_z = avg_baseline_z(tree=tree,
                                   model=treevae,
                                   posterior=full_posterior,
                                   weighted=False,
                                   n_samples_z=1,
                                   library_size=empirical_l,
                                   gaussian=False,
                                   use_cuda=True)#,
                                   #known_latent=True,
                                   #latent=np.array([known_latent.cpu().numpy()])
                                  #)

# 4. Baselines

### Baseline 1: Unweighted Average of gene expression in Clade

The simple idea here is to impute the value of an internal node, with the (un)weighted average of the gene expression values of the leaves, taking the query internal node as the root of the subtree.

In [41]:
weighted = False
imputed_avg = avg_weighted_baseline(tree, weighted, glm.X, rounding=True)

#get internal nodes
avg_X = np.array([x for x in imputed_avg.values()]).reshape(-1, glm.X.shape[1])
internal_avg_X, _, _ = get_internal(avg_X, glm.mu, tree)

### Baseline 2: (Un)weighted Average of decoded latent vectors, with scVI

We use the same averaging of the subtrees leaves in **Baseline 1**, only this time, the gene expression data is recovered with scVI

In [42]:
# anndata
gene_dataset = GeneExpressionDataset()
gene_dataset.populate_from_data(leaves_X)

In [43]:
import torch

n_epochs =600
use_batches = False

vae = VAE(gene_dataset.nb_genes,
                  n_batch=cas_dataset.n_batches * use_batches,
                  n_hidden=128,
                  n_layers=1,
                  reconstruction_loss='poisson',
                  n_latent=glm.latent,
                  ldvae=ldvae
              )

if freeze:
    new_weight = torch.from_numpy(glm.W).float()
    new_bias = torch.from_numpy(glm.beta).float()

    with torch.no_grad():
        vae.decoder.factor_regressor.fc_layers[0][0].weight = torch.nn.Parameter(new_weight)
        vae.decoder.factor_regressor.fc_layers[0][0].bias = torch.nn.Parameter(new_bias)
        
    for param in vae.decoder.factor_regressor.fc_layers[0][0].parameters():
        param.requires_grad = False

In [44]:
px_scale, px_r, px_rate, px_dropout = vae.decoder.forward(vae.dispersion,
                                        torch.from_numpy(leaves_z).float(),
                                        torch.from_numpy(np.array([np.log(10000)])).float(),
                                        None
                                        )

from sklearn.metrics import mean_squared_error



if ldvae:
    foo = np.clip(a=np.exp(px_r.detach().numpy()),
            a_min=0,
            a_max=5000
    )
    mse = mean_squared_error(mu, foo)
else:
    mse = mean_squared_error(mu, px_rate.detach().numpy())

print("the distance between the Poisson and the NB means is {}".format(mse))

the distance between the Poisson and the NB means is 416.5329559785676


In [45]:
trainer_scvi = UnsupervisedTrainer(model=vae,
                              gene_dataset=gene_dataset,
                              train_size=1.0,
                              use_cuda=use_cuda,
                              frequency=10,
                              n_epochs_kl_warmup=200
                              )

# train scVI
trainer_scvi.train(n_epochs=n_epochs, lr=1e-3) 
                                        
elbo_train_scvi = trainer_scvi.history["elbo_train_set"]
x = np.linspace(0, 100, (len(elbo_train_scvi)))
plt.plot(np.log(elbo_train_scvi), 
         label="train", color='blue',
         linestyle=':',
         linewidth=3
        )
        
plt.xlabel('Epoch')
plt.ylabel("ELBO")
plt.legend()
plt.title("Train history scVI")
plt.show()

training: 100%|██████████| 600/600 [00:28<00:00, 20.84it/s]
Press Ctrl+C to stop WebAgg server


RuntimeError: This event loop is already running

In [52]:
scvi_posterior = trainer_scvi.create_posterior(model=vae,
                                               gene_dataset=gene_dataset 
                                                )

error = mean_squared_error(scvi_posterior.get_latent()[0], leaves_z)
print("the distance is {}".format(error))

the distance is 1.8439556837797675


Encode test data

In [53]:
with torch.no_grad():
    known_latent = vae.z_encoder(input)[2].cpu().numpy()

mean_squared_error(known_latent, full_posterior.get_latent()), mean_squared_error(known_latent, leaves_z), mean_squared_error(leaves_z, full_posterior.get_latent())

(54.08839, 54.37805583585522, 0.8575174677695735)

***scVI Baseline 2 (Decoded Average Latent space)***

In [54]:
library_size = np.mean(np.sum(glm.X, axis=1))

scvi_latent = np.array([scvi_posterior.get_latent(give_mean=False)[0] for i in range(10)])

imputed_scvi_2, imputed_scvi_2_z = scvi_baseline_z(tree,
                                        posterior=scvi_posterior,
                                        model=vae,
                                        weighted=False,
                                        n_samples_z=1,
                                        library_size=library_size,
                                        use_cuda=True)#,
                                        #known_latent=True,
                                        #latent=np.expand_dims(known_latent, axis=0)
                                        #)


# 5. Likelihood Ratio

In [55]:
cascvi_latent = full_posterior.get_latent()
scvi_latent = scvi_posterior.get_latent()[0]

scvi_latent.shape, cascvi_latent.shape

((500, 10), (500, 10))

In [56]:
treevae.initialize_visit()
treevae.initialize_messages(scvi_latent, cas_dataset.barcodes, scvi_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), scvi_latent.shape[1], False)
mp_lik_scvi = treevae.aggregate_messages_into_leaves_likelihood(d, add_prior=True)
print("Likelihood of scVI encodings: ", mp_lik_scvi.item())

Likelihood of scVI encodings:  -20368.280447135054


In [57]:
treevae.initialize_visit()
treevae.initialize_messages(cascvi_latent, cas_dataset.barcodes, cascvi_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), cascvi_latent.shape[1], False)
mp_lik_cascvi = treevae.aggregate_messages_into_leaves_likelihood(d, add_prior=True)
print("Likelihood of cascVI encodings: ", mp_lik_cascvi.item())

Likelihood of cascVI encodings:  -105.08357084181719


In [58]:
treevae.initialize_visit()
treevae.initialize_messages(leaves_z, cas_dataset.barcodes, cascvi_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), cascvi_latent.shape[1], False)
mp_lik_cascvi = treevae.aggregate_messages_into_leaves_likelihood(d, add_prior=True)
print("Likelihood of observations: ", mp_lik_cascvi.item())

Likelihood of observations:  -139.91774090346644


In [59]:
# Likelihood ratio
lambda_ = (mp_lik_cascvi - mp_lik_scvi)
print("Likelihood Ratio:", lambda_)

Likelihood Ratio: tensor(20228.3627, dtype=torch.float64)


# 6. Evaluation

***CPM Normalization (for sample-sample correlation)***

get imputations into an array

In [60]:
internal_scvi_X_2 = np.array([x for x in imputed_scvi_2.values()]).reshape(-1, glm.X.shape[1])
internal_cascvi_X = np.array([x for x in imputed_cascvi_1.values()]).reshape(-1, glm.X.shape[1])
internal_cascvi_X_2 = np.array([x for x in imputed_cascvi_2.values()]).reshape(-1, glm.X.shape[1])

internal_cascvi_X_2.shape, internal_cascvi_X.shape, internal_scvi_X_2.shape, imputed_X.shape, internal_avg_X.shape, internal_X.shape

((500, 1000), (500, 1000), (500, 1000), (500, 1000), (500, 1000), (500, 1000))

In [61]:
from sklearn.preprocessing import normalize

norm_internal_X = sc.pp.normalize_total(AnnData(internal_X), target_sum=1e4, inplace=False)['X'] 
norm_scvi_X_2 = sc.pp.normalize_total(AnnData(internal_scvi_X_2), target_sum=1e4, inplace=False)['X']
norm_avg_X = sc.pp.normalize_total(AnnData(internal_avg_X), target_sum=1e4, inplace=False)['X']
norm_imputed_X = sc.pp.normalize_total(AnnData(imputed_X), target_sum=1e4, inplace=False)['X']
norm_cascvi_X = sc.pp.normalize_total(AnnData(internal_cascvi_X), target_sum=1e4, inplace=False)['X']
norm_cascvi_X_2 = sc.pp.normalize_total(AnnData(internal_cascvi_X_2), target_sum=1e4, inplace=False)['X']

norm_internal_X.shape

(500, 1000)

## I. Sample-Sample Correlations

***1. Sample-Sample correlation (Without Normalization)***

We will use Scipy to compute a nonparametric rank correlation between the imputed and the groundtruth profiles. The correlation is based on the Spearman Correlation Coefficient.

In [62]:
data = {'groundtruth': internal_X.T, 'cascVI': imputed_X.T, 'scVI': internal_scvi_X_2.T,
        'Average': internal_avg_X.T , 'cascVI + Avg': internal_cascvi_X_2.T,
        'MP Oracle': internal_cascvi_X.T
        }
df1 = correlations(data, 'None', True)
#df1.head(5)
#plt.show()

***2. Sample-Sample correlation (With ScanPy Normalization)***

In [63]:
data = {'groundtruth': norm_internal_X.T, 'cascVI': norm_imputed_X.T, 'scVI': norm_scvi_X_2.T, 
        'Average': norm_avg_X.T , 'cascVI + Avg': norm_cascvi_X_2.T,
        'MP Oracle': norm_cascvi_X.T
        }

df2 = correlations(data, 'None', True)
#df2.head(5)
#plt.show()


## II. Gene-Gene Correlations

***2. Gene-Gene correlation (With Normalization)***

In [64]:
data = {'groundtruth': internal_X, 'cascVI': imputed_X, 'scVI': internal_scvi_X_2,
        'Average': internal_avg_X , 'cascVI + Avg': internal_cascvi_X_2,
        'MP Oracle': internal_cascvi_X
        }

df3 = correlations(data, 'None', True)
#df3.head(5)
#plt.show()

***2. Gene-Gene correlation (With Normalization)***

In [65]:
data = {'groundtruth': norm_internal_X, 'cascVI': norm_imputed_X, 'scVI': norm_scvi_X_2, 
        'Average': norm_avg_X , 'cascVI + Avg': norm_cascvi_X_2,
        'MP Oracle': norm_cascvi_X
        }

df4 = correlations(data, 'None', True)
#df4.head(5)
#plt.show()

***3. Gene-Gene correlation (With Rank Normalization)***

In [66]:
#data = {'groundtruth': norm_internal_X, 'cascVI': norm_imputed_X, 'scVI': norm_scvi_X_2, 
#        'Average': norm_avg_X , 'cascVI + Avg': norm_cascvi_X_2,
#        'MP Oracle': norm_cascvi_X
#        }

data = {'groundtruth': internal_X, 'cascVI': imputed_X, 'scVI': internal_scvi_X_2,
        'Average': internal_avg_X , 'cascVI + Avg': internal_cascvi_X_2,
        'MP Oracle': internal_cascvi_X
        }
        
df5 = correlations(data, 'rank', True)
#df5.head(5)
#plt.show()

### III. Table Summary

In [67]:
columns = ["Method", "Spearman CC", "Pearson CC", "Kendall Tau"]
data = [df1, df2, df3, df4, df5]
#data = [df2, df4]

data 
tables = [[] for i in range(len(data))]

#task = ["Sample-Sample (None)", "Sample-Sample (CPM)", "Gene-Gene (None)", 
           #"Gene-Gene(CPM)", "Gene-Gene (Rank)" ]

for (df, t) in zip(data, tables):
    for m in np.unique(df.Method):
        sub_df = np.round(df[df['Method'] == m].mean(), decimals=3)
        t.append([m, sub_df['Spearman CC'], sub_df['Pearson CC'], sub_df['Kendall Tau']])
        
# Create and style Data Frames
df_table1 = pd.DataFrame(tables[0], columns=columns)
df_table2 = pd.DataFrame(tables[1], columns=columns)
df_table3 = pd.DataFrame(tables[2], columns=columns)
df_table4 = pd.DataFrame(tables[3], columns=columns)
df_table5 = pd.DataFrame(tables[4], columns=columns)

In [68]:
print(" >>> Sample-Sample | No Normalization <<<")
df_table1.head(10)

 >>> Sample-Sample | No Normalization <<<


Unnamed: 0,Method,Spearman CC,Pearson CC,Kendall Tau
0,Average,0.721,0.919,0.63
1,MP Oracle,0.776,0.944,0.637
2,cascVI,0.757,0.91,0.616
3,cascVI + Avg,0.752,0.908,0.613
4,scVI,0.749,0.909,0.61


In [69]:
print(">>> Sample-Sample | CPM Normalization <<<")
df_table2.head(10)

>>> Sample-Sample | CPM Normalization <<<


Unnamed: 0,Method,Spearman CC,Pearson CC,Kendall Tau
0,Average,0.721,0.919,0.63
1,MP Oracle,0.776,0.944,0.637
2,cascVI,0.757,0.91,0.616
3,cascVI + Avg,0.752,0.908,0.613
4,scVI,0.749,0.909,0.61


In [70]:
print(">>> Gene-Gene | No Normalization <<<")
df_table3.head(10)

>>> Gene-Gene | No Normalization <<<


Unnamed: 0,Method,Spearman CC,Pearson CC,Kendall Tau
0,Average,0.476,0.583,0.404
1,MP Oracle,0.565,0.677,0.45
2,cascVI,0.477,0.537,0.37
3,cascVI + Avg,0.473,0.531,0.367
4,scVI,0.483,0.533,0.376


In [71]:
print(">>> Gene-Gene | CPM Normalization <<<")
df_table4.head(10)

>>> Gene-Gene | CPM Normalization <<<


Unnamed: 0,Method,Spearman CC,Pearson CC,Kendall Tau
0,Average,0.475,0.555,0.369
1,MP Oracle,0.565,0.65,0.429
2,cascVI,0.522,0.588,0.392
3,cascVI + Avg,0.519,0.585,0.388
4,scVI,0.528,0.595,0.398


In [72]:
print(">>> Gene-Gene | Rank Normalization <<<")
df_table5.head(10)

>>> Gene-Gene | Rank Normalization <<<


Unnamed: 0,Method,Spearman CC,Pearson CC,Kendall Tau
0,Average,0.476,0.476,0.404
1,MP Oracle,0.565,0.565,0.45
2,cascVI,0.477,0.477,0.37
3,cascVI + Avg,0.473,0.473,0.367
4,scVI,0.483,0.483,0.376
