# Initial clustering of gene expression data with public reference data included

Here initial clustering of the gene exprresion data, including public reference gene expression data, will be performed to group cells into clusters of broad cell-types

In [None]:
!date

#### import libraries

In [None]:
import scvi
from numpy import where
import scanpy as sc
from anndata import AnnData
from anndata import concat as ad_concat
from pandas import read_csv, concat, DataFrame, Series
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import torch
from seaborn import lineplot
from sklearn.metrics import silhouette_score
from numpy import arange, mean

import random
random.seed(42)

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# variables and constants
project = 'aging_phase2'
DEBUG = True
MIN_CELL_PERCENT = 0.005
MAX_MITO_PERCENT = 10
TESTING = False
testing_cell_size = 5000
DETECT_HV_FEATURES = True
FILTER_HV_FEATURES = True
TOP_FEATURES_PERCENT = 0.15
leiden_res = 1.0
RUN_TRAINING = True
BATCH_SIZE = 10000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'
public_dir = f'{wrk_dir}/public'
celltypist_dir = f'{wrk_dir}/celltypist'

# in files
arc_file = f'{quants_dir}/{project}_ARC.raw.h5ad'
arc_celltypist_file = f'{celltypist_dir}/{project}_ARC_MTG_predicted_labels.csv'
arc_scrublet_file = f'{quants_dir}/{project}_ARC.scrublet_scores.csv'
gex_file = f'{quants_dir}/{project}_GEX.raw.h5ad'
gex_celltypist_file = f'{celltypist_dir}/{project}_GEX_MTG_predicted_labels.csv'
gex_scrublet_file = f'{quants_dir}/{project}_GEX.scrublet_scores.csv'

# out files
raw_anndata_file =f'{quants_dir}/{project}.dev.rna.raw.h5ad'
out_h5ad_file = f'{quants_dir}/{project}.dev.scvi.h5ad'
trained_model_path = f'{models_dir}/{project}_dev_trained_scvi'

if DEBUG:
    print(f'{arc_file=}')
    print(f'{arc_scrublet_file=}')
    print(f'{gex_file=}')
    print(f'{gex_scrublet_file=}')
    print(f'{arc_celltypist_file=}')
    print(f'{gex_celltypist_file=}')
    print(f'{device=}')

#### functions

In [None]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

def heatmap_compare(adata: AnnData, set1: str, set2: str):
    this_df = (
        adata.obs.groupby([set1, set2])
        .size()
        .unstack(fill_value=0)
    )
    norm_df = this_df/this_df.sum(axis=0)

    with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
        plt.style.use('seaborn-v0_8-bright')
        _ = plt.pcolor(norm_df, edgecolor='black')
        _ = plt.xticks(arange(0.5, len(this_df.columns), 1), this_df.columns, rotation=90)
        _ = plt.yticks(arange(0.5, len(this_df.index), 1), this_df.index)
        plt.xlabel(set2)
        plt.ylabel(set1)
        plt.show()

## load data

### load ARC data

In [None]:
%%time
arc_data = sc.read_h5ad(arc_file)
# discovery data from Duffy
arc_data.obs['Study'] = 'LNG'
arc_data.obs['Study_type'] = 'discovery'
peek_anndata(arc_data, 'ARC anndata loaded', DEBUG)

In [None]:
# load and add the Celltypist human MTG predicted cell-labels
arc_celltypist_pred = read_csv(arc_celltypist_file, index_col=0)
peek_dataframe(arc_celltypist_pred, 'ARC CellTypist predictions', DEBUG)
arc_data.obs['ori_celltype'] = arc_data.obs.index.map(arc_celltypist_pred['majority_voting'])
peek_anndata(arc_data, 'ARC anndata updated with CellTypist predictions', DEBUG)
if DEBUG:
    display(arc_data.obs.ori_celltype.value_counts())
    display(arc_celltypist_pred.majority_voting.value_counts())

