# TOTALVI experimenting

https://docs.scarches.org/en/latest/totalvi_surgery_pipeline.html

https://docs.scvi-tools.org/en/1.0.1/tutorials/notebooks/totalVI_reference_mapping.html#query-cell-type-prediction


In [1]:

import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import scanpy as sc
import anndata
import torch
import scarches as sca
import matplotlib.pyplot as plt
import numpy as np
import scvi as scv
import pandas as pd
import time

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)



 captum (see https://github.com/pytorch/captum).


## Data loading and preprocessing

In [None]:
condition_key = 'orig.ident'
cell_type_key = 'seurat_clusters'
target_conditions = [3228]

adata_all = sc.read('/Users/evelynschmidt/Protein_folzconversion_fig5.h5ad')
adata = adata_all.raw.to_adata()

In [None]:
adata_3228 = adata[adata.obs['orig.ident'].isin([3228])].copy()
adata_3228.obs["batch"] = "3228"
adata_730 = adata[adata.obs['orig.ident'].isin([730])].copy()
adata_730.obs["batch"] = "730"
adata_451 = adata[adata.obs['orig.ident'].isin([451])].copy()
adata_451.obs["batch"] = "451"

In [None]:
# create the reference
adata_ref = anndata.concat([adata_730,adata_451])

# separate the query 
adata_query = adata_3228
# put matrix of zeros for protein expression (considered missing)
pro_exp = adata_ref.obsm["protein_expression"]
data = np.zeros((adata_query.n_obs, pro_exp.shape[1]))
adata_query.obsm["protein_expression"] = pd.DataFrame(columns=pro_exp.columns, index=adata_query.obs_names, data = data)

In [None]:
adata_query.obsm["protein_expression"]

Unnamed: 0,CD38ADT,CD314ADT,HLA-DRADT,CD62LADT,CD45ROADT,CD337ADT,CD56ADT,CD335ADT,CD57ADT,CD45RAADT,...,IgG2aADT,IgG2bADT,KIR2DL1-S1-S3-S5ADT,KIR2DL2-3ADT,KIR3DL1ADT,KIR2DL5ADT,CD94ADT,NKG2CADT,PD-1ADT,CD8ADT
3228_AAAGTAGAGCTACCGC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_AAATGCCCAGATGGGT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_AAGACCTCATTCGACA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_ACACTGATCAGGTTCA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_ACGAGGAGTTACGCGC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3228_TTTGTCATCTTGTCAT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_ACTGAGTGTTACGACT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_CAGAATCAGCACCGTC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3228_GCGCGATTCACGGTTA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
adata_full = anndata.concat([adata_ref, adata_query])

sc.pp.highly_variable_genes(
    adata_full,
    n_top_genes=4000,
    flavor="seurat_v3",
    batch_key="batch",
    subset=True,
)

adata_ref = adata_full[np.logical_or(adata_full.obs.batch == "451", adata_full.obs.batch == "730")].copy()
adata_query = adata_full[adata_full.obs.batch == "3228"].copy()

In [None]:
'''
TF_CPP_MIN_LOG_LEVEL=0
sca.models.TOTALVI.setup_anndata(
    adata_ref,
    batch_key="batch",
    protein_expression_obsm_key="protein_expression"
)

arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
)

vae_ref = sca.models.TOTALVI(
    adata_ref,
    **arches_params
)
vae_ref.train()
'''

In [None]:
def run_totalvi_default(ref, query):
    """Run 'offline' totalVI."""

    adata_full_new = anndata.concat([ref, query])
    scvi.data.setup_anndata(
        adata_full_new,
        batch_key="orig.ident",
        protein_expression_obsm_key="protein_counts",
    )

    # initialize and train model
    arches_params = dict(
        use_layer_norm="both",
        use_batch_norm="none",
        n_layers_decoder=2,
        n_layers_encoder=2,
    )

    start = time.time()
    vae = scvi.model.TOTALVI(
        adata_full_new, 
        **arches_params
    )
    N_EPOCHS=250
    vae.train(max_epochs=N_EPOCHS, batch_size=256, lr=4e-3)
    end = time.time()
    print("\n Total default train time: {}".format(end-start))

    adata_full_new.obsm["X_totalvi_default"] = vae.get_latent_representation()


    plt.plot(vae.history["elbo_validation"][10:], label="validation")
    plt.title("Negative ELBO over training epochs")
    plt.legend()

    ref.obsm["X_totalvi_default"] = vae.get_latent_representation(ref)
    query.obsm["X_totalvi_default"] = vae.get_latent_representation(query)

    # predict cell types of query
    query.obs["predicted_l2_default"] = classify_from_latent(ref, query, ref_obsm_key="X_totalvi_default")


    query.obs["celltype.l2"] = query.obs["predicted_l2_default"]

    print("Computing full umap")
    sc.pp.neighbors(adata_full_new, use_rep="X_totalvi_default", metric="cosine")
    sc.tl.umap(adata_full_new, min_dist=0.3)

    return vae, adata_full_new