## Attempt subclustering of the broad cell-types

In [None]:
!date

#### import libraries

In [None]:
import scvi
from numpy import where
import scanpy as sc
from anndata import AnnData
from anndata import concat as ad_concat
from pandas import read_csv, concat, DataFrame, Series
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import torch
from seaborn import lineplot, barplot
from sklearn.metrics import silhouette_score
from numpy import arange, mean

import random
random.seed(42)

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
broad_type = 'NonNeuronal'

In [None]:
# variables and constants
project = 'aging_phase2'
DEBUG = True
MIN_CELL_PERCENT = 0.005
MAX_MITO_PERCENT = 10
TESTING = False
testing_cell_size = 5000
DETECT_HV_FEATURES = True
FILTER_HV_FEATURES = True
TOP_FEATURES_PERCENT = 0.15
leiden_res = 1.0
RUN_TRAINING = True
BATCH_SIZE = 10000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cell_types_clusters = {'ExN': ['1', '2', '3', '4', '5',' 6', '7', '10', '11', '12',
                               '13', '14', '15', '24', '27', '28'],
                       'InN': ['4', '17', '18', '19', '20', '21', '24', '25', '26', '28'], 
                       'Astro': ['8', '24'], 'Micro': ['16'], 'Oligo': ['0', '28'], 
                       'OPC': ['9'], 'Endo': ['22'], 'VLMC': ['23', '28'], 
                       'NonNeuronal' : ['8', '24', '16', '0', '28', '9', '22', '23']}

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'
public_dir = f'{wrk_dir}/public'
resolution_dir = f'{quants_dir}/resolution_selection'

# in files
raw_anndata_file =f'{quants_dir}/{project}.dev.rna.raw.h5ad'
rna_scvi_h5ad_file = f'{quants_dir}/{project}.dev.rna.scvi.h5ad'

# out files
trained_model_path = f'{models_dir}/{project}_{broad_type}_dev_trained_scvi'
celltype_var_features_file = f'{resolution_dir}/{broad_type}_varfeats.txt'
adata_out_file = f'{quants_dir}/{project}_{broad_type}.dev.rna.scvi.h5ad'
res_obs_file = f'{resolution_dir}/{broad_type}_res_obs.csv'

if DEBUG:
    print(f'{raw_anndata_file=}')
    print(f'{rna_scvi_h5ad_file=}')
    print(f'{trained_model_path=}')
    print(f'{celltype_var_features_file=}')
    print(f'{adata_out_file=}')
    print(f'{res_obs_file=}')
    print(f'{device=}')    

#### functions

In [None]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

## load the raw anndata file

In [None]:
%%time
adata_raw = sc.read_h5ad(raw_anndata_file)
peek_anndata(adata_raw, 'loaded raw anndata', DEBUG)

## load the scVI processed anndata file

In [None]:
%%time
adata_proc = sc.read_h5ad(rna_scvi_h5ad_file)
peek_anndata(adata_proc, 'loaded processed anndata', DEBUG)

## subset the raw anndata based on the broad cell-type annotation in the processed anndata for the specifiec broad cell type

In [None]:
# find the cell IDs for the cells 
cluster_ids = cell_types_clusters.get(broad_type)
adata_sub = adata_proc[adata_proc.obs.leiden_scvi.isin(cluster_ids)]
peek_anndata(adata_sub, f'processed anndata cell subset for {broad_type} cell type', DEBUG)
adata_gex = adata_raw[adata_raw.obs.index.isin(adata_sub.obs.index)]
peek_anndata(adata_gex, f'raw anndata cell subset for {broad_type} cell type', DEBUG)
if DEBUG:
    print(cluster_ids)
    print(adata_sub.obs.leiden_scvi.unique())

## perform some typical pre-processing

We also filter features to remove those that appear in fewer than MIN% of the cells

In [None]:
%%time
print(adata_gex.shape)
# annotate the group of mitochondrial genes as 'mt'
adata_gex.var['mt'] = adata_gex.var_names.str.startswith('MT-')  
# ribosomal genes
adata_gex.var['ribo'] = adata_gex.var_names.str.startswith(('RPS', 'RPL'))
# hemoglobin genes
adata_gex.var['hb'] = adata_gex.var_names.str.contains('^HB[^(P)]')