In [None]:
arc_srublet_df = read_csv(arc_scrublet_file, index_col=0)
print(f'shape of ARC scrublet df {arc_srublet_df.shape}')
if DEBUG:
    display(arc_srublet_df.predicted_doublet.value_counts())    
    display(arc_srublet_df.sample(5))

### drop the ATAC features

In [None]:
arc_data = arc_data[:, arc_data.var.modality == 'Gene Expression'].copy()
peek_anndata(arc_data, 'just the gene expression', DEBUG)

### load GEX data

In [None]:
%%time
gex_data = sc.read_h5ad(gex_file)
print(gex_data)
# filter non-demultiplex cells
gex_data = gex_data[~gex_data.obs['donor_id'].isna()]
# discovery data from Duffy
gex_data.obs['Study'] = 'LNG'
gex_data.obs['Study_type'] = 'discovery'
peek_anndata(gex_data, 'GEX anndata loaded', DEBUG)

In [None]:
# load and add the Celltypist human MTG predicted cell-labels
gex_celltypist_pred = read_csv(gex_celltypist_file, index_col=0)
peek_dataframe(arc_celltypist_pred, 'GEX CellTypist predictions', DEBUG)
gex_data.obs['ori_celltype'] = gex_data.obs.index.map(gex_celltypist_pred['majority_voting'])
peek_anndata(gex_data, 'GEX anndata updated with CellTypist predictions', DEBUG)
if DEBUG:
    display(gex_data.obs.ori_celltype.value_counts())
    display(gex_celltypist_pred.majority_voting.value_counts())

In [None]:
gex_srublet_df = read_csv(gex_scrublet_file, index_col=0)
print(f'shape of GEX scrublet df {gex_srublet_df.shape}')
if DEBUG:
    display(gex_srublet_df.predicted_doublet.value_counts())    
    display(gex_srublet_df.sample(5))

#### filter any of the addtional non-genotype predicted doublets

the doublets detected based on genotype using demux to demultiplex the GEX pools have already been removed

In [None]:
# ARC
non_doublet_barcodes = arc_srublet_df.loc[~arc_srublet_df.predicted_doublet].index.values
arc_data = arc_data[arc_data.obs.index.isin(non_doublet_barcodes)]
# GEX
non_doublet_barcodes = gex_srublet_df.loc[~gex_srublet_df.predicted_doublet].index.values
gex_data = gex_data[gex_data.obs.index.isin(non_doublet_barcodes)]
print(f'ARC: {arc_data}')
print(f'GEX: {gex_data}')
if DEBUG:
    display(arc_data.obs.sample(5))
    display(gex_data.obs.sample(5))    

## load public reference GEX data
Leng et al, entorhinal cortex samples only keep the Braak Stage 0 samples (n=3)

In [None]:
%%time
ec_file = f'{public_dir}/cellxgene_collections/Leng_entorhinal_cortex.h5ad'
adata_ref_ec = sc.read_h5ad(ec_file)
# retain original barcode
adata_ref_ec.obs['Barcode'] = adata_ref_ec.obs.index.astype('category')
# filter by Braak Stage
adata_ref_ec = adata_ref_ec[adata_ref_ec.obs.BraakStage == '0']
# reference data from Leng
adata_ref_ec.obs['Study'] = 'Leng'
adata_ref_ec.obs['Study_type'] = 'reference'
peek_anndata(adata_ref_ec, 'Leng et al adata', DEBUG)

### Other public EC prepped by Liam

In [None]:
%%time
pub_ec_file = f'{public_dir}/EC_merged_cleaned_NAMED.h5ad'
print(f'{pub_ec_file=}')
adata_pub_ec = sc.read_h5ad(pub_ec_file)
# exclude our study's data, for now also exclude Mathys until I can separate cases out
adata_pub_ec = adata_pub_ec[adata_pub_ec.obs.dataset != 'Duffy']
peek_anndata(adata_pub_ec, 'public EC prepped by Liam adata', DEBUG)
if DEBUG:
    display(adata_pub_ec.obs.dataset.value_counts())
    display(adata_pub_ec.obs.region.value_counts())
    print(f'{adata_pub_ec.obs.groupby('dataset').donor.nunique()=}')    

