# 0. Standard imports

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import copy
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('WebAgg')
import numpy as np
import pandas as pd

In [2]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi/external


In [3]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi


In [4]:
%reload_ext autoreload
%matplotlib inline

***import ete3 Tree***

In [10]:
from ete3 import Tree

tree_name = "/home/eecs/khalil.ouardini/cas_scvi_topologies/newick_objects/100cells/no_fitness/topology7.nwk"
tree = Tree(tree_name, 1)

#tree = Tree()
#tree.populate(30)

leaves = tree.get_leaves()

for i, n in enumerate(tree.traverse('levelorder')):
    n.add_features(index=i)
    if not n.is_leaf():
        n.name = str(i)

In [11]:
# Data
from anndata import AnnData
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from external.dataset.tree import TreeDataset, GeneExpressionDataset
from external.dataset.ppca import PPCA
from external.dataset.anndataset import AnnDatasetFromAnnData

# Models
import scanpy as sc
from external.inference.gaussian_inference import GaussianTrainer
from external.inference.gaussian_tree_inference import GaussianTreeTrainer
from external.inference.gaussian_tree_inference import GaussianTreePosterior
from inference import posterior
from external.models.treevae import TreeVAE
from external.models.gaussian_vae import GaussianVAE
from external.models.gaussian_treevae import GaussianTreeVAE

# Utils
from external.utils.data_util import get_leaves, get_internal
from external.utils.metrics import ks_pvalue, accuracy_imputation, correlations, mse, knn_purity, knn_purity_stratified
from external.utils.plots_util import plot_histograms, plot_scatter_mean, plot_ecdf_ks, plot_density, plot_embedding
from external.utils.plots_util import plot_losses, plot_elbo, plot_common_ancestor, plot_one_gene, training_dashboard
from external.utils.baselines import avg_weighted_baseline, scvi_baseline, scvi_baseline_z, cascvi_baseline_z, avg_baseline_z

# 1. Simulations (Gaussian Likelihood model)

We assume that the latent variables $z \in \mathbb{R}^{N \times D}$ are gaussian (correlated). A phylogenetic tree $\tau$ (with $N$ nodes) encodes the covariance $\Sigma$ of $z$. 

$$\mathbf{z}=(z_1, ..., z_N) \sim \mathcal{N}(0, \Sigma)$$

$z$ is partitionned into two groups:

- the leaves $\mathcal{L} = {1, ..., L}$
- the internal nodes $\mathcal{I} = {L + 1, ..., N}$

***

***We describe the generative model***:

Consider a dataset of $ X={x_n}_{n=1}^{L} $ (also partitioned such that $1, ..., N = \mathcal{L} \bigcup \mathcal{I}$) such that $x_n \in \mathbb{R}^{P}$. We aim to represent each $x_n$ under a latent variable $z_n \in \mathbb{R}^{D}$ with  with $D << P$ lower dimension. 
We only observe data at the leaves. the generative model is defined $\forall n \in \mathcal{L}$

The set of principal axes $W$ relates the latent variables to the data.

The corresponding data point is generated via a projection:

$$
\forall n \in \mathcal{L}, x_n =  W z_n + e_n
$$

with $W \in \mathbb{R}^{P x D}$ and $e_n \sim \mathcal{N}(0, \sigma^2 I_P)$. Thus:


$$
\forall n \in \mathcal{L},  x_n | z_n \sim \mathcal{N}(W z_n, \sigma^2 I_P)
$$

After marginalization

$$
\forall n \in \mathcal{L}, x_n \sim \mathcal{N}(0, W^T W + \sigma^2 I_P)
$$

The posterior $p(z_n|x_n)$ for each $n$ is also ***tractable***, indeed

$\begin{pmatrix} x_n \\ z_n \end{pmatrix} = \begin{pmatrix} W z_n + e \\ z_n \end{pmatrix}$ is a gaussian vector (because for $a \in \mathbb{R}$, $b \in \mathbb{R}$, $a(W z_n + e) + bz_n$ is still gaussian) such that:

$$
\begin{pmatrix} x_n \\ z_n \end{pmatrix} \sim \mathcal{N}(\begin{pmatrix} 0 \\ 0 \end{pmatrix}, \begin{pmatrix} W^T W + \sigma^2 I_P  & W\Sigma_n \\ (W\Sigma_n)^T & \Sigma_n \end{pmatrix})
$$

where $\Sigma_n$ is the marginalized covariance $\Sigma$ of $z_n$

Therefore we can use the conditioning formula to infer the mean and the covariance of the (gaussian) posterior $p(z_n|x_n)$:

$$
\mu_{z_n|x_n} = (W\Sigma_{n}(W^{T} W + \sigma^{2} I_{P})^{-1}\Sigma_{n}^{T}W^{T}) x_{n} \\
\Sigma_{z_n|x_n} = \Sigma_n - W\Sigma_{n}(W^{T} W + \sigma^{2} I_{P})^{-1}\Sigma_{n}^{T}W^{T}
$$

***

***Imputation at internal nodes***

Let $j \in \mathcal{I}$, and $X_{\mathcal{L}} = {x_1, ... x_L}$ the set of leaves.
We want to infer $p(x_j|X_{\mathcal{L}})$. If we consider that the data at the internal nodes is "seen" and that the generative model is also known $\forall n \in \mathcal{I}$, we could easily (and accurately) compute $p(x_j|X_{\mathcal{L}})$ by using the gaussian conditioning formula on the gaussian vector:

$$
\begin{pmatrix} x_j \\ X_{\mathcal{L}} \end{pmatrix}
$$

In the case of unseen data at the internal nodes, one can estimate the posterior predictive density:

1. $$
p(x_j|X_{\mathcal{L}}) = p(x_j|x_1, ..., x_L) = \int p(x_j|z_j)p(z_j|z_1,...,z_L)\prod_{i=1}^{L}p(z_i|x_i)(dz_j,dz_1,...,dz_L)
$$

Therefore:
$$
p(x_j|x_1, ..., x_L) \approx  p(x_j|z_j)p(z_j|z_1,...,z_L)\prod_{i=1}^{L}p(z_i|x_i)
$$

$$
p(x_j|x_1, ..., x_L) \approx  \mathcal{N}(x_j|Wz_j, \sigma^2I_P)  \mathcal{N}(z_j|\mu_{j|\mathcal{I}}, \Sigma_{j|\mathcal{I}}) \prod_{i=1}^{L} \mathcal{N}(z_i|\mu_{z_i|x_i}, \Sigma_{z_i|x_i})
$$

2. $ p(x_j|X_{\mathcal{L}}) = Wp(z_j|X_{\mathcal{L}}) + p(e_j)$

In [12]:
print(tree)


                  /-c214
               /-|
              |  |   /-c218
              |   \-|
              |      \-c219
              |
            /-|   /-c56
           |  |  |
           |  |  |            /-c178
           |  |  |         /-|
           |  |  |        |  |   /-c186
           |  |  |      /-|   \-|
           |   \-|     |  |      \-c187
           |     |   /-|  |
           |     |  |  |   \-c161
           |     |  |  |
           |     |  |   \-c141
           |     |  |
           |      \-|      /-c146
           |        |   /-|
           |        |  |   \-c147
           |        |  |
         /-|        |  |      /-c172
        |  |         \-|   /-|
        |  |           |  |  |   /-c190
        |  |           |  |   \-|
        |  |           |  |     |   /-c208
        |  |            \-|      \-|
        |  |              |         \-c209
        |  |              |
        |  |              |   /-c210
        |  |               \-|
        |  |  

***Branch Length***

In [13]:
eps = 1e-3
branch_length = {}
for node in tree.traverse('levelorder'):
    if node.name == '0':
        branch_length[node.name] = 0.1
        continue
    branch_length[node.name] = node.dist
