In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import LogNorm
import pickle
import seaborn as sns
import scipy

import gseapy as gp
lib_names = gp.get_library_name(organism='Human')

# Module-level analyses

In [None]:
# save gsea terms dict for all celltypes
path = 'data'
file = f'{path}/gsea-terms_all-celltypes_dict.pkl'

# load gsea terms dict for all celltypes
with open(file, 'rb') as f:
   multicelltype_gsea_term_sets = pickle.load(f)

## Keyword search

In [None]:
# create dataframe
cols = []
for k,v in multicelltype_gsea_term_sets.items():
    comps = np.sort(v['component'].unique())
    cols.extend([k+ ' ' + str(c+1) for c in comps])

corr_threshold = 0.1
max_superset_count = 0

def get_kw_names(kw_df, kws):
    if(type(kws) is not list):
        return

    def locate_kw(row, kw):
        return kw in row.name.lower()

    for kw in kws:
        for celltype in multicelltype_gsea_term_sets.keys():
            celltype_gsea_terms = multicelltype_gsea_term_sets[celltype]

            celltype_gsea_terms = celltype_gsea_terms[np.abs(celltype_gsea_terms['pseudodx_corr'])>corr_threshold]
            celltype_gsea_terms = celltype_gsea_terms[celltype_gsea_terms['superset_count']<=max_superset_count]

            kw_mask = celltype_gsea_terms.apply(locate_kw, args=(kw,), axis=1)

            gene_set_names = pd.Series([t[:-1] for t in celltype_gsea_terms.index[kw_mask]], dtype='object')
            gene_set_comps = np.array([celltype + ' ' + str(c+1) for c in celltype_gsea_terms['component'][kw_mask]])
            # color by nes (enrichment score) or fdr
            gene_set_nes = -np.log(celltype_gsea_terms['fdr'][kw_mask].values + 1e-6) * np.sign(celltype_gsea_terms['nes'][kw_mask].values)

            # add gene sets
            for gs in gene_set_names.unique():
                if(gs not in kw_df.index):
                    kw_df = pd.concat((kw_df, pd.DataFrame(data=0, columns=kw_df.columns,index=pd.Index([gs]), dtype=int)), axis=0)
                
                for col in kw_df.columns:
                    gs_mask = np.logical_and(gene_set_comps==col, gene_set_names.values==gs)
                    kw_df.loc[gs, col] += np.sum(gene_set_nes[gs_mask])
    
    return kw_df

kw_df = pd.DataFrame(columns=cols, dtype=int)

# construct dataframe of keyword hits
kw_df = get_kw_names(kw_df, ['microglia', 'mapk', 'inflamm', 'tumor necrosis', \
                            'mhc', 'toll-like', 'oligod', 'myelin', 'alzheimer', \
                            'amyloid', 'lipid', 'cholesterol', 'neuron', 'actin', \
                            'apopto', 'phagocyt', 'copper'])

# remove gene set codes (GO, Wikipathways, Panther)
kw_df.index = kw_df.index.to_series().str.split(' WP', expand=True)[0].str.split(" \(GO", expand=True)[0].str.split(" P0", expand=True)[0].str.title().values

In [None]:
# plot results heatmap
ccmap = 'coolwarm'

kw_df = kw_df.replace(0, np.nan)

plt.figure(figsize=(4,25))
axr = sns.heatmap(kw_df, cmap=ccmap, linewidths=1, linecolor='k', square=True, cbar=False, vmin=-5, vmax=5 )
axr.tick_params(left=False, right=True, top=False, labelleft=False, labelright=True, labeltop=False, rotation=0)
plt.xticks(rotation=90);
axr.tick_params(labelsize=12)

## GWAS search