In [None]:
# # restore the raw counts
# adata_pub_ec.X = adata_pub_ec.layers['counts_RNA']
# peek_anndata(adata_pub_ec, 'public EC with counts restored', DEBUG)

In [None]:
# load additional donor info
pub_ec_info_df = read_csv(f'{public_dir}/EC_merged_cleaned_donor_info.csv')
rosmap_info_df = read_csv(f'{public_dir}/EC_merged_mathys_info.csv')

peek_dataframe(pub_ec_info_df, 'EC donor info', DEBUG)
# append additional donor info
pub_ec_info_df.sex = pub_ec_info_df.sex.replace({'M': 'male', 'F': 'female'})
pub_ec_info_df.age = pub_ec_info_df.age.replace('90+', '91').astype('int')
adata_pub_ec.obs['age'] = adata_pub_ec.obs.donor.map(pub_ec_info_df.set_index('donor')['age'])
adata_pub_ec.obs['sex'] = adata_pub_ec.obs.donor.map(pub_ec_info_df.set_index('donor')['sex'])
peek_anndata(adata_pub_ec, 'public EC with additional obs values', DEBUG)

peek_dataframe(rosmap_info_df, 'ROSMAP donor info', DEBUG)
# find the ROSMAP samples with path to exclude
display(rosmap_info_df.Braak_group.value_counts())
path_samples = rosmap_info_df.loc[rosmap_info_df.Braak_group == '2+'].Donor.to_list()
path_samples = list(map(str, path_samples))
if DEBUG:
    print(path_samples)

In [None]:
# exclude the ROSMAP that have BRAAK of 2+
adata_pub_ec = adata_pub_ec[~adata_pub_ec.obs.donor.isin(path_samples)]
peek_anndata(adata_pub_ec, 'public EC without pathology', DEBUG)
if DEBUG:
    display(adata_pub_ec.obs.dataset.value_counts())
    display(adata_pub_ec.obs.region.value_counts())
    print(f'{adata_pub_ec.obs.groupby('dataset').donor.nunique()=}')   

#### switch var attribute to use gene name instead of ID

In [None]:
adata_ref_ec.var['gene_id'] = adata_ref_ec.var.index.astype('category')
adata_ref_ec.var.index = adata_ref_ec.var.feature_name
if DEBUG:
    display(adata_ref_ec.var.head(10))

In [None]:
# try to clean up the var for the other public EC which is a mix of gene names and gene IDs
adata_pub_ec.var = adata_pub_ec.var.rename(columns={'names': 'feature_name'})
adata_pub_ec.var['gene_id'] = where(adata_pub_ec.var.feature_name.str.startswith('ENSG00'), 
                        adata_pub_ec.var.feature_name, 'missing')
adata_pub_ec.var['feature_name'] = where(adata_pub_ec.var.feature_name.str.startswith('ENSG00'), 
                        'missing', adata_pub_ec.var.feature_name)
# use the ref_ec's var info to clean up the pub_ec's var
ref_ec_var = adata_ref_ec.var.copy().set_index('feature_name')[ 'gene_id']
adata_pub_ec.var.loc[adata_pub_ec.var.gene_id == 'missing', 'gene_id'] = adata_pub_ec.var.feature_name.map(ref_ec_var)
ref_ec_var = adata_ref_ec.var.copy().set_index('gene_id')[ 'feature_name']
adata_pub_ec.var.loc[adata_pub_ec.var.feature_name == 'missing', 'feature_name'] = adata_pub_ec.var.gene_id.map(ref_ec_var)
with_missing = adata_pub_ec.var[adata_pub_ec.var.isnull().any(axis=1)]
print(f'{with_missing.shape=}')
adata_pub_ec = adata_pub_ec[:, ~adata_pub_ec.var.index.isin(with_missing.index)]
adata_pub_ec.var = adata_pub_ec.var.set_index('feature_name')
peek_anndata(adata_pub_ec, 'public EC with updated var', DEBUG)

