## Notebook to run correlation between age and feature quantification (expression and accessibiltiy) in single-cell data using GLM model and pseudo-bulk quantifications per sample

based on some of the observations related to pseudo-replicate and zero-inflation from

[Zimmerman KD, Espeland MA, Langefeld CD. A practical solution to pseudoreplication bias in single-cell studies. Nat Commun 2021;12:738.](https://pubmed.ncbi.nlm.nih.gov/33531494/)

[Murphy AE, Skene NG. A balanced measure shows superior performance of pseudobulk methods in single-cell RNA-sequencing analysis. Nat Commun 2022;13:7851.](https://pubmed.ncbi.nlm.nih.gov/36550119/)

In [None]:
!date

#### import libraries

In [None]:
from anndata import AnnData
import numpy as np
from pandas import DataFrame as PandasDF, read_parquet
import scanpy as sc
from matplotlib.pyplot import rc_context
import statsmodels.api as sm
import statsmodels.formula.api as smf
from multiprocessing import Process
from os.path import exists

import warnings
warnings.simplefilter('ignore')

import random
random.seed(420)

#### set notebook variables

In [None]:
# parameters
category = 'cluster_name' # 'curated_type' for broad and 'cluster_name' for specific
modality = 'GEX' # 'GEX' or 'ATAC'

In [None]:
# naming
project = 'aging_phase2'
if category == 'curated_type':
    prefix_type = 'broad'
elif category == 'cluster_name':
    prefix_type = 'specific' 

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
anndata_file = f'{quants_dir}/{project}.multivi.curated_final.h5ad'    

# out files

# variables
DEBUG = False
TESTING = False
TEST_FEATURE_SIZE = 1000
covariate_terms = ['sex', 'ancestry', 'pmi', 'ph', 'smoker', 'bmi', 'pool']
covar_term_formula = ' + '.join(covariate_terms)
if DEBUG:
    print(covar_term_formula)

#### functions

In [None]:
def load_quantification(cell_name: str, verbose: bool=False) -> PandasDF:
    this_file = f'{quants_dir}/{project}.{modality}.{prefix_type}.{cell_type}.pb.parquet'
    if not exists(this_file):
        return None
    df = read_parquet(this_file)
    if verbose:
        print(f'shape of read {cell_name} quantifications {df.shape}')        
        display(df.sample(5))
    return df

def glm_model(formula: str, df: PandasDF, use_tweedie: bool=True):
    if use_tweedie:
        model = smf.glm(formula=formula, data=df, 
                        family=sm.families.Tweedie(link=sm.families.links.log(), 
                                                   var_power=1.6, 
                                                   eql=True))
    else:
        model = smf.glm(formula=formula, data=df)
    result = model.fit()
    return result

def glm_age(df: PandasDF, feature: str, verbose: bool=False) -> tuple:
    endo_term = feature
    exog_term = 'age'
    model_terms = [endo_term, exog_term] + covariate_terms + ['cell_count']
    this_formula = f'Q("{endo_term}") ~ {exog_term} + {covar_term_formula} + cell_count'
    try:
        # run GLM via statsmodel
        result = glm_model(this_formula, df[model_terms])
        ret_list = [endo_term, result.params['Intercept'], 
                    result.params[exog_term], result.bse[exog_term], 
                    result.tvalues[exog_term], result.pvalues[exog_term]]
        if verbose:
            print(f'df shape {df.shape}')
            print(result.summary())
            print(['feature', 'intercept', 'coef', 'stderr', 'z', 'p-value'])
            print(ret_list)
    except:
#         print(f'Caught Error for {endo_term}')
        ret_list = [endo_term] + [np.nan] * 6
  
    return ret_list

def regress_age(quants: PandasDF, covars: PandasDF, cell_name: str, 
                cat_type: str) -> PandasDF:
    # cell count covariate is in the quantified features
    features_set = set(quants.columns)
    features_set.remove('cell_count')
    data_df = quants.merge(covars, how='inner', left_index=True, right_index=True)
    type_results = [glm_age(data_df, feature) for feature in features_set]
    results_df = PandasDF(data=type_results,
                          columns=['feature', 'intercept', 'coef', 'stderr', 
                                    'z', 'p-value'])
    results_df['tissue'] = cell_name
    results_df['type'] = cat_type
    save_results(results_df, cell_name)
        
def save_results(df: PandasDF, cell_name: str):
    out_file = f'{results_dir}/{project}.{modality}.{prefix_type}.{cell_name}.glm_age.csv'
    df.to_csv(out_file, index=False)

def subset_for_test(df: PandasDF, feature_cnt: int) -> PandasDF:
    if len(quants_df.columns) < feature_cnt:
        return quants_df
        
    features = random.sample(list(df.columns), feature_cnt)
    if not 'cell_count' in features:
        features.append('cell_count')
    return df[features]

def check_detection(data_df: PandasDF, total_cnt: int, min_cell_cnt: int=3, min_sample_frac: float=0.3) -> {bool, float}:
    detected_df = data_df.loc[data_df.cell_count >= min_cell_cnt].cell_count
    ret_cnt = detected_df.shape[0]
    if ret_cnt >= round(min_sample_frac * total_cnt, 0):
        ret_check = True
    else:
        ret_check = False
    return ret_check, ret_cnt


### load data
read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))

