# 0. Standard imports

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import copy
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('WebAgg')
import numpy as np
import pandas as pd

In [2]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi/external


In [4]:
cd ..

/home/eecs/khalil.ouardini/Cassiopeia_Transcriptome/scvi


In [5]:
%reload_ext autoreload
%matplotlib inline

***import ete3 Tree***

In [6]:
from ete3 import Tree

tree_name = "/home/eecs/khalil.ouardini/cas_scvi_topologies/newick_objects/100cells/high_fitness/topology16.nwk"
tree = Tree(tree_name, 1)

#tree = Tree()
#tree.populate(30)

leaves = tree.get_leaves()

for i, n in enumerate(tree.traverse('levelorder')):
    n.add_features(index=i)
    if not n.is_leaf():
        n.name = str(i)

In [7]:
# Data
from anndata import AnnData
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from external.dataset.tree import TreeDataset, GeneExpressionDataset
from external.dataset.ppca import PPCA
from external.dataset.anndataset import AnnDatasetFromAnnData

# Models
import scanpy as sc
from external.inference.gaussian_inference import GaussianTrainer
from external.inference.gaussian_tree_inference import GaussianTreeTrainer
from external.inference.gaussian_tree_inference import GaussianTreePosterior
from inference import posterior
from external.models.treevae import TreeVAE
from external.models.gaussian_vae import GaussianVAE
from external.models.gaussian_treevae import GaussianTreeVAE

# Utils
from external.utils.data_util import get_leaves, get_internal
from external.utils.metrics import ks_pvalue, accuracy_imputation, correlations, mse, knn_purity, knn_purity_stratified
from external.utils.plots_util import plot_histograms, plot_scatter_mean, plot_ecdf_ks, plot_density, plot_embedding
from external.utils.plots_util import plot_losses, plot_elbo, plot_common_ancestor, plot_one_gene, training_dashboard
from external.utils.baselines import avg_weighted_baseline, scvi_baseline, scvi_baseline_z, cascvi_baseline_z



# 1. Simulations (Gaussian Likelihood model)

We assume that the latent variables $z \in \mathbb{R}^{N \times D}$ are gaussian (correlated). A phylogenetic tree $\tau$ (with $N$ nodes) encodes the covariance $\Sigma$ of $z$. 

$$\mathbf{z}=(z_1, ..., z_N) \sim \mathcal{N}(0, \Sigma)$$

$z$ is partitionned into two groups:

- the leaves $\mathcal{L} = {1, ..., L}$
- the internal nodes $\mathcal{I} = {L + 1, ..., N}$

***

***We describe the generative model***:

Consider a dataset of $ X={x_n}_{n=1}^{L} $ (also partitioned such that $1, ..., N = \mathcal{L} \bigcup \mathcal{I}$) such that $x_n \in \mathbb{R}^{P}$. We aim to represent each $x_n$ under a latent variable $z_n \in \mathbb{R}^{D}$ with  with $D << P$ lower dimension. 
We only observe data at the leaves. the generative model is defined $\forall n \in \mathcal{L}$

The set of principal axes $W$ relates the latent variables to the data.

The corresponding data point is generated via a projection:

$$
\forall n \in \mathcal{L}, x_n =  W z_n + e_n
$$

with $W \in \mathbb{R}^{P x D}$ and $e_n \sim \mathcal{N}(0, \sigma^2 I_P)$. Thus:


$$
\forall n \in \mathcal{L},  x_n | z_n \sim \mathcal{N}(W z_n, \sigma^2 I_P)
$$

After marginalization

$$
\forall n \in \mathcal{L}, x_n \sim \mathcal{N}(0, W^T W + \sigma^2 I_P)
$$

The posterior $p(z_n|x_n)$ for each $n$ is also ***tractable***, indeed

$\begin{pmatrix} x_n \\ z_n \end{pmatrix} = \begin{pmatrix} W z_n + e \\ z_n \end{pmatrix}$ is a gaussian vector (because for $a \in \mathbb{R}$, $b \in \mathbb{R}$, $a(W z_n + e) + bz_n$ is still gaussian) such that:

$$
\begin{pmatrix} x_n \\ z_n \end{pmatrix} \sim \mathcal{N}(\begin{pmatrix} 0 \\ 0 \end{pmatrix}, \begin{pmatrix} W^T W + \sigma^2 I_P  & W\Sigma_n \\ (W\Sigma_n)^T & \Sigma_n \end{pmatrix})
$$

where $\Sigma_n$ is the marginalized covariance $\Sigma$ of $z_n$

Therefore we can use the conditioning formula to infer the mean and the covariance of the (gaussian) posterior $p(z_n|x_n)$:

$$
\mu_{z_n|x_n} = (W\Sigma_{n}(W^{T} W + \sigma^{2} I_{P})^{-1}\Sigma_{n}^{T}W^{T}) x_{n} \\
\Sigma_{z_n|x_n} = \Sigma_n - W\Sigma_{n}(W^{T} W + \sigma^{2} I_{P})^{-1}\Sigma_{n}^{T}W^{T}
$$

***

***Imputation at internal nodes***

Let $j \in \mathcal{I}$, and $X_{\mathcal{L}} = {x_1, ... x_L}$ the set of leaves.
We want to infer $p(x_j|X_{\mathcal{L}})$. If we consider that the data at the internal nodes is "seen" and that the generative model is also known $\forall n \in \mathcal{I}$, we could easily (and accurately) compute $p(x_j|X_{\mathcal{L}})$ by using the gaussian conditioning formula on the gaussian vector:

$$
\begin{pmatrix} x_j \\ X_{\mathcal{L}} \end{pmatrix}
$$

In the case of unseen data at the internal nodes, one can estimate the posterior predictive density:

1. $$
p(x_j|X_{\mathcal{L}}) = p(x_j|x_1, ..., x_L) = \int p(x_j|z_j)p(z_j|z_1,...,z_L)\prod_{i=1}^{L}p(z_i|x_i)(dz_j,dz_1,...,dz_L)
$$

Therefore:
$$
p(x_j|x_1, ..., x_L) \approx  p(x_j|z_j)p(z_j|z_1,...,z_L)\prod_{i=1}^{L}p(z_i|x_i)
$$

$$
p(x_j|x_1, ..., x_L) \approx  \mathcal{N}(x_j|Wz_j, \sigma^2I_P)  \mathcal{N}(z_j|\mu_{j|\mathcal{I}}, \Sigma_{j|\mathcal{I}}) \prod_{i=1}^{L} \mathcal{N}(z_i|\mu_{z_i|x_i}, \Sigma_{z_i|x_i})
$$

2. $ p(x_j|X_{\mathcal{L}}) = Wp(z_j|X_{\mathcal{L}}) + p(e_j)$

In [8]:
print(tree)


                  /-c148
               /-|
              |   \-c149
            /-|
           |  |   /-c198
         /-|   \-|
        |  |      \-c199
        |  |
        |   \-c11
        |
        |         /-c46
        |        |
        |        |      /-c62
        |        |   /-|
        |      /-|  |   \-c63
        |     |  |  |
        |     |  |  |         /-c138
        |     |  |  |      /-|
        |     |  |  |     |   \-c139
        |     |   \-|   /-|
        |     |     |  |  |      /-c196
        |     |     |  |  |   /-|
        |     |     |  |   \-|   \-c197
        |     |     |  |     |
        |     |     |  |      \-c167
        |     |      \-|
        |     |        |            /-c176
        |     |        |         /-|
        |   /-|        |      /-|   \-c177
        |  |  |        |     |  |
        |  |  |        |   /-|   \-c141
        |  |  |        |  |  |
        |  |  |         \-|   \-c123
        |  |  |           |
        |  |  |      

***Branch Length***

In [9]:
eps = 1e-3
branch_length = {}
for node in tree.traverse('levelorder'):
    if node.name == '0':
        branch_length[node.name] = 1.0
        continue
    branch_length[node.name] = node.dist
branch_length['prior_root'] = 1.0


