# Clustering HCA Mouse Atlas with scVI and ScanPy

Disclaimer: some of the code in this notebook was taken from Scanpy's Clustering tutorial (https://scanpy-tutorials.readthedocs.io/en/latest/pbmc3k.html) which is itself based on SEURAT's clustering tutorial in R.

This notebook is designed as a demonstration of scVI's potency on the tasks considered in the Scanpy PBMC 3K Clustering notebook.
In order to do so, we follow the same workflow adopted by scanpy in their clustering tutorial while performing the analysis using scVI as often as possible.
Specifically, we use scVI's latent representation and differential expression analysis (which computes a Bayes Factor on imputed values). 
For visualisation, pre-processing and for some canonical analysis, we use the Scanpy package directly.

When useful, we provide high-level wrappers around scVI's analysis tools. These functions are designed to make standard use of scVI as easy as possible.
For specific use cases, we encourage the reader to take a closer look at those functions and modify them according to his needs.

In [1]:
cd ../../

/nfs/team205/zx3/PycharmProject/scVI


## Automated testing configuration

In [2]:
# This is for notebook automated testing purpose
def allow_notebook_for_test():
    print("Testing the annotation notebook")

import sys, os

sys.path.append(os.path.abspath("../.."))
n_epochs_all = None
test_mode = False


def if_not_test_else(x, y):
    if not test_mode:
        return x
    else:
        return y


save_path = "data/"

# End of configuration

## Initialization

In [3]:
# Uncomment to download the data (only works on Unix system)
# !mkdir data
# !wget http://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/pbmc3k_filtered_gene_bc_matrices.tar.gz -O data/pbmc3k_filtered_gene_bc_matrices.tar.gz
# !cd data; tar -xzf pbmc3k_filtered_gene_bc_matrices.tar.gz

In [4]:
# Seed for reproducability
import torch
import numpy as np

torch.manual_seed(0)
np.random.seed(0)

In [5]:
import pandas as pd
import scanpy as sc

sc.settings.verbosity = 0  # verbosity: errors (0), warnings (1), info (2), hints (3)

In [6]:
if not test_mode:
    %matplotlib inline
    sc.settings.set_figure_params(dpi=60)

In [7]:
test_mode

False

# Load the data
adata = sc.read_10x_mtx(
    os.path.join(
        save_path, "filtered_gene_bc_matrices/hg19/"
    ),  # the directory with the `.mtx` file
    var_names="gene_symbols",  # use gene symbols for the variable names (variables-axis index)
)
adata.var_names_make_unique()

In [8]:
save_path = "/lustre/scratch117/cellgen/team205/tpcg/backup/backup_20190401/sc_sclassification/CellTypist/data_repo/MouseAtlas/MouseAtlas.total.h5ad"

In [34]:
adata = sc.read_h5ad(save_path)

In [84]:
adata.raw.X.shape

(41282, 37878)

In [35]:
adata.var_names_make_unique()
adata.obs_names_make_unique()

In [61]:
from collections import Counter
Counter(adata.obs['Organ'])

