## Notebook to cluster and transfer cell-type labels for the replication data along with the discovery data and other public human brain data

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, DataFrame
import scanpy as sc
from anndata import AnnData
import scvi
import torch
from matplotlib.pyplot import rc_context
import matplotlib.pyplot as plt
from seaborn import barplot
from numpy import arange

import random
random.seed(42)

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42
torch.set_float32_matmul_precision('high')
print(f'Last run with scvi-tools version: {scvi.__version__}')

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : 'w'}
%config InlineBackend.figure_format='retina'

#### set notebook variable

In [None]:
# naming
project = 'aging_phase1'
set_name = f'{project}_replication'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase1'
replication_dir = f'{wrk_dir}/replication'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
raw_anndata_file = f'{replication_dir}/{set_name}.raw.h5ad'

# out files
trained_model_path = f'{replication_dir}/{set_name}_trained_scvi'
out_anndata_file = f'{replication_dir}/{set_name}.scvi.h5ad'
out_all_anndata_file = f'{replication_dir}/{project}_ref_disc_rep.scvi.h5ad'
cluster_de_file = f'{replication_dir}/{set_name}.cluster_de_markers.csv'
celltype_de_file = f'{replication_dir}/{set_name}.celltype_de_markers.csv'

# variables
DEBUG = True
HVF_PERCENT = 0.10
MAX_MITO_PERCENT = 10
SCVI_LATENT_KEY = 'X_scVI'
SCVI_CLUSTERS_KEY = 'leiden_scVI'
SCVI_NORMALIZED_KEY = 'scvi_normalized'
TESTING = False

### load data

In [None]:
adata = sc.read(raw_anndata_file)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))
    display(adata.var.sample(5))    

#### if testing subset the cells

In [None]:
def random_cells_subset(adata: AnnData, num_cells: int=10000) -> AnnData:
    cells_subset = random.sample(list(adata.obs.index.values), num_cells)
    adata = adata[cells_subset]
    return adata

if TESTING:
    adata = random_cells_subset(adata)
    if DEBUG:
        print(adata)        
        display(adata.obs.head())    

### simple filters and prep for SCVI

In [None]:
%%time
# annotate the group of mitochondrial genes as 'mt'
adata.var['mt'] = adata.var_names.str.startswith('MT-')  
# With pp.calculate_qc_metrics, we can compute many metrics very efficiently.
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, 
                           log1p=False, inplace=True)
adata = adata[adata.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.filter_cells(adata, min_counts=3)
adata.layers['counts'] = adata.X.copy()  # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata  # freeze the state in `.raw`


#### use high variance features

In [None]:
top_gene_count = adata.var.shape[0] * HVF_PERCENT
sc.pp.highly_variable_genes(adata, n_top_genes=top_gene_count, subset=True, 
                            layer='counts', flavor='seurat_v3', batch_key='Study')

In [None]:
print(adata)
if DEBUG:
    display(adata.obs.sample(5))
    display(adata.var.sample(5))    

### setup the SCVI model

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key = 'Study',
    categorical_covariate_keys=['Sample_ID', 'Sex', 'Brain_region', 'Batch'],
    continuous_covariate_keys=['pct_counts_mt'])

### create and train the model

In [None]:
model = scvi.model.SCVI(adata, use_layer_norm='both', use_batch_norm='none', 
                        encode_covariates=True, dropout_rate=0.2, n_layers=2)

In [None]:
model

In [None]:
model.view_anndata_setup(adata)

In [None]:
%%time
model.train()

### save the model and reload it

In [None]:
model.save(trained_model_path, overwrite=True)

In [None]:
model = scvi.model.SCVI.load(trained_model_path, adata=adata, use_gpu=True)

### Extracting and visualizing the latent space
We can now use the get_latent_representation to get the latent space from the trained model, and visualize it using scanpy functions:

In [None]:
adata.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

### add quantification layer for scVI normalized

In [None]:
adata.layers[SCVI_NORMALIZED_KEY] = model.get_normalized_expression()

In [None]:
print(adata)

#### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata, min_dist=0.3)

#### visualize the latent representation

In [None]:
figure_file = f'_{project}.umap.study.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['Study'], save=figure_file)

In [None]:
figure_file = f'_{project}.umap.region.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['Brain_region'], save=figure_file)

In [None]:
figure_file = f'_{project}.umap.celltype_off.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['Cell_type'], save=figure_file)

In [None]:
figure_file = f'_{project}.umap.celltype_on.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['Cell_type'], legend_loc='on data', save=figure_file)

In [None]:
figure_file = f'_{project}.umap.prev_cluster_on.png'
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['Cluster'], legend_loc='on data', save=figure_file)

### Clustering on the scVI latent space