In [None]:
# 2021 GWAS (38 loci)
# gene list from: Wightman, D. P. et al. A genome-wide association study 
# with 1,126,563 individuals identifies new risk loci for Alzheimer’s disease. Nat. Genet. 53, 1276–1282 (2021).
# renamed genes (alternative): INPPD5, EPHA1-AS1 (EPHA1), TSPOAP1-AS1 (SUPT4H1 (or RNF43 or BZRAP1)), 
gwas_goi = ['AGRN','CR1','NCK2','BIN1','INPPD5','CLNK','TNIP1','HAVCR2','HLA-DRB1','TREM2','CD2AP',
    'TMEM106B','ZCWPW1', 'NYAP1','EPHA1','CLU','SHARPIN','USP6NL', 'ECHDC3','CCDC6','MADD', 'SPI1','MS4A4A','PICALM',
    'SORL1','FERMT2','RIN3','ADAM10','APH1B','SCIMP', 'RABEP1','GRN','ABI3','SUPT4H1','ACE','ABCA7','APOE','NTN5',
    'CD33','LILRB2','CASS4','APP',]

In [None]:
# rank pathways by number of gwas hits
def count_pathway_gwas(row):
    return len(row['AD GWAS genes'])

def test_pathway_gwas(list, gene):
    if(len(list)==0):
        return False
    return gene in list

fdr_threshold = 0.05
corr_threshold = 0.1
max_superset_count = 0 

gwas_celltype_count_dict = {}
gwas_celltype_dict = {hit:[] for hit in gwas_goi}

for celltype, gsea_terms_celltype in multicelltype_gsea_term_sets.items():
    gsea_terms_celltype = gsea_terms_celltype[gsea_terms_celltype['fdr']<fdr_threshold]
    gsea_terms_celltype = gsea_terms_celltype[gsea_terms_celltype['pseudodx_corr']>corr_threshold]
    gsea_terms_celltype = gsea_terms_celltype[gsea_terms_celltype['superset_count']<=max_superset_count]
    for comp in range(gsea_terms_celltype['component'].unique().size):
        # gsea_terms = gsea_terms[~gsea_terms.index.duplicated()]
        gsea_terms_celltype_comp = gsea_terms_celltype[gsea_terms_celltype['component']==comp].copy()
        gsea_terms_celltype_comp['gwas_count'] = gsea_terms_celltype_comp.apply(count_pathway_gwas, axis=1)

        # look at which gwas genes are present in top pathways across celltypes
        gwas_celltype_count_dict[f'{celltype} {comp+1}'] = []
        for hit in gwas_goi:
            gwas_celltype_count_dict[f'{celltype} {comp+1}'].append(gsea_terms_celltype_comp['ledge_genes'].apply(test_pathway_gwas, args=[hit]).sum())
            gwas_celltype_dict[hit].extend([f'{celltype} comp-{comp+1}: {x}' for x in gsea_terms_celltype_comp.index[gsea_terms_celltype_comp['ledge_genes'].apply(test_pathway_gwas, args=[hit])].values.tolist()])

gwas_celltype_df = pd.DataFrame(gwas_celltype_count_dict, index=gwas_goi)

df_plot = gwas_celltype_df[gwas_celltype_df.sum(axis=1)>0] # drop genes with no hits
df_plot = df_plot.loc[(df_plot>0).sum(axis=1).sort_values(ascending=False).index] # sort columns (genes) by number number of component hits

plt.figure(figsize=(20,8))
sns.heatmap(df_plot.T, cmap='Reds', annot=True, norm=LogNorm(), cbar=False, linewidths=1, linecolor='k')
plt.xticks(rotation=60, ha='center')
plt.title('Number of pathways containing top Alzheimer GWAS hits')
plt.xlabel('GWAS hits')
plt.ylabel('Celltype');

# Celltype-celltype interaction

In [None]:
fdr_threshold = 0.05
corr_threshold = 0.1

