## Notebook to score cells based on gene set

attempt to use this functionality for determing scores for
- disease associated microglia (DAM) 
    - Hu Y, Fryatt GL, Ghorbani M et al. Replicative senescence dictates the emergence of disease-associated microglia and contributes to Aβ pathology. Cell Rep 2021;35:109228.
    https://www.sciencedirect.com/science/article/pii/S2211124721005799
        - "DAM signature (Cst7, Csf1, Lpl, Apoe, Spp1, Cd74, Itgax)"
- senescence
    - Hu Y, Fryatt GL, Ghorbani M et al. Replicative senescence dictates the emergence of disease-associated microglia and contributes to Aβ pathology. Cell Rep 2021;35:109228.
    https://www.sciencedirect.com/science/article/pii/S2211124721005799?via%3Dihub
        - "custom senescence signature (Cdkn2a, Cdkn1a, Cdkn2d, Casp8, Il1b, Glb1, Serpine1)"
    - Dehkordi SK, Walker J, Sah E et al. Profiling senescent cells in human brains reveals neurons with CDKN2D/p19 and tau neuropathology. Nat Aging 2021;1:1107–16.
    https://pubmed.ncbi.nlm.nih.gov/35531351/
    - Casella G, Munk R, Kim KM et al. Transcriptome signature of cellular senescence. Nucleic Acids Res 2019;47:7294–305.
https://pubmed.ncbi.nlm.nih.gov/31251810/
        - Canonical Senescence Pathway (CDKN2D, ETS2, RB1, E2F3, CDK6, RBL2, ATM, BMI1, MDM2, CDK4, CCNE1)
        - Senescence Response Pathway (IGFBP7, VIM, FN1, SPARC, IGFBP4, TIMP1, TBX2, TBX3, COL1A1, COL3A1, IGFBP2, TGFB1I1, PTEN, CD44, NFIA, CALR, TIMP2, CXCL8)
        - Senescence Initiating Pathway (SOD1, MAP2K1, GSK3B, PIK3CA, SOD2, MAPK14, IGF1R, TP53BP1, NBN, HRAS, CITED2, CREG1, ABL1, MORC3, NFKB1, AKT1, CDKN1B, EGR1, RBL1, MAP2K6, IGF1, IRF3, PCNA, GADD45A, MAP2K3, IGFBP5, SIRT1, ING1, TGFB1, TERF2)
        - Cell Age (PEBP1, PKM, CKB, AAK1, NUAK1, MAST1, SORBS2, BRAF, SPIN1, MAP2K1, YPEL3, MAPK14, PDPK1, TOP1, ITPK1, MATK, RPS6KA6, SPOP, ITSN2, PDZD2, MAP2K2, LIMK1, DHCR24, PBRM1, MAP3K7, SIN3B, SOX5, EWSR1, PDCD10, CPEB1, NEK4, RB1, MCRS1, PNPT1, HRAS, STK32C, RAF1, ETS2, SMARCB1, FASTK, SLC13A3, TRIM28, MORC3, MAPKAPK5, MAP2K7, STK40, PMVK, CEBPB, GRK6, STAT5B, CDKN1B, PDIK1L, AKT1, MAPK12, MAP2K6, PIAS4, ADCK5, SMURF2, PCGF2, IRF3, PLA2R1, TYK2, ERRFI1, BRD7, ING2, FBXO31, NADK, PTTG1, BHLHE40, ASF1A, ING1, NINJ1, MXD4)
        - UniUp ['TMEM159', 'CHPF2', 'SLC9A7', 'PLOD1', 'FAM2234B', 'DHRS7', 'SRPX', 'SRPX2', 'TNFSF13B', 'PDLIM1', 'ELMOD1', 'CCND3', 'TMEM30A', 'STAT1', 'RND3', 'TMEM59', 'SARAF', 'SLC16A14', 'SLC02B1', 'ARRDC4', 'PAM', 'WDR78', 'NCSTN', 'GPR155', 'CLDN1', 'JCAD', 'BLCAP', 'FILIP1L', 'TAP1', 'TNFRSF10C', 'SAMD9L', 'SMC03', 'POFUT2', 'KIAA1671', 'LRP10', 'BMS1P9', 'MAP4K3-DT', 'AC002480.1', 'LINC02154', 'TM4SF1-AS1', 'PTCHD4', 'H2AFJ', 'PURPL']

In [None]:
!date

#### import libraries

In [None]:
import scanpy as sc
from anndata import AnnData
from numpy import ndarray
from random import seed, randint
from scipy.stats import zscore, shapiro, kstest, norm
from pandas import DataFrame
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
from seaborn import scatterplot, displot, heatmap
from pandas import read_csv, DataFrame

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