In [None]:
%%time
# neighbors were already computed using scVI
leiden_res = 0.6
sc.tl.leiden(adata, key_added=SCVI_CLUSTERS_KEY, resolution=leiden_res)

In [None]:
figure_file = f'_{project}.umap.leiden_on.png'
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=[SCVI_CLUSTERS_KEY], 
               frameon=False, legend_loc='on data', save=figure_file)

In [None]:
figure_file = f'_{project}.umap.leiden_off.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    sc.pl.umap(adata, color=[SCVI_CLUSTERS_KEY], 
               frameon=False, save=figure_file)

In [None]:
figure_file = f'_{project}.umap.age.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    sc.pl.umap(adata, color=['Age'], 
               frameon=False, save=figure_file)

### transfer to cell types to replication data

#### split data set

In [None]:
adata_ref = adata[adata.obs.Study_type == 'reference']
adata_rep = adata[adata.obs.Study_type == 'replication']
print('#### reference ####')
print(adata_ref)
print('#### replication ####')
print(adata_rep)

#### for the reference and disctory data what is the likely cell-type per cluster
per leiden cluster which labeled cell-type is most frequent

In [None]:
cluster_to_celltype = {}
cluster_to_refcluster = {}
for cluster_num in adata_ref.obs[SCVI_CLUSTERS_KEY].unique():
    temp = adata_ref.obs.loc[adata.obs[SCVI_CLUSTERS_KEY] == cluster_num]
    cluster_to_celltype[cluster_num] = temp.Cell_type.value_counts().idxmax()
    cluster_to_refcluster[cluster_num] = temp.Cluster.value_counts().idxmax()
    if DEBUG:
        display(temp.Cell_type.value_counts().head())
        display(temp.Cluster.value_counts().head())
display(cluster_to_celltype)
display(cluster_to_refcluster)

##### assign the labels

In [None]:
if DEBUG:
    print(adata_rep.obs.Cell_type.cat.categories)
    display(adata_rep.obs.Cell_type.value_counts())

In [None]:
# add the possible categories
adata_rep.obs.Cell_type = adata_rep.obs.Cell_type.cat.add_categories(set(cluster_to_celltype.values()))
adata_rep.obs.Cluster = adata_rep.obs.Cluster.cat.add_categories(set(cluster_to_refcluster.values()))

for cluster_num in adata_rep.obs[SCVI_CLUSTERS_KEY].unique():
    cell_type = cluster_to_celltype.get(cluster_num)
    ref_cluster = cluster_to_refcluster.get(cluster_num)
    print(cluster_num, cell_type, ref_cluster)
    adata_rep.obs.loc[adata_rep.obs[SCVI_CLUSTERS_KEY] == cluster_num, 'Cell_type'] = cell_type
    adata_rep.obs.loc[adata_rep.obs[SCVI_CLUSTERS_KEY] == cluster_num, 'Cluster'] = ref_cluster
    # also add to full object
    adata.obs.loc[(adata.obs[SCVI_CLUSTERS_KEY] == cluster_num) & 
                  (adata.obs.Study_type == 'replication'), 'Cell_type'] = cell_type
    adata.obs.loc[(adata.obs[SCVI_CLUSTERS_KEY] == cluster_num) & 
                  (adata.obs.Study_type == 'replication'), 'Cluster'] = ref_cluster    
if DEBUG:
    display(adata_rep.obs.Cell_type.value_counts())
    display(adata_rep.obs.Cluster.value_counts())

#### drop any unused cell-type or cluster categories

In [None]:
adata_rep.obs.Cell_type = adata_rep.obs.Cell_type.cat.remove_unused_categories()
adata_rep.obs.Cluster = adata_rep.obs.Cluster.cat.remove_unused_categories()
adata.obs.Cell_type = adata.obs.Cell_type.cat.remove_unused_categories()
adata.obs.Cluster = adata.obs.Cluster.cat.remove_unused_categories()

In [None]:
figure_file = f'_{project}.umap.replication_cell_types.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    sc.pl.umap(adata_rep, color=['Cell_type'], 
               frameon=False, save=figure_file, title='Replication data cell types')

In [None]:
figure_file = f'_{project}.umap.replication_leiden.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    sc.pl.umap(adata_rep, color=[SCVI_CLUSTERS_KEY], legend_loc='on data',
               frameon=False, save=figure_file, title='Replication data Leiden cluster')

In [None]:
figure_file = f'_{project}.umap.replication_refcluster.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-talk')
    sc.pl.umap(adata_rep, color=['Cluster'], legend_loc='on data',
               frameon=False, save=figure_file, title='Replication data reference clusters')

### save the modified anndata object

In [None]:
# all the studies 
adata.write(out_all_anndata_file)
# just the replication study
adata_rep.write(out_anndata_file)

### Differential expression of Leiden clusters

