## Notebook to perform cell assignments on data using [CellAssign from scvi-tools](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/cellassign_tutorial.html)


In [None]:
!date

#### import libraries

In [None]:
import scanpy as sc
import scvi
import torch
from scvi.external import CellAssign
from pandas import DataFrame
from json import load as json_load
from numpy import zeros
from seaborn import clustermap, scatterplot
from matplotlib.pyplot import rc_context
import matplotlib.pyplot as plt

scvi.settings.seed = 0
print('Last run with scvi-tools version:', scvi.__version__)

sc.set_figure_params(figsize=(4, 4))
torch.set_float32_matmul_precision('high')

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
marker_set = '' # 'sctypes', 'pangloadb', or 'bakken'

In [None]:
# naming
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
public_dir = f'{wrk_dir}/public'
model_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
cell_markers_file = f'{public_dir}/adrd_markers_{marker_set}_SCRN.json'
anndata_file = f'{quants_dir}/{project}_GEX.raw.h5ad'

# out files
out_file = f'{quants_dir}/{project}_GEX.{marker_set}.cellassign.h5ad'
pred_file = f'{quants_dir}/{project}_GEX.{marker_set}.cellassign.predictions.csv'

# variables
DEBUG = False
SCVI_LATENT_KEY = 'X_scVI'
SCVI_CLUSTERS_KEY = 'leiden_scVI'
LEIDEN_RESOLUTION = 1.0

### load data

#### load cell-type markers and format for CellAssign

In [None]:
with open(cell_markers_file, 'r') as in_file:
    cell_markers = json_load(in_file)
marker_list = []
for cell, genes in cell_markers.items():
    marker_list+= genes
marker_list = list(set(marker_list))
celltypes_list = list(cell_markers.keys())
print(f'number of markers: {len(marker_list)}')
print(f'number of cell types {len(celltypes_list)}')
if DEBUG:
    print(marker_list)
    print(celltypes_list)

In [None]:
zero_matrix = zeros((len(marker_list), len(celltypes_list)), dtype=int)
markers_df = DataFrame(data=zero_matrix, columns=celltypes_list, index=marker_list)
print(f'shape of marker df {markers_df.shape}')
for cell_type in celltypes_list:
    markers = cell_markers.get(cell_type)
    markers_df.loc[markers_df.index.isin(markers), cell_type] = 1
if DEBUG:
    display(markers_df.head())

In [None]:
# addl_remove = ['Neuroepithelial cells', 'Cancer stem cells', 
#                'Immune system cells', 'Neuroblasts', 'Neural Progenitor cells']
# markers_df = markers_df.drop(columns=addl_remove)
# print(f'new shape of markers df {markers_df.shape}')
# if DEBUG:
#     display(markers_df.head())

#### load the single-cell GEX data

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)

#### drop any cell not properly demultiplexed

In [None]:
adata = adata[~adata.obs.donor_id.isna()]
print(adata)
if DEBUG:
    display(adata.obs.head())
    display(adata.var.head())

#### convert pool info to categoricals

In [None]:
adata.obs.gex_pool = adata.obs.gex_pool.astype('str')

### prep data

In [None]:
%%time
sc.pp.filter_genes(adata, min_counts=3)
adata.layers['counts'] = adata.X.copy()  # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata  # freeze the state in `.raw`
sc.pp.highly_variable_genes(adata, n_top_genes=2000, subset=True, layer='counts', 
                            flavor='seurat_v3')
print(adata)
if DEBUG:
    display(adata.obs.sample(10))
    display(adata.var.sample(10))    

### latent modeling

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key='gex_pool',
                              categorical_covariate_keys=['sample_id'])

In [None]:
model = scvi.model.SCVI(adata)
model

In [None]:
%%time
model.train()

#### Inspecting the convergence

In [None]:
train_elbo = model.history["elbo_train"][1:]
recon_elbo = model.history["reconstruction_loss_train"]

with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-talk')
    ax = train_elbo.plot()
    recon_elbo.plot(ax=ax)

#### save and re-load the model

In [None]:
model.save(model_dir, prefix=f'{project}.{marker_set}.scvi', overwrite=True)
model = scvi.model.SCVI.load(model_dir, adata=adata, prefix=f'{project}.{marker_set}.scvi')