# With pp.calculate_qc_metrics, we can compute many metrics very efficiently.
sc.pp.calculate_qc_metrics(adata_gex, qc_vars=['mt', 'ribo', 'hb'], 
                           inplace=True, log1p=True)
adata_gex = adata_gex[adata_gex.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
# Basic filtering:
sc.pp.filter_cells(adata_gex, min_genes=200)
sc.pp.filter_genes(adata_gex, min_cells=int(adata_gex.shape[0] * MIN_CELL_PERCENT))

peek_anndata(adata_gex, f'GEX anndata with QC metrics', DEBUG)

#### if flag set then subset to highest variance features

MultiVI tutorial doesn't suggest this so probably typically will set to false

In [None]:
if DETECT_HV_FEATURES:
    n_top_genes = int(adata_gex.n_vars * TOP_FEATURES_PERCENT)
    sc.pp.highly_variable_genes(adata_gex, n_top_genes=n_top_genes, 
                                batch_key='gex_pool',flavor='seurat_v3', 
                                subset=FILTER_HV_FEATURES)                                
    peek_anndata(adata_gex, f'GEX anndata only HVF', DEBUG)

## Setup and Training scVI model

In [None]:
scvi.model.SCVI.setup_anndata(adata_gex, batch_key='sample_id',
                              categorical_covariate_keys = ['gex_pool', 'Study'],
                              continuous_covariate_keys=['pct_counts_mt', 'pct_counts_ribo'],)

In [None]:
model = scvi.model.SCVI(adata_gex)
print(model)

In [None]:
%%time
if RUN_TRAINING:
    model.train()

## Save and Load MultiVI models

Saving and loading models is similar to all other scvi-tools models, and is very straight forward:

In [None]:
if RUN_TRAINING:
    model.save(trained_model_path, overwrite=True)

In [None]:
model = scvi.model.SCVI.load(trained_model_path, adata=adata_gex, accelerator='gpu')
print(model)

## Extracting and visualizing the latent space

We can now use the `get_latent_representation` to get the latent space from the trained model, and visualize it using scanpy functions:

In [None]:
%%time
adata_gex.obsm['scvi_latent'] = model.get_latent_representation()

#### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata_gex, use_rep='scvi_latent')
# sc.tl.umap(adata_gex, min_dist=0.3)
sc.tl.umap(adata_gex)

#### visualize the latent representation

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_gex, color=['sample_id'])
    sc.pl.umap(adata_gex, color=['Study'])
    sc.pl.umap(adata_gex, color=['ori_celltype'], frameon=False, 
               legend_loc='on data')
    sc.pl.umap(adata_gex, color=['phase1_celltype'], frameon=False, 
               legend_loc='on data')  

### Clustering on the scVI latent space
The user will note that we imported curated labels from the original publication. Our interface with scanpy makes it easy to cluster the data with scanpy from scVI's latent space and then reinject them into scVI (e.g., for differential expression).

In [None]:
%%time
# neighbors were already computed using scVI
sc.tl.leiden(adata_gex, key_added='leiden_scvi', resolution=0.6, flavor='igraph', n_iterations=2)

#### check range of Leiden resolutions for clustering

add the normalized expression from the model as a layer to be used for marker gene info

In [None]:
%%time
adata_gex.layers['scvi_normalized'] = model.get_normalized_expression(library_size=10e4)

also save the variable features for the broad cell-type for usage with the marker gene info

In [None]:
variable_genes = adata_gex.var[adata_gex.var['highly_variable']].index.to_frame()
variable_genes.to_csv(celltype_var_features_file, index=False, header=False)