branch_length['prior_root'] = 1.0


In [14]:
import torch
    
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7ff3544a9dd0>

In [15]:
d = 5
p = 100
vis = True
leaves_only = False
var = 1.0
sigma_scale = 2.0

#ppca = PPCA(tree, p, d, vis, leaves_only, var, sigma_scale)
ppca = PPCA(tree=tree, 
            dim=p, 
            latent=d, 
            vis=vis, 
            only=leaves_only,
            branch_length=branch_length, 
            sigma_scale=sigma_scale
            )

ppca.simulate_latent()

***Marginalization***

In [16]:
ppca.simulate_normal()
ppca.W.shape

(100, 5)

In [17]:
lik_tree = ppca.likelihood_obs(leaves_only=False)
lik_leaves = ppca.likelihood_obs(leaves_only=True)

print("Log-Likelihood of the tree {}".format(lik_tree))
print("LogLikelihood of the leaves {}".format(lik_leaves))

Log-Likelihood of the tree -37571.643014424335
LogLikelihood of the leaves -18761.67733934441


***Get data***

In [18]:
# Latent vectors
leaves_z, _, _ = get_leaves(ppca.z, ppca.mu, tree)

#FIXED training set
leaves_X, leaves_idx, mu = get_leaves(ppca.X, ppca.mu, tree)

# internal nodes data (for imputation)
internal_X, internal_idx, internal_mu = get_internal(ppca.X, ppca.mu, tree)

# internal nodes z
internal_z, _, _ = get_internal(ppca.z, ppca.mu, tree)

leaves_X.shape, mu.shape, internal_X.shape, internal_mu.shape, leaves_z.shape

((100, 100), (100, 100), (100, 100), (100, 100), (100, 5))

***Posterior Distributions***

***evidence***

In [19]:
evidence_leaves = ppca.get_evidence_leaves_levelorder(X=ppca.X, dim=ppca.dim)
evidence_leaves.shape

(10000,)

***Leaves covariance***

In [20]:
import time

t = time.time()
ppca.compute_leaves_covariance()

print('Data covariance computation + inversion took {}'.format(time.time() - t))

Data covariance computation + inversion took 73.86596417427063


***Posterior mean and covariance***

In [21]:
posterior_mean, posterior_cov = ppca.compute_posterior()

***Posterior predictive density***

In [22]:
#predictive_mean, predictive_cov = ppca.compute_posterior_predictive()

# Preliminary: Baselines

## Baseline 1: Unweighted Average of gene expression in Clade

The simple idea here is to impute the value of an internal node, with the (un)weighted average of the gene expression values of the leaves, taking the query internal node as the root of the subtree.

In [23]:
imputed_avg = avg_weighted_baseline(tree=tree, 
                                    weighted=False, 
                                    X=ppca.X,
                                    rounding=False
                                   )

#get internal nodes
avg_X = np.array([x for x in imputed_avg.values()]).reshape(-1, ppca.X.shape[1])
internal_avg_X, _, _ = get_internal(avg_X, ppca.mu, tree)

## Baseline 2: (groundtruth) posterior predictive density

In [24]:
#imputed_ppca = {}
#for n in tree.traverse('levelorder'):
#    if not n.is_leaf():
#        samples = np.array([np.random.multivariate_normal(mean=predictive_mean[n.name],
#                                                            cov=predictive_cov[n.name])
#                           for i in range(20)])
#        imputed_ppca[n.name] = np.mean(samples, axis=0)

#internal_ppca_X = np.array([x for x in imputed_ppca.values()]).reshape(-1, ppca.X.shape[1])

## Baseline 3: Approximation through Message Passing (Oracle)


i.e, 

1. sample from $z_1, ..., z_n \sim p(z_1, ..., z_n|x_1, ..., x_n)$ (conditionning formula)
2. impute $z_i \sim p(z_i | z_1, ..., z_n)$ (Message Passing)
3. Decode $p(x_i|z_i) = W z_i + \sigma^2 I_P$ (Generative model)

