# 0. Standard imports

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('WebAgg')
import numpy as np
import pandas as pd

In [5]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi/external


In [6]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi


In [7]:
# Data
from anndata import AnnData
import matplotlib.pyplot as plt
from external.dataset.tree import TreeDataset, GeneExpressionDataset
from external.dataset.anndataset import AnnDatasetFromAnnData

# Models
from models.vae import VAE
import scanpy as sc
from external.inference.tree_inference import TreeTrainer
from inference.inference import UnsupervisedTrainer
from external.inference import posterior
from external.models.treevae import TreeVAE

# Utils
from external.utils.data_util import get_leaves, get_internal
from external.utils.plots_util import plot_histograms, plot_density
from external.utils.plots_util import plot_one_gene, training_dashboard
from external.utils.baselines import scvi_baseline_z, cascvi_baseline_z, avg_baseline_z, construct_latent
from external.utils.metrics import knn_purity_tree



### 0. import ete3 Tree

In [26]:
from ete3 import Tree
import ete3

#tree_name = "/data/yosef2/users/mattjones/projects/metastasis/JQ19/5k/trees/lg4lg4_tree_hybrid_priors.alleleThresh.processed.txt"
#tree_name = "/home/eecs/khalil.ouardini/cas_scvi_topologies/Cassiopeia_trees/lg7_branch_length.txt"
tree_name = "/home/eecs/khalil.ouardini/cas_scvi_topologies/Cassiopeia_trees/lg7_tree_hybrid_priors.alleleThresh.processed.ultrametric.tree"

tree = Tree(tree_name, 1)

N = len([n for n in tree.traverse()])

#leaves = [n.name for n in tree.traverse('levelorder')]
leaves = [n for n in tree.traverse('levelorder') if n.is_leaf()]

### 1. Gene expression data

In [27]:
barcodes = pd.read_csv("/data/yosef2/users/mattjones/projects/metastasis/JQ19/5k/RNA/ALL_Samples/GRCh38/barcodes.tsv", names=['Barcodes'])

barcodes.head(5)

Unnamed: 0,Barcodes
0,LL.AAACCTGCAAGAAAGG-1
1,LL.AAACCTGCATGAAGTA-1
2,LL.AAACCTGGTACATGTC-1
3,LL.AAACCTGTCAAGGTAA-1
4,LL.AAACGGGGTCTTGATG-1


In [28]:
genes = pd.read_csv("/data/yosef2/users/mattjones/projects/metastasis/JQ19/5k/RNA/ALL_Samples/GRCh38/genes.tsv", sep='\t', names=['id', 'Gene'])

genes.head(5)

Unnamed: 0,id,Gene
0,ENSG00000243485,RP11-34P13.3
1,ENSG00000237613,FAM138A
2,ENSG00000186092,OR4F5
3,ENSG00000238009,RP11-34P13.7
4,ENSG00000239945,RP11-34P13.8


In [29]:
hotspot_genes = pd.read_csv("/data/yosef2/users/mattjones/projects/metastasis/hotspot_gene_sets/lg7_hotspot_genes.tsv", sep='\t')
hotspot_genes.head(5)

Unnamed: 0,Gene,C,Z,Pval,FDR
0,TFF3,0.174633,42.425806,0.0,0.0
1,TFF1,0.227374,41.599879,0.0,0.0
2,MUC5AC,0.157811,37.94052,0.0,0.0
3,CEACAM6,0.245428,32.255498,1.472918e-228,2.2185829999999997e-225
4,MUC5B,0.0846,24.601674,6.06069e-134,7.303131e-131


Gene expression

In [30]:
#from scipy.io import mmread

#X = mmread('/data/yosef2/users/mattjones/projects/metastasis/JQ19/5k/RNA/ALL_Samples/GRCh38/matrix.mtx')
#Y = X.todense()
#print(Y.shape)

# 1. Data preprocessing

### a. Filter out cells not present in gene expression matrix

In [31]:
#N = 0
#leaves_to_delete = []
#leaves_X = {}
#for barcode in leaves:
#    foo = barcodes.index[barcodes['Barcodes'] == barcode].tolist()
#    if foo == []:
#        N += 1
#        leaves_to_delete.append(barcode)
#        continue
#    else:
#        idx = foo[0]
    #x = np.squeeze(np.array(Y[:, idx]))
    #leaves_X[barcode] = x

