# Metastasis data analysis

This notebook recreates the results fo the table 3 of the paper: "Reconstructing unobserved cellular states from  paired single-cell lineage tracing and transcriptomics data" paper.

# 0. Standard imports

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd

*** Go to top level directory ***

In [1]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi/external


In [13]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi


*** Imports ***

In [14]:
# Tree
from ete3 import Tree
import ete3

# Data
from anndata import AnnData
import matplotlib.pyplot as plt
from external.dataset.tree import TreeDataset, GeneExpressionDataset
from external.dataset.anndataset import AnnDatasetFromAnnData

# Models
from external.models.vae import VAE
import scanpy as sc
from external.inference.tree_inference import TreeTrainer
from external.inference.inference import UnsupervisedTrainer
from external.inference import posterior
from external.models.treevae import TreeVAE

# Utils
from external.utils.data_util import get_leaves, get_internal
from external.utils.plots_util import plot_histograms, plot_density
from external.utils.plots_util import plot_one_gene, training_dashboard
from external.utils.baselines import scvi_baseline_z, cascvi_baseline_z, avg_baseline_z, construct_latent
from external.utils.metrics import knn_purity_tree

# 1. Data Loading

### Import tree

In contrary to the tree used in simulations, this tree is not binary (and therefore requires using the message passing algorithm generalization for multifurcating trees described in the appendix)

The leaves of this tree corresponds to a single clone of 603 cells from a recent dataset that traced the lineages of lung cancer tumors as they metastasized throughout a mouse

In [15]:
tree_name = "data/metastasis/lg7_tree_hybrid_priors.alleleThresh.processed.ultrametric.annotated.tree"

tree = Tree(tree_name, 1)

N = len([n for n in tree.traverse()])
leaves = [n for n in tree.traverse('levelorder') if n.is_leaf()]
print("The tree contains {} nodes and {} leaves".format(N, len(leaves)))

The tree contains 916 nodes and 603 leaves


*Tree preprocessing*: 

As it has (empirically) shown better stability during training, we scale the branch lengths of the tree with the diameter of the tree (which is equal to the maximal distance between the root and its leaves). 

In [16]:
radius = tree.get_leaves()[0].get_distance(tree)
diameter = 2 * radius
print("The diameter of the tree is {}".format(diameter))


branch_length = {}
for node in tree.traverse('levelorder'):
    if node.is_root():
        branch_length[node.name] = 0.0
    else:
        branch_length[node.name] = node.dist / radius
branch_length['prior_root'] = 1.0

The diameter of the tree is 28.0


### Gene expression data

A large fraction of genes detected by scRNA-seq do not have a strong relationship to the lineage, we only considered the top 100 genes autocorrelated with the phylogeny, as evaluated by *Hotspot* (https://github.com/YosefLab/Hotspot) 

We import the preprocessed gene expresion matrix

In [17]:
X = np.load('data/metastasis/Metastasis_lg7_100g.npy')
print("the matrix contains {} cells and {} genes".format(X.shape[0], X.shape[1]))

the matrix contains 603 cells and 100 genes


# 2. Fitting treeVAE

In [18]:
import torch
    
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f850d50dd10>

***Anndata***

In [19]:
# Anndata + gene and cells filtering
adata = AnnData(X)
adata.obs_names = [n.name for n in leaves]
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.filter_cells(adata, min_counts=0)

***Create a TreeDataset object***

In [20]:
# treeVAE
import copy

n_cells, n_genes = X.shape
tree_bis = copy.deepcopy(tree)
scvi_dataset = AnnDatasetFromAnnData(adata, filtering=False)
scvi_dataset.initialize_cell_attribute('barcodes', adata.obs_names)
cas_dataset = TreeDataset(scvi_dataset, tree=tree, filtering=False)
cas_dataset

# No batches beacause of the message passing
use_cuda = True
use_MP = True
ldvae = False

n_counts is a protected attribute or already exists as a cell attribute and cannot be set with this name in initialize_gene_attribute, changing name to n_counts_gene and setting
go


***Initialize model***

In [21]:
treevae = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=10,
              n_hidden=64,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=use_MP
             )

***Training hyperparameters***

In [22]:
n_epochs = 700
lr = 1e-3
lambda_ = 1.0

***Trainer***

