## Notebook to run post processing of differential expression in single-cell data using glmmTMB for the replication cohort data

basically 
- read glmmTMB r script results per region and cell-type and then integrate them
- apply B&H FDR 
- take a look at overlap between brain regions and cell-types do some sample plotting
- this is a reduced copy of the discovery cohort post processing

In [None]:
!date

#### import libraries

In [None]:
from anndata import AnnData
import numpy as np
import pandas as pd
import scanpy as sc
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import rc_context
import json
from os.path import exists

import warnings
warnings.simplefilter('ignore')

import random
random.seed(420)

#### set notebook variables

In [None]:
# parameters
project = 'aging_phase1'
set_name = f'{project}_replication'
cohort = 'aging'

# directories for initial setup
wrk_dir = '/home/jupyter/brain_aging_phase1'
replication_dir = f'{wrk_dir}/replication'

# in files
anndata_file = f'{replication_dir}/{set_name}.full.h5ad'
temp_name_remap_json = '{this_dir}/{name}_gene_name_remap_temp.csv'
temp_r_out_file = '{this_dir}/{chrt}.{name}_glmmtmb_results_temp.csv'

# out files
results_file = f'{replication_dir}/{set_name}.glmmtmb_age_diffs.csv'
results_fdr_file = f'{replication_dir}/{set_name}.glmmtmb_age_diffs_fdr.csv'

# constants
DEBUG = True
min_cell_count = 3
young_age_limit = 30.0


sns.set_theme(style='white', palette='Paired', font_scale=1.2)

# allow for more rows in output
prev_default = pd.get_option('max_rows')
pd.set_option('max_rows', 2000)    
# # restore default setting
# pd.set_option('max_rows',prev_default)

#### analysis functions

In [None]:
def read_feature_renamed_map(group_name: str) -> dict:
    # read dict from json file
    rename_cols = json.load(open(temp_name_remap_json.format(this_dir=replication_dir,
                                                             name=group_name.replace(" ", "_"))))
    return rename_cols

def reformat_glmmtmb_df(df: pd.DataFrame) -> pd.DataFrame:
    # reformat results into one row per feature
    temp_term = df.loc[df['term'] == 'old'].copy()
    temp_intercepts = df.loc[df['term'] == '(Intercept)', ['feature', 'estimate']].copy()
    temp_intercepts = temp_intercepts.rename(columns={'estimate': 'intercept'})
    this_df = temp_term.merge(temp_intercepts, how='inner', on='feature')
    return this_df[['feature', 'intercept', 'estimate', 'std.error', 'statistic', 'p.value']]

def read_glmmtmb_results(group_name: str, group_type: str, cols_to_rename: dict) -> pd.DataFrame:
    this_file = temp_r_out_file.format(this_dir=replication_dir, chrt=cohort, 
                                       name=group_name.replace(" ", "_"))
    if exists(this_file):
        this_df = pd.read_csv(this_file)
        # need to flip the features with '-' -> '_' for R back to originals
        # the the key/values
        rename_cols = {value: key for (key, value) in cols_to_rename.items()}
        this_df['feature'] = this_df['feature'].replace(rename_cols)
        this_df = reformat_glmmtmb_df(this_df)
        this_df['tissue'] = group_name
        this_df['type'] = group_type     
    else:
        this_df = None
    return this_df

def compute_bh_fdr(df: pd.DataFrame, alpha: float=0.05, p_col: str='p.value',
                   method: str='fdr_bh', verbose: bool=True) -> pd.DataFrame:
    ret_df = df.copy()
    test_adjust = multipletests(np.array(ret_df[p_col]), alpha=alpha, 
                                method=method)
    ret_df[method] = test_adjust[1]
    if verbose:
        print(f'total significant after correction: {ret_df.loc[ret_df[method] < alpha].shape}')
    return ret_df

def plot_feature_by_age_group(df: pd.DataFrame, x_term: str, y_term: str):
    plt.figure(figsize=(9,9))
    sns.boxenplot(x=x_term,y=y_term, scale='exponential', data=df,
                  k_depth='trustworthy')

    grsplt = sns.stripplot(x=x_term,y=y_term, data=df, alpha=0.75,
                           jitter=True, color='darkgrey')
    plt.title(f'{y_term} ~ {x_term}', fontsize='large') 
    plt.xlabel(x_term)
    plt.ylabel(y_term)
    plt.show()
    
def plot_feature_by_sample(df: pd.DataFrame, x_term: str, y_term: str):
    # set up order by young then old
    temp = df.groupby('Age_group')['Sample_ID'].unique()
    this_list = temp['young'].to_list() + temp['old'].to_list()    
    plt.figure(figsize=(9,9))
    sns.boxenplot(x='Sample_ID',y=y_term, scale='exponential', data=df,
                  k_depth='trustworthy', hue=x_term, order=this_list)
    grsplt = sns.stripplot(x='Sample_ID',y=y_term, data=df, alpha=0.75,
                           jitter=True, color='darkgrey', order=this_list)
    plt.xticks(rotation=75)
    plt.title(f'{y_term} ~ {x_term}', fontsize='large') 
    plt.xlabel('Sample')
    plt.ylabel(y_term)
    plt.show()    
    
def volcano_plot(df: pd.DataFrame, x_term: str='estimate', y_term: str='p.value', 
                 alpha: float=0.05, adj_p_col: str='fdr_bh', title: str=None, 
                 filter_nseeff: bool=True, extreme_size: float=10.0):
    if filter_nseeff:
        df = df.loc[((-extreme_size < df[x_term]) & 
                    (df[x_term] < extreme_size) &
                    (~df['statistic'].isna()) | 
                    (df[adj_p_col] < alpha))]
    plt.figure(figsize=(9,9))
    log_pvalue = -np.log10(df[y_term])
    is_sig = df[adj_p_col] < alpha
    sns.scatterplot(x=x_term, y=log_pvalue, data=df, hue=is_sig)
    plt.title(title)
    plt.xlabel('effect')
    plt.ylabel('-log10(p-value)')
    plt.show()
    
def prep_plot_feature(data: AnnData, feature_results: pd.Series, 
                      group: str='old', filter_zeros: bool=False):
    this_ad = subset_ad_by_type(data, feature_results.tissue, feature_results.type)
    this_df = convert_ad_to_df(this_ad)
    if filter_zeros:
        this_df = this_df.loc[this_df[feature_results.feature] > 0]
    print(feature_results)
    sns.set_theme(style='white', palette='Paired', font_scale=1.2)
    plot_feature_by_age_group(this_df, group, feature_results.feature)
    plot_feature_by_sample(this_df, group, feature_results.feature)
    
def subset_ad_by_type(data: AnnData, group_name: str, type_name: str,
                      reapply_filter: bool=True, min_cell_count: int=3,
                      verbose: bool=False) -> AnnData:
    name_parts = group_name.split(' ')
    broad_cell_name = name_parts[len(name_parts)-1]
    this_data = data[data.obs.Cell_type == broad_cell_name].copy()
    shape_before = this_data.shape
    if reapply_filter:
        sc.pp.filter_genes(this_data, min_counts=min_cell_count)
        sc.pp.filter_cells(this_data, min_counts=min_cell_count)
        shape_after = this_data.shape
    if verbose:
        print(f'shape before and after: {shape_before} {shape_after}')
        print(this_data)
    return this_data

def convert_ad_to_df(data: AnnData, young_age_limit: float=30.0, 
                     verbose: bool=False) -> pd.DataFrame:
    data_df = data.to_df()
    annots = data.obs[['Brain_region', 'Age', 
                       'Sample_ID', 'Sex']].copy()
    annots['old'] = np.where((annots['Age'] > young_age_limit), 1, 0)
    annots['Age_group'] = np.where((annots['Age'] > young_age_limit), 'old', 'young')
    annots['female'] = np.where((annots['Sex'] == 'Female'), 1, 0)
    this_df = None
    if data_df.index.equals(annots.index):
        this_df = pd.concat([data_df, annots], axis='columns')
        if verbose:
            print(this_df.shape)
            display(this_df.head())
    return this_df

#### read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)
print(adata)

#### take a look at the cell counts by cell type
only single region and broad cell-types

In [None]:
display(adata.obs.Cluster.value_counts())
display(adata.obs.Cell_type.value_counts())

#### get sample counts per age group by counts by cell-type

In [None]:
adata.obs['old'] = np.where((adata.obs['Age'] > young_age_limit), 1, 0)
adata.obs['female'] = np.where((adata.obs['Sex'] == 'Female'), 1, 0)
display(adata.obs.groupby(['Cell_type','old'])['Sample_ID'].nunique())