In [10]:
import torch
    
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f65d2c383f0>

In [11]:
d = 5
p = 100
vis = True
leaves_only = False
var = 1.0
sigma_scale = 1.0

#ppca = PPCA(tree, p, d, vis, leaves_only, var, sigma_scale)
ppca = PPCA(tree=tree, 
            dim=p, 
            latent=d, 
            vis=vis, 
            only=leaves_only,
            branch_length=branch_length, 
            sigma_scale=sigma_scale
            )

ppca.simulate_latent()

***Marginalization***

In [12]:
ppca.simulate_normal()
ppca.W.shape

(100, 5)

In [13]:
lik_tree = ppca.likelihood_obs(leaves_only=False)
lik_leaves = ppca.likelihood_obs(leaves_only=True)

print("Log-Likelihood of the tree {}".format(lik_tree))
print("LogLikelihood of the leaves {}".format(lik_leaves))

Log-Likelihood of the tree -30893.548139778148
LogLikelihood of the leaves -15440.49366990377


***Get data***

In [14]:
# Latent vectors
leaves_z, _, _ = get_leaves(ppca.z, ppca.mu, tree)

#FIXED training set
leaves_X, leaves_idx, mu = get_leaves(ppca.X, ppca.mu, tree)

# internal nodes data (for imputation)
internal_X, internal_idx, internal_mu = get_internal(ppca.X, ppca.mu, tree)

# internal nodes z
internal_z, _, _ = get_internal(ppca.z, ppca.mu, tree)

leaves_X.shape, mu.shape, internal_X.shape, internal_mu.shape, leaves_z.shape

((100, 100), (100, 100), (100, 100), (100, 100), (100, 5))

***Posterior Distributions***

***evidence***

In [15]:
evidence_leaves = ppca.get_evidence_leaves_levelorder(X=ppca.X, dim=ppca.dim)
evidence_leaves.shape

(10000,)

***Leaves covariance***

In [16]:
ppca.compute_leaves_covariance()

***Posterior mean and covariance***

In [17]:
posterior_mean, posterior_cov = ppca.compute_posterior()

***Posterior predictive density***

In [18]:
predictive_mean, predictive_cov = ppca.compute_posterior_predictive()

# Preliminary: Baselines

## Baseline 1: Unweighted Average of gene expression in Clade

The simple idea here is to impute the value of an internal node, with the (un)weighted average of the gene expression values of the leaves, taking the query internal node as the root of the subtree.

In [19]:
imputed_avg = avg_weighted_baseline(tree=tree, 
                                    weighted=False, 
                                    X=ppca.X,
                                    rounding=False
                                   )

#get internal nodes
avg_X = np.array([x for x in imputed_avg.values()]).reshape(-1, ppca.X.shape[1])
internal_avg_X, _, _ = get_internal(avg_X, ppca.mu, tree)

## Baseline 2: (groundtruth) posterior predictive density

In [20]:
imputed_ppca = {}
for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        samples = np.array([np.random.multivariate_normal(mean=predictive_mean[n.name],
                                                            cov=predictive_cov[n.name])
                           for i in range(20)])
        imputed_ppca[n.name] = np.mean(samples, axis=0)

internal_ppca_X = np.array([x for x in imputed_ppca.values()]).reshape(-1, ppca.X.shape[1])

  


## Baseline 3: Approximation through Message Passing (Oracle)


i.e, 

1. sample from $z_1, ..., z_n \sim p(z_1, ..., z_n|x_1, ..., x_n)$ (conditionning formula)
2. impute $z_i \sim p(z_i | z_1, ..., z_n)$ (Message Passing)
3. Decode $p(x_i|z_i) = W z_i + \sigma^2 I_P$ (Generative model)

In [21]:
posterior_mean_corr, posterior_cov_corr = ppca.compute_correlated_posterior()

In [22]:
imputed_mp, imputed_z_mp = ppca.compute_approx_posterior_predictive(iid=False, use_MP=True, sample_size=200)

  cov=posterior_cov).reshape(-1, self.latent) for i in range(sample_size)])