In [32]:
#print("Kept {} leaves".format(len(leaves_X)))
#print("{} Cells were not present in the gene expression dataset".format(N))

### b. Prune cells from the tree

In [33]:
#keep_leaves = [n for n in leaves if n not in leaves_to_delete]
#tree.prune(keep_leaves) 

Resolve unifurcations in tree

In [35]:
def collapse_unifurcations(tree: ete3.Tree) -> ete3.Tree:
    """Collapse unifurcations.
    Collapse all unifurcations in the tree, namely any node with only one child
    should be removed and all children should be connected to the parent node.
    Args:
        tree: tree to be collapsed
    Returns:
        A collapsed tree.
    """
    collapse_fn = lambda x: (len(x.children) == 1)
    collapsed_tree = tree.copy()
    to_collapse = [n for n in collapsed_tree.traverse() if collapse_fn(n)]
    for n in to_collapse:
        n.delete()
    return collapsed_tree

import copy
collapsed_tree = copy.deepcopy(tree) #collapse_unifurcations(tree)

### c. Filter out genes not present in hot spot gene set

***--> genes to keep***

In [37]:
keep_genes = []
max = 100
for i, g in enumerate(hotspot_genes['Gene'].values):
    if i > max:
        break
    idx_g = genes.index[genes['Gene'] == g].tolist()[0]
    keep_genes.append(idx_g)

In [38]:
n_genes = len(keep_genes)
n_cells = len(keep_leaves)

print("Dataset has {} genes and {} cells".format(n_genes, n_cells))

Dataset has 101 genes and 663 cells


***--> Rearrange cells in array - in level order***

In [39]:
#X = np.zeros((n_cells, genes.shape[0]))
#idx = 0
#for n in tree.traverse('levelorder'):
#    if n.is_leaf():
#        barcode = n.name
#        X[idx] = leaves_X[barcode]
#        idx += 1
#print("Dropout rate in gene expression matrix: {}".format(np.mean(X == 0)))

***--> Filtering genes***

In [40]:
#X = X[:, keep_genes]
#print("shape; {}".format(X.shape))
#print("Dropout rate in gene expression matrix: {}".format(np.mean(X == 0)))

In [76]:
#
X = np.load('/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi/metastasis_data/Metastasis_lg7_100g.npy')
X.shape

(603, 100)

In [74]:
#np.save('Metastasis_lg7_100g.npy', X)

***--> Historgram***

In [42]:
#plot_histograms(X, 'Metastasis', 'histogram.png')

In [100]:
for i, n in enumerate(collapsed_tree.traverse('levelorder')):
    n.add_features(index=i)
    if not n.is_leaf():
        n.name = str(i)

In [101]:
diameter = collapsed_tree.get_leaves()[0].get_distance(collapsed_tree)
#diameter = 200
diameter

14.0

In [103]:
branch_length = {}
for node in collapsed_tree.traverse('levelorder'):
    if node.is_root():
        branch_length[node.name] = 0.0
    else:
        branch_length[node.name] = node.dist / diameter
branch_length['prior_root'] = 1.0

In [104]:
import torch
    
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f121767d3d0>

# 2. Fitting CascVI

In [82]:
import scanpy as sc

# anndata + gene and celle filtering
adata = AnnData(X)
leaves = [n for n in collapsed_tree.traverse('levelorder') if n.is_leaf()]
adata.obs_names = [n.name for n in leaves]
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.filter_cells(adata, min_counts=0)

***Create a TreeDataset object***

In [83]:
# treeVAE
import copy

n_cells, n_genes = X.shape
tree_bis = copy.deepcopy(collapsed_tree)
scvi_dataset = AnnDatasetFromAnnData(adata, filtering=False)
scvi_dataset.initialize_cell_attribute('barcodes', adata.obs_names)
cas_dataset = TreeDataset(scvi_dataset, tree=tree_bis, filtering=False)
cas_dataset

# No batches beacause of the message passing
use_cuda = True
use_MP = True
ldvae = False

n_counts is a protected attribute or already exists as a cell attribute and cannot be set with this name in initialize_gene_attribute, changing name to n_counts_gene and setting
go


***Initialize model***

In [84]:
treevae = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=10,
              n_hidden=64,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=use_MP
             )