### harmonize Leng et al reference data obs

In [None]:
adata_ref_ec.obs['age'] = (adata_ref_ec.obs.development_stage.str.replace('-year-old human stage','')
.str.replace(' year-old and over human stage',''))
drop_cols = ['donor_id', 'BraakStage', 'nUMI', 'nGene', 'initialClusterAssignments', 
             'cell_type', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 
             'assay_ontology_term_id', 'disease_ontology_term_id', 
             'self_reported_ethnicity_ontology_term_id', 'development_stage_ontology_term_id', 
             'is_primary_data', 'organism_ontology_term_id', 'suspension_type', 
             'cell_type', 'assay', 'disease', 'organism', 'self_reported_ethnicity', 
             'seurat.clusters', 'sex_ontology_term_id', 'development_stage', 'tissue', 'Barcode']
adata_ref_ec.obs = adata_ref_ec.obs.drop(columns=drop_cols)
adata_ref_ec.obs = adata_ref_ec.obs.rename(columns={'SampleID': 'sample_id', 
                                                'SampleBatch': 'gex_pool', 
                                                'clusterAssignment': 'ori_cluster', 
                                                'clusterCellType': 'ori_celltype'})
drop_cols = ['feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype']
adata_ref_ec.var = adata_ref_ec.var.drop(columns=drop_cols)
# drop any gene's that aren't in our GEX
intersect_features = gex_data.var.index.intersection(adata_ref_ec.var.index)
adata_ref_ec = adata_ref_ec[:, intersect_features]
adata_ref_ec.obsm = None

if DEBUG:
    display(adata_ref_ec.obs.sample(10))
    display(adata_ref_ec.var.sample(10))

### harmonize the other EC data

In [None]:
drop_cols = ['nCount_RNA', 'nFeature_RNA', 'pct.mito']
adata_pub_ec.obs = adata_pub_ec.obs.drop(columns=drop_cols)
adata_pub_ec.obs = adata_pub_ec.obs.rename(columns={'donor': 'sample_id', 
                                                'region': 'gex_pool', 
                                                'subcluster': 'ori_cluster', 
                                                'broad_celltype': 'ori_celltype', 
                                                    'dataset': 'Study'})
# go ahead and set the cluster to the celltype for this data
adata_pub_ec.obs['ori_celltype'] = adata_pub_ec.obs['ori_cluster']
adata_pub_ec.obs['Study_type'] = 'reference'
# drop any gene's that aren't in our GEX
intersect_features = gex_data.var.index.intersection(adata_pub_ec.var.index)
adata_pub_ec = adata_pub_ec[:, intersect_features]
adata_pub_ec.obsm = None

peek_anndata(adata_pub_ec, 'public EC with updated harmonize obs', DEBUG)

## combine the discovery and reference GEX data

In [None]:
adata_gex = ad_concat([arc_data, gex_data, adata_ref_ec, adata_pub_ec], join='outer')
adata_gex.obs_names_make_unique()
peek_anndata(adata_gex, 'combined GEX anndata', DEBUG)

## if testing notebook subset the cells

In [None]:
def random_cells_subset(adata: AnnData, num_cells: int=5000, 
                        verbose: bool=False) -> AnnData:
    cells_subset = random.sample(list(adata.obs.index.values), num_cells)
    adata = adata[cells_subset]
    if DEBUG:
        print(adata)        
        display(adata.obs.sample(5))
    return adata

if TESTING:
    adata_gex = random_cells_subset(adata_gex, verbose=DEBUG)

#### clean up the obs as needed
the ARC samples are pooled but will use pool info as categorical variable so make one for non-pooled samples
drop the batch ID

In [None]:
if DEBUG:
    display(adata_gex.obs.gex_pool.value_counts())
    
# convert the age obs attribute feature to float from string
adata_gex.obs.age = adata_gex.obs.age.astype('float')
    
