In [1]:
input_data_path = ''
batch = ''
label = ''

Jupyter notebook for trvae integration algoritm

Author: Erno Hänninen

Title: run_trvae.ipynb

Created: 2022-12-16

In [None]:
import sys
sys.path.insert(0, "../../../../Scripts") #Adding a path to be able to import the jupyter_functions
from jupyter_functions import *
import scib
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
import pandas

In [None]:
#Read the adata object and extract the variable genes to list
adata = sc.read(input_data_path)
hvgList = adata.var.index.tolist()
print(len(hvgList))

trVAE algorithm requires randomized batches of data and condition as input for training

In prediction the batches from source condition are transformed to target condtion

In [None]:
#Get target condition (the dominant batch covariate)
target_condition = list(adata.obs[batch].value_counts().idxmax())

In [None]:
trvae_epochs = 50
surgery_epochs = 50

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [None]:
#Split the dataset int reference and query dataset

#Process the data for trVAE
#adata = adata.raw.to_adata()
adata = remove_sparsity(adata) #if adata.X is sparse matrix -> converts it in to normal matrix

source_adata = adata[~adata.obs[batch].isin(target_condition)]
target_adata = adata[adata.obs[batch].isin(target_condition)]

#Get source conditions (all batches of the data)
source_conditions = source_adata.obs[batch].unique().tolist()

In [None]:
source_adata

In [None]:
target_adata

In [None]:
print(source_conditions)
print(target_condition)

In [None]:
#Create the TRVAE model
trvae = sca.models.TRVAE(
    source_adata,
    condition_key=batch,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)

In [None]:
#Training trVAE with the reference dataset (source_adata)
trvae.train(
    n_epochs=trvae_epochs,  
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    seed = 42
)

In [None]:
adata_latent = sc.AnnData(trvae.get_latent())
adata_latent.obs[label] = source_adata.obs[label].tolist()
adata_latent.obs[batch] = source_adata.obs[batch].tolist()

In [None]:
sc.pp.neighbors(adata_latent, n_neighbors=8)
#sc.tl.leiden(adata_latent)
sc.tl.umap(adata_latent)
sc.pl.umap(adata_latent,
           color=[batch, label],
           frameon=False,
           wspace=0.6,
           )

In [None]:
adata_latent

In [None]:
#Fine tune te reference model with query data
new_trvae = sca.models.TRVAE.load_query_data(adata=target_adata, reference_model=trvae)

In [None]:
new_trvae.train(
    n_epochs=surgery_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0, 
    seed = 42
    
)

In [None]:
adata_latent = sc.AnnData(new_trvae.get_latent())
adata_latent.obs[label] = target_adata.obs[label].tolist()
adata_latent.obs[batch] = target_adata.obs[batch].tolist()

In [None]:
adata_latent

In [None]:
sc.pp.neighbors(adata_latent, n_neighbors=8)
#sc.tl.leiden(adata_latent)
sc.tl.umap(adata_latent)
sc.pl.umap(adata_latent,
           color=[batch, label],
           frameon=False,
           wspace=0.6,
           )

In [None]:
# Get latent representation of reference + query dataset and compute UMAP
full_latent = sc.AnnData(new_trvae.get_latent(adata.X, adata.obs[batch]))
full_latent.obs[label] = adata.obs[label].tolist()
full_latent.obs[batch] = adata.obs[batch].tolist()

In [None]:
sc.tl.pca(full_latent)
sc.pp.neighbors(full_latent, n_neighbors=8)
#sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
sc.pl.umap(full_latent,
           color=[batch, label],
           frameon=False,
           wspace=0.6,
           )

In [None]:
full_latent

In [None]:
full_latent.write("../Integrated_adata/trvae_tl_adata_final.h5ad")


In [None]:
#Call the compute metrics function which computes the integration metrics
df = compute_metrics("trvae_tl", adata, full_latent, batch, label ,"X_pca", "full")

print(df)  