#### take a look at the cell counts by cell type

In [None]:
display(adata.obs[category].value_counts())

In [None]:
# sc.pl.umap(adata, color=[celltype_obs_feature], legend_loc='on data')
with rc_context({'figure.figsize': (9, 9)}):
    sc.pl.umap(adata, color=[category], legend_loc='on data', 
               add_outline=True, legend_fontsize=10)

### format sample covariates

sex, ancestry, age, (gex_pool or atac_pool), pmi, ph, smoker, bmi

In [None]:
keep_terms = ['sample_id','sex', 'ancestry', 'age', 'gex_pool', 'atac_pool', 
              'pmi', 'ph', 'smoker', 'bmi']
covars_df = adata.obs[keep_terms].drop_duplicates().reset_index(drop=True)
covars_df = covars_df.set_index('sample_id')

if DEBUG:
    print(covars_df.shape)
    display(covars_df.head())
    display(covars_df.info())
    display(covars_df.smoker.value_counts())
    display(covars_df.bmi.describe())

#### fill any missing covariate terms
looks like smoker and bmi is missing for one sample will set it to mean of those values

In [None]:
# fill the missing smoker and bmi value
covars_df.loc[covars_df.smoker.isna(), 'smoker'] = covars_df.smoker.mean().round(1)
covars_df.loc[covars_df.bmi.isna(), 'bmi'] = covars_df.bmi.mean().round(1)

if DEBUG:
    print(covars_df.shape)
    display(covars_df.info())
    display(covars_df.smoker.value_counts())
    display(covars_df.bmi.describe())

#### set the pool term based on modality being analyzed

In [None]:
if modality == 'GEX':
    covars_df['pool'] = covars_df.gex_pool
elif modality == 'ATAC':
    covars_df['pool'] = covars_df.atac_pool
covars_df = covars_df.drop(columns=['gex_pool', 'atac_pool'])
print(f'shape of covariate terms is {covars_df.shape}')
if DEBUG:
    display(covars_df.head(40))

### check if age, exogenous variable, is correlated with any ouf the covariate terms

none of the terms appear to have a statistically significant correlation with age

In [None]:
this_formula = f'age ~ {covar_term_formula}'
print(this_formula)
model = smf.glm(formula=this_formula, data=covars_df)
result = model.fit()
display(result.summary())

### for each cell-type compute the differential expression 
using pseudobulk and GLM

parallelized by tissue

In [None]:
%%time

cmds = {}
for cell_type in adata.obs[category].unique():
    print(cell_type)
    quants_df = load_quantification(cell_type)    
    # if empty skip
    if quants_df is None or quants_df.shape[1] < 2:
        print(f'nothing to do for {cell_type} skipping')
        continue
    # if not enough samples skip
    this_check, this_cnt = check_detection(quants_df, covars_df.shape[0])
    if not this_check:
        print(f'skipping {cell_type}, cells from only {this_cnt} samples')
        continue
    if TESTING:
        quants_df = subset_for_test(quants_df, TEST_FEATURE_SIZE)
        print(quants_df.shape)
        display(quants_df.sample(5))
    # regress_age(quants_df, covars_df, cell_type, category)
    p = Process(target=regress_age,args=(quants_df, covars_df, cell_type, category))
    p.start()
    # Append process and key to keep track
    cmds[cell_type] = p    
    # diffexp_group(adata_sub, cell_name)
# Wait for all processes to finish
for key, p in cmds.items():
    p.join()    

In [None]:
!date