import warnings
warnings.filterwarnings(action='ignore')

#### set notebook variables

In [None]:
# naming
project = 'aging'

# directories 
wrk_dir = '/home/jupyter/brain_aging_phase1'
quants_dir = f'{wrk_dir}/demux'
results_dir = f'{wrk_dir}/results'

# in files
in_file = f'{quants_dir}/{project}.pegasus.leiden_085.subclustered.h5ad'
glmmtmb_file = f'{results_dir}/{project}.glmmtmb_age_diffs_fdr.csv'

# variables
DEBUG = False
dpi_value = 50
gene_sets = {'DAM': ['CST7', 'CSF1', 'LPL', 'APOE', 'SPP1', 'CD74', 'ITGAX'],
             'CSS': ['CDKN2A', 'CDKN1A', 'CDKN2D', 'CASO8', 'IL1B', 'GLB1', 'SERPINE1'],
             'CSP': ['CDKN2D', 'ETS2', 'RB1', 'E2F3', 'CDK6', 'RBL2', 'ATM', 
                     'BMI1', 'MDM2', 'CDK4', 'CCNE1'],
             'SRP': ['IGFBP7', 'VIM', 'FN1', 'SPARC', 'IGFBP4', 'TIMP1', 'TBX2', 
                     'TBX3', 'COL1A1', 'COL3A1', 'IGFBP2', 'TGFB1I1', 'PTEN', 
                     'CD44', 'NFIA', 'CALR', 'TIMP2', 'CXCL8'],
             'SIP': ['SOD1', 'MAP2K1', 'GSK3B', 'PIK3CA', 'SOD2', 'MAPK14', 
                     'IGF1R', 'TP53BP1', 'NBN', 'HRAS', 'CITED2', 'CREG1', 
                     'ABL1', 'MORC3', 'NFKB1', 'AKT1', 'CDKN1B', 'EGR1', 
                     'RBL1', 'MAP2K6', 'IGF1', 'IRF3', 'PCNA', 'GADD45A', 
                     'MAP2K3', 'IGFBP5', 'SIRT1', 'ING1', 'TGFB1', 'TERF2'], 
             'CellAge': ['PEBP1', 'PKM', 'CKB', 'AAK1', 'NUAK1', 'MAST1', 
                         'SORBS2', 'BRAF', 'SPIN1', 'MAP2K1', 'YPEL3', 'MAPK14', 
                         'PDPK1', 'TOP1', 'ITPK1', 'MATK', 'RPS6KA6', 'SPOP', 
                         'ITSN2', 'PDZD2', 'MAP2K2', 'LIMK1', 'DHCR24', 'PBRM1', 
                         'MAP3K7', 'SIN3B', 'SOX5', 'EWSR1', 'PDCD10', 'CPEB1', 
                         'NEK4', 'RB1', 'MCRS1', 'PNPT1', 'HRAS', 'STK32C', 'RAF1', 
                         'ETS2', 'SMARCB1', 'FASTK', 'SLC13A3', 'TRIM28', 'MORC3', 
                         'MAPKAPK5', 'MAP2K7', 'STK40', 'PMVK', 'CEBPB', 'GRK6', 
                         'STAT5B', 'CDKN1B', 'PDIK1L', 'AKT1', 'MAPK12', 'MAP2K6', 
                         'PIAS4', 'ADCK5', 'SMURF2', 'PCGF2', 'IRF3', 'PLA2R1', 
                         'TYK2', 'ERRFI1', 'BRD7', 'ING2', 'FBXO31', 'NADK', 'PTTG1', 
                         'BHLHE40', 'ASF1A', 'ING1', 'NINJ1', 'MXD4'], 
             'UniUp': ['TMEM159', 'CHPF2', 'SLC9A7', 'PLOD1', 'FAM2234B', 'DHRS7', 
                       'SRPX', 'SRPX2', 'TNFSF13B', 'PDLIM1', 'ELMOD1', 'CCND3', 
                       'TMEM30A', 'STAT1', 'RND3', 'TMEM59', 'SARAF', 'SLC16A14', 
                       'SLC02B1', 'ARRDC4', 'PAM', 'WDR78', 'NCSTN', 'GPR155', 
                       'CLDN1', 'JCAD', 'BLCAP', 'FILIP1L', 'TAP1', 'TNFRSF10C', 
                       'SAMD9L', 'SMC03', 'POFUT2', 'KIAA1671', 'LRP10', 'BMS1P9', 
                       'MAP4K3-DT', 'AC002480.1', 'LINC02154', 'TM4SF1-AS1', 
                       'PTCHD4', 'H2AFJ', 'PURPL'],
             'UniDown': ['MCUB', 'FBL', 'HIST1H1D', 'HIST1H1A', 'FAM129A', 
                         'ANP32B', 'PARP1', 'LBR', 'SSRP1', 'TMSB15A', 
                         'CBS', 'CDCA7L', 'HIST1H1E', 'CBX2', 'PTMA', 'HIST2H2AB', 
                         'ITPRIPL1', 'AC074135.1'] 
            }