adata_gex.obs.gex_pool = adata_gex.obs.gex_pool.fillna('non')
adata_gex.obs.gex_pool = adata_gex.obs.gex_pool.astype('str')

adata_gex.obs.phase1_cluster = adata_gex.obs.phase1_cluster.cat.add_categories(['phase2'])
adata_gex.obs.phase1_celltype = adata_gex.obs.phase1_celltype.cat.add_categories(['phase2'])
adata_gex.obs.phase1_cluster = where(adata_gex.obs.phase1_cluster == 'NA', 
                                     'phase2', adata_gex.obs.phase1_cluster)
adata_gex.obs.phase1_celltype = where(adata_gex.obs.phase1_celltype == 'NA', 
                                     'phase2', adata_gex.obs.phase1_celltype)

if DEBUG:
    display(adata_gex.obs.sample(10))
    display(adata_gex.obs.gex_pool.value_counts())
    display(adata_gex.obs.phase1_cluster.value_counts())
    display(adata_gex.obs.phase1_celltype.value_counts())
    display(adata_gex.obs.ori_celltype.value_counts())

## save the organized but unprocessed anndata object note that the subject is in the obs

In [None]:
%%time
adata_gex.write(raw_anndata_file)

In [None]:
peek_anndata(adata_gex, f'GEX anndata that was just saved to {raw_anndata_file}', DEBUG)

## perform some typical pre-processing

We also filter features to remove those that appear in fewer than MIN% of the cells

In [None]:
%%time
print(adata_gex.shape)
# annotate the group of mitochondrial genes as 'mt'
adata_gex.var['mt'] = adata_gex.var_names.str.startswith('MT-')  
# ribosomal genes
adata_gex.var['ribo'] = adata_gex.var_names.str.startswith(('RPS', 'RPL'))
# hemoglobin genes
adata_gex.var['hb'] = adata_gex.var_names.str.contains('^HB[^(P)]')

# With pp.calculate_qc_metrics, we can compute many metrics very efficiently.
sc.pp.calculate_qc_metrics(adata_gex, qc_vars=['mt', 'ribo', 'hb'], 
                           inplace=True, log1p=True)
adata_gex = adata_gex[adata_gex.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
# Basic filtering:
sc.pp.filter_cells(adata_gex, min_genes=200)
sc.pp.filter_genes(adata_gex, min_cells=int(adata_gex.shape[0] * MIN_CELL_PERCENT))

peek_anndata(adata_gex, f'GEX anndata with QC metrics', DEBUG)

#### if flag set then subset to highest variance features

MultiVI tutorial doesn't suggest this so probably typically will set to false

In [None]:
if DETECT_HV_FEATURES:
    n_top_genes = int(adata_gex.n_vars * TOP_FEATURES_PERCENT)
    sc.pp.highly_variable_genes(adata_gex, n_top_genes=n_top_genes, 
                                batch_key='gex_pool',flavor='seurat_v3', 
                                subset=FILTER_HV_FEATURES)                                
    peek_anndata(adata_gex, f'GEX anndata only HVF', DEBUG)

## Setup and Training scVI model

In [None]:
scvi.model.SCVI.setup_anndata(adata_gex, batch_key='sample_id',
                              categorical_covariate_keys = ['gex_pool'],
                              continuous_covariate_keys=['pct_counts_mt', 'pct_counts_ribo'],)

In [None]:
model = scvi.model.SCVI(adata_gex)
print(model)

In [None]:
%%time
if RUN_TRAINING:
    model.train()

## Save and Load MultiVI models

Saving and loading models is similar to all other scvi-tools models, and is very straight forward:

In [None]:
if RUN_TRAINING:
    model.save(trained_model_path, overwrite=True)

In [None]:
model = scvi.model.SCVI.load(trained_model_path, adata=adata_gex, accelerator='gpu')
print(model)

## Extracting and visualizing the latent space

We can now use the `get_latent_representation` to get the latent space from the trained model, and visualize it using scanpy functions:

In [None]:
%%time
adata_gex.obsm['scvi_latent'] = model.get_latent_representation()

#### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata_gex, use_rep='scvi_latent')
# sc.tl.umap(adata_gex, min_dist=0.3)
sc.tl.umap(adata_gex)

#### visualize the latent representation

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_gex, color=['sample_id'])
    sc.pl.umap(adata_gex, color=['Study'])
    sc.pl.umap(adata_gex, color=['ori_celltype'], frameon=False, 
               legend_loc='on data')
    sc.pl.umap(adata_gex, color=['phase1_celltype'], frameon=False, 
               legend_loc='on data')  

