## Notebook to compare the preliminary cell-type predictions and annotate to the h5ad files

In [None]:
!date

#### import libraries

In [None]:
import scanpy as sc
from anndata import AnnData
import numpy as np
from matplotlib.pyplot import rc_context
import matplotlib.pyplot as plt

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# naming
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
raw_anndata_file = f'{quants_dir}/{project}.raw.h5ad'
multivi_anndata_file = f'{quants_dir}/{project}.multivi.h5ad'

# out files
new_raw_anndata_file = f'{quants_dir}/{project}.raw.cellassign.h5ad'
new_multivi_anndata_file = f'{quants_dir}/{project}.multivi.cellassign.h5ad'

# variables
DEBUG = True
marker_sets = ['sctypes', 'pangloadb', 'bakken']

### load data

In [None]:
%%time
adata_raw = sc.read_h5ad(raw_anndata_file)
print(adata_raw)
if DEBUG:
    display(adata_raw.obs.head())

In [None]:
%%time
adata_multivi = sc.read_h5ad(multivi_anndata_file)
print(adata_multivi)
if DEBUG:
    display(adata_multivi.obs.head())

#### load cellassign results

In [None]:
%%time
cell_assignments = {}
for marker_set in marker_sets:
    print(marker_set)
    cell_file = f'{quants_dir}/{project}_GEX.{marker_set}.cellassign.h5ad'
    this_adata = sc.read_h5ad(cell_file)
    # rename cell prediction column to marker_set
    this_adata.obs = this_adata.obs.rename(columns={'cellassign_predictions': marker_set})
    # for merging with multi-modal data need modified barcode IDs
    this_adata.obs['barcode'] = this_adata.obs.index.values + '_expression'
    print(this_adata)
    cell_assignments[marker_set] = this_adata.obs
    if DEBUG:
        display(this_adata.obs.head())

### update anndata objects with cell-type predictions

In [None]:
prev_raw_index = adata_raw.obs.index.copy()
prev_mvi_index = adata_multivi.obs.index.copy()
for marker_set in marker_sets:
    cell_predictions = cell_assignments[marker_set][['barcode', marker_set]]
    cell_predictions = cell_predictions.set_index(cell_predictions.barcode)
    cell_predictions = cell_predictions.drop(columns=['barcode'])
    print(marker_set, len(set(adata_raw.obs.index) & set(cell_predictions.index)))
    
    adata_raw.obs = adata_raw.obs.merge(cell_predictions, how='left', 
                                        left_index=True, right_index=True)
    adata_multivi.obs = adata_multivi.obs.merge(cell_predictions, how='left', 
                                                left_index=True, right_index=True)    
    print(f'raw index still good {prev_raw_index.equals(adata_raw.obs.index)}')
    print(f'multivi index still good {prev_mvi_index.equals(adata_multivi.obs.index)}')
    if DEBUG:
        display(adata_multivi.obs[marker_set].value_counts())

### save data with updated cell-type predictions

In [None]:
%%time
adata_raw.write(new_raw_anndata_file)
adata_multivi.write(new_multivi_anndata_file)

### Visualize clusters and cell-type predictions

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_multivi, color=['leiden_MultiVI'], 
               frameon=False, legend_loc='on data')

In [None]:
# update the phase1_celltype so the phase2 cells are NA for plotting
adata_multivi.obs.phase1_celltype = np.where(adata_multivi.obs.phase1_celltype == 'phase2', 
                                             np.NaN, adata_multivi.obs.phase1_celltype)
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_multivi, color=['phase1_celltype'], 
               frameon=False)
    
for marker_set in marker_sets:
    with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
        plt.style.use('seaborn-v0_8-bright')
        sc.pl.umap(adata_multivi, color=[marker_set], 
                   frameon=False)    

### compare the predicted cell-types

In [None]:
def heatmap_compare(adata: AnnData, set1: str, set2: str):
    this_df = (
        adata.obs.groupby([set1, set2])
        .size()
        .unstack(fill_value=0)
    )
    norm_df = this_df/this_df.sum(axis=0)

    with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
        plt.style.use('seaborn-v0_8-bright')
        _ = plt.pcolor(norm_df, edgecolor='black')
        _ = plt.xticks(np.arange(0.5, len(this_df.columns), 1), this_df.columns, rotation=90)
        _ = plt.yticks(np.arange(0.5, len(this_df.index), 1), this_df.index)
        plt.xlabel(set2)
        plt.ylabel(set1)
        plt.show()

In [None]:
import itertools
for pair in list(itertools.combinations(['phase1_celltype', 'Cell_type']+marker_sets, 2)):
    print(pair[0], pair[1])
    heatmap_compare(adata_multivi.copy(), pair[0], pair[1])

### compare marker_sets with the Leiden clusters

In [None]:
for marker_set in ['phase1_celltype', 'Cell_type']+marker_sets:
    heatmap_compare(adata_multivi.copy(), 'leiden_MultiVI', marker_set)

In [None]:
!date