In [25]:
posterior_mean_corr, posterior_cov_corr = ppca.compute_correlated_posterior()

In [36]:
posterior_cov_corr.shape

(500, 500)

In [26]:
imputed_mp, imputed_z_mp, predictive_mean_z, predictive_cov_z  = ppca.compute_approx_posterior_predictive(iid=False, use_MP=True, sample_size=200)

  cov=posterior_cov).reshape(-1, self.latent) for i in range(sample_size)])
go
[2021-05-09 23:25:17,802] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:25:17,804] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-09 23:25:17,817] INFO - scvi.dataset.dataset | Merging datasets. Input objects are modified in place.
[2021-05-09 23:25:17,819] INFO - scvi.dataset.dataset | Gene names and cell measurement names are assumed to have a non-null intersection between datasets.
[2021-05-09 23:25:17,821] INFO - scvi.dataset.dataset | Keeping 100 genes
[2021-05-09 23:25:17,825] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2021-05-09 23:25:17,829] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:25:17,831] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-09 23:25:17,835] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:25:17,837] INFO - scvi.dataset.dat

In [27]:
imputed_X = np.array([x for x in imputed_mp.values()]).reshape(-1, ppca.X.shape[1])

## Baseline 4: Approximation through Message Passing + iid posteriors

i.e, 

1. sample from marginal conditional $z_l \sim p(z_l|x_1) \forall l \in (1, ...,L)$ (conditionning formula)
2. impute $z_i \sim p(z_i | z_1, ..., z_n)$ (Message Passing)
3. Decode $p(x_i|z_i) = W z_i + \sigma^2 I_P$ (Generative model)

In [72]:
imputed_mp2, imputed_z_mp2, predictive_mean_z2, predictive_cov_z2  = ppca.compute_approx_posterior_predictive(iid=True, use_MP=True, sample_size=200)
imputed_X2 = np.array([x for x in imputed_mp2.values()]).reshape(-1, ppca.X.shape[1])

  cov=posterior_cov[k]) for i in range(sample_size)])
go
[2021-05-09 23:32:16,736] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:32:16,737] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-09 23:32:16,743] INFO - scvi.dataset.dataset | Merging datasets. Input objects are modified in place.
[2021-05-09 23:32:16,744] INFO - scvi.dataset.dataset | Gene names and cell measurement names are assumed to have a non-null intersection between datasets.
[2021-05-09 23:32:16,745] INFO - scvi.dataset.dataset | Keeping 100 genes
[2021-05-09 23:32:16,747] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2021-05-09 23:32:16,751] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:32:16,752] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-09 23:32:16,755] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:32:16,757] INFO - scvi.dataset.dataset | Remapping label

## Baseline 5: Gaussian VAE decoded averaged latent space

In [29]:
# anndata
gene_dataset = GeneExpressionDataset()
gene_dataset.populate_from_data(leaves_X)

[2021-05-09 23:25:34,315] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:25:34,316] INFO - scvi.dataset.dataset | Remapping labels to [0,N]


In [30]:
n_epochs = 400

vae = GaussianVAE(gene_dataset.nb_genes,
                  n_hidden=64,
                  n_layers=1,
                  n_latent=ppca.latent,
                  sigma_ldvae=None
              )

#new_weight = torch.from_numpy(ppca.W).float()

#with torch.no_grad():
    #vae.decoder.factor_regressor.fc_layers[0][0].weight = torch.nn.Parameter(new_weight)
    
#for param in vae.decoder.factor_regressor.fc_layers[0][0].parameters():
#    param.requires_grad = False
    
#vae.decoder

In [31]:
cuda_z = torch.from_numpy(leaves_z).float().to('cuda:0')
p_m, p_v = vae.decoder.forward( torch.from_numpy(leaves_z).float())
p_m = p_m.detach().cpu().numpy()

