# 0. Imports

In [153]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('WebAgg')
import numpy as np
import pandas as pd
import copy

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [154]:
cd ..

/home/eecs


In [155]:
cd ..

/home


***import ete3 Tree***

In [156]:
from ete3 import Tree

tree_name = "/home/eecs/khalil.ouardini/cas_scvi_topologies/newick_objects/500cells/high_fitness/topology8.nwk"
tree = Tree(tree_name, 1)

for i, n in enumerate(tree.traverse('levelorder')):
    n.add_features(index=i)
    n.name = str(i)

eps = 1e-3
branch_length = {}
for node in tree.traverse('levelorder'):
    if node.is_root():
        branch_length[node.name] = 0.0
    else:
        branch_length[node.name] = node.dist
branch_length['prior_root'] = 1.0

In [157]:
# Data
from anndata import AnnData
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from external.dataset.tree import TreeDataset, GeneExpressionDataset
from external.dataset.poisson_glm import Poisson_GLM
from external.dataset.anndataset import AnnDatasetFromAnnData

# Models
from models.vae import VAE
import scanpy as sc
from external.inference.tree_inference import TreeTrainer
from inference.inference import UnsupervisedTrainer
from scvi.inference import posterior
from external.models.treevae import TreeVAE

# Utils
from external.utils.data_util import get_leaves, get_internal
from external.utils.metrics import ks_pvalue, accuracy_imputation, correlations, knn_purity, knn_purity_stratified
from external.utils.plots_util import plot_histograms, plot_scatter_mean, plot_ecdf_ks, plot_density
from external.utils.plots_util import plot_losses, plot_elbo, plot_common_ancestor, plot_one_gene, training_dashboard
from external.utils.baselines import avg_weighted_baseline, scvi_baseline, scvi_baseline_z, cascvi_baseline_z, avg_baseline_z, construct_latent

In [158]:
import torch
    
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f9fdbeeb250>

# 1. Simulations (Poisson GLM)

In [159]:
d = 10
g = 1000
vis = False
leaves_only = False
var = 1.0
alpha = 1.0

glm = Poisson_GLM(tree, g, d, vis, leaves_only, branch_length, alpha)

glm.simulate_latent()

***Generate gene expression count data***

In [160]:
glm.simulate_ge(negative_binomial=False)
# Quality Control (i.e Gene Filtering)
#glm.gene_qc()

glm.X.shape, glm.W.shape, glm.beta.shape

((1000, 1000), (1000, 10), (1000,))

***Binomial thinning***

In [161]:
print("Proportion of dropouts: {}".format(np.mean(glm.X == 0)))
#glm.binomial_thinning(p=0.1)
print("Proportion of dropouts after Binomial thinning: {}".format(np.mean(glm.X == 0)))

Proportion of dropouts: 0.397263
Proportion of dropouts after Binomial thinning: 0.397263


***Get the data and the indexes at the leaves***

In [189]:
# Latent vectors
leaves_z, _, _ = get_leaves(glm.z, glm.mu, tree)

#FIXED training set
leaves_X, leaves_idx, mu = get_leaves(glm.X, glm.mu, tree)

# internal nodes data (for imputation)
internal_X, internal_idx, internal_mu = get_internal(glm.X, glm.mu, tree)

# Additional data for clade samling
n_leaves_X = glm.generate_ge(n_samples=50, leaves_idx=leaves_idx)

leaves_X.shape, mu.shape, internal_X.shape, internal_mu.shape, leaves_z.shape

((500, 1000), (500, 1000), (500, 1000), (500, 1000), (500, 10))

# 2. Fitting CascVI

In [190]:
gene_dataset = GeneExpressionDataset()

leaves = [n for n in tree.traverse('levelorder') if n.is_leaf()]
cell_names = np.array([[n.name for n in leaves] * len(n_leaves_X)]).flatten()
gene_dataset.populate_from_data(X=np.vstack(n_leaves_X),
                               gene_names=[str(i) for i in range(glm.X.shape[1])])
gene_dataset.initialize_cell_attribute('barcodes', cell_names)

***Create a TreeDataset object***

In [191]:
# treeVAE
import copy