# look at coordination between top N pathways (by pseudodx correlation) between celltypes
def cross_correlate_top_pathways(gsea_terms, cell1, comp1, cell2, comp2, rank_column, metric_list_column, n_compare=10):
    gsea_terms_celltype1 = gsea_terms[cell1].sort_values(by=rank_column, ascending=True)
    gsea_terms_celltype1 = gsea_terms_celltype1[gsea_terms_celltype1['component']==comp1]
    gsea_terms_celltype1 = gsea_terms_celltype1[gsea_terms_celltype1['fdr']<fdr_threshold]
    gsea_terms_celltype1 = gsea_terms_celltype1[gsea_terms_celltype1['pseudodx_corr']>corr_threshold]
    
    gsea_terms_celltype2 = gsea_terms[cell2].sort_values(by=rank_column, ascending=True)
    gsea_terms_celltype2 = gsea_terms_celltype2[gsea_terms_celltype2['component']==comp2]
    gsea_terms_celltype2 = gsea_terms_celltype2[gsea_terms_celltype2['fdr']<fdr_threshold]
    gsea_terms_celltype2 = gsea_terms_celltype2[gsea_terms_celltype2['pseudodx_corr']>corr_threshold]
    
    n_compare = min(n_compare, min(gsea_terms_celltype1.shape[0], gsea_terms_celltype2.shape[0]))

    r_list = []
    for i1 in range(n_compare):
        for i2 in range(n_compare):
            pathway1_by_id = np.array(gsea_terms_celltype1.iloc[i1][metric_list_column])
            pathway2_by_id = np.array(gsea_terms_celltype2.iloc[i2][metric_list_column])
            
            # pearson
            nas = np.logical_or(np.isnan(pathway1_by_id), np.isnan(pathway2_by_id))
            r = scipy.stats.pearsonr(pathway1_by_id[~nas], pathway2_by_id[~nas])[0]
            
            r_list.append((f'{cell1} {comp1+1}',f'{cell2} {comp2+1}',gsea_terms_celltype1.index[i1],gsea_terms_celltype2.index[i2],r))
    
    return r_list

gsea_celltypes = list(multicelltype_gsea_term_sets.keys())
celltype_n_comps = {cell:multicelltype_gsea_term_sets[cell]['component'].unique().shape[0] for cell in gsea_celltypes}
gsea_celltypes_comps = [f'{cell} {comp+1}' for cell in gsea_celltypes for comp in range(celltype_n_comps[cell])]
celltype_pair_df = pd.DataFrame(0, index=gsea_celltypes_comps, columns=gsea_celltypes_comps)

pathway_pair_columns = ['celltype1', 'celltype2', 'pathway1', 'pathway2', 'correlation']
pathway_pair_df = pd.DataFrame(columns=pathway_pair_columns)

n_celltypes = len(gsea_celltypes)
for i_cell1 in range(n_celltypes):
    for i_cell2 in range(i_cell1+1,n_celltypes):
        cell1 = gsea_celltypes[i_cell1]
        cell2 = gsea_celltypes[i_cell2]

        for comp1 in range(celltype_n_comps[cell1]):
            for comp2 in range(celltype_n_comps[cell2]):
                # evaluate correlation between pairs of top pathways
                r_list = cross_correlate_top_pathways(multicelltype_gsea_term_sets, cell1, comp1, cell2, comp2, 
                                                      rank_column='fdr', metric_list_column='dx_pred_score', n_compare=50)
                
                pathway_pair_df_tmp = pd.DataFrame(r_list, columns=pathway_pair_columns)

                # save median absolute correlation for celltype pair
                celltype_pair_df.loc[f'{cell1} {comp1+1}', f'{cell2} {comp2+1}'] = np.median(np.abs(pathway_pair_df_tmp['correlation']))

                # save pathway pairs
                pathway_pair_df = pd.concat([pathway_pair_df, pathway_pair_df_tmp], axis=0)

In [None]:
# mask out upper triangle of dataframe
celltype_pair_df = celltype_pair_df.replace(0,np.nan)

# set color of None values
cmap = cm.get_cmap('Reds').copy()
cmap.set_bad('w')

plt.figure(figsize=(10,10))
plt.imshow(celltype_pair_df.T, cmap=cmap)
plt.title('Median absolute correlation between top pathway pairs', fontsize=16)
plt.xticks(ticks=np.arange(len(gsea_celltypes_comps)), labels=gsea_celltypes_comps, rotation=90, fontsize=14)
plt.yticks(ticks=np.arange(len(gsea_celltypes_comps)), labels=gsea_celltypes_comps, fontsize=14)
plt.colorbar()