go
[2021-05-03 21:45:53,286] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:45:53,287] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-03 21:45:53,294] INFO - scvi.dataset.dataset | Merging datasets. Input objects are modified in place.
[2021-05-03 21:45:53,294] INFO - scvi.dataset.dataset | Gene names and cell measurement names are assumed to have a non-null intersection between datasets.
[2021-05-03 21:45:53,295] INFO - scvi.dataset.dataset | Keeping 100 genes
[2021-05-03 21:45:53,296] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2021-05-03 21:45:53,298] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:45:53,299] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-03 21:45:53,301] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:45:53,301] INFO - scvi.dataset.dataset | 

In [23]:
imputed_X = np.array([x for x in imputed_mp.values()]).reshape(-1, ppca.X.shape[1])

## Baseline 4: Approximation through Message Passing + iid posteriors

i.e, 

1. sample from marginal conditional $z_l \sim p(z_l|x_1) \forall l \in (1, ...,L)$ (conditionning formula)
2. impute $z_i \sim p(z_i | z_1, ..., z_n)$ (Message Passing)
3. Decode $p(x_i|z_i) = W z_i + \sigma^2 I_P$ (Generative model)

In [24]:
imputed_mp2, imputed_z_mp2 = ppca.compute_approx_posterior_predictive(iid=True, use_MP=True, sample_size=200)

  cov=posterior_cov[k]) for i in range(sample_size)])
go
[2021-05-03 21:46:14,400] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:14,401] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-03 21:46:14,407] INFO - scvi.dataset.dataset | Merging datasets. Input objects are modified in place.
[2021-05-03 21:46:14,408] INFO - scvi.dataset.dataset | Gene names and cell measurement names are assumed to have a non-null intersection between datasets.
[2021-05-03 21:46:14,409] INFO - scvi.dataset.dataset | Keeping 100 genes
[2021-05-03 21:46:14,412] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2021-05-03 21:46:14,415] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:14,416] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-03 21:46:14,419] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:14,420] INFO - scvi.dataset.dataset | Remapping batch_indice

In [25]:
imputed_X2 = np.array([x for x in imputed_mp2.values()]).reshape(-1, ppca.X.shape[1])

## Baseline 5: Gaussian VAE decoded averaged latent space

In [26]:
# anndata
gene_dataset = GeneExpressionDataset()
gene_dataset.populate_from_data(leaves_X)

[2021-05-03 21:46:26,480] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:26,481] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]


In [27]:
n_epochs = 500

vae = GaussianVAE(gene_dataset.nb_genes,
                  n_hidden=64,
                  n_layers=1,
                  n_latent=ppca.latent,
                  sigma_ldvae=None
              )

#new_weight = torch.from_numpy(ppca.W).float()

#with torch.no_grad():
    #vae.decoder.factor_regressor.fc_layers[0][0].weight = torch.nn.Parameter(new_weight)
    
#for param in vae.decoder.factor_regressor.fc_layers[0][0].parameters():
    #param.requires_grad = False
    
#vae.decoder

In [28]:
p_m, p_v = vae.decoder.forward(torch.from_numpy(leaves_z).float())
p_m = p_m.detach().numpy()

from sklearn.metrics import mean_squared_error

mse = mean_squared_error(p_m, mu)
print("the distance is {}".format(mse))

the distance is 4.236195116252171


In [29]:
use_cuda = False

trainer = GaussianTrainer(model=vae,
                          gene_dataset=gene_dataset,
                          train_size=1.0,
                          use_cuda=use_cuda,
                          frequency=10,
                          n_epochs_kl_warmup=None,
                         )

In [30]:
# train VAE
trainer.train(n_epochs=n_epochs, lr=1e-2) 