***Hyperparameters***

In [85]:
n_epochs = 700
lr = 1e-3
lambda_ = 1.0

***trainer***

In [86]:
freq = 100
trainer = TreeTrainer(
    model = treevae,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

train_leaves:  [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63], [64], [65], [66], [67], [68], [69], [70], [71], [72], [73], [74], [75], [76], [77], [78], [79], [80], [81], [82], [83], [84], [85], [86], [87], [88], [89], [90], [91], [92], [93], [94], [95], [96], [97], [98], [99], [100], [101], [102], [103], [104], [105], [106], [107], [108], [109], [110], [111], [112], [113], [114], [115], [116], [117], [118], [119], [120], [121], [122], [123], [124], [125], [126], [127], [128], [129], [130], [131], [132], [133], [134], [135], [136], [137], [138], [139], [140], [141], [142], [143], [144], [145], [146], [147], [148], [149], [150], [151], [152], [153], [154], [155], [1

***Start training***

In [87]:
trainer.train(n_epochs=n_epochs,
              lr=lr
              )

944880375277
training:  78%|███████▊  | 548/700 [04:51<01:30,  1.67it/s]Encodings MP Likelihood: 9.147760027106251
ELBO Loss: 267.9500091650155
training:  78%|███████▊  | 549/700 [04:51<01:19,  1.90it/s]Encodings MP Likelihood: 9.235904625509509
ELBO Loss: 267.6481670220532
training:  79%|███████▊  | 550/700 [04:52<01:11,  2.10it/s]Encodings MP Likelihood: 9.277956094638759
ELBO Loss: 269.29523804088024
training:  79%|███████▊  | 551/700 [04:53<01:25,  1.75it/s]Encodings MP Likelihood: 9.132675539279857
ELBO Loss: 267.29917594452377
training:  79%|███████▉  | 552/700 [04:53<01:14,  1.98it/s]Encodings MP Likelihood: 9.108848282465495
ELBO Loss: 268.82251500296917
training:  79%|███████▉  | 553/700 [04:53<01:07,  2.17it/s]Encodings MP Likelihood: 9.207169932175836
ELBO Loss: 268.83989428057936
training:  79%|███████▉  | 554/700 [04:54<01:22,  1.76it/s]Encodings MP Likelihood: 9.216706872623845
ELBO Loss: 267.89750792742205
training:  79%|███████▉  | 555/700 [04:55<01:13,  1.98it/s]Encodi

***Loss Functions***

In [88]:
#training_dashboard(trainer, treevae.encoder_variance)

***Training ELBO***

In [93]:
full_posterior = trainer.create_posterior(trainer.model, cas_dataset, trainer.clades,
                                indices=np.arange(len(cas_dataset))
                                         )

print("CascVI train ELBO: {}".format(full_posterior.compute_elbo().item()))

CascVI train ELBO: 262.59528856120596


### 3. Posterior and MV imputation

***Missing Value imputation By Posterior Predictive sampling***

In [123]:
empirical_l = np.mean(np.sum(X, axis=1))

# CascVI imputations
imputed = {}
imputed_z = {}

imputed_mcmc_cov = {}

for n in collapsed_tree.traverse('levelorder'):
    if not n.is_leaf():
        imputed[n.name], imputed_z[n.name]  = full_posterior.imputation_internal(n,
                                                            give_mean=False,
                                                            library_size=empirical_l
                                                           )
        _, imputed_mcmc_cov[n.name] = full_posterior.mcmc_estimate(query_node=n,
                                                                    n_samples=20
                                                                    )

In [128]:
imputed_X = [x[0] for x in imputed.values()]
imputed_X = np.array(imputed_X).reshape(-1, cas_dataset.X.shape[1])
#plot_histograms(imputed_X, "Histogram of CasscVI imputed gene expression data", 'histogram.png')

ValueError: cannot reshape array of size 313 into shape (100)

# 3. Fitting scVI - full batch update

In [107]:
treevae_full = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=10,
              n_hidden=64,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=False
             )

freq = 100
trainer_full = TreeTrainer(
    model = treevae_full,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

trainer_full.train(n_epochs=n_epochs,
              lr=lr
              )

198866184
Encodings MP Likelihood: 0.0
ELBO Loss: 280.8707211289305
Encodings MP Likelihood: 0.0
ELBO Loss: 279.2635907919513
training:  58%|█████▊    | 404/700 [00:06<00:05, 58.34it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 278.66221619907407
Encodings MP Likelihood: 0.0
ELBO Loss: 279.50475999816143
Encodings MP Likelihood: 0.0
ELBO Loss: 278.7267065054512
Encodings MP Likelihood: 0.0
ELBO Loss: 279.56555266147535
Encodings MP Likelihood: 0.0
ELBO Loss: 279.9251085059803
Encodings MP Likelihood: 0.0
ELBO Loss: 278.4613978120728
training:  59%|█████▊    | 410/700 [00:06<00:04, 58.77it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 280.3260823515753
Encodings MP Likelihood: 0.0
ELBO Loss: 279.12120274606576
Encodings MP Likelihood: 0.0
ELBO Loss: 279.16508523158035
Encodings MP Likelihood: 0.0
ELBO Loss: 278.6762903453123
Encodings MP Likelihood: 0.0
ELBO Loss: 279.64854628952804
Encodings MP Likelihood: 0.0
ELBO Loss: 278.8934236446005
training:  59%|█████▉    | 416/700 [00:06<00:05, 55.51

In [109]:
full_batch_posterior = trainer_full.create_posterior(trainer_full.model, cas_dataset, trainer_full.clades,
                                indices=np.arange(len(cas_dataset))
                                         )

print("scVI - full - batch train ELBO: {}".format(full_batch_posterior.compute_elbo().item()))

scVI - full - batch train ELBO: 266.6711219475009


In [110]:
elbo_treevae  = []
elbo_vae = []
N = 100
for i in range(N):
    elbo_vae.append(full_batch_posterior.compute_elbo().item())
    elbo_treevae.append(full_posterior.compute_elbo().item())

In [111]:
print("CascVI Elbo: mean = {} | std = {}".format(np.mean(elbo_treevae), np.std(elbo_treevae)))

CascVI Elbo: mean = 262.9916118672789 | std = 0.3982202545057567


In [112]:
print("scVI Elbo: mean = {} | std = {}".format(np.mean(elbo_vae), np.std(elbo_vae)))

scVI Elbo: mean = 266.3793400925319 | std = 0.4087974732964147


# 4. Fitting scVI

In [113]:
# anndata
gene_dataset = GeneExpressionDataset()
gene_dataset.populate_from_data(X)

In [114]:
import torch

n_epochs = 700
use_batches = False

vae = VAE(gene_dataset.nb_genes,
                  n_batch=cas_dataset.n_batches * use_batches,
                  n_hidden=64,
                  n_layers=1,
                  reconstruction_loss='poisson',
                  n_latent=10,
                  ldvae=ldvae
              )

In [115]:
trainer_scvi = UnsupervisedTrainer(model=vae,
                              gene_dataset=gene_dataset,
                              train_size=1.0,
                              use_cuda=use_cuda,
                              frequency=10,
                              n_epochs_kl_warmup=None
                              )

# train scVI
trainer_scvi.train(n_epochs=n_epochs, lr=1e-3, use_cuda=use_cuda,) 
                                        
elbo_train_scvi = trainer_scvi.history["elbo_train_set"]
x = np.linspace(0, 100, (len(elbo_train_scvi)))
plt.plot(np.log(elbo_train_scvi), 
         label="train", color='blue',
         linestyle=':',
         linewidth=3
        )
        
plt.xlabel('Epoch')
plt.ylabel("ELBO")
plt.legend()
plt.title("Train history scVI")
plt.show()
plt.savefig('training_scvi.png')

training: 100%|██████████| 700/700 [01:05<00:00, 10.67it/s]


***Training elbo***

In [116]:
scvi_posterior = trainer_scvi.create_posterior(model=vae,
                                               gene_dataset=gene_dataset 
                                                )

scvi_posterior.elbo()

635.9727080742123

***scVI Baseline 2 (Decoded Average Latent space)***

In [146]:
library_size = np.mean(np.sum(X, axis=1))
scvi_latent = np.array([scvi_posterior.get_latent(give_mean=False)[0] for i in range(10)])

imputed_scvi_2, imputed_scvi_2_z = scvi_baseline_z(collapsed_tree,
                                        posterior=scvi_posterior,
                                        model=vae,
                                        weighted=False,
                                        n_samples_z=1,
                                        library_size=library_size,
                                        use_cuda=True)


# 5. Likelihood Ratio

In [118]:
cascvi_latent = full_posterior.get_latent()
scvi_latent = scvi_posterior.get_latent()[0]
scvi_full_latent = full_batch_posterior.get_latent()

scvi_latent.shape, cascvi_latent.shape

((603, 10), (603, 10))

In [119]:
treevae.initialize_visit()
treevae.initialize_messages(scvi_latent, cas_dataset.barcodes, scvi_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), scvi_latent.shape[1], False)
mp_lik_scvi = treevae.aggregate_messages_into_leaves_likelihood(10, add_prior=True)
print("Likelihood of scVI encodings: ", mp_lik_scvi.item())

Likelihood of scVI encodings:  -17542.425066081545


In [120]:
treevae.initialize_visit()
treevae.initialize_messages(cascvi_latent, cas_dataset.barcodes, cascvi_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), cascvi_latent.shape[1], False)
mp_lik_cascvi = treevae.aggregate_messages_into_leaves_likelihood(10, add_prior=True)
print("Likelihood of cascVI encodings: ", mp_lik_cascvi.item())

Likelihood of cascVI encodings:  -4268.115672534094


In [121]:
treevae_full.initialize_visit()
treevae_full.initialize_messages(scvi_full_latent, cas_dataset.barcodes, cascvi_latent.shape[1])
treevae_full.perform_message_passing((treevae_full.tree & treevae_full.root), cascvi_latent.shape[1], False)
mp_lik_full = treevae_full.aggregate_messages_into_leaves_likelihood(10, add_prior=True)
print("Likelihood of cascVI encodings: ", mp_lik_full.item())

Likelihood of cascVI encodings:  -16072.930023850036


# 6. Output files

***Uncertainty***

In [129]:
qz_v = full_posterior.empirical_qz_v(n_samples=100,
                                    norm=True
                                    )

In [164]:
output = []
columns = ['node ID', 'variance', 'depth']
idx = 0
for n in collapsed_tree.traverse('levelorder'):
    depth = n.get_distance(collapsed_tree)
    if not n.is_leaf():
        output.append([n.name, np.linalg.norm(imputed_mcmc_cov[n.name]), depth])
    else:
        output.append([n.name, qz_v[idx], depth])
        idx += 1

df_uncertainty = pd.DataFrame(data=output, columns=columns)
df_uncertainty.head(50)

Unnamed: 0,node ID,variance,depth
0,0,0.205366,0.0
1,1,0.103199,12.0
2,2,0.205366,6.0
3,3,0.051055,6.0
4,4,0.187519,1.0
5,5,0.086763,13.0
6,RE.CTAGAGTCAGTAAGAT-1,1.2e-05,14.0
7,7,0.201125,11.0
8,8,0.233221,7.0
9,9,0.06412,10.0


In [165]:
df_uncertainty.to_csv('uncertainty_corrected.csv')

***Imputations***

In [143]:
n_nodes = len([n for n in collapsed_tree.traverse()])
n_nodes, n_genes

(916, 100)

In [155]:
imputed_X = np.zeros((n_nodes, n_genes))
imputed_avg = np.zeros((n_nodes, n_genes))

idx = 0
for i, n in enumerate(collapsed_tree.traverse('levelorder')):
    if not n.is_leaf():
        imputed_X[i] = imputed[n.name].astype(int)
        imputed_avg[i] = imputed_scvi_2[n.name].astype(int)
    else:
        imputed_X[i] = X[idx]
        imputed_avg[i] = X[idx]     
        idx += 1

In [158]:
top_100_genes = hotspot_genes['Gene'].values[:100]

df_imputed = pd.DataFrame(data=imputed_X, 
                          index=[n.name for n in collapsed_tree.traverse('levelorder')],
                         columns=top_100_genes
                         )

df_imputed_avg = pd.DataFrame(data=imputed_avg, 
                          index=[n.name for n in collapsed_tree.traverse('levelorder')],
                         columns=top_100_genes
                         )

In [161]:
df_imputed.to_csv('imputations_metastasis_treeVAE.txt')
df_imputed_avg.to_csv('imputations_metastasis_VAE.txt')