## Notebook to run post processing of age regression of  in single-cell multiome data for glm pseudo-bulk based analysis

basically 
- read glm results per region and cell-type and then integrate them
- apply B&H FDR 
- take a look at overlap between brain regions and cell-types do some sample plotting

In [None]:
!date

#### import libraries

In [None]:
from anndata import AnnData
import numpy as np
from pandas import DataFrame, concat, read_csv, Series, read_parquet, set_option as pd_set_option
import scanpy as sc
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
from seaborn import scatterplot, lmplot, displot
from matplotlib.pyplot import rc_context
import json
from os.path import exists
from sklearn.preprocessing import MinMaxScaler

import warnings
warnings.simplefilter('ignore')

import random
random.seed(420)

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
modality = 'ATAC' # 'GEX' or 'ATAC'

In [None]:
# parameters
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
anndata_file = f'{quants_dir}/{project}.multivi.curated.h5ad'

# out files
results_file = f'{results_dir}/{project}.{modality}.glm_age.csv'
results_fdr_file = f'{results_dir}/{project}.{modality}.glm_age_fdr.csv'

# constants
DEBUG = False
# categories = ['curated_type', 'cluster_name'] # 'curated_type' for broad and 'cluster_name' for specific
categories = ['curated_type']
pd_set_option('display.max_rows', 500)

#### functions

In [None]:
def compute_bh_fdr(df: DataFrame, alpha: float=0.05, p_col: str='p-value',
                   method: str='fdr_bh', verbose: bool=True) -> DataFrame:
    ret_df = df.copy()
    test_adjust = multipletests(np.array(ret_df[p_col]), alpha=alpha, 
                                method=method)
    ret_df[method] = test_adjust[1]
    if verbose:
        print(f'total significant after correction: {ret_df.loc[ret_df[method] < alpha].shape}')
    return ret_df

def plot_feature_by_age(df: DataFrame, y_term: str):
    with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 50}):
        plt.style.use('seaborn-v0_8-talk')
        x_term = 'age'
        lmplot(x=x_term,y=y_term, data=df, palette='Purples')
        plt.title(f'{y_term} ~ {x_term}', fontsize='large') 
        plt.xlabel(x_term)
        plt.ylabel(y_term)        
        plt.show()
        
def volcano_plot(df: DataFrame, x_term: str='coef', y_term: str='p-value', 
                 alpha: float=0.05, adj_p_col: str='fdr_bh', title: str=None, 
                 filter_nseeff: bool=True, extreme_size: float=10.0):
    df = df.copy()
    df = df.reset_index(drop=True)    
    if filter_nseeff:
        df = df.loc[((-extreme_size < df[x_term]) & 
                    (df[x_term] < extreme_size) &
                    (~df['z'].isna()) | 
                    (df[adj_p_col] < alpha))]
    plt.figure(figsize=(9,9))
    log_pvalue = -np.log10(df[y_term])
    is_sig = df[adj_p_col] < alpha
    with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 50}):
        plt.style.use('seaborn-v0_8-talk')    
        scatterplot(x=x_term, y=log_pvalue, data=df, hue=is_sig, palette='Purples')
        plt.title(title)
        plt.xlabel('effect')
        plt.ylabel('-log10(p-value)')
        fig_file = f'{figures_dir}/{project}.{modality}.glmpb_volcano.{title}.png'
        plt.savefig(fig_file)
        plt.show()
    
def prep_plot_feature(feature_results: Series, covars: DataFrame):
    if feature_results.type == 'curated_type':
        prefix_type = 'broad'
    elif feature_results.type == 'cluster_name':
        prefix_type = 'specific'
    # load the pseudobulk quantifications
    this_df = read_parquet((f'{quants_dir}/{project}.{modality}.{prefix_type}'
                            f'.{feature_results.tissue}.pb.parquet'))
    this_df = this_df.merge(covars, how='inner', left_index=True, right_index=True)        
    print(feature_results)
    if DEBUG:
        print(f'shape the quantified pseudobulk {this_df.shape}')
        display(this_df.sample(5))    
    plot_feature_by_age(this_df, feature_results.feature)
    with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 50}):
        plt.style.use('seaborn-v0_8-talk')        
        displot(data=this_df[feature_results.feature], kind='kde')
        plt.show()
    

### load discovery cohort data

#### read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))

#### take a look at the cell counts by cell type

In [None]:
for category in categories:
    display(adata.obs[category].value_counts())

### read the age regressions results by cell-type

In [None]:
%%time

glm_results = None
for category in categories:
    print(f'### {category}')    
    for cell_type in adata.obs[category].unique():
        print(f'--- {cell_type}')
        if category == 'curated_type':
            prefix_type = 'broad'
        elif category == 'cluster_name':
            prefix_type = 'specific'   
        in_file = f'{results_dir}/{project}.{modality}.{prefix_type}.{cell_type}.glm_age.csv'
        if exists(in_file):
            glm_results = concat([glm_results, read_csv(in_file)])