Counter({'Lung': 1816,
         'Thymus': 404,
         'Brain_Microglia': 188,
         'Skin': 1177,
         'Pancreas': 1311,
         'Muscle': 2869,
         'Fat': 3060,
         'Bladder': 4207,
         'Tongue': 4958,
         'Marrow': 380,
         'Heart': 3843,
         'Liver': 1353,
         'Trachea': 1093,
         'Mammary': 3428,
         'Colon': 2138,
         'Spleen': 28,
         'Kidney': 1475,
         'Brain_Neurons': 2658,
         'Uterus': 396,
         'TrophoblastStemCells': 1,
         'Testis': 360,
         'Stomach': 128,
         'SmallIntestine': 78,
         'Placenta': 12,
         'Ovary': 182,
         'NeontalBrain': 1,
         'NeonatalSkin': 15,
         'NeonatalRib': 145,
         'NeonatalMuscle': 26,
         'NeonatalHeart': 10,
         'NeonatalCalvaria': 12,
         'Mouse3T3': 20,
         'MesenchymalStemCellsPrimary': 10,
         'MesenchymalStemCells': 6,
         'MammaryGland.Virgin': 74,
         'MammaryGland.Pregnancy': 

In [37]:
Counter(adata.obs['Dataset'])

Counter({'Tabula(Plate)': 45959,
         'Tabula(Droplet)': 41079,
         'MCA': 58809,
         'HSC': 61077,
         'Embryo': 287,
         'Gastrula': 456,
         'Brain': 2931,
         'Kidney': 39651,
         'Thymus': 17441})

In [62]:
adata.X.shape

(41282, 2777)

In [63]:
save_path2 = "/lustre/scratch117/cellgen/team205/tpcg/human_data/HumanAtlas.h5ad"

adata_human = sc.read_h5ad(save_path2)

In [66]:
pd.crosstab(adata_human.obs['Dataset'], adata_human.obs['Tissue'])

Tissue,Blood,Brain,Brain_Microglia,Colon,Decidua,ES cells,Intestine,Kidney,Liver,Lung Parenchyma,Pancreas,Placenta,Prostate,Skin,Testis,Tumour,Upper airway,Ventral Midbrain,mLN
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
Baron,0,0,0,0,0,0,0,0,0,0,8569,0,0,0,0,0,0,0,0
Muraro,0,0,0,0,0,0,0,0,0,0,2126,0,0,0,0,0,0,0,0
Segerstolpe,0,0,0,0,0,0,0,0,0,0,3363,0,0,0,0,0,0,0,0
Wang,0,0,0,0,0,0,0,0,0,0,635,0,0,0,0,0,0,0,0
felipe-lung,0,0,0,0,0,0,0,0,0,19451,0,0,0,0,0,0,6562,0,0
gierahn17,5584,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
guo18,0,0,0,0,0,0,0,0,0,0,0,0,0,0,12985,0,0,0,0
kylie-colon,0,0,0,21554,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10674
lamanno16a,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1977,0
lamanno16b,0,0,0,0,0,1715,0,0,0,0,0,0,0,0,0,0,0,0,0


In [73]:
adata_human.obs["CellType"].value_counts()

Unnamed: 0_level_0,Barcodes,Dataset,Barcode,Tissue,CellType,Protocol,DonorType,Other,n_counts,leiden
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
ARMS032_AACCATGCAATAACGA-felipe-lung,ARMS032_AACCATGCAATAACGA-felipe-lung,felipe-lung,ARMS032_AACCATGCAATAACGA-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,10969.0,4
ARMS032_ACACCCTCAGGTGCCT-felipe-lung,ARMS032_ACACCCTCAGGTGCCT-felipe-lung,felipe-lung,ARMS032_ACACCCTCAGGTGCCT-felipe-lung,Upper airway,Neutrophils,10X,biopsy,Ctrl,31608.0,4
ARMS032_ACATGGTAGGCAATTA-felipe-lung,ARMS032_ACATGGTAGGCAATTA-felipe-lung,felipe-lung,ARMS032_ACATGGTAGGCAATTA-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,38928.0,4
ARMS032_ACTTACTAGTTTGCGT-felipe-lung,ARMS032_ACTTACTAGTTTGCGT-felipe-lung,felipe-lung,ARMS032_ACTTACTAGTTTGCGT-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,20231.0,4
ARMS032_AGTAGTCCATATGAGA-felipe-lung,ARMS032_AGTAGTCCATATGAGA-felipe-lung,felipe-lung,ARMS032_AGTAGTCCATATGAGA-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,28352.0,4


In [74]:
adata_human.obs.head()

Unnamed: 0_level_0,Barcodes,Dataset,Barcode,Tissue,CellType,Protocol,DonorType,Other,n_counts,leiden
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
ARMS032_AACCATGCAATAACGA-felipe-lung,ARMS032_AACCATGCAATAACGA-felipe-lung,felipe-lung,ARMS032_AACCATGCAATAACGA-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,10969.0,4
ARMS032_ACACCCTCAGGTGCCT-felipe-lung,ARMS032_ACACCCTCAGGTGCCT-felipe-lung,felipe-lung,ARMS032_ACACCCTCAGGTGCCT-felipe-lung,Upper airway,Neutrophils,10X,biopsy,Ctrl,31608.0,4
ARMS032_ACATGGTAGGCAATTA-felipe-lung,ARMS032_ACATGGTAGGCAATTA-felipe-lung,felipe-lung,ARMS032_ACATGGTAGGCAATTA-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,38928.0,4
ARMS032_ACTTACTAGTTTGCGT-felipe-lung,ARMS032_ACTTACTAGTTTGCGT-felipe-lung,felipe-lung,ARMS032_ACTTACTAGTTTGCGT-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,20231.0,4
ARMS032_AGTAGTCCATATGAGA-felipe-lung,ARMS032_AGTAGTCCATATGAGA-felipe-lung,felipe-lung,ARMS032_AGTAGTCCATATGAGA-felipe-lung,Upper airway,Secretory,10X,biopsy,Ctrl,28352.0,4


In [77]:
adata_human.raw.var.head()

Unnamed: 0_level_0,gene_name
index,Unnamed: 1_level_1
MIR1302-2HG,MIR1302-2HG
FAM138A,FAM138A
OR4F5,OR4F5
RP11-34P13.7,RP11-34P13.7
RP11-34P13.8,RP11-34P13.8


In [79]:
sc.tl.leiden(adata, resolution = 1.5)

In [83]:
adata_human.raw.X

<827448x65104 sparse matrix of type '<class 'numpy.float32'>'
	with 1534508369 stored elements in Compressed Sparse Column format>

## Preprocessing

In the following section, we reproduce the preprocessing steps adopted in the scanpy notebook. 


Basic filtering: we remove cells with a low number of genes expressed and genes which are expressed in a low number of cells.

In [44]:
min_genes = if_not_test_else(200, 0)
min_cells = if_not_test_else(20, 0)

In [45]:
sc.settings.verbosity = 2
sc.pp.filter_cells(adata, min_genes=min_genes)
sc.pp.filter_genes(adata, min_cells=min_cells)
sc.pp.filter_cells(adata, min_genes=1)

filtered out 14 genes that are detected in less than 20 cells


In [52]:
adata.shape

(41282, 2777)

As in the scanpy notebook, we then look for high levels of mitochondrial genes and high number of expressed genes which are indicators of poor quality cells.

#### Non applicable step
mito_genes = adata.var_names.str.startswith("MT-")
adata.obs["percent_mito"] = (
    np.sum(adata[:, mito_genes].X, axis=1).A1 / np.sum(adata.X, axis=1).A1
)
adata.obs["n_counts"] = adata.X.sum(axis=1).A1

In [48]:
adata = adata[adata.obs["n_genes"] < 2500, :]

adata = adata[adata.obs["percent_mito"] < 0.05, :]

## ⚠ scVI uses non normalized data so we keep the original data in a separate `AnnData` object, then the normalization steps are performed

#####  Normalization and more filtering

We only keep highly variable genes

In [50]:
adata_original = adata.copy()

sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)



filtered out 3128 cells that have less than 1 counts


  np.log1p(X, out=X)
  np.log1p(X, out=X)


In [54]:
adata_original.X


array([[-0.08209129, -0.06849433, -0.02150778, ..., -0.09477201,
         0.6100729 ,  1.1633312 ],
       [-0.08209129, -0.06849433, -0.02150778, ..., -0.09477201,
         0.08112828, -0.1112038 ],
       [-0.08209129, -0.06849433, -0.02150778, ..., -0.09477201,
        -0.06557146, -0.1112038 ],
       ...,
       [-0.08209129, -0.06849433, -0.02150778, ..., -0.09477201,
        -0.7475515 , -0.1112038 ],
       [-0.08209129, -0.06849433, -0.02150778, ..., -0.09477201,
        -0.7475515 , -0.1112038 ],
       [-0.08209129, -0.06849433, -0.02150778, ...,  4.2228007 ,
        -0.7475515 , -0.1112038 ]], dtype=float32)

In [51]:

min_mean = if_not_test_else(0.0125, -np.inf)
max_mean = if_not_test_else(3, np.inf)
min_disp = if_not_test_else(0.5, -np.inf)
max_disp = if_not_test_else(None, np.inf)

sc.pp.highly_variable_genes(
    adata,
    min_mean=min_mean,
    max_mean=max_mean,
    min_disp=min_disp,
    max_disp=max_disp
    # n_top_genes=500
)


ValueError: Bin edges must be unique: array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan]).
You can drop duplicate edges by setting the 'duplicates' kwarg