from sklearn.metrics import mean_squared_error

mse = mean_squared_error(p_m, mu)
print("the distance is {}".format(mse))

the distance is 8.438525321796481


In [32]:
use_cuda = True

trainer = GaussianTrainer(model=vae,
                          gene_dataset=gene_dataset,
                          train_size=1.0,
                          use_cuda=use_cuda,
                          frequency=10,
                          n_epochs_kl_warmup=None,
                         )

In [33]:
# train VAE
trainer.train(n_epochs=n_epochs, lr=1e-2) 

computing elbo
ELBO: 63088.1171875
computing elbo
ELBO: 62904.734375
computing elbo
ELBO: 63180.953125
training:   2%|▏         | 9/400 [00:00<00:04, 88.79it/s]computing elbo
ELBO: 43775.8671875
computing elbo
ELBO: 43905.18359375
computing elbo
ELBO: 43930.2578125
training:   4%|▍         | 18/400 [00:00<00:04, 78.64it/s]computing elbo
ELBO: 35067.76953125
computing elbo
ELBO: 35050.078125
computing elbo
ELBO: 35004.34375
training:   6%|▋         | 26/400 [00:00<00:05, 68.83it/s]computing elbo
ELBO: 29081.15234375
computing elbo
ELBO: 29130.7734375
computing elbo
ELBO: 29184.95703125
training:   8%|▊         | 34/400 [00:00<00:05, 71.59it/s]computing elbo
ELBO: 27214.595703125
computing elbo
ELBO: 27207.9921875
computing elbo
ELBO: 27186.078125
training:  10%|█         | 42/400 [00:00<00:05, 67.30it/s]computing elbo
ELBO: 23923.00390625
computing elbo
ELBO: 23949.57421875
computing elbo
ELBO: 23939.91015625
training:  12%|█▎        | 50/400 [00:00<00:05, 69.75it/s]computing elbo
ELBO:

In [34]:
elbo_train = trainer.history["elbo_train_set"]
x = np.linspace(0, 100, (len(elbo_train)))
plt.plot(np.log(elbo_train), 
         label="train", color='blue',
         linestyle=':',
         linewidth=3
        )
        
plt.xlabel('Epoch')
plt.ylabel("ELBO")
plt.legend()
plt.title("Train history Gaussian VAE")
plt.show()

In [37]:
from sklearn.metrics import mean_squared_error

posterior =  trainer.create_posterior(model=vae,
                                      gene_dataset=gene_dataset
                                      )
                                      
qz_m, qz_v = posterior.get_latent(give_mean=True, give_cov=True)
mean_squared_error(qz_m, leaves_z)

3.105403640113538

In [38]:
imputed_avg_vae, imputed_avg_z, imputed_avg_cov_z = avg_baseline_z(tree=tree,
                                 model=vae,
                                 posterior=posterior,
                                 weighted=False,
                                 n_samples_z=1,
                                 gaussian=True,
                                 use_cuda=True
                                )

internal_vae_X = np.array([x for x in imputed_avg_vae.values()]).reshape(-1, ppca.X.shape[1])
internal_vae_X.shape

(100, 100)

# 3. Our Model: CascVI

In [39]:
import scanpy as sc

adata = AnnData(leaves_X)
adata.obs_names = [n.name for n in tree.traverse('levelorder') if n.is_leaf()]
scvi_dataset = AnnDatasetFromAnnData(adata, filtering=False)
scvi_dataset.initialize_cell_attribute('barcodes', adata.obs_names)

#TreeDataset
cas_dataset = TreeDataset(scvi_dataset, tree=tree, filtering=False)
cas_dataset