computing elbo
ELBO: 36955.26953125
computing elbo
ELBO: 36971.0234375
computing elbo
ELBO: 36961.27734375
training:   1%|▏         | 7/500 [00:00<00:07, 64.87it/s]computing elbo
ELBO: 28474.59375
computing elbo
ELBO: 28790.3671875
computing elbo
ELBO: 28929.29296875
training:   4%|▍         | 19/500 [00:00<00:05, 93.02it/s]computing elbo
ELBO: 30706.1484375
computing elbo
ELBO: 32072.056640625
computing elbo
ELBO: 31596.259765625
computing elbo
ELBO: 29247.15625
computing elbo
ELBO: 29418.73046875
computing elbo
ELBO: 29420.63671875
training:   6%|▌         | 30/500 [00:00<00:04, 98.17it/s]computing elbo
ELBO: 25250.009765625
computing elbo
ELBO: 25239.40625
computing elbo
ELBO: 25224.63671875
training:   8%|▊         | 42/500 [00:00<00:04, 104.44it/s]computing elbo
ELBO: 24048.978515625
computing elbo
ELBO: 24010.6796875
computing elbo
ELBO: 23955.333984375
training:  11%|█         | 54/500 [00:00<00:04, 107.95it/s]computing elbo
ELBO: 22088.640625
computing elbo
ELBO: 22089.984375
c

In [31]:
elbo_train = trainer.history["elbo_train_set"]
x = np.linspace(0, 100, (len(elbo_train)))
plt.plot(np.log(elbo_train), 
         label="train", color='blue',
         linestyle=':',
         linewidth=3
        )
        
plt.xlabel('Epoch')
plt.ylabel("ELBO")
plt.legend()
plt.title("Train history Gaussian VAE")
plt.show()

In [32]:
from sklearn.metrics import mean_squared_error

posterior =  trainer.create_posterior(model=vae,
                                      gene_dataset=gene_dataset
                                      )
latent = posterior.get_latent()
mean_squared_error(latent, leaves_z)

2.4091782390417955

In [33]:
imputed_avg_vae, _ = scvi_baseline_z(tree=tree,
                                 model=vae,
                                 posterior=posterior,
                                 weighted=False,
                                 n_samples_z=1,
                                 gaussian=True
                                )

internal_vae_X = np.array([x for x in imputed_avg_vae.values()]).reshape(-1, ppca.X.shape[1])
internal_vae_X.shape

(100, 100)

# 3. Our Model: CascVI

In [34]:
import scanpy as sc

adata = AnnData(leaves_X)
adata.obs_names = [n.name for n in tree.traverse('levelorder') if n.is_leaf()]
scvi_dataset = AnnDatasetFromAnnData(adata, filtering=False)
scvi_dataset.initialize_cell_attribute('barcodes', adata.obs_names)

#TreeDataset
cas_dataset = TreeDataset(scvi_dataset, tree=tree, filtering=False)
cas_dataset

go
[2021-05-03 21:46:31,788] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:31,788] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-03 21:46:31,798] INFO - scvi.dataset.dataset | Merging datasets. Input objects are modified in place.
[2021-05-03 21:46:31,799] INFO - scvi.dataset.dataset | Gene names and cell measurement names are assumed to have a non-null intersection between datasets.
[2021-05-03 21:46:31,801] INFO - scvi.dataset.dataset | Keeping 100 genes
[2021-05-03 21:46:31,803] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2021-05-03 21:46:31,805] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:31,806] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2021-05-03 21:46:31,809] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2021-05-03 21:46:31,810] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]


GeneExpressionDataset object with n_cells x nb_genes = 100 x 100
    gene_attribute_names: 'gene_names'
    cell_attribute_names: 'batch_indices', 'barcodes', 'local_means', 'local_vars', 'labels'
    cell_categorical_attribute_names: 'labels', 'batch_indices'

In [35]:
use_cuda = False
use_MP = True

treevae = GaussianTreeVAE(cas_dataset.nb_genes,
              tree = cas_dataset.tree,
              n_latent=ppca.latent,
              n_hidden=64,
              n_layers=1,
              prior_t = branch_length,
              use_MP=use_MP,
              sigma_ldvae=None
             )

***Freezing the decoder***

In [36]:
#new_weight = torch.from_numpy(ppca.W).float()

#with torch.no_grad():
    #treevae.decoder.factor_regressor.fc_layers[0][0].weight = torch.nn.Parameter(new_weight)
    