# exclude_cell_types = ['uncertain', 'uncertain-2', 'uncertain-3', 'Astrocyte-GFAP-Hi']
exclude_cell_types = ['Astrocyte-GFAP-Hi']
seed(42)

#### utilitiy functions

In [None]:
def test_score_normality(scores):
    stat, p = shapiro(scores)
    if p > 0.05:
        print(f'by Shapiro is Gaussian p={p}, stat={stat}')
    else:
        print(f'by Shapiro is not Gaussian p={p}, stat={stat}')
    stat, p = kstest(scores, norm.cdf)
    if p > 0.05:
        print(f'by Kolmogorov-Smirnov is Gaussian p={p}, stat={stat}')
    else:
        print(f'by Kolmogorov-Smirnov is not Gaussian p={p}, stat={stat}')
    
    
def score_gene_set(data: AnnData, set_name: str, genes: list, iter_cnt: int=100, 
                   num_devs: int=2, verbose: bool=False):
    print(f'scoring {set_name}')
    score_name = f'{set_name}_score'
    scores = ndarray(shape=(data.obs.shape[0], iter_cnt), dtype=float)
    print('bootstrapping score_genes', end='.')
    for index in range(0, iter_cnt):
        print(index, end='.')
        genes = list(set(genes) & set(data.var.index))
        temp_data = sc.tl.score_genes(data, genes, score_name=score_name, 
                                      copy=True, random_state=randint(0, 9999))
        scores[:,index] = temp_data.obs[score_name]
    data.obs[score_name] = scores.mean(axis=1)
    # binarize score
    data.obs[f'is{set_name}'] = zscore(scores.mean(axis=1)) > num_devs
    # test normality of score
    test_score_normality(data.obs[score_name])
    # plot score distribution
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        displot(x=score_name, data=data.obs, kind='kde')
    # plot standardized score distribution
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        displot(x=zscore(scores.mean(axis=1)), kind='kde')
        plt.show()
    
    data.obs[f'is{set_name}'].replace({True: set_name, False: f'non{set_name}'}, inplace=True)
    # find proportion of total cells where condition exists
    positive_cnt = data.obs.loc[data.obs[f'is{set_name}'] == set_name].shape[0]
    total_cnt = data.obs.shape[0]
    positive_frac = round(positive_cnt/total_cnt, 3)
    print(f'\nFractions of cells positive for {set_name} is {positive_frac} or \
{positive_cnt} of {total_cnt}')        
    print('visualizing', end='.')
    
    # show positive cells by Age
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        scatterplot(x='Age', y=score_name, data=data.obs, hue=f'is{set_name}')    
    # show dotplots of genes by cell-type and age group
    this_set = list(set(genes) & set(data.var.index))
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        sc.pl.dotplot(data, this_set, groupby='new_anno', 
                      mean_only_expressed=True)
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        sc.pl.dotplot(data, this_set, groupby='Age_group', 
                      mean_only_expressed=True)
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        sc.pl.dotplot(data, this_set, groupby=f'is{set_name}', 
                      mean_only_expressed=True)
        
    # visualize umap of positive cells and score
    pos_data = data[data.obs[f'is{set_name}'] == set_name]
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        sc.pl.umap(pos_data, color=['new_anno', score_name], legend_loc='on data')
    # visualize binarized score
    sc.tl.embedding_density(data, groupby=f'is{set_name}')
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-bright')
        sc.pl.embedding_density(data, groupby=f'is{set_name}')
        
    # get counts by cell-type and brain region
    display(data.obs.groupby('new_anno')[f'is{set_name}'].value_counts())
    display(data.obs.groupby('Brain_region')[f'is{set_name}'].value_counts())
    return positive_frac

def score_broad_celltypes_for_set(name: str, genes: list, adata: AnnData) -> dict:
    set_scores = {}
    for target_cell_type in adata.obs.broad_celltype.unique():
        print(f'#### {target_cell_type} ####')
        sdata = adata[adata.obs.broad_celltype == target_cell_type]
        score = score_gene_set(sdata, name, genes, iter_cnt=10, verbose=DEBUG)
        set_scores[target_cell_type] = score
    return set_scores