In [None]:

adata.raw = adata

highly_variable_genes = adata.var["highly_variable"]
adata = adata[:, highly_variable_genes]

sc.pp.regress_out(adata, ["n_counts", "percent_mito"])
sc.pp.scale(adata, max_value=10)

# Also filter the original adata genes
adata_original = adata_original[:, highly_variable_genes]
print(highly_variable_genes.sum())

# We also store adata_original into adata.raw
# (which was designed for this purpose but actually has limited functionnalities)
adata.raw = adata_original

## Compute the scVI latent space

Below we provide then use a wrapper function designed to compute scVI's latent representation of the non-normalized data. Specifically, we train scVI's VAE, compute and store the latent representation then return the posterior which will later be used for further inference.

In [3]:

from scvi.dataset.anndata import AnnDataset
from scvi.inference import UnsupervisedTrainer
from scvi.models.vae import VAE
from typing import Tuple

In [None]:
def compute_scvi_latent(
    adata: sc.AnnData,
    n_latent: int = 5,
    n_epochs: int = 100,
    lr: float = 1e-3,
    use_batches: bool = False,
    use_cuda: bool = False,
) -> Tuple[scvi.inference.Posterior, np.ndarray]:
    """Train and return a scVI model and sample a latent space
    
    :param adata: sc.AnnData object non-normalized
    :param n_latent: dimension of the latent space
    :param n_epochs: number of training epochs
    :param lr: learning rate
    :param use_batches
    :param use_cuda
    :return: (scvi.Posterior, latent_space)
    """
    # Convert easily to scvi dataset
    scviDataset = AnnDataset(adata)

    # Train a model
    vae = VAE(
        scviDataset.nb_genes,
        n_batch=scviDataset.n_batches * use_batches,
        n_latent=n_latent,
    )
    trainer = UnsupervisedTrainer(vae, scviDataset, train_size=1.0, use_cuda=use_cuda)
    trainer.train(n_epochs=n_epochs, lr=lr)
    ####

    # Extract latent space
    posterior = trainer.create_posterior(
        trainer.model, scviDataset, indices=np.arange(len(scviDataset))
    ).sequential()

    latent, _, _ = posterior.get_latent()

    return posterior, latent