#for param in treevae.decoder.factor_regressor.fc_layers[0][0].parameters():
    #param.requires_grad = False
    
#treevae.decoder

In [37]:
#assert(treevae.decoder.factor_regressor.fc_layers[0][0].weight.numpy().all() == ppca.W.T.all())


***Are we able to generate the gene expression data by decoding the simulated latent space?***

In [38]:
#p_m, p_v = treevae.decoder.forward(torch.from_numpy(leaves_z).float())
#p_m = p_m.detach().numpy()
#p_m.shape, mu.shape

In [39]:
#mse = mean_squared_error(p_m, mu)
#print("the distance is {}".format(mse))

***Training***

In [40]:
n_epochs = 500
lr = 1e-2
lambda_ = 1.0
freq = 10

tree_trainer = GaussianTreeTrainer(
        model = treevae,
        gene_dataset = cas_dataset,
        lambda_ = lambda_,
        train_size=1.0,
        test_size=0,
        use_cuda=use_cuda,
        frequency=freq,
        n_epochs_kl_warmup=None
    )

train_leaves:  [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63], [64], [65], [66], [67], [68], [69], [70], [71], [72], [73], [74], [75], [76], [77], [78], [79], [80], [81], [82], [83], [84], [85], [86], [87], [88], [89], [90], [91], [92], [93], [94], [95], [96], [97], [98], [99]]
test_leaves:  []
validation leaves:  []


In [41]:
tree_trainer.train(n_epochs=n_epochs,
              lr=lr)

computing elbo
training:   2%|▏         | 9/500 [00:00<00:17, 27.93it/s]computing elbo
training:   4%|▎         | 18/500 [00:00<00:16, 28.37it/s]computing elbo
training:   5%|▌         | 27/500 [00:00<00:16, 28.16it/s]computing elbo
training:   8%|▊         | 39/500 [00:01<00:16, 28.40it/s]computing elbo
training:  10%|▉         | 48/500 [00:01<00:18, 24.92it/s]computing elbo
training:  11%|█▏        | 57/500 [00:02<00:19, 23.19it/s]computing elbo
training:  14%|█▍        | 69/500 [00:02<00:18, 23.21it/s]computing elbo
training:  16%|█▌        | 78/500 [00:03<00:18, 22.73it/s]computing elbo
training:  17%|█▋        | 87/500 [00:03<00:17, 22.97it/s]computing elbo
training:  20%|█▉        | 99/500 [00:04<00:17, 23.27it/s]computing elbo
training:  22%|██▏       | 108/500 [00:04<00:17, 22.94it/s]computing elbo
training:  23%|██▎       | 117/500 [00:04<00:16, 22.90it/s]computing elbo
training:  26%|██▌       | 129/500 [00:05<00:15, 23.21it/s]computing elbo
training:  28%|██▊       | 138/500

In [42]:
training_dashboard(tree_trainer, treevae.encoder_variance)

In [43]:
tree_posterior = tree_trainer.create_posterior(model=treevae,
                                              gene_dataset=cas_dataset,
                                               clades=tree_trainer.clades,
                                               indices=np.arange(len(cas_dataset))
                                              )
tree_latent = tree_posterior.get_latent()
tree_latent.shape, internal_z.shape

((100, 5), (100, 5))

In [44]:
tree_latent = tree_posterior.get_latent()
mean_squared_error(tree_latent, leaves_z)

1.264154237669033

***Missing Value Imputation***

In [45]:
# CascVI imputations
imputed = {}
imputed_z = {}

for n in tree.traverse('levelorder'):
    if not n.is_leaf():
        imputed[n.name], imputed_z[n.name] = tree_posterior.imputation_internal(query_node=n.name,
                                                            pp_averaging=200,
                                                            z_averaging=None                           
                                                           )

In [46]:
internal_treevae_X = [x for x in imputed.values()]
internal_treevae_X = np.array(internal_treevae_X).reshape(-1, cas_dataset.X.shape[1])

***Evaluation: Correlations***