In [None]:
%%time
resolutions_to_try = arange(0.1, 1.0, 0.1)
print(resolutions_to_try)
clust_assign_by_res = None
mean_scores = {}
largest_score = 0
best_res = 0
new_leiden_key = 'leiden_scvi'
for leiden_res in resolutions_to_try:
    # use only 2 decimals
    leiden_res = round(leiden_res, 2)    
    print(f'### using Leiden resolution of {leiden_res}')
    # neighbors were already computed using scVI
    sc.tl.leiden(adata_gex, key_added=new_leiden_key, resolution=leiden_res, flavor='igraph', n_iterations=2)
    silhouette_avg = silhouette_score(adata_gex.obsm['scvi_latent'], adata_gex.obs[new_leiden_key])
    print((f'For res = {leiden_res:.2f}, average silhouette: {silhouette_avg:.3f} '
           f'for {adata_gex.obs[new_leiden_key].nunique()} clusters'))
    # number of donors per cluster
    df_grouped = adata_gex.obs.groupby([new_leiden_key])['sample_id'].value_counts()
    df_grouped = df_grouped[df_grouped >= 30].to_frame().reset_index()
    df_grouped = df_grouped.groupby(new_leiden_key)['sample_id'].nunique()    
    mean_sample_per_cluster = df_grouped.mean()
    less_than_half = df_grouped[df_grouped < adata_gex.obs.sample_id.nunique()/3].shape[0]
    # mean cell count per cluster
    df_grouped = adata_gex.obs[new_leiden_key].value_counts()
    mean_cell_per_cluster = df_grouped.mean()        
    mean_scores[leiden_res] = [silhouette_avg, adata_gex.obs[new_leiden_key].nunique(), 
                               mean_sample_per_cluster, mean_cell_per_cluster, less_than_half]
    # retain cluster assignments at this resolution
    if clust_assign_by_res is None:
        clust_assign_by_res = (adata_gex.obs[[new_leiden_key]].copy()
                               .rename(columns={new_leiden_key: f'leiden_{leiden_res}'}))
    else:
        clust_assign_by_res = concat([clust_assign_by_res, 
                                      (adata_gex.obs[[new_leiden_key]].copy()
                                       .rename(columns={new_leiden_key: f'leiden_{leiden_res}'}))], 
                                     axis='columns')
    # update best resolution info
    if silhouette_avg > largest_score:
        largest_score = silhouette_avg
        best_res = leiden_res

    # generate markers and average expression per resolution tested
    avgexp = sc.get.obs_df(adata_gex, keys=list(adata_gex.var_names), 
                           layer='scvi_normalized').groupby(adata_gex.obs[new_leiden_key]).mean()
    res_avg_exp_file = f'{resolution_dir}/{broad_type}_avgexp_res{leiden_res}.csv'
    avgexp.to_csv(res_avg_exp_file)
    sc.tl.rank_genes_groups(adata_gex, groupby=new_leiden_key, method='wilcoxon', 
                            pts=True, layer='scvi_normalized')
    markers_df = sc.get.rank_genes_groups_df(adata_gex, group=None)
    res_markers_file = f'{resolution_dir}/{broad_type}_markers_res{leiden_res}.csv'
    markers_df.to_csv(res_markers_file, index=False)

if DEBUG:
    print(f'{clust_assign_by_res.shape=}')
    display(clust_assign_by_res.head())

#### save the different resolution clusters with the info from the anndata obs
since the indices are the same can just append them

In [None]:
%%time
res_obs_info = concat([clust_assign_by_res, adata_gex.obs], axis='columns')
peek_dataframe(res_obs_info, 'res_obs_info', DEBUG)    
if DEBUG:
    print(clust_assign_by_res.index.equals(adata_gex.obs.index))
    print(res_obs_info.index.equals(adata_gex.obs.index))

res_obs_info.to_csv(res_obs_file)

In [None]:
scores_df = DataFrame(index=mean_scores.keys(), data=mean_scores.values())
scores_df.columns = ['score', 'num_clusters', 'mean_samples', 'mean_cells', 'less_than_half']
print('max score at')
best_result = scores_df.loc[scores_df.score == scores_df.score.max()]
display(best_result)
best_resolution = best_result.index.values[0]
print(f'best resolution found at {best_resolution}')
if DEBUG:
    display(scores_df)