In [None]:
n_epochs = 10 if n_epochs_all is None else n_epochs_all

scvi_posterior, scvi_latent = compute_scvi_latent(
    adata_original, n_epochs=n_epochs, n_latent=6, use_cuda = True
)
adata.obsm["X_scvi"] = scvi_latent

## Principal component analysis to reproduce ScanPy results and compare them against scVI's

Below, we reproduce exactly scanpy's PCA on normalized data.

In [None]:
sc.tl.pca(adata, svd_solver="arpack")

In [None]:
sc.pl.pca(adata, color="CST3")

In [None]:
sc.pl.pca_variance_ratio(adata, log=True)

## Computing, embedding and clustering the neighborhood graph

The Scanpy API computes a neighborhood graph with `sc.pp.neighbors` which can be called to work on a specific representation `use_rep='your rep'`.
Once the neighbors graph has been computed, all Scanpy algorithms working on it can be called as usual (that is *louvain*, *paga*, *umap* ...)

### Using PCA representation (Scanpy tutorial)

In [None]:
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
sc.tl.louvain(adata, key_added="louvain_pca")
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=["louvain_pca", "CST3", "NKG7", "MS4A1"], ncols=4)

### Using scVI latent space representation

In [None]:
sc.pp.neighbors(adata, n_neighbors=20, n_pcs=40, use_rep="X_scvi")
sc.tl.umap(adata)

In [None]:
sc.tl.louvain(adata, key_added="louvain_scvi", resolution=0.7)

In [None]:
sc.pl.umap(adata, color=["louvain_scvi", "CST3", "NKG7", "MS4A1"], ncols=4)

## Finding marker genes

ScanPy tries to determine marker genes using a *t-test* and a *Wilcoxon* test.

For the same task, from scVI's trained VAE model we can sample the gene expression rate for each gene in each cell. For the two populations of interest, we can then randomly sample pairs of cells, one from each population to compare their expression rate for a gene. The degree of **differential expression** is measured by logit($\frac{p}{1-p}$) (Bayes Factor) where $p$ is the probability of a cell from population $A$ having a higher expression than a cell from population $B$. We can form the null distribution of the DE values by sampling pairs randomly from the combined population.

Below, we provide a wrapper around scVI's differential expression process. Specifically, it computes the average of the Bayes factor where population $A$ covers each cluster in `adata.obs[label_name]` and is compared with the aggregate formed by all the other clusters.

