Install Dependencies

In [None]:
!pip install --quiet anndata torch scrublet scanpy gdown bbknn scikit-misc scib-metrics leidenalg lightning ml_collections docrep mudata pyro-ppl numpyro sparse

Pip install editable version of scvi-tools

In [None]:
%cd "PATH-TO-scvi-tools"
!ls
!pip install -e .

Import necessary libraries

In [None]:
import sys
import os
import warnings
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scrublet as scr
import scvi
from scib_metrics.benchmark import Benchmarker, BioConservation
import torch

Setup configurations for optimal performance and visualization

In [None]:
sc.set_figure_params(figsize=(6, 6))
torch.set_float32_matmul_precision("high")
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

In [None]:
# Load the  datasets
adata = sc.read_h5ad(
    ".../neftel_ss2.h5ad",
)

bdata = sc.read_h5ad(
    ".../neftel_10x.h5ad",
)

# Add a categorical variable indicating the source for each dataset
adata.obs['tech'] = 'SS2'
bdata.obs['tech'] = '10X'

# Combine the two AnnData objects into one
adata = sc.concat([adata, bdata], axis=0, join='outer')
adata

Pre-processing

In [None]:
# Filter out cycling cells
cycling_phases = ['S', 'G2M']  # Active cell cycle phases
non_cycling_mask = ~adata.obs['phase'].isin(cycling_phases)
adata = adata[non_cycling_mask].copy()  # Keep only non-cycling cells


# Calculate QC metrics
sc.pp.calculate_qc_metrics(adata, inplace=True)

# Perform log-count normalization to compute highly variable genes
adata.X = adata.layers["counts"].copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata.layers["logcounts"] = adata.X.copy()

# Compute highly variable genes
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=5000,
    layer="counts",
    subset=False,
    flavor="seurat_v3",
    batch_key="sample"
)

sc.pl.highly_variable_genes(adata, log=True)

# Filter the dataset to keep only highly variable genes
adata = adata[:, adata.var['highly_variable']].copy()


# Recompute total counts now that some genes have been filtered out
counts = adata.layers["counts"]
# Compute total counts for each cell
total_counts = np.sum(counts, axis=1)
adata.obs['total_counts'] = total_counts

# Dimensionality reduction and clustering
sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
# Plot UMAP with colors representing the batch information
sc.pl.umap(adata, title=['Batch', 'Cell Type', 'Library Size', 'Sequencing Method'],
                   color=['sample', 'celltype', 'total_counts', 'tech'],
                   legend_loc='right margin', wspace=0.6, cmap="viridis",
                   ncols=2)



Setup your anndata according to scVI formats


In [None]:
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")

Train the model with a random gene swapping augmentation & save the model.


In [None]:
model = scvi.model.SCVI(adata=adata, augmentation_to_apply=["random_swap"])
model.train(check_val_every_n_epoch=1)
model.save(file_path, overwrite=True)

For a full list of augmentations, consult scvi-tools --> scvi --> module --> _vae.py. The augmentations are under the Augmentation class