#### Obtaining model outputs
It’s often useful to store the outputs of scvi-tools back into the original anndata, as it permits interoperability with Scanpy.

In [None]:
print(f'SCVI_LATENT_KEY is {SCVI_LATENT_KEY}')
latent = model.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent
print(f'shape of latent {latent.shape}')

### simple clustering

In [None]:
%%time
# use scVI latent space for UMAP generation
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.leiden(adata, key_added=SCVI_CLUSTERS_KEY, resolution=LEIDEN_RESOLUTION)
sc.tl.umap(adata)

In [None]:
with rc_context({'figure.figsize': (12, 9), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    sc.pl.umap(adata, color=SCVI_CLUSTERS_KEY, 
               frameon=False, legend_loc='on data')
    sc.pl.umap(adata, color='gex_pool', frameon=False)    
    sc.pl.umap(adata, color='sample_id', frameon=False)    

### Create and fit CellAssign model
The anndata object and cell type marker matrix should contain the same genes, so we index into adata to include only the genes from marker_gene_mat.

In [None]:
markers_df = markers_df[markers_df.index.isin(adata.var.index)]
print(f'new shape of markers df {markers_df.shape}')
# drop cell-type columns where no markers are left
zero_cols = markers_df.columns[markers_df.eq(0).all()]
print(f'dropping {zero_cols}')
markers_df = markers_df.drop(columns=zero_cols)
if DEBUG:
    display(markers_df.head())

In [None]:
display(adata.var.head())
bdata = adata[:, adata.var.index.isin(markers_df.index)].copy()
print(bdata)
if DEBUG:
    display(bdata.var.head())

### setup SCVI

In [None]:
from numpy import mean
lib_size = adata.layers['counts'].sum(1)
bdata.obs['size_factor'] = lib_size / mean(lib_size)
scvi.external.CellAssign.setup_anndata(bdata, size_factor_key='size_factor',
                                       layer='counts',
                                       categorical_covariate_keys=['gex_pool', 'sample_id'])

### create and train the model

In [None]:
%%time
model = CellAssign(bdata, markers_df)
print(model)
model.train()

#### Inspecting the convergence

In [None]:
train_elbo = model.history["elbo_train"][1:]
test_elbo = model.history["elbo_validation"]

with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-talk')
    ax = train_elbo.plot()
    test_elbo.plot(ax=ax)

### Predict and plot assigned cell types

In [None]:
predictions = model.predict()
print(f'shape of predictions: {predictions.shape}')
if DEBUG:
    display(predictions.head())

In [None]:
%%time
# this can take forever for many cells and isn't really that interesting
if predictions.shape[0] < 250000:
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 100}):
        plt.style.use('seaborn-talk')
        figure_file = f'{figures_dir}/{project}.{marker_set}.celltypes_heatmap.png'
        splot = clustermap(predictions, cmap="viridis")
        splot.figure.savefig(figure_file)

We then create a UMAP plot labeled by maximum probability assignments from the CellAssign model. The left plot contains the true cell types and the right plot contains our model’s predictions.

In [None]:
bdata.obs['cellassign_predictions'] = predictions.idxmax(axis=1).values
if DEBUG:
    display(bdata.obs.cellassign_predictions.value_counts())

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    figure_file = f'{project}.{marker_set}.celltypes_off.png' 
    sc.pl.umap(bdata, color=['cellassign_predictions'], frameon=False, 
               ncols=1, save=figure_file)

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    figure_file = f'{project}.{marker_set}.celltypes_on.png' 
    sc.pl.umap(bdata, color=['cellassign_predictions'], frameon=False, 
               legend_loc='on data', ncols=1, save=figure_file)

### merge predicted cell-types onto the full clustered anndata object

In [None]:
# indices should still be the same, but double check
if adata.obs.index.equals(bdata.obs.index):
    adata.obs['cellassign_predictions'] = predictions.idxmax(axis=1).values
else:
    print('indices no longer match, CellAssign predictions not added.')

In [None]:
display(adata.obs.cellassign_predictions.value_counts())

### save output

In [None]:
adata.write(out_file)

In [None]:
predictions = predictions.set_index(adata.obs.index)
predictions.to_csv(pred_file)

In [None]:
!date