In [None]:
def rank_genes_groups_bayes(
    adata: sc.AnnData,
    scvi_posterior: scvi.inference.Posterior,
    n_samples: int = None,
    M_permutation: int = None,
    n_genes: int = 25,
    label_name: str = "louvain_scvi",
) -> pd.DataFrame:
    """
    Rank genes for characterizing groups. 
    Computes Bayes factor for each cluster against the others to test for differential expression.
    See Nature article (https://rdcu.be/bdHYQ)

    :param adata: sc.AnnData object non-normalized
    :param scvi_posterior: 
    :param n_samples: 
    :param M_permutation: 
    :param n_genes: 
    :param label_name: The groups tested are taken from adata.obs[label_name] which can be computed 
                       using clustering like Louvain (Ex: sc.tl.louvain(adata, key_added=label_name) )
    :return: Summary of Bayes factor per gene, per cluster
    """

    # Call scvi function
    per_cluster_de, cluster_id = scvi_posterior.one_vs_all_degenes(
        cell_labels=np.asarray(adata.obs[label_name].values).astype(int).ravel(),
        min_cells=1,
        n_samples=n_samples,
        M_permutation=M_permutation,
    )

    # convert to ScanPy format -- this is just about feeding scvi results into a format readable by ScanPy
    markers = []
    scores = []
    names = []
    for i, x in enumerate(per_cluster_de):
        subset_de = x[:n_genes]
        markers.append(subset_de)
        scores.append(tuple(subset_de["bayes1"].values))
        names.append(tuple(subset_de.index.values))

    markers = pd.concat(markers)
    dtypes_scores = [(str(i), "<f4") for i in range(len(scores))]
    dtypes_names = [(str(i), "<U50") for i in range(len(names))]
    scores = np.array([tuple(row) for row in np.array(scores).T], dtype=dtypes_scores)
    scores = scores.view(np.recarray)
    names = np.array([tuple(row) for row in np.array(names).T], dtype=dtypes_names)
    names = names.view(np.recarray)

    adata.uns["rank_genes_groups_scvi"] = {
        "params": {
            "groupby": "",
            "reference": "rest",
            "method": "",
            "use_raw": True,
            "corr_method": "",
        },
        "scores": scores,
        "names": names,
    }
    return markers

### Use a t-test on scvi_clusters like in the ScanPy tutorial

In [None]:
n_genes = 20
sc.tl.rank_genes_groups(
    adata,
    "louvain_scvi",
    method="t-test",
    use_raw=False,
    key_added="rank_genes_groups_ttest",
    n_genes=n_genes,
)
sc.tl.rank_genes_groups(
    adata,
    "louvain_scvi",
    method="wilcoxon",
    use_raw=False,
    key_added="rank_genes_groups_wilcox",
    n_genes=n_genes,
)
sc.pl.rank_genes_groups(
    adata, key="rank_genes_groups_ttest", sharey=False, n_genes=n_genes
)
sc.pl.rank_genes_groups(
    adata, key="rank_genes_groups_wilcox", sharey=False, n_genes=n_genes
)

### Use differential expression from the scVI posterior

In [None]:
rank_genes_groups_bayes(
    adata, scvi_posterior, label_name="louvain_scvi", n_genes=n_genes
)
sc.pl.rank_genes_groups(
    adata, key="rank_genes_groups_scvi", sharey=False, n_genes=n_genes
)

### Measure similarity between *scVI differential expression*, *t-test* and *wilcoxon-test*

In [None]:
# We compute the rank of every gene to perform analysis after
all_genes = len(adata.var_names)

sc.tl.rank_genes_groups(adata, 'louvain_scvi', method='t-test',   use_raw=False, key_added='rank_genes_groups_ttest',  n_genes=all_genes)
sc.tl.rank_genes_groups(adata, 'louvain_scvi', method='wilcoxon', use_raw=False, key_added='rank_genes_groups_wilcox', n_genes=all_genes)
differential_expression = rank_genes_groups_bayes(adata, scvi_posterior, label_name='louvain_scvi', n_genes=all_genes)

In [None]:
def ratio(A, B):
    A, B = set(A), set(B)
    return len(A.intersection(B)) / len(A) * 100

In [None]:
cluster_distrib = adata.obs.groupby("louvain_scvi").count()["n_genes"]