### read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(in_file, cache=True)
print(adata)

### Plot the clusters

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['new_anno', 'broad_celltype'], legend_loc='on data')

In [None]:
adata.obs.broad_celltype.value_counts()

### remove cell-types that are known excludes

In [None]:
adata = adata[~adata.obs.new_anno.isin(exclude_cell_types)]
print(adata)

### replot the clusters without the excluded celltypes

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata, color=['new_anno', 'broad_celltype'], legend_loc='on data')

In [None]:
adata.obs.new_anno.value_counts()

In [None]:
adata.obs.broad_celltype.value_counts()

### score the DAM

In [None]:
%%time
mg_data = adata[adata.obs.broad_celltype == 'Microglia'].copy()
print(mg_data)
score_gene_set(mg_data, 'DAM', gene_sets['DAM'], iter_cnt=10, num_devs=1, verbose=DEBUG)

### score the CSS, custom senescence signature

In [None]:
%%time
scores = {}
this_set = 'CSS'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)

### score the CSP, Canonical Senescence Pathway

In [None]:
%%time
this_set = 'CSP'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)    

### score the SRP, Senescence Response pathway

In [None]:
%%time
this_set = 'SRP'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)        

### score the SIP, Senescence Initiating pathway

In [None]:
%%time
this_set = 'SIP'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)      

### score the CellAge

In [None]:
%%time
this_set = 'CellAge'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)      

### score the UniUp

In [None]:
%%time
this_set = 'UniUp'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)       

### score the UniDown

In [None]:
%%time
this_set = 'UniDown'
scores[this_set] = score_broad_celltypes_for_set(this_set, gene_sets[this_set], adata)     

### score all combined

#### combine the senescence gene sets into single marker set

In [None]:
marker_set = []
for name, genes in gene_sets.items():
    marker_set.extend(genes)
marker_set = list(set(marker_set))    
print(f'lenght of marker set is {len(marker_set)}')

if DEBUG:
    print(marker_set)

In [None]:
%%time
this_set = 'All'
scores[this_set] = score_broad_celltypes_for_set(this_set, marker_set, adata)

### visualize the senescent cells fractions by cell-type and senescent gene mark sets

In [None]:
fracs_by_set = DataFrame.from_dict(scores).sort_index()
print(f'shape of fractions dataframe is {fracs_by_set.shape}')
if DEBUG:
    display(fracs_by_set)

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-bright')    
    heatmap(fracs_by_set, annot=True, 
            annot_kws={"fontsize":10}, linewidths=0.05, cmap='Blues')    
    plt.title(f'Fraction of cells identified as senescent')
    plt.show()

### check diff expression by age results

#### load the GLMMTMB results

In [None]:
diff_df = read_csv(glmmtmb_file)
print(f'diff df shape {diff_df.shape}')
if DEBUG:
    display(diff_df.head())

#### check each gene set in both region and cell-type

In [None]:
results = []
for result_type in ['brain_region', 'broad_type']:
    temp_diff_df = diff_df.loc[diff_df.type == result_type]
    tissues = temp_diff_df.tissue.unique()
    for tissue in tissues:
        for name, gene_set in gene_sets.items():
            # found = temp_diff_df.loc[(temp_diff_df.tissue == tissue) & 
            #                          (temp_diff_df.estimate > 0) & 
            #                          (temp_diff_df.feature.isin(gene_set))]
            # removing restriction that change is increasing
            found = temp_diff_df.loc[(temp_diff_df.tissue == tissue) & 
                                     (temp_diff_df.feature.isin(gene_set))]            
            proportion = round(found.shape[0]/len(gene_set), 3)
            # print(name, gene_set)
            print(result_type, tissue, name, len(gene_set), proportion)
            results.append([result_type, tissue, name, len(gene_set), proportion])

#### format results as dataframe and pivot as necessary

In [None]:
results_df = DataFrame(results, columns=['type', 'tissue', 'score_type', 'score_features', 'fraction'])
print(f'results df shape{results_df.shape}')
if DEBUG:
    display(results_df.head())

In [None]:
results_pv = results_df.pivot(index=['type', 'tissue'], columns='score_type', values='fraction')
print(f'result pivot shape {results_pv.shape}')
if DEBUG:
    display(results_pv.head())

#### visualize the reformated data as a heatmap

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-bright')    
    heatmap(results_pv.droplevel(0), annot=True, 
            annot_kws={"fontsize":10}, linewidths=0.05, cmap='Blues')    
    plt.title(f'Fraction of gene set with changed expression by age')
    plt.show()

In [None]:
for name, gene_set in gene_sets.items():
    print(name, len(gene_set))

In [None]:
!date