### Clustering on the scVI latent space
The user will note that we imported curated labels from the original publication. Our interface with scanpy makes it easy to cluster the data with scanpy from scVI's latent space and then reinject them into scVI (e.g., for differential expression).

In [None]:
%%time
# neighbors were already computed using scVI
sc.tl.leiden(adata_gex, key_added='leiden_scvi', resolution=0.6, flavor='igraph', n_iterations=2)

#### check range of Leiden resolutions for clustering

In [None]:
%%time
resolutions_to_try = arange(0.3, 0.8, 0.1)
print(resolutions_to_try)
mean_scores = {}
largest_score = 0
best_res = 0
new_leiden_key = 'leiden_scvi'
for leiden_res in resolutions_to_try:
    # use only 2 decimals
    leiden_res = round(leiden_res, 2)    
    print(f'### using Leiden resolution of {leiden_res}')
    # neighbors were already computed using scVI
    sc.tl.leiden(adata_gex, key_added=new_leiden_key, resolution=leiden_res, flavor='igraph', n_iterations=2)
    silhouette_avg = silhouette_score(adata_gex.obsm['scvi_latent'], adata_gex.obs[new_leiden_key])
    print((f'For res = {leiden_res:.2f}, average silhouette: {silhouette_avg:.3f} '
           f'for {adata_gex.obs[new_leiden_key].nunique()} clusters'))
    # mean sample count per cluster
    df_grouped = adata_gex.obs.groupby(new_leiden_key)['sample_id'].count()
    mean_sample_per_cluster = df_grouped.mean()
    # mean cell count per cluster
    df_grouped = adata_gex.obs[new_leiden_key].value_counts()
    mean_cell_per_cluster = df_grouped.mean()        
    mean_scores[leiden_res] = [silhouette_avg, adata_gex.obs[new_leiden_key].nunique(), 
                               mean_sample_per_cluster, mean_cell_per_cluster]
    # update best resolution info
    if silhouette_avg > largest_score:
        largest_score = silhouette_avg
        best_res = leiden_res

In [None]:
scores_df = DataFrame(index=mean_scores.keys(), data=mean_scores.values())
scores_df.columns = ['score', 'num_clusters', 'mean_samples', 'mean_cells']
print('max score at')
best_result = scores_df.loc[scores_df.score == scores_df.score.max()]
display(best_result)
best_resolution = best_result.index.values[0]
print(f'best resolution found at {best_resolution}')
if DEBUG:
    display(scores_df)
fig_filename = f'{figures_dir}/leiden_resolution_silhouette_score.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    lineplot(x=scores_df.index, y='score', data=scores_df)
    plt.xlabel('resolution')
    plt.savefig(fig_filename)
    plt.show()
lineplot(x=scores_df.index, y='num_clusters', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_samples', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_cells', data=scores_df)
plt.xlabel('resolution')
plt.show()

In [None]:
best_resolution = round(best_resolution, 2)
print(f'{best_resolution=}')
sc.tl.leiden(adata_gex, key_added='leiden_scvi', resolution=best_resolution, flavor='igraph', n_iterations=2)

### visualize the clusters

In [None]:
# figure_file = f'_{project}.umap.leiden_on.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 400}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_gex, color=['leiden_scvi'], frameon=False, 
               legend_loc='on data')   
    sc.pl.umap(adata_gex, color=['Study'], frameon=False)
    sc.pl.umap(adata_gex, color=['ori_celltype'], 
               frameon=False, legend_loc='on data')
    sc.pl.umap(adata_gex, color=['age'], frameon=False)
    sc.pl.umap(adata_gex, color=['phase1_celltype'], 
               frameon=False, legend_loc='on data')