For each cluster, we compute the percentage of genes which are in the `n_genes` most expressed  genes of both Scanpy's and scVI's differential expression tests.

In [None]:
n_genes = 25

sc.pl.umap(adata, color=["louvain_scvi"], ncols=1)
for c in cluster_distrib.index:
    print(
        "Cluster %s (%d cells): t-test / wilcox %6.2f %%  & t-test / scvi %6.2f %%"
        % (
            c,
            cluster_distrib[c],
            ratio(
                adata.uns["rank_genes_groups_ttest"]["names"][c][:n_genes],
                adata.uns["rank_genes_groups_wilcox"]["names"][c][:n_genes],
            ),
            ratio(
                adata.uns["rank_genes_groups_ttest"]["names"][c][:n_genes],
                adata.uns["rank_genes_groups_scvi"]["names"][c][:n_genes],
            ),
        )
    )

## Plot px_scale for most expressed genes and less expressed genes by cluster

Sample the scale for all the data (all genes, cells), average on multiple samples

``` python
scale = scvi_posterior.get_sample_scale()
for _ in range(9):
    scale += scvi_posterior.get_sample_scale()
scale /= 10

for gene, gene_scale in zip(adata.var.index, np.squeeze(scale).T):
    adata.obs["scale_" + gene] = gene_scale
    
```

This is not tractable for large dataset so we provide another function below

### The code below doesn't work

In [None]:
from typing import List


def get_scales_per_gene(
    gene_names: List[str],
    adata: sc.AnnData,
    scvi_posterior: scvi.inference.Posterior,
    n_samples: int = 10,
    batchsize: int = 32,
):
    """Get imputed values for each gene in gene_names - for each cell in adata. Performed inplace.
    Scales are added in adata.obs under the alias 'scale_' + gene_name.
    
    This function handles very large dataset thanks to batch size control

    Args:
        gene_names: list of gene names
        adata: scRNAseq dataset
        posterior: scVI Posterior object
        n_samples: number of samples to average on
        batchsize: for computation: number of cells to query in each iteration
    """
    all_gene_names = list(scvi_posterior.gene_dataset.gene_names)
    gene_idx = [all_gene_names.index(g) for g in gene_names]
    ashape = scvi_posterior.gene_dataset.X.shape
    px_scales = np.zeros((len(gene_names), n_samples, ashape[0]))

    for idx in range(int(ashape[0] / batchsize)):
        current_slice = slice(idx * batchsize, (idx + 1) * batchsize)
        x = torch.tensor(scvi_posterior.gene_dataset.X[current_slice], device="cuda")
        px_scales_batch = scvi_posterior.model.get_sample_scale(
            x,
            n_samples=n_samples
        )
        px_scales[:, :, current_slice] = np.transpose(
            px_scales_batch.detach().to("cpu").numpy()[:, :, gene_idx], (2, 0, 1)
        )
for name, scales in zip(gene_names, px_scales):
    adata.obs["scale_" + name] = scales.mean(axis=0)
    for name, scales in zip(gene_names, px_scales):
        adata.obs["scale_" + name] = scales.mean(axis=0)

#### Most differentialy expressed genes

In [None]:
cluster_id = 2
n_best_genes = 10
gene_names = differential_expression[
    differential_expression["clusters"] == cluster_id
].index.tolist()[:n_best_genes]
gene_names

In [None]:
get_scales_per_gene(gene_names, adata, scvi_posterior)

In [None]:
print("Top genes for cluster %d" % cluster_id)
sc.pl.umap(adata, color=["louvain_scvi"] + ["scale_" + g for g in gene_names], ncols=3)

#### Less differentialy expressed genes

In [None]:
cluster_id = 2
n_best_genes = 10
gene_names = differential_expression[
    differential_expression["clusters"] == cluster_id
].index.tolist()[-n_best_genes:]
gene_names

In [None]:
get_scales_per_gene(gene_names, adata, scvi_posterior)

In [None]:
print("Top down regulated genes for cluster %d" % cluster_id)
sc.pl.umap(adata, color=["louvain_scvi"] + ["scale_" + g for g in gene_names], ncols=3)

### Analyze ranking difference between **t-test** and **scVI**

In [None]:
cluster_id = if_not_test_else("2", "0")

In [None]:
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt


def plot_ranking(method_1, method_2):
    mapping = defaultdict(list)

    for rank, gene in enumerate(
        adata.uns["rank_genes_groups_" + method_1]["names"][cluster_id]
    ):
        mapping[gene].append(rank)
    for rank, gene in enumerate(
        adata.uns["rank_genes_groups_" + method_2]["names"][cluster_id]
    ):
        mapping[gene].append(rank)

    x, y = np.array(list(mapping.values())).T

    # n_genes = all_genes
    n_genes = all_genes

    plt.figure(figsize=(8, 8))
    sns.scatterplot(x, y, s=10)
    plt.axhline(100, c="red")
    plt.axvline(100, c="red")
    plt.xlim(0, n_genes)
    plt.ylim(0, n_genes)
    plt.xlabel(method_1 + " ranking")
    plt.ylabel(method_2 + " ranking")

In [None]:
plot_ranking("scvi", "ttest")
plot_ranking("wilcox", "ttest")

### Investigating discrepancies

Cluster 4 top genes of t-test and scvi are totally different, but when we look closer at the data one can notice:
- The Bayes factor (or t-test score) are all very low for the cluster (no genes are significant)
- Plots confirm the latter point: the top genes are not specific to the cluster and are either noise or overlapping with other clusters

Specifically, we plot first the expression levels of genes selected by scVI, then of genes selected by the t-test. In both cases, genes seem irrelevant.

In [None]:
n_genes = 10

cluster_id = if_not_test_else(2, 0)

genes = differential_expression[
    differential_expression["clusters"] == cluster_id
].index.tolist()
sc.pl.umap(
    adata,
    color=["louvain_scvi"]
    + adata.uns["rank_genes_groups_scvi"]["names"][str(cluster_id)].tolist()[:n_genes],
    ncols=3,
)
sc.pl.umap(
    adata,
    color=["louvain_scvi"]
    + adata.uns["rank_genes_groups_ttest"]["names"][str(cluster_id)].tolist()[:n_genes],
    ncols=3,
)

scVi tends to predict samples that are not expressed outside the cluster when t-test tends to select highly expressed genes in the cluster even if it also expressed everywhere.

### Store differential expression scores

In [None]:
def store_de_scores(
    adata: sc.AnnData, differential_expression: pd.DataFrame, save_path: str = None
):
    """Creates, returns and writes a DataFrame with all the differential scores used in this notebook.
    
    Args:
        adata: scRNAseq dataset
        differential_expression: Pandas Dataframe containing the bayes factor for all genes and clusters
        save_path: file path for writing the resulting table

    Returns:
        pandas.DataFrame containing the scores of each differential expression test.

    """
    # get shapes for array initialisation
    n_genes_de = differential_expression[
        differential_expression["clusters"] == 0
    ].shape[0]
    all_genes = adata.shape[1]
    # check that all genes have been used
    if n_genes_de != all_genes:
        raise ValueError(
            "scvi differential expression has to have been run with n_genes=all_genes"
        )
    # get tests results from AnnData unstructured annotations
    rec_scores = []
    rec_names = []
    test_types = ["ttest", "wilcox"]
    for test_type in test_types:
        res = adata.uns["rank_genes_groups_" + test_type]
        rec_scores.append(res["scores"])
        rec_names.append(res["names"])
    # restrict scvi table to bayes factor
    res = differential_expression[["bayes1", "clusters"]]
    # for each cluster join then append all
    dfs_cluster = []
    groups = res.groupby("clusters")
    for cluster, df in groups:
        for rec_score, rec_name, test_type in zip(rec_scores, rec_names, test_types):
            temp = pd.DataFrame(
                rec_score[str(cluster)],
                index=rec_name[str(cluster)],
                columns=[test_type],
            )
            df = df.join(temp)
        dfs_cluster.append(df)
    res = pd.concat(dfs_cluster)
    if save_path:
        res.to_csv(save_path)
    return res

In [None]:
de_table = store_de_scores(adata, differential_expression, save_path=None)
de_table.head()

# Running other ScanPy algorithms is easy, binding the index keys

### PAGA

In [None]:
sc.tl.paga(adata, groups="louvain_scvi")
sc.pl.paga(adata)

###  HeatMap

In [None]:
marker_genes = gene_names[1:10]

In [None]:
# sc.pl.heatmap(adata, marker_genes, groupby="louvain_scvi", dendrogram=True)