fig_filename = f'{figures_dir}/leiden_resolution_silhouette_score.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    lineplot(x=scores_df.index, y='score', data=scores_df)
    plt.xlabel('resolution')
    plt.savefig(fig_filename)
    plt.show()
lineplot(x=scores_df.index, y='num_clusters', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_samples', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_cells', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='less_than_half', data=scores_df)
plt.ylabel('number clusters with less than 1/3 of donors')
plt.xlabel('resolution')
plt.show()

In [None]:
best_resolution = round(best_resolution, 2)
print(f'{best_resolution=}')
sc.tl.leiden(adata_gex, key_added='leiden_scvi', resolution=best_resolution, flavor='igraph', n_iterations=2)

### visualize the cell counts per cluster

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 400}):
    plt.style.use('seaborn-v0_8-paper')
    barplot(data=adata_gex.obs.leiden_scvi.value_counts())
    plt.show()
    barplot(data=adata_gex.obs.groupby('leiden_scvi')['sample_id'].nunique())
    plt.show()
if DEBUG:
    display(adata_gex.obs.leiden_scvi.value_counts())

### visualize the clusters

In [None]:
# figure_file = f'_{project}.umap.leiden_on.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 400}):
    plt.style.use('seaborn-v0_8-paper')
    sc.pl.umap(adata_gex, color=['leiden_scvi'], frameon=False, 
               legend_loc='on data', legend_fontsize=6)   
    sc.pl.umap(adata_gex, color=['Study'], frameon=False)
    sc.pl.umap(adata_gex, color=['ori_celltype'], 
               frameon=False, legend_loc='on data', legend_fontsize=6)
    sc.pl.umap(adata_gex, color=['age'], frameon=False)
    sc.pl.umap(adata_gex, color=['phase1_celltype'], 
               frameon=False, legend_loc='on data', legend_fontsize=6)

## what would the new cluster labels be

### labels using any of the initial labels

In [None]:
any_labels_dict = {}
for cluster in adata_gex.obs.leiden_scvi.unique():
    this_cluster = adata_gex.obs.loc[adata_gex.obs.leiden_scvi == cluster]
    cell_type_cnts = this_cluster.ori_celltype.value_counts()
    percentages = (cell_type_cnts / cell_type_cnts.sum()) * 100
    print(f'\n###{cluster=}')
    cell_type_cnts = cell_type_cnts.to_frame()
    cell_type_cnts['percentages'] = percentages
    any_labels_dict[cluster] = f'{cell_type_cnts.index[0]}-({cell_type_cnts.iloc[0].percentages:.2f}%)-{cluster}'
    if DEBUG:
        display(cell_type_cnts.percentages)
if DEBUG:
    display(any_labels_dict)

### what if we labeled based on Liam's previous clusters

In [None]:
liams_labels_dict = {}
for cluster in adata_gex.obs.leiden_scvi.unique():
    this_cluster = adata_gex.obs.loc[adata_gex.obs.leiden_scvi == cluster]
    total_cells_cnt = this_cluster.shape[0]
    this_cluster = this_cluster.loc[this_cluster.Study.isin(['Siletti', 'Mathys', 'Franjic'])]
    fraq_ref = this_cluster.shape[0]/total_cells_cnt
    if fraq_ref == 0.0:
        liams_labels_dict[cluster] = 'Not Present'
    else:
        cell_type_cnts = this_cluster.ori_celltype.value_counts()
        percentages = (cell_type_cnts / cell_type_cnts.sum()) * 100
        print(f'\n###{cluster=}, {fraq_ref=}')
        cell_type_cnts = cell_type_cnts.to_frame()
        cell_type_cnts['percentages'] = percentages
        liams_labels_dict[cluster] = f'{cell_type_cnts.index[0]}-({cell_type_cnts.iloc[0].percentages:.2f}%)-({fraq_ref*100:.2f}%)'
    if DEBUG:
        display(cell_type_cnts.percentages)
if DEBUG:
    display(liams_labels_dict)

### what if we labeled based on Phase1 cell labels