In [None]:
%%time
de_df = model.differential_expression(adata_rep, groupby=SCVI_CLUSTERS_KEY,)
display(de_df.head(15))

#### save cluster DE results

In [None]:
de_df.to_csv(cluster_de_file)

#### We now extract top markers for each cluster using the DE results.

In [None]:
markers = {}
number_of_top_markers = 5
cats = adata_rep.obs[SCVI_CLUSTERS_KEY].cat.categories
for i, c in enumerate(cats):
    cid = f"{c} vs Rest"
    cell_type_df = de_df.loc[de_df.comparison == cid]

    cell_type_df = cell_type_df[cell_type_df.lfc_mean > 0]

    cell_type_df = cell_type_df[cell_type_df["bayes_factor"] > 3]
    cell_type_df = cell_type_df[cell_type_df["non_zeros_proportion1"] > 0.1]

    markers[c] = cell_type_df.index.tolist()[:number_of_top_markers]

In [None]:
sc.tl.dendrogram(adata_rep, groupby=SCVI_CLUSTERS_KEY, use_rep=SCVI_LATENT_KEY)

In [None]:
figure_file = f'_{project}.cluster_markers.png'
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 200}):
    plt.style.use('seaborn-talk')
    sc.pl.dotplot(adata_rep, markers, groupby=SCVI_CLUSTERS_KEY, dendrogram=True,
                  color_map='Blues', swap_axes=True, use_raw=False,
                  standard_scale='var', save=figure_file)

#### We can also visualize the scVI normalized gene expression values with the layer option.

In [None]:
figure_file = f'_{project}.cluster_markers_heatmap.png'
with rc_context({'figure.figsize': (15, 15), 'figure.dpi': 200, 'font.size': 6}):
    plt.style.use('seaborn-talk')
    sc.pl.heatmap(adata_rep, markers, groupby=SCVI_CLUSTERS_KEY, layer=SCVI_NORMALIZED_KEY, 
                  standard_scale='var', dendrogram=True, figsize=(8, 12),
                  show_gene_labels=True, save=figure_file)

### Differential expression of broad cell-types

In [None]:
%%time
cell_de_df = model.differential_expression(adata_rep, groupby='Cell_type',)
display(cell_de_df.head(15))

#### save cell-type DE results

In [None]:
cell_de_df.to_csv(celltype_de_file)

#### We now extract top markers for each cell-type using the DE results.

In [None]:
markers = {}
number_of_top_markers = 5
cats = adata_rep.obs.Cell_type.cat.categories
for i, c in enumerate(cats):
    cid = f"{c} vs Rest"
    cell_type_df = cell_de_df.loc[cell_de_df.comparison == cid]

    cell_type_df = cell_type_df[cell_type_df.lfc_mean > 0]

    cell_type_df = cell_type_df[cell_type_df["bayes_factor"] > 3]
    cell_type_df = cell_type_df[cell_type_df["non_zeros_proportion1"] > 0.1]

    markers[c] = cell_type_df.index.tolist()[:number_of_top_markers]

In [None]:
sc.tl.dendrogram(adata_rep, groupby='Cell_type', use_rep=SCVI_LATENT_KEY)

In [None]:
figure_file = f'_{project}.celltype_markers.png'
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 200}):
    plt.style.use('seaborn-talk')
    sc.pl.dotplot(adata_rep, markers, groupby='Cell_type', dendrogram=True,
                  color_map='Blues', swap_axes=True, use_raw=False,
                  standard_scale='var', save=figure_file)

### compare the Leiden cluster to assigned cell-types and previous clusters

In [None]:
def heatmap_compare(adata: AnnData, set1: str, set2: str, save_figure: bool=True):
    this_df = (
        adata.obs.groupby([set1, set2])
        .size()
        .unstack(fill_value=0)
    )
    norm_df = this_df/this_df.sum(axis=0)

    with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
        plt.style.use('seaborn-bright')
        _ = plt.pcolor(norm_df, edgecolor='black')
        _ = plt.xticks(arange(0.5, len(this_df.columns), 1), this_df.columns, rotation=90)
        _ = plt.yticks(arange(0.5, len(this_df.index), 1), this_df.index)
        plt.xlabel(set2)
        plt.ylabel(set1)
        if save_figure:
            figure_file = f'{figures_dir}/{project}.{set1}_{set2}_leiden_heatmap.png'
            plt.savefig(figure_file, bbox_inches='tight')        
        plt.show()

In [None]:
heatmap_compare(adata.copy(), SCVI_CLUSTERS_KEY, 'Cell_type')

In [None]:
heatmap_compare(adata.copy(), SCVI_CLUSTERS_KEY, 'Cluster')

In [None]:
heatmap_compare(adata.copy(), 'Cluster', 'Cell_type')

In [None]:
!date