go
[2021-05-09 23:26:35,715] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:26:35,717] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-09 23:26:35,734] INFO - scvi.dataset.dataset | Merging datasets. Input objects are modified in place.
[2021-05-09 23:26:35,735] INFO - scvi.dataset.dataset | Gene names and cell measurement names are assumed to have a non-null intersection between datasets.
[2021-05-09 23:26:35,737] INFO - scvi.dataset.dataset | Keeping 100 genes
[2021-05-09 23:26:35,740] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2021-05-09 23:26:35,745] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:26:35,747] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-09 23:26:35,750] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-09 23:26:35,751] INFO - scvi.dataset.dataset | Remapping labels to [0,N]


GeneExpressionDataset object with n_cells x nb_genes = 100 x 100
    gene_attribute_names: 'gene_names'
    cell_attribute_names: 'labels', 'local_means', 'batch_indices', 'barcodes', 'local_vars'
    cell_categorical_attribute_names: 'batch_indices', 'labels'

In [40]:
use_cuda = True
use_MP = True

treevae = GaussianTreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=ppca.latent,
              n_hidden=64,
              n_layers=1,
              prior_t = branch_length,
              use_MP=use_MP,
              sigma_ldvae=None
             )

***Freezing the decoder***

In [41]:
#new_weight = torch.from_numpy(ppca.W).float()

#with torch.no_grad():
    #treevae.decoder.factor_regressor.fc_layers[0][0].weight = torch.nn.Parameter(new_weight)
    
#for param in treevae.decoder.factor_regressor.fc_layers[0][0].parameters():
    #param.requires_grad = False
    
#treevae.decoder

In [42]:
#assert(treevae.decoder.factor_regressor.fc_layers[0][0].weight.numpy().all() == ppca.W.T.all())


***Are we able to generate the gene expression data by decoding the simulated latent space?***

In [43]:
#p_m, p_v = treevae.decoder.forward(torch.from_numpy(leaves_z).float())
#p_m = p_m.detach().numpy()
#p_m.shape, mu.shape

In [44]:
#mse = mean_squared_error(p_m, mu)
#print("the distance is {}".format(mse))

***Training***

In [45]:
n_epochs = 500
lr = 1e-2
lambda_ = 1.0
freq = 10

tree_trainer = GaussianTreeTrainer(
        model = treevae,
        gene_dataset = cas_dataset,
        lambda_ = lambda_,
        train_size=1.0,
        test_size=0,
        use_cuda=use_cuda,
        frequency=freq,
        n_epochs_kl_warmup=None
    )

train_leaves:  [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63], [64], [65], [66], [67], [68], [69], [70], [71], [72], [73], [74], [75], [76], [77], [78], [79], [80], [81], [82], [83], [84], [85], [86], [87], [88], [89], [90], [91], [92], [93], [94], [95], [96], [97], [98], [99]]
test_leaves:  []
validation leaves:  []


In [46]:
tree_trainer.train(n_epochs=n_epochs,
              lr=lr)

computing elbo
training:   1%|▏         | 7/500 [00:00<00:26, 18.30it/s]computing elbo
training:   4%|▍         | 19/500 [00:00<00:21, 22.65it/s]computing elbo
training:   6%|▌         | 28/500 [00:01<00:19, 23.86it/s]computing elbo
training:   7%|▋         | 37/500 [00:01<00:21, 21.93it/s]computing elbo
training:  10%|▉         | 49/500 [00:02<00:18, 24.30it/s]computing elbo
training:  12%|█▏        | 58/500 [00:02<00:20, 21.32it/s]computing elbo
training:  13%|█▎        | 67/500 [00:03<00:18, 22.98it/s]computing elbo
training:  16%|█▌        | 79/500 [00:03<00:17, 24.35it/s]computing elbo
training:  18%|█▊        | 88/500 [00:03<00:18, 22.44it/s]computing elbo
training:  19%|█▉        | 97/500 [00:04<00:17, 22.79it/s]computing elbo
training:  22%|██▏       | 109/500 [00:04<00:16, 23.29it/s]computing elbo
training:  24%|██▎       | 118/500 [00:05<00:16, 23.61it/s]computing elbo
training:  25%|██▌       | 127/500 [00:05<00:16, 22.36it/s]computing elbo
training:  28%|██▊       | 139/500

In [47]:
#training_dashboard(tree_trainer, treevae.encoder_variance)

In [91]:
tree_posterior = tree_trainer.create_posterior(model=treevae,
                                              gene_dataset=cas_dataset,
                                               clades=tree_trainer.clades,
                                               indices=np.arange(len(cas_dataset))
                                              )
tree_latent = tree_posterior.get_latent()
tree_latent.shape, internal_z.shape

((100, 5), (100, 5))

In [92]:
tree_latent = tree_posterior.get_latent()
mean_squared_error(tree_latent, leaves_z)

2.2012335955775315

***Missing Value Imputation***

In [93]:
# CascVI imputations
imputed = {}
imputed_z = {}
imputed_cov_z = {}
imputed_mean_z = {}

for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        imputed[n.name], imputed_z[n.name], imputed_mean_z[n.name], imputed_cov_z[n.name] = tree_posterior.imputation_internal(query_node=n.name,
                                                                                                                                pp_averaging=200,
                                                                                                                                z_averaging=100,
                                                                                                                                give_mean=True                           
                                                                                                                            )

In [51]:
internal_treevae_X = [x for x in imputed.values()]
internal_treevae_X = np.array(internal_treevae_X).reshape(-1, cas_dataset.X.shape[1])

# Evaluation

## Evaluation 1.a.i: MSE/MAE L2/L1 error

In [52]:
from external.utils.metrics import mse

data = {'groundtruth': imputed_X, 'average': internal_avg_X,
        'gaussian VAE': internal_vae_X
        , 'gaussian treeVAE': internal_treevae_X
       }

In [53]:
results = mse(data=data, metric='MSE')
print('L2')
results

L2


Unnamed: 0,average,gaussian VAE,gaussian treeVAE
MSE,0.784928,2.205904,0.530174
std,0.389073,0.662961,0.208879


In [54]:
results = mse(data=data, metric='L1')
print('L1')
results

L1


Unnamed: 0,average,gaussian VAE,gaussian treeVAE
MSE,39.3426,115.3276,65.32607
std,7.105427e-15,1.421085e-14,1.421085e-14


## Evaluation 1.a.ii: Correlations 

In [55]:
data = {'groundtruth': imputed_X, 'average': internal_avg_X,
        'gaussian VAE': internal_vae_X
        , 'gaussian treeVAE': internal_treevae_X
       }

df1 = correlations(data, 'None', True)

In [56]:
data_dict = {}
methods = list(data.keys())[1:]
for method in methods:
    data_dict[method] = list(df1[df1.Method==method].mean())
results_corr = pd.DataFrame.from_dict(data_dict, orient='index', columns=['Spearman CC', 'Pearson CC', 'Kendal Tau CC'])

print('gene-gene correlation')
results_corr.head(10)

gene-gene correlation


Unnamed: 0,Spearman CC,Pearson CC,Kendal Tau CC
average,0.845984,0.871949,0.678044
gaussian VAE,0.748121,0.76538,0.569386
gaussian treeVAE,0.887167,0.914792,0.734691


## Evaluation 1.a.iii: MSE/MAE (L1/L2) of variance in latent space

In [168]:
data_var = {'groundtruth': predictive_cov_z, 'gaussian VAE':imputed_avg_cov_z, 'gaussian treeVAE': imputed_cov_z}

In [181]:
predictive_cov_z['0'].shape[0]

5

In [175]:
def mean_variance_latent(tree, predictive_cov_z, imputed_avg_cov_z, imputed_cov_z):
    mse_treevae = 0
    mse_vae = 0
    N = 0
    for n in tree.traverse('levelorder'):
        if not n.is_leaf():
            true_cov = np.diag(predictive_cov_z[n.name])
            vae_cov = imputed_avg_cov_z[n.name].cpu().numpy()
            treevae_cov = imputed_cov_z[n.name] * np.ones((d))

            mse_treevae += mean_squared_error(true_cov, treevae_cov)
            mse_vae += mean_squared_error(true_cov, vae_cov)
            N += 1
    mse_treevae /= N
    mse_vae /= N

    return mse_treevae, mse_vae

mean_variance_latent(tree, data_var, predictive_cov_z, imputed_avg_cov_z, imputed_cov_z)

(7.294980110242236e-05, 0.012943967976536946)

## Evaluation 1.a.iv: Averaged KL divergence (internal nodes)

***TreeVAE***

In [111]:
from torch.distributions import Normal, kl_divergence
from sklearn.preprocessing import normalize

kl_mean = 0
N = 0
for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        mean_true = normalize(np.array([predictive_mean_z[n.name].cpu().numpy()]))
        cov_true = torch.diagonal(torch.from_numpy(predictive_cov_z[n.name]))
        dist_true = Normal(torch.from_numpy(mean_true),
                        cov_true
                        )

        # Approx
        mean_approx = normalize(np.array([imputed_mean_z[n.name].cpu().numpy()]))                
        dist_approx = Normal(torch.from_numpy(mean_approx),
                        imputed_cov_z[n.name]* torch.ones((d,))
                        )
        kl_mean += kl_divergence(dist_true, dist_approx).sum()
        N += 1
kl_mean /= N
print('Average Kl divergence {}'.format(kl_mean))

Average Kl divergence 22917.077406557542


***VAE***

In [87]:
kl_mean = 0
N = 0
for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        cov_true = torch.diagonal(torch.from_numpy(predictive_cov_z[n.name]))
        dist_true = Normal(predictive_mean_z[n.name],
                        cov_true
                        )
                        
        dist_approx = Normal(torch.from_numpy(imputed_avg_z[n.name]),
                        torch.sqrt(imputed_avg_cov_z[n.name].cpu())
                        )
        kl_mean += torch.mean(kl_divergence(dist_true, dist_approx))
        N += 1
kl_mean /= N
print('Average Kl divergence {}'.format(kl_mean))

Average Kl divergence 88.56559050523421


## Evaluation 1.a.v: Likelihood (internal nodes)

In [179]:
from scipy.stats import multivariate_normal

def mean_posterior_lik(tree, predictive_mean_z, imputed_avg_z, imputed_mean_z, predictive_cov_z, imputed_avg_cov_z, imputed_cov_z):
    treevae_lik = 0
    vae_lik = 0
    N = 0
    for n in tree.traverse('levelorder'):
        if not n.is_leaf():
            # mean
            true_mean = predictive_mean_z[n.name].cpu().numpy() 
            vae_mean = imputed_avg_z[n.name][0]
            treevae_mean = imputed_mean_z[n.name].cpu().numpy()

            # covariance
            true_cov = np.diag(predictive_cov_z[n.name])
            vae_cov = np.diag(imputed_avg_cov_z[n.name].cpu().numpy())
            treevae_cov = np.diag(imputed_cov_z[n.name] * np.ones((d)))

            sample_treevae = np.random.multivariate_normal(mean=treevae_mean,
                                                            cov=treevae_cov)
            sample_vae = np.random.multivariate_normal(mean=vae_mean,
                                                        cov=vae_cov)

            treevae_lik += multivariate_normal.logpdf(sample_treevae,
                                                    true_mean,
                                                    true_cov)
            vae_lik += multivariate_normal.logpdf(sample_vae,
                                                    true_mean,
                                                    true_cov)
            
            N += 1

    vae_lik /= N
    treevae_lik /= N
    return [vae_lik, treevae_lik]

mean_posterior_lik(tree, predictive_mean_z, imputed_avg_z, imputed_mean_z, predictive_cov_z, imputed_avg_cov_z, imputed_cov_z)

[-560.7132005861627, -385.35406114699333]