## Run Gene Set Enrichment Analysis (GSEA) Enrichr using gseapy for the partitioned Aging components using the feature loading of the latent factors grouped into each parition

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, concat, DataFrame, pivot
from gseapy.enrichr import Enrichr
from json import load as json_load
from igraph import Graph
from time import sleep
import statsmodels.stats.multitest as smm
from numpy import log10
from math import ceil
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
project = 'aging_phase2'
latent_type = 'all'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# in files
latent_part_file = f'{figures_dir}/{project}.latents.{latent_type}.partitioned_factors.json'
graphml_file = f'{figures_dir}/{project}.latents.{latent_type}.graphml'

# out files

# constants and variables
DEBUG = True
categories = ['broad', 'specific']
# won't use ATAC here only genes
# modalities = ['GEX', 'ATAC']
MODALITY = 'GEX'
REGRESSION_TYPE = 'glm_tweedie'
marker_sets = ['MSigDB_Hallmark_2020',
               'GO_Biological_Process_2023', 
               'GO_Cellular_Component_2023',
               'GO_Molecular_Function_2023']
PAUSE_AMT = 2
dpi_value = 50

### load age associated features

In [None]:
results = []
for category in categories:
    print(category)
    in_file = (f'{results_dir}/{project}.{MODALITY}.{category}.'
               f'{REGRESSION_TYPE}_fdr_filtered.age.csv')
    this_df = read_csv(in_file)
    this_df['category'] = category
    results.append(this_df)
age_glm_df = concat(results)
print(f'shape of all age associated features {age_glm_df.shape}')
if DEBUG:
    display(age_glm_df.sample(4))
    display(age_glm_df.category.value_counts())

### load the partitioned age associated latent factors graph

In [None]:
latent_graph = Graph.Read_GraphML(graphml_file)
if DEBUG:
    print(latent_graph.vcount())
    print(latent_graph.ecount())
feature_nodes = latent_graph.vs.select(type='feature')
print(f'length non-latent feature nodes in the latent graph is {len(feature_nodes)}')

### load the partitioned age associated latent factors

In [None]:
with open(latent_part_file, 'r') as in_file:
    partitioned_factors = json_load(in_file)
print(f'length of partitioned_factors is {len(partitioned_factors)}')

#### extract the partition groups and cell types

In [None]:
age_latents = {}
for part_index, latents in partitioned_factors.items():
    latent_name = f'Aging-{part_index}'
    pairs = [element.split(':')[0] for element in latents]
    age_latents[latent_name] = list(set(pairs))
print(f'age_latents length is {len(age_latents)}')
if DEBUG:
    display(age_latents)  

### resolve the cell-types to their age associate features

In [None]:
latent_features = {}
for latent, cell_types in age_latents.items():
    print(latent, cell_types)
    age_features = {}
    for cell_type in cell_types:
        these_results = age_glm_df.loc[age_glm_df.tissue == cell_type]
        if len(age_features) == 0:
            age_features = set(these_results.feature)
        else:
            age_features = age_features | set(these_results.feature)
    latent_features[latent] = age_features
    print(f'{latent} has {len(age_features)} features')

### alternatively group features from the partitioned graph; ie these would be based on latent age factor loadings

In [None]:
# latent_features = {}
# for part_index in partitioned_factors.keys():
#     latent_name = f'Aging-{part_index}'
#     member_nodes = feature_nodes.select(membership=float(part_index))
#     print(latent_name, len(member_nodes))
#     age_features = []
#     for node in member_nodes:
#         age_features.append(node['name'])
#     latent_features[latent_name] = list(set(age_features))
#     print(f'{latent_name} has {len(age_features)} feature loadings')

### run the GSEA Enrichr

#### if debugging see available GSEA libraries

In [None]:
if DEBUG:
    import gseapy
    gene_set_names = gseapy.get_library_name(organism='Human')
    print(gene_set_names)

#### utility functions for accessing and scoring GSEA Enrichr

In [None]:
def find_enrichment(name: str, genes: list, sets,
                    verbose: bool=False) -> DataFrame:
    enr_res = gseapy.enrichr(gene_list=genes,
                             organism='Human',
                             gene_sets=sets,
                             cutoff=0.5)
    enr_res.results['factor'] = name    
    if verbose:
        print(f'full {sets} results shape{enr_res.results.shape}')        
        sig = enr_res.results.loc[enr_res.results['Adjusted P-value'] <= 0.05]
        print(f'significant {sets} results shape{sig.shape}')
        display(sig)
    return enr_res.results

In [None]:
results = []
for latent, gene_list in latent_features.items():
    print(f'\n########### {latent} ###########')
    for gene_set in marker_sets:
        print(f'\n+++++++++++ {gene_set} +++++++++++')
        results.append(find_enrichment(latent, list(gene_list), gene_set, verbose=False))
        sleep(PAUSE_AMT)

#### convert full enrichment results into combined data frame

In [None]:
results_df = concat(results)
print(f'full results shape {results_df.shape}')
if DEBUG:
    display(results_df.sample(5))
    display(results_df.Gene_set.value_counts())

#### how many are statistically significant

In [None]:
alpha = 0.05
# print(results_df.loc[results_df.bh_fdr <= alpha].shape)
# display(results_df.loc[results_df.bh_fdr <= alpha].sort_values('Odds Ratio', ascending=False).head())

print(results_df.loc[results_df['Adjusted P-value'] <= alpha].shape)
display(results_df.loc[results_df['Adjusted P-value'] <= alpha].sort_values('Odds Ratio', ascending=False))

### clean-up the GO term entity

In [None]:
results_df['Gene_set'] = results_df.Gene_set.str.replace('GO_','')
results_df['Gene_set'] = results_df.Gene_set.str.replace('_2020','')
results_df['Gene_set'] = results_df.Gene_set.str.replace('_2023','')
results_df['Term'] = results_df.Gene_set + ': ' + results_df.Term
print(f'shape of GSEA post Term naming cleanup {results_df.shape}')
if DEBUG:
    display(results_df.sample(5))

### reshape the dataframe from long to wide

In [None]:
temp_df = results_df.loc[results_df['Adjusted P-value'] <= alpha]
# compute -log10 of p-value
results_df['log10_pvalue'] = -log10(results_df['P-value'])
w_df = pivot(results_df.loc[results_df.Term.isin(temp_df.Term)], 
                  index=['Term'], 
                  columns=['factor'], values='log10_pvalue')
# set precision
w_df = w_df.round(2)
# drop rows that are all null
w_df.dropna(how='all', inplace=True)
print(f'shape of wide reformated results {w_df.shape}')
if DEBUG:
    display(w_df)

### visualize the reformated data as a heatmap

In [None]:
from seaborn import heatmap

if w_df.shape[0] > 9:
    height = 9+ceil(w_df.shape[0]/5)
else:
    height = 9
print(height)        
with rc_context({'figure.figsize': (11, height), 'figure.dpi': dpi_value}):
    plt.style.use('seaborn-v0_8-bright')    
    heatmap(w_df, linecolor='grey', linewidths=0.05, cmap='Purples')    
    plt.title(f'GSEA Enrichr for latent age factors')
    # plt.savefig(figure_file, dpi=dpi_value, bbox_inches='tight', 
    #             transparent=True, pad_inches=1)
    plt.show()

### visualize as clustered heatmap

In [None]:
from seaborn import clustermap

# fill the missing
w_df = w_df.fillna(0)

with rc_context({'figure.figsize': (11, height), 'figure.dpi': dpi_value}):
    plt.style.use('seaborn-v0_8-bright')    
    # clustermap(w_df, cmap='Purples', cbar_pos=(0.75, 0.9, 0.05, 0.18))
    clustermap(w_df, cmap='Purples', cbar_pos=None, linecolor='grey', linewidths=0.05)        
    # plt.title('GSEA Enrichr')
    plt.xticks(rotation = 90)
    # plt.savefig(figure_file, dpi=dpi_value, bbox_inches='tight', 
    #             transparent=True, pad_inches=1)
    plt.show()