## Notebook to run correlation between age and feature quantification (expression and accessibiltiy) in single-cell data using GLM model and pseudo-bulk quantifications per sample

based on some of the observations related to pseudo-replicate and zero-inflation from

[Zimmerman KD, Espeland MA, Langefeld CD. A practical solution to pseudoreplication bias in single-cell studies. Nat Commun 2021;12:738.](https://pubmed.ncbi.nlm.nih.gov/33531494/)

[Murphy AE, Skene NG. A balanced measure shows superior performance of pseudobulk methods in single-cell RNA-sequencing analysis. Nat Commun 2022;13:7851.](https://pubmed.ncbi.nlm.nih.gov/36550119/)

In [None]:
!date

#### import libraries

In [None]:
from anndata import AnnData
import numpy as np
from pandas import DataFrame as PandasDF, concat, read_csv
from polars import read_parquet, DataFrame as PolarsDF, col as pl_col
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import statsmodels.api as sm
import statsmodels.formula.api as smf
from numba import jit
from sklearn.preprocessing import MinMaxScaler
from multiprocessing import Process

import warnings
warnings.simplefilter('ignore')

import random
random.seed(420)

#### set notebook variables

In [None]:
# parameters
category = 'curated_type' # 'curated_type' for broad and 'cluster_name' for specific
modality = 'GEX' # 'GEX' or 'ATAC'

In [None]:
# naming
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
anndata_file = f'{quants_dir}/{project}.multivi.curated.h5ad'
if modality == 'GEX':
    quants_file = f'{quants_dir}/{project}.multivi_norm_exp.parquet'
elif modality == 'ATAC':
    quants_file = f'{quants_dir}/{project}.multivi_peak_est.parquet'    

# out files

# variables
DEBUG = True
TESTING = True
TEST_FEATURE_SIZE = 1000

#### analysis functions

In [None]:
def subset_data(adata: AnnData, quants: PolarsDF, cell_name: str, obs_col: str, 
                reapply_filter: bool=True, min_cell_count: int=3, 
                verbose: bool=False) -> tuple[PandasDF, PolarsDF]:
    sub_adata = adata[(adata.obs[obs_col] == cell_name)].copy()
    if reapply_filter:
        sc.pp.filter_genes(sub_adata, min_counts=min_cell_count)
        sc.pp.filter_cells(sub_adata, min_counts=min_cell_count)
    # now that have subset of anndata, grab the same cells and features from the separate data matrix
    features = list(sub_adata.var.index.values)
    features.append('barcode')
    sub_quants = (quants_df.select(features)
                  .filter(pl_col('barcode').is_in(sub_adata.obs.index.values)))
    return sub_adata.obs, sub_quants

def scale_dataframe(this_df : DataFrame):
    scaledX = MinMaxScaler().fit_transform(this_df)
    scaled_df = DataFrame(data=scaledX, columns=this_df.columns, 
                          index=this_df.index) 
    return scaled_df 

def convert_ad_to_df(data: AnnData, young_age_limit: float=30.0, 
                     scale: bool=True, verbose: bool=False) -> DataFrame:
    data_df = data.to_df(SCVI_NORMALIZED_KEY)
    if scale:
        data_df = scale_dataframe(data_df)
    annots = data.obs[['Sample_ID', 'Age','Sex']].copy()
    annots['old'] = np.where((annots['Age'] > young_age_limit), 1, 0)
    annots['female'] = np.where((annots['Sex'] == 'female'), 1, 0)
    this_df = None
    if data_df.index.equals(annots.index):
        this_df = concat([data_df, annots], axis='columns')
        this_df.index.name = 'barcodekey'
        if verbose:
            print(f'anndata to pandas df complete: {this_df.shape}')
            print(this_df.shape)
            display(this_df.head())
    return this_df

def feature_detected(feature_col, features: list=None, df: DataFrame=None, 
                     min_cell_count: int=3, min_sample_det_rate: float=0.5,
                     verbose: bool=False):    
    good_feature = True
    if feature_col.name in features:
        nz_df = feature_col[feature_col > 0]
        ok_cnts = df.loc[nz_df.index].Sample_ID.value_counts() > min_cell_count
        ok_sample_cnt = ok_cnts.sum()
        unique_sample_id_count = df.Sample_ID.nunique()
        good_feature = ok_sample_cnt / unique_sample_id_count >= min_sample_det_rate
        if verbose:
            print(feature_col.name, end=', ')
            print(f'nz_df.shape = {nz_df.shape}', end=', ')
            print(f'{ok_sample_cnt}/{unique_sample_id_count}', end=', ')
            print(good_feature)
    return good_feature

def poorly_detected_features(features: list=None, df: DataFrame=None, 
                             verbose=False) -> list:
    feature_detect_df = df.apply(feature_detected, features=features, df=df)
    bad_features = feature_detect_df.loc[~feature_detect_df].index.to_list()
    if verbose:
        print(f'bad features counts is {len(bad_features)}')
    return bad_features

def glm_model(formula: str, df: DataFrame, use_tweedie: bool=True):
    if use_tweedie:
        model = smf.glm(formula=formula, data=df, 
                        family=sm.families.Tweedie(link=sm.families.links.log(), 
                                                   var_power=1.6, 
                                                   eql=True))
    else:
        model = smf.glm(formula=formula, data=df)
    result = model.fit()
    return result

def compute_frmt_pb(df: DataFrame, feature: str) -> DataFrame:
    ret_df = df[[feature, 'Sample_ID']].groupby('Sample_ID').mean()
    ret_df['cell_count'] = df[[feature, 'Sample_ID']].groupby('Sample_ID').count()
    ret_df = ret_df.merge(df[['Sample_ID', 'old', 'female']].drop_duplicates(), 
                          how='left', left_index=True, right_on='Sample_ID')
    return ret_df

def glm_age(df: DataFrame, feature: str, verbose: bool=False) -> tuple:
    dep_term = feature
    indep_term = 'old'
    this_formula = f'Q("{dep_term}") ~ {indep_term} + female + cell_count'
    # just drop zeros 
    try:
        pb_df = compute_frmt_pb(df, feature)
        # run GLM via statsmodel
        result = glm_model(this_formula, pb_df)
        ret_list = [dep_term, result.params['Intercept'], 
                    result.params[indep_term], result.bse[indep_term], 
                    result.tvalues[indep_term], result.pvalues[indep_term]]
        if verbose:
            print(f'df shape {df.shape}')
            print(f'non-zero df shape {pb_df.shape}')
            print(result.summary())
            print(['feature', 'intercept', 'coef', 'stderr', 'z', 'p-value'])
            print(ret_list)
    except:
#         print(f'Caught Error for {dep_term}')
        ret_list = [dep_term] + [np.nan] * 6
  
    return ret_list

def regress_age(obs_pdf: PandasDF, data_plf: PolarsDF, cell_name: str, cat_type: str,
                verbose: bool=False) -> PandasDF:
    if verbose:
        print('converting anndata to pandas df')        
    type_df = convert_ad_to_df(data)
    if verbose:
        print(f'finding poorly detected features from cells x features {type_df.shape}')    
    bad_features = poorly_detected_features(data.var.index.values, type_df)
    type_clean_df = type_df.drop(columns=bad_features)
    keep_features = set(data.var.index) & set(type_clean_df.columns)
    type_clean_ad = data[:,list(keep_features)] 
    features_set = set(type_clean_ad.var.index) & set(type_clean_df.columns)    
    type_results = [glm_age(type_clean_df, feature) for feature in features_set]
    results_df = DataFrame(data=type_results, 
                              columns=['feature', 'intercept', 'coef', 
                                       'stderr', 'z', 'p-value', 'log2_fc'])
    results_df['tissue'] = cell_name
    results_df['type'] = cat_type
    save_results(results_df, cell_name)
    if verbose:
        print(f'done', end='. ')

def regression_wrapper(data: AnnData, cell_name: str):
    regress_age(data, cell_name)
    
def save_results(df: DataFrame, cell_name: str):
    out_file = f'{replication_dir}/{cell_name.replace(" ", "_")}_glm_pb_age_diffs.csv'
    df.to_csv(out_file, index=False)


### load data
read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))