In [47]:
#data = {'groundtruth': internal_X, 'average': internal_avg_X, 'ppca':internal_ppca_X,
#        'approx ppca Oracle':imputed_X, 'approx ppca mean field': imputed_X2,
#        'gaussian VAE': internal_vae_X, 'gaussian treeVAE': internal_treevae_X
#      }

data = {'groundtruth': internal_X.T, 'average': internal_avg_X.T, 'ppca':internal_ppca_X.T,
        'approx ppca Oracle':imputed_X.T, 'approx ppca iid': imputed_X2.T,
        'gaussian VAE': internal_vae_X.T
        , 'gaussian treeVAE': internal_treevae_X.T
       }

df1 = correlations(data, 'None', True)
df1.head(5)
plt.show()

In [48]:
df1[df1.Method=='average'].mean(), df1[df1.Method=='ppca'].mean(), df1[df1.Method=='gaussian VAE'].mean()

(Spearman CC    0.756487
 Pearson CC     0.767067
 Kendall Tau    0.574368
 dtype: float64,
 Spearman CC    0.812028
 Pearson CC     0.820406
 Kendall Tau    0.630683
 dtype: float64,
 Spearman CC    0.656615
 Pearson CC     0.603462
 Kendall Tau    0.490053
 dtype: float64)

In [49]:
df1[df1.Method=='approx ppca Oracle'].mean(), df1[df1.Method=='approx ppca iid'].mean(),  df1[df1.Method=='gaussian treeVAE'].mean()

(Spearman CC    0.815085
 Pearson CC     0.824786
 Kendall Tau    0.633762
 dtype: float64,
 Spearman CC    0.815705
 Pearson CC     0.824596
 Kendall Tau    0.634432
 dtype: float64,
 Spearman CC    0.782457
 Pearson CC     0.791754
 Kendall Tau    0.599818
 dtype: float64)

In [50]:
data_dict = {}
methods = list(data.keys())[1:]
for method in methods:
    data_dict[method] = list(df1[df1.Method==method].mean())
results_corr = pd.DataFrame.from_dict(data_dict, orient='index', columns=['Spearman CC', 'Pearson CC', 'Kendal Tau CC'])

results_corr.head(10)

Unnamed: 0,Spearman CC,Pearson CC,Kendal Tau CC
average,0.756487,0.767067,0.574368
ppca,0.812028,0.820406,0.630683
approx ppca Oracle,0.815085,0.824786,0.633762
approx ppca iid,0.815705,0.824596,0.634432
gaussian VAE,0.656615,0.603462,0.490053
gaussian treeVAE,0.782457,0.791754,0.599818


***Evaluation 2: MSE***

In [51]:
#data = {'groundtruth': internal_X.T, 'average': internal_avg_X.T, 'ppca':internal_ppca_X.T,
#        'approx ppca Oracle':imputed_X.T, 'approx ppca mean field': imputed_X2.T,
#        'gaussian VAE': internal_vae_X.T, 'gaussian treeVAE': internal_treevae_X.T
#       }

from external.utils.metrics import mse

data = {'groundtruth': internal_X, 'average': internal_avg_X, 'ppca':internal_ppca_X,
        'approx ppca Oracle':imputed_X, 'approx ppca iid': imputed_X2,
        'gaussian VAE': internal_vae_X, 'gaussian treeVAE': internal_treevae_X
      }

results = mse(data)

In [52]:
results

Unnamed: 0,average,ppca,approx ppca Oracle,approx ppca iid,gaussian VAE,gaussian treeVAE
MSE,1.676866,1.265429,1.257161,1.260095,34573470.0,1.517241
std,0.505498,0.405923,0.430971,0.439821,301290200.0,0.507593


***Testing Message Passing***

In [53]:
#evidence = ppca.get_evidence_leaves_levelorder(ppca.z, ppca.latent)
#mean1, scale1 = ppca.compute_posterior_predictive_z_MP(evidence.reshape(-1, ppca.latent))
#mean2, scale2 = ppca.compute_posterior_predictive_z(evidence)