## what would the new cluster labels be

In [None]:
adata_gex.obs['new_cluster_name'] = 'unnamed'
for cluster in adata_gex.obs.leiden_scvi.unique():
    this_cluster = adata_gex.obs.loc[adata_gex.obs.leiden_scvi == cluster]
    cell_type_cnts = this_cluster.ori_celltype.value_counts()
    percentages = (cell_type_cnts / cell_type_cnts.sum()) * 100
    print(f'\n###{cluster=}')
    cell_type_cnts = cell_type_cnts.to_frame()
    cell_type_cnts['percentages'] = percentages
    if cell_type_cnts.iloc[0].percentages > 50:
        adata_gex.obs.loc[adata_gex.obs.index.isin(this_cluster.index), 'new_cluster_name'] = f'{cell_type_cnts.index[0]}-{cluster}'
    elif cell_type_cnts.iloc[0].percentages > 5:
        adata_gex.obs.loc[adata_gex.obs.index.isin(this_cluster.index), 'new_cluster_name'] = f'{cell_type_cnts.index[0]}?-{cluster}'
    if DEBUG:
        display(cell_type_cnts.percentages)

## what if we labeled based on Liam's previous clusters

In [None]:
adata_gex.obs['reference_cluster'] = 'unnamed'
for cluster in adata_gex.obs.leiden_scvi.unique():
    this_cluster = adata_gex.obs.loc[adata_gex.obs.leiden_scvi == cluster]
    total_cells_cnt = this_cluster.shape[0]
    this_cluster = this_cluster.loc[this_cluster.Study.isin(['Siletti', 'Mathys', 'Franjic'])]
    fraq_ref = this_cluster.shape[0]/total_cells_cnt
    cell_type_cnts = this_cluster.ori_celltype.value_counts()
    percentages = (cell_type_cnts / cell_type_cnts.sum()) * 100
    print(f'\n###{cluster=}, {fraq_ref=}')
    cell_type_cnts = cell_type_cnts.to_frame()
    cell_type_cnts['percentages'] = percentages
    display(cell_type_cnts.loc[cell_type_cnts.percentages > 50])
    if cell_type_cnts.iloc[0].percentages > 50 and fraq_ref > 0.05:
        adata_gex.obs.loc[adata_gex.obs.index.isin(this_cluster.index), 'reference_cluster'] = f'{cell_type_cnts.index[0]}'
    elif cell_type_cnts.iloc[0].percentages > 5 and fraq_ref > 0.05:
        adata_gex.obs.loc[adata_gex.obs.index.isin(this_cluster.index), 'reference_cluster'] = f'{cell_type_cnts.index[0]}?'        
    if DEBUG:
        display(cell_type_cnts.percentages)

## visualize the new labels

In [None]:
with rc_context({'figure.figsize': (15, 15), 'figure.dpi': 400}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_gex, color=['leiden_scvi'], frameon=False, 
               legend_loc='on data')    
    sc.pl.umap(adata_gex, color=['new_cluster_name', 'reference_cluster'], frameon=False, 
               legend_loc='on data')   

In [None]:
heatmap_compare(adata_gex.copy(), 'new_cluster_name', 'leiden_scvi')
heatmap_compare(adata_gex.copy(), 'reference_cluster', 'leiden_scvi')
heatmap_compare(adata_gex.copy(), 'new_cluster_name', 'ori_celltype')
heatmap_compare(adata_gex.copy(), 'new_cluster_name', 'reference_cluster')
heatmap_compare(adata_gex.copy(), 'new_cluster_name', 'phase1_cluster')
heatmap_compare(adata_gex.copy(), 'new_cluster_name', 'phase1_celltype')
heatmap_compare(adata_gex.copy(), 'new_cluster_name', 'phase1_celltype')

In [None]:
!date