In [None]:
phase1_labels_dict = {}
for cluster in adata_gex.obs.leiden_scvi.unique():
    this_cluster = adata_gex.obs.loc[adata_gex.obs.leiden_scvi == cluster]
    total_cells_cnt = this_cluster.shape[0]
    this_cluster = this_cluster.loc[((adata_gex.obs.Study == 'LNG')) & 
                                    (this_cluster.phase1_celltype != 'phase2')]
    fraq_ref = this_cluster.shape[0]/total_cells_cnt
    if fraq_ref == 0.0:
        phase1_labels_dict[cluster] = 'Not Present'
    else:
        cell_type_cnts = this_cluster.phase1_celltype.value_counts()
        percentages = (cell_type_cnts / cell_type_cnts.sum()) * 100
        print(f'\n###{cluster=}, {fraq_ref=}')
        cell_type_cnts = cell_type_cnts.to_frame()
        cell_type_cnts['percentages'] = percentages
        phase1_labels_dict[cluster] = f'{cell_type_cnts.index[0]}-({cell_type_cnts.iloc[0].percentages:.2f}%)-({fraq_ref*100:.2f}%)'
    if DEBUG:
        display(cell_type_cnts.percentages)
if DEBUG:
    display(phase1_labels_dict)

### what if we labeled based on Phase2 celltypist cell labels

In [None]:
celltypist_labels_dict = {}
for cluster in adata_gex.obs.leiden_scvi.unique():
    this_cluster = adata_gex.obs.loc[adata_gex.obs.leiden_scvi == cluster]
    total_cells_cnt = this_cluster.shape[0]
    this_cluster = this_cluster.loc[adata_gex.obs.Study == 'LNG']
    fraq_ref = this_cluster.shape[0]/total_cells_cnt
    if fraq_ref == 0.0:
        celltypist_labels_dict[cluster] = 'Not Present'
    else:
        cell_type_cnts = this_cluster.ori_celltype.value_counts()
        percentages = (cell_type_cnts / cell_type_cnts.sum()) * 100
        print(f'\n###{cluster=}, {fraq_ref=}')
        cell_type_cnts = cell_type_cnts.to_frame()
        cell_type_cnts['percentages'] = percentages
        celltypist_labels_dict[cluster] = f'{cell_type_cnts.index[0]}-({cell_type_cnts.iloc[0].percentages:.2f}%)-({fraq_ref*100:.2f}%)'
    if DEBUG:
        display(cell_type_cnts.percentages)
if DEBUG:
    display(celltypist_labels_dict)

### add the possible labels to the anndata obs

In [None]:
adata_gex.obs['any_label'] = adata_gex.obs.leiden_scvi.map(any_labels_dict)
adata_gex.obs['liams_label'] = adata_gex.obs.leiden_scvi.map(liams_labels_dict)
adata_gex.obs['phase1_label'] = adata_gex.obs.leiden_scvi.map(phase1_labels_dict)
adata_gex.obs['celltypist_mtg_label'] = adata_gex.obs.leiden_scvi.map(celltypist_labels_dict)
peek_anndata(adata_gex, 'anndata with possible cell-type labels added', DEBUG)

## visualize the new labels

In [None]:
with rc_context({'figure.figsize': (15, 15), 'figure.dpi': 400}):
    plt.style.use('seaborn-v0_8-paper')
    sc.pl.umap(adata_gex, color=['any_label'], frameon=False, 
               legend_loc='on data', legend_fontsize=4, legend_fontweight='bold')
    sc.pl.umap(adata_gex, color=['liams_label'], frameon=False, 
               legend_loc='on data', legend_fontsize=4, legend_fontweight='bold')
    sc.pl.umap(adata_gex, color=['phase1_label'], frameon=False, 
               legend_loc='on data', legend_fontsize=4, legend_fontweight='bold')
    sc.pl.umap(adata_gex, color=['celltypist_mtg_label'], frameon=False, 
               legend_loc='on data', legend_fontsize=4, legend_fontweight='bold')

## save to processed anndata object

In [None]:
%%time
adata_gex.write(adata_out_file)

In [None]:
!date