tree_bis = copy.deepcopy(tree)
cas_dataset = TreeDataset(gene_dataset, tree=tree_bis, filtering=False)

# No batches beacause of the message passing
use_cuda = True
use_MP = True
ldvae = False

cas_dataset

GeneExpressionDataset object with n_cells x nb_genes = 25000 x 1000
    gene_attribute_names: 'gene_names'
    cell_attribute_names: 'local_vars', 'barcodes', 'local_means', 'batch_indices', 'labels'
    cell_categorical_attribute_names: 'batch_indices', 'labels'

***Initialize model***

In [192]:
treevae = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=glm.latent,
              n_hidden=128,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=use_MP,
              use_clades=False
             )

***Hyperparameters***

In [193]:
n_epochs = 1000
lr = 1e-3
lambda_ = 1.0

***trainer***

In [194]:
freq = 100
trainer = TreeTrainer(
    model = treevae,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

8951, 20451, 24951, 951, 6451, 18951, 12451, 3451, 11951, 18451, 10951, 9951, 4951, 19951, 23451, 1951, 451, 23951, 22451], [14452, 5952, 5452, 20952, 1452, 13952, 19452, 15952, 11452, 2452, 16952, 17952, 13452, 17452, 9452, 3952, 7452, 22952, 24452, 14952, 7952, 15452, 16452, 8452, 21452, 10452, 21952, 4452, 6952, 12952, 2952, 8952, 20452, 24952, 952, 6452, 18952, 12452, 3452, 11952, 18452, 10952, 9952, 4952, 19952, 23452, 1952, 452, 23952, 22452], [14453, 5953, 5453, 20953, 1453, 13953, 19453, 15953, 11453, 2453, 16953, 17953, 13453, 17453, 9453, 3953, 7453, 22953, 24453, 14953, 7953, 15453, 16453, 8453, 21453, 10453, 21953, 4453, 6953, 12953, 2953, 8953, 20453, 24953, 953, 6453, 18953, 12453, 3453, 11953, 18453, 10953, 9953, 4953, 19953, 23453, 1953, 453, 23953, 22453], [14454, 5954, 5454, 20954, 1454, 13954, 19454, 15954, 11454, 2454, 16954, 17954, 13454, 17454, 9454, 3954, 7454, 22954, 24454, 14954, 7954, 15454, 16454, 8454, 21454, 10454, 21954, 4454, 6954, 12954, 2954, 8954, 2045

***Start training***

In [195]:
trainer.train(n_epochs=n_epochs,
              lr=lr
              )

619392161788492
ELBO Loss: 547.8903127655192
training:  85%|████████▌ | 850/1000 [03:06<00:32,  4.58it/s]Encodings MP Likelihood: 3.1238883796369796
ELBO Loss: 548.6046658091453
training:  85%|████████▌ | 851/1000 [03:06<00:32,  4.55it/s]Encodings MP Likelihood: 3.0160012275014925
ELBO Loss: 547.821594901664
training:  85%|████████▌ | 852/1000 [03:07<00:32,  4.54it/s]Encodings MP Likelihood: 3.159468114825604
ELBO Loss: 547.5561929569711
training:  85%|████████▌ | 853/1000 [03:07<00:32,  4.56it/s]Encodings MP Likelihood: 2.9598585675118576
ELBO Loss: 548.5187769716201
training:  85%|████████▌ | 854/1000 [03:07<00:32,  4.53it/s]Encodings MP Likelihood: 2.846599525361267
ELBO Loss: 546.6036250994849
training:  86%|████████▌ | 855/1000 [03:07<00:31,  4.54it/s]Encodings MP Likelihood: 3.190666396941343
ELBO Loss: 546.0382478930685
training:  86%|████████▌ | 856/1000 [03:07<00:31,  4.50it/s]Encodings MP Likelihood: 3.466221648148704
ELBO Loss: 548.1130543513677
training:  86%|████████▌ | 85

***Loss Functions***

In [196]:
#training_dashboard(trainer, treevae.encoder_variance)

In [197]:
from sklearn.metrics import mean_squared_error

full_posterior = trainer.create_posterior(trainer.model, cas_dataset, trainer.clades,
                                indices=np.arange(len(cas_dataset))
                                         )

full_posterior.compute_elbo(treevae)

computing elbo


tensor(541.6196, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)

***Elbo on held-out log-likelihood***

In [198]:
elbo_cascvi = 0
with torch.no_grad():
    input = torch.from_numpy(leaves_X).float().to(device='cuda:0')
    reconst_loss, qz, mp_lik = treevae.forward(input)
    elbo_cascvi += torch.sum(reconst_loss)
    elbo_cascvi += lambda_ * torch.sum(qz)
    elbo_cascvi -= lambda_ * mp_lik

n_samples = leaves_X.shape[0]
elbo_cascvi /= n_samples

elbo_cascvi

tensor(2384.3051, device='cuda:0', dtype=torch.float64)

# 3. Fitting scVI - full batch

In [199]:
treevae_full = TreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=glm.latent,
              n_hidden=128,
              n_layers=1,
              reconstruction_loss='poisson',
              prior_t = branch_length,
              ldvae = ldvae,
              use_MP=False,
              use_clades=False
             )