In [None]:
print(f'shape of all load results {glm_results.shape}')
if DEBUG:
    display(glm_results.type.value_counts())
    display(glm_results.groupby('type').tissue.value_counts())    
    display(glm_results.sample(5))

### compute the FDR values

In [None]:
glm_results['p-value'] = glm_results['p-value'].fillna(1)
glm_results = compute_bh_fdr(glm_results)
print(glm_results.shape)
if DEBUG:
    display(glm_results.sort_values('fdr_bh').head())

In [None]:
with rc_context({'figure.figsize': (9, 9)}):
    scatterplot(data=glm_results, x=glm_results['fdr_bh'], y=glm_results['p-value'])
    plt.axhline(y=0.05, linestyle='--')
    plt.axvline(x=0.05, linestyle='--')
    plt.show()

### count of significant genes by broad curated cell-type

In [None]:
print(glm_results.loc[glm_results['fdr_bh'] < 0.05]['tissue'].nunique())
display(glm_results.loc[glm_results['fdr_bh'] < 0.05].groupby('type').tissue.value_counts())

#### format tested versus signficant as table

In [None]:
tested = glm_results.groupby('type').tissue.value_counts()
tested.name = 'tested'
significant = glm_results.loc[glm_results['fdr_bh'] < 0.05].groupby('type').tissue.value_counts()
significant.name = 'significant'
combined = concat([tested, significant], axis='columns')
combined['percent'] = round(combined.significant/combined.tested * 100, 2)
display(combined.sort_values('significant', ascending=False))

### save results

#### save the full results

In [None]:
glm_results.to_csv(results_file, index=False)

#### save the statistically significant results

In [None]:
glm_results.loc[glm_results['fdr_bh'] < 0.05].to_csv(results_fdr_file, index=False)

### visualize results

#### visualize volcano plots

In [None]:
volcano_plot(glm_results, title='all_results')

print('### broad cell-types')
for category in categories:
    print(f'### {category}')    
    for cell_type in adata.obs[category].unique():
        cell_name = f'Frontal_cortex_{cell_type}'
        print(f'--- {cell_type}')
        volcano_plot(glm_results.loc[(glm_results.tissue == cell_type) & 
                                     (glm_results.type == category)], title=cell_type)

#### look at some of the individual results

##### format sample covariates

sex, ancestry, age, (gex_pool or atac_pool), pmi, ph, smoker, bmi

In [None]:
keep_terms = ['sample_id','sex', 'ancestry', 'age', 'gex_pool', 'atac_pool', 
              'pmi', 'ph', 'smoker', 'bmi']
covars_df = adata.obs[keep_terms].drop_duplicates().reset_index(drop=True)
covars_df = covars_df.set_index('sample_id')

if DEBUG:
    print(covars_df.shape)
    display(covars_df.head())
    display(covars_df.info())
    display(covars_df.smoker.value_counts())
    display(covars_df.bmi.describe())

##### fill any missing covariate terms
looks like smoker and bmi is missing for one sample will set it to mean of those values

In [None]:
# fill the missing smoker and bmi value
covars_df.loc[covars_df.smoker.isna(), 'smoker'] = covars_df.smoker.mean().round(1)
covars_df.loc[covars_df.bmi.isna(), 'bmi'] = covars_df.bmi.mean().round(1)

if DEBUG:
    print(covars_df.shape)
    display(covars_df.info())
    display(covars_df.smoker.value_counts())
    display(covars_df.bmi.describe())

##### set the pool term based on modality being analyzed

In [None]:
if modality == 'GEX':
    covars_df['pool'] = covars_df.gex_pool
elif modality == 'ATAC':
    covars_df['pool'] = covars_df.atac_pool
covars_df = covars_df.drop(columns=['gex_pool', 'atac_pool'])
print(f'shape of covariate terms is {covars_df.shape}')
if DEBUG:
    display(covars_df.head(40))

In [None]:
# max significant by p-value
this_results = glm_results.loc[glm_results['p-value'] == min(glm_results['p-value'])]
this_hit = this_results.sort_values(by=['coef'], ascending=False).iloc[0]
prep_plot_feature(this_hit, covars_df)

In [None]:
# min significant by coef
sig_results = glm_results.loc[glm_results['fdr_bh'] < 0.05]
this_results = sig_results.loc[sig_results['coef'] == min(sig_results['coef'])]
this_hit = this_results.sort_values(by=['coef'], ascending=False).iloc[0]
prep_plot_feature(this_hit, covars_df)

In [None]:
# random
this_hit = sig_results.sample().iloc[0]
prep_plot_feature(this_hit, covars_df)

In [None]:
# max non-significat by coef
nonsig_results = glm_results.loc[(glm_results['fdr_bh'] > 0.05) & 
                                 (~glm_results['z'].isna())]
this_results = nonsig_results.loc[nonsig_results['coef'] == max(nonsig_results['coef'])]
this_hit = this_results.iloc[0]
prep_plot_feature(this_hit, covars_df)

In [None]:
!date