In [23]:
freq = 100
trainer = TreeTrainer(
    model = treevae,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

train_leaves:  [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63], [64], [65], [66], [67], [68], [69], [70], [71], [72], [73], [74], [75], [76], [77], [78], [79], [80], [81], [82], [83], [84], [85], [86], [87], [88], [89], [90], [91], [92], [93], [94], [95], [96], [97], [98], [99], [100], [101], [102], [103], [104], [105], [106], [107], [108], [109], [110], [111], [112], [113], [114], [115], [116], [117], [118], [119], [120], [121], [122], [123], [124], [125], [126], [127], [128], [129], [130], [131], [132], [133], [134], [135], [136], [137], [138], [139], [140], [141], [142], [143], [144], [145], [146], [147], [148], [149], [150], [151], [152], [153], [154], [155], [1

***Start training***

In [24]:
trainer.train(n_epochs=n_epochs,
              lr=lr
              )

944880375277
training:  78%|███████▊  | 548/700 [03:30<00:55,  2.75it/s]Encodings MP Likelihood: 9.147760027106251
ELBO Loss: 267.9500091650155
training:  78%|███████▊  | 549/700 [03:31<00:56,  2.69it/s]Encodings MP Likelihood: 9.235904625509509
ELBO Loss: 267.6481670220532
training:  79%|███████▊  | 550/700 [03:31<00:53,  2.78it/s]Encodings MP Likelihood: 9.277956094638759
ELBO Loss: 269.29523804088024
training:  79%|███████▊  | 551/700 [03:31<00:52,  2.84it/s]Encodings MP Likelihood: 9.132675539279857
ELBO Loss: 267.29917594452377
training:  79%|███████▉  | 552/700 [03:32<00:51,  2.89it/s]Encodings MP Likelihood: 9.108848282465495
ELBO Loss: 268.82251500296917
training:  79%|███████▉  | 553/700 [03:32<00:50,  2.92it/s]Encodings MP Likelihood: 9.207169932175836
ELBO Loss: 268.83989428057936
training:  79%|███████▉  | 554/700 [03:32<00:50,  2.91it/s]Encodings MP Likelihood: 9.216706872623845
ELBO Loss: 267.89750792742205
training:  79%|███████▉  | 555/700 [03:33<00:49,  2.92it/s]Encodi

***Training dashboard***

the learning curves will be saved in a temporary file '*training_cascvi.png*'

In [25]:
training_dashboard(trainer, treevae.encoder_variance)

***Training ELBO***

In [26]:
full_posterior = trainer.create_posterior(trainer.model, cas_dataset, trainer.clades,
                                indices=np.arange(len(cas_dataset))
                                         )

print("treeVAE training ELBO: {}".format(full_posterior.compute_elbo().item()))

treeVAE training ELBO: 262.8700116225728


# 3. Fitting scVI

In [29]:
scVI = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=10,
              n_hidden=64,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=False
             )

freq = 100
trainer_scvi = TreeTrainer(
    model = scVI,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

trainer_scvi.train(n_epochs=n_epochs,
              lr=lr
              )

946945877661
Encodings MP Likelihood: 0.0
ELBO Loss: 277.70605286315475
Encodings MP Likelihood: 0.0
ELBO Loss: 276.8577173944936
Encodings MP Likelihood: 0.0
ELBO Loss: 276.9423020031589
Encodings MP Likelihood: 0.0
ELBO Loss: 277.3614496715459
Encodings MP Likelihood: 0.0
ELBO Loss: 276.27454413151384
Encodings MP Likelihood: 0.0
ELBO Loss: 276.99514272145177
training:  58%|█████▊    | 404/700 [00:05<00:04, 69.88it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 276.64412200050685
Encodings MP Likelihood: 0.0
ELBO Loss: 276.5200403695335
Encodings MP Likelihood: 0.0
ELBO Loss: 277.14111592118246
Encodings MP Likelihood: 0.0
ELBO Loss: 276.9975486851867
Encodings MP Likelihood: 0.0
ELBO Loss: 275.62292784403047
Encodings MP Likelihood: 0.0
ELBO Loss: 275.9117249377849
Encodings MP Likelihood: 0.0
ELBO Loss: 276.95446030031576
Encodings MP Likelihood: 0.0
ELBO Loss: 275.2379095965368
training:  59%|█████▉    | 412/700 [00:05<00:04, 69.71it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 275.809551

In [30]:
scvi_posterior = trainer_full.create_posterior(trainer_scvi.model, cas_dataset, trainer_scvi.clades,
                                indices=np.arange(len(cas_dataset))
                                         )

print("scVI - full - batch train ELBO: {}".format(scvi_posterior.compute_elbo().item()))

scVI - full - batch train ELBO: 266.2982126777756


# 5. Average ELBO

We compute the average ELBO on 100 forward passes on the training set.

In [32]:
elbo_treevae  = []
elbo_vae = []
N = 100
for i in range(N):
    elbo_vae.append(scvi_posterior.compute_elbo().item())
    elbo_treevae.append(full_posterior.compute_elbo().item())

treeVAE averaged ELBO:

In [42]:
print("treeVAE Elbo: mean = {} | std = {}".format(np.mean(elbo_treevae), np.std(elbo_treevae)))

treeVAE Elbo: mean = 262.96005592169513 | std = 0.40667664374527984


scVI averaged ELBO:

In [43]:
print("scVI Elbo: mean = {} | std = {}".format(np.mean(elbo_vae), np.std(elbo_vae)))

scVI Elbo: mean = 266.4009421526538 | std = 0.5251558280809274


# 6. Cross-Entropy

In [38]:
treevae_latent = full_posterior.get_latent()
scvi_latent = scvi_posterior.get_latent()

We compute the cross-entropy for scVI's latent space

In [39]:
treevae.initialize_visit()
treevae.initialize_messages(scvi_latent, cas_dataset.barcodes, scvi_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), scvi_latent.shape[1], False)
mp_lik_scvi = treevae.aggregate_messages_into_leaves_likelihood(10, add_prior=True)
print("scVI Cross-Entropy: ", mp_lik_scvi.item())

scVI Cross-Entropy:  -13536.449896980246


We compute the cross-entropy for treeVAE's latent space

In [41]:
treevae.initialize_visit()
treevae.initialize_messages(treevae_latent, cas_dataset.barcodes, treevae_latent.shape[1])
treevae.perform_message_passing((treevae.tree & treevae.root), treevae_latent.shape[1], False)
mp_lik_cascvi = treevae.aggregate_messages_into_leaves_likelihood(10, add_prior=True)
print("Likelihood of cascVI encodings: ", mp_lik_cascvi.item())

Likelihood of cascVI encodings:  -4271.685849988737