In [200]:
freq = 100
trainer_full = TreeTrainer(
    model = treevae_full,
    gene_dataset = cas_dataset,
    lambda_ = lambda_,
    train_size=1.0,
    test_size=0,
    use_cuda=use_cuda,
    frequency=freq,
    n_epochs_kl_warmup=150
)

8951, 20451, 24951, 951, 6451, 18951, 12451, 3451, 11951, 18451, 10951, 9951, 4951, 19951, 23451, 1951, 451, 23951, 22451], [14452, 5952, 5452, 20952, 1452, 13952, 19452, 15952, 11452, 2452, 16952, 17952, 13452, 17452, 9452, 3952, 7452, 22952, 24452, 14952, 7952, 15452, 16452, 8452, 21452, 10452, 21952, 4452, 6952, 12952, 2952, 8952, 20452, 24952, 952, 6452, 18952, 12452, 3452, 11952, 18452, 10952, 9952, 4952, 19952, 23452, 1952, 452, 23952, 22452], [14453, 5953, 5453, 20953, 1453, 13953, 19453, 15953, 11453, 2453, 16953, 17953, 13453, 17453, 9453, 3953, 7453, 22953, 24453, 14953, 7953, 15453, 16453, 8453, 21453, 10453, 21953, 4453, 6953, 12953, 2953, 8953, 20453, 24953, 953, 6453, 18953, 12453, 3453, 11953, 18453, 10953, 9953, 4953, 19953, 23453, 1953, 453, 23953, 22453], [14454, 5954, 5454, 20954, 1454, 13954, 19454, 15954, 11454, 2454, 16954, 17954, 13454, 17454, 9454, 3954, 7454, 22954, 24454, 14954, 7954, 15454, 16454, 8454, 21454, 10454, 21954, 4454, 6954, 12954, 2954, 8954, 2045

In [201]:
trainer_full.train(n_epochs=n_epochs,
              lr=lr
              )

elihood: 0.0
ELBO Loss: 551.904085081295
training:  71%|███████   | 708/1000 [00:14<00:05, 49.74it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 552.8278068298582
Encodings MP Likelihood: 0.0
ELBO Loss: 551.9224001873963
Encodings MP Likelihood: 0.0
ELBO Loss: 552.670422462832
Encodings MP Likelihood: 0.0
ELBO Loss: 552.3407046941792
Encodings MP Likelihood: 0.0
ELBO Loss: 552.529149061548
Encodings MP Likelihood: 0.0
ELBO Loss: 551.7435128710008
training:  71%|███████▏  | 714/1000 [00:14<00:05, 50.47it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 551.7305169586552
Encodings MP Likelihood: 0.0
ELBO Loss: 552.0041177024121
Encodings MP Likelihood: 0.0
ELBO Loss: 551.2963111416195
Encodings MP Likelihood: 0.0
ELBO Loss: 554.1573420084993
Encodings MP Likelihood: 0.0
ELBO Loss: 551.6719670772698
Encodings MP Likelihood: 0.0
ELBO Loss: 551.9197260178236
training:  72%|███████▏  | 720/1000 [00:14<00:05, 51.97it/s]Encodings MP Likelihood: 0.0
ELBO Loss: 552.7957194130665
Encodings MP Likelihood: 0.

In [202]:
full_batch_posterior = trainer.create_posterior(trainer_full.model, cas_dataset, trainer_full.clades,
                                indices=np.arange(len(cas_dataset))
                                         )

full_batch_posterior.compute_elbo(treevae_full)

computing elbo


tensor(547.3132, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)

In [203]:
elbo_full = 0
with torch.no_grad():
    input = torch.from_numpy(leaves_X).float().to(device='cuda:0')
    reconst_loss, qz, _ = treevae_full.forward(input)
    elbo_full += torch.sum(reconst_loss)
    elbo_full += lambda_ * torch.sum(qz)

n_samples = leaves_X.shape[0]
elbo_full /= n_samples

elbo_full

tensor(2213.0098, device='cuda:0', dtype=torch.float64)

# 4. Fitting scVI

### Baseline 2: (Un)weighted Average of decoded latent vectors, with scVI

We use the same averaging of the subtrees leaves in **Baseline 1**, only this time, the gene expression data is recovered with scVI

In [204]:
# anndata
scvi_dataset = GeneExpressionDataset()

scvi_dataset.populate_from_data(X=np.vstack(n_leaves_X),
                               gene_names=[str(i) for i in range(glm.X.shape[1])])
scvi_dataset.initialize_cell_attribute('barcodes', cell_names)

scvi_dataset

GeneExpressionDataset object with n_cells x nb_genes = 25000 x 1000
    gene_attribute_names: 'gene_names'
    cell_attribute_names: 'local_vars', 'barcodes', 'local_means', 'batch_indices', 'labels'
    cell_categorical_attribute_names: 'batch_indices', 'labels'

In [205]:
import torch

n_epochs =600
use_batches = False

vae = VAE(gene_dataset.nb_genes,
                  n_batch=cas_dataset.n_batches * use_batches,
                  n_hidden=128,
                  n_layers=1,
                  reconstruction_loss='poisson',
                  n_latent=glm.latent,
                  ldvae=ldvae
              )

In [206]:
trainer_scvi = UnsupervisedTrainer(model=vae,
                              gene_dataset=gene_dataset,
                              train_size=1.0,
                              use_cuda=use_cuda,
                              frequency=10,
                              n_epochs_kl_warmup=200
                              )

# train scVI
trainer_scvi.train(n_epochs=n_epochs, lr=1e-3) 
                                        
elbo_train_scvi = trainer_scvi.history["elbo_train_set"]
x = np.linspace(0, 100, (len(elbo_train_scvi)))
plt.plot(np.log(elbo_train_scvi), 
         label="train", color='blue',
         linestyle=':',
         linewidth=3
        )
        
plt.xlabel('Epoch')
plt.ylabel("ELBO")
plt.legend()
plt.title("Train history scVI")

training: 100%|██████████| 600/600 [28:16<00:00,  2.83s/it]


Text(0.5, 1.0, 'Train history scVI')

In [207]:
scvi_posterior = trainer_scvi.create_posterior(model=vae,
                                               gene_dataset=scvi_dataset 
                                                )

In [212]:
scvi_posterior.elbo()

545.93428234375

In [216]:
for tensors in scvi_posterior:
    x, local_l_mean, local_l_var, _, _ = tensors
    break

In [227]:
l1 = local_l_mean[0][0].item() * torch.ones((input.shape[0])).to('cuda:0')
l2 = local_l_var[0][0].item() * torch.ones((input.shape[0])).to('cuda:0')

input = torch.from_numpy(leaves_X).float().to(device='cuda:0')
elbo = 0
with torch.no_grad():
    reconst_loss, kl_divergence, kl_divergence_global = vae.forward(input, l1, l2)
    elbo += torch.sum(reconst_loss + kl_divergence).item()
n_samples = leaves_X.shape[0]
elbo /= n_samples

elbo

61485.956

In [219]:
kl_divergence_global

0.0