#### take a look at the cell counts by cell type

In [None]:
display(adata.obs[category].value_counts())

In [None]:
# sc.pl.umap(adata, color=[celltype_obs_feature], legend_loc='on data')
with rc_context({'figure.figsize': (9, 9)}):
    sc.pl.umap(adata, color=[category], legend_loc='on data', 
               add_outline=True, legend_fontsize=10)

### load the quantfied features
either expression (GEX) and chromatin accessibiltiy (ATAT) base on the modality that was specified

In [None]:
%%time
quants_df = read_parquet(quants_file)
# for a polars dataframe read from parquent need to fix index column name
quants_df = quants_df.rename({'__index_level_0__': 'barcode'})
print(f'shape of quantified {modality} features: {quants_df.shape}')
if DEBUG:
    display(quants_df.sample(5))

### if testing notebooks for debugging purpose subset the features

In [None]:
if TESTING:
    if modality == 'GEX':
        features = random.sample(list(adata.var.loc[adata.var.modality == 'Gene Expression'].index.values),
                                 TEST_FEATURE_SIZE)
    elif modality == 'ATAC':
        features = random.sample(list(adata.var.loc[adata.var.modality == 'Peaks'].index.values),
                                 TEST_FEATURE_SIZE)
    adata = adata[:,features]
    # need to keep the barcode as well
    features.append('barcode')    
    quants_df = quants_df.select(features)
    print(adata)
    display(adata.var.modality.value_counts())
    print(f'shape of testing quants dataframe is: {quants_df.shape}')
    display(quants_df.sample(5))

### for each cell-type compute the differential expression 
using pseudobulk and GLM

parallelized by tissue

In [None]:
%%time

cmds = {}
for cell_type in adata.obs[category].unique():
    print(cell_type)
    obs_sub, quants_sub = subset_data(adata, quants_df, cell_type, category)
    regress_age(obs_sub, quants_sub, cell_name, category)
#     p = Process(target=diffexp_group_wrapper,args=(adata_sub, cell_name))
#     p.start()
#     # Append process and key to keep track
#     cmds[cell_type] = p    
#     # diffexp_group(adata_sub, cell_name)
# # Wait for all processes to finish
# for key, p in cmds.items():
#     p.join()    

In [None]:
import sys
print(sys.getsizeof(quants_df))
quants_pdf = quants_df.to_pandas()
print(sys.getsizeof(quants_pdf))

In [None]:
!date