In [None]:
with rc_context({'figure.figsize': (12, 12)}):
    sc.pl.umap(adata, color=['Cell_type'], legend_loc='on data', 
               add_outline=True, legend_fontsize=10)

##### find cell-types not used in analysis
remove them, and then refilter genes based on cell count

In [None]:
adata = adata[~adata.obs.Cell_type.isna(), :]
sc.pp.filter_genes(adata, min_cells=min_cell_count)
print(adata)

### read the diff by age results by region and cell-type

In [None]:
%%time
glmmtmb_results = None
this_type = 'region_broad_celltype'
for region in adata.obs.Brain_region.unique():
    for broad_type in adata.obs.Cell_type.unique():
        this_tissue = f'{region.capitalize()} {broad_type}'
        print(this_tissue)
        renamed_features = read_feature_renamed_map(this_tissue)
        glmmtmb_results = pd.concat([glmmtmb_results, 
                                     read_glmmtmb_results(this_tissue, this_type,
                                                          renamed_features)])

In [None]:
if DEBUG:
    print(glmmtmb_results.shape)
    display(glmmtmb_results.sample(5))

#### count of glmmTMB results by type

In [None]:
display(glmmtmb_results.groupby('type')['tissue'].value_counts())

#### compute the FDR values

In [None]:
glmmtmb_results['p.value'] = glmmtmb_results['p.value'].fillna(1)
glmmtmb_results = compute_bh_fdr(glmmtmb_results)
print(glmmtmb_results.shape)
if DEBUG:
    display(glmmtmb_results.sort_values('fdr_bh').head())

#### count of significant genes by type for glmmTMB

In [None]:
display(glmmtmb_results.loc[glmmtmb_results['fdr_bh'] < 0.05].groupby('type')['tissue'].value_counts())

In [None]:
display(glmmtmb_results.loc[glmmtmb_results['p.value'] < 5e-02].groupby('type')['tissue'].value_counts())

### save the full results

In [None]:
glmmtmb_results.to_csv(results_file, index=False)

### save the statistically significant results

In [None]:
glmmtmb_results.loc[glmmtmb_results['fdr_bh'] < 0.05].to_csv(results_fdr_file, index=False)

In [None]:
if DEBUG:
    display(glmmtmb_results.loc[glmmtmb_results['fdr_bh'] < 0.05].head())

### visualize volcano plots

In [None]:
print('---- all glmmTMB results ----')
volcano_plot(glmmtmb_results)

print('---- per region broad cell-type glmmTMB results ----')
for tissue in glmmtmb_results['tissue'].unique():
    print(f'*** {tissue} glmmTMB results ***')
    volcano_plot(glmmtmb_results.loc[glmmtmb_results['tissue'] == tissue], title=tissue)

In [None]:
if DEBUG:
    display(glmmtmb_results.sample(10))

### look at some of the individual results

##### max significant by p-value

In [None]:
this_results = glmmtmb_results.loc[glmmtmb_results['p.value'] == min(glmmtmb_results['p.value'])]
this_hit = this_results.sort_values(by=['estimate'], ascending=False).iloc[0]
prep_plot_feature(adata, this_hit)
prep_plot_feature(adata, this_hit, filter_zeros=True)

##### max significant by estimate (increasing)

In [None]:
sig_results = glmmtmb_results.loc[glmmtmb_results['fdr_bh'] < 0.05]
this_results = sig_results.loc[sig_results['estimate'] == max(sig_results['estimate'])]
this_hit = this_results.sort_values(by=['estimate'], ascending=False).iloc[0]
prep_plot_feature(adata, this_hit)
prep_plot_feature(adata, this_hit, filter_zeros=True)

##### random results

In [None]:
# random
this_hit = sig_results.sample().iloc[0]
prep_plot_feature(adata, this_hit)
prep_plot_feature(adata, this_hit, filter_zeros=True)

##### max non-significat by coef (increasing)

In [None]:
nonsig_results = glmmtmb_results.loc[(glmmtmb_results['fdr_bh'] > 0.05) & 
                                     (~glmmtmb_results['statistic'].isna())]
this_results = nonsig_results.loc[nonsig_results['estimate'] == max(nonsig_results['estimate'])]
this_hit = this_results.sort_values(by=['estimate'], ascending=True).iloc[0]
prep_plot_feature(adata, this_hit)
prep_plot_feature(adata, this_hit, filter_zeros=True)

In [None]:
!date