## Notebook to cluster and transfer cell-type labels for the replication data along with the discovery data and other public human brain data

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, DataFrame
import scanpy as sc
import scvi
import torch
from matplotlib.pyplot import rc_context
import matplotlib.pyplot as plt
from seaborn import barplot

import random
random.seed(42)

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42
torch.set_float32_matmul_precision('high')
print(f'Last run with scvi-tools version: {scvi.__version__}')

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : 'w'}
%config InlineBackend.figure_format='retina'

#### set notebook variable

In [None]:
# naming
project = 'aging_phase1'
set_name = f'{project}_replication'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase1'
replication_dir = f'{wrk_dir}/replication'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
raw_anndata_file = f'{replication_dir}/{set_name}.raw.h5ad'

# out files
trained_model_path = f'{replication_dir}/{set_name}_trained_scvi'
out_anndata_file = f'{replication_dir}/{set_name}.scvi.h5ad'

# variables
DEBUG = True
HVF_PERCENT = 0.15
MAX_MITO_PERCENT = 10
SCVI_LATENT_KEY = 'X_scVI'
SCVI_CLUSTERS_KEY = 'leiden_scVI'

### load data

In [None]:
adata = sc.read(raw_anndata_file)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))
    display(adata.var.sample(5))    

### simple filters and prep for SCVI

In [None]:
%%time
# annotate the group of mitochondrial genes as 'mt'
adata.var['mt'] = adata.var_names.str.startswith('MT-')  
# With pp.calculate_qc_metrics, we can compute many metrics very efficiently.
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, 
                           log1p=False, inplace=True)
adata = adata[adata.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
sc.pp.filter_genes(adata, min_counts=3)
adata.layers['counts'] = adata.X.copy()  # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata  # freeze the state in `.raw`
top_gene_count = adata.var.shape[0] * HVF_PERCENT
sc.pp.highly_variable_genes(adata, n_top_genes=top_gene_count, subset=True, 
                            layer='counts', flavor='seurat_v3', batch_key='Study')

In [None]:
print(adata)
if DEBUG:
    display(adata.obs.sample(5))
    display(adata.var.sample(5))    

### setup the SCVI model

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer='counts',
    categorical_covariate_keys=['Sample_ID', 'Sex', 'Study', 'Batch'],
    continuous_covariate_keys=['pct_counts_mt'])

### create and train the model

In [None]:
model = scvi.model.SCVI(adata)

In [None]:
model

In [None]:
model.train()

### save the model and reload it

In [None]:
model.save(trained_model_path, overwrite=True)

In [None]:
model = scvi.model.SCVI.load(trained_model_path, adata=adata, use_gpu=True)

### Extracting and visualizing the latent space
We can now use the get_latent_representation to get the latent space from the trained model, and visualize it using scanpy functions:

In [None]:
adata.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

#### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata)

#### visualize the latent representation

In [None]:
figure_file = f'{project}.umap.study.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_mvi, color=['Study'], save=figure_file)

In [None]:
figure_file = f'{project}.umap.region.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_mvi, color=['Brain_region'], save=figure_file)

In [None]:
figure_file = f'{project}.umap.celltype.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_mvi, color=['Cell_type'], save=figure_file)