## Notebook to run differential expression in single-cell data using GLM model and pseudo-bulk quantifications per sample

based on some of the observations related to pseudo-replicate and zero-inflation from

[Zimmerman KD, Espeland MA, Langefeld CD. A practical solution to pseudoreplication bias in single-cell studies. Nat Commun 2021;12:738.](https://pubmed.ncbi.nlm.nih.gov/33531494/)


In [None]:
!date

#### import libraries

In [None]:
from anndata import AnnData
import numpy as np
from pandas import DataFrame, concat, read_csv
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import statsmodels.api as sm
import statsmodels.formula.api as smf
from numba import jit
from sklearn.preprocessing import MinMaxScaler
from multiprocessing import Process

import warnings
warnings.simplefilter('ignore')

import random
random.seed(420)

#### set notebook variables

In [None]:
# naming
project = 'aging_phase1'
set_name = f'{project}_replication'

# directories for initial setup
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase1'
replication_dir = f'{wrk_dir}/replication'

# in files
anndata_file = f'{replication_dir}/{set_name}.scvi.h5ad'

# out files

# constants
DEBUG = True
SCVI_NORMALIZED_KEY = 'scvi_normalized'
TESTING=False
TEST_FEATURE_SIZE = 1000

#### analysis functions

In [None]:
def subset_anndata(data: AnnData, cell_name: str, reapply_filter: bool=True, 
                   min_cell_count: int=3, verbose: bool=False) -> AnnData:
    this_data = data[(data.obs.Cell_type == cell_name)].copy()
    shape_before = this_data.shape
    if reapply_filter:
        sc.pp.filter_genes(this_data, min_counts=min_cell_count)
        sc.pp.filter_cells(this_data, min_counts=min_cell_count)
        shape_after = this_data.shape
    if verbose:
        print(f'subset complete, shape before and after: {shape_before} {shape_after}')
        print(this_data)
    return this_data

def scale_dataframe(this_df : DataFrame):
    scaledX = MinMaxScaler().fit_transform(this_df)
    scaled_df = DataFrame(data=scaledX, columns=this_df.columns, 
                          index=this_df.index) 
    return scaled_df 

def convert_ad_to_df(data: AnnData, young_age_limit: float=30.0, 
                     scale: bool=True, verbose: bool=False) -> DataFrame:
    data_df = data.to_df(SCVI_NORMALIZED_KEY)
    if scale:
        data_df = scale_dataframe(data_df)
    annots = data.obs[['Sample_ID', 'Age','Sex']].copy()
    annots['old'] = np.where((annots['Age'] > young_age_limit), 1, 0)
    annots['female'] = np.where((annots['Sex'] == 'female'), 1, 0)
    this_df = None
    if data_df.index.equals(annots.index):
        this_df = concat([data_df, annots], axis='columns')
        this_df.index.name = 'barcodekey'
        if verbose:
            print(f'anndata to pandas df complete: {this_df.shape}')
            print(this_df.shape)
            display(this_df.head())
    return this_df

def feature_detected(feature_col, features: list=None, df: DataFrame=None, 
                     min_cell_count: int=3, min_sample_det_rate: float=0.5,
                     verbose: bool=False):    
    good_feature = True
    if feature_col.name in features:
        nz_df = feature_col[feature_col > 0]
        ok_cnts = df.loc[nz_df.index].Sample_ID.value_counts() > min_cell_count
        ok_sample_cnt = ok_cnts.sum()
        unique_sample_id_count = df.Sample_ID.nunique()
        good_feature = ok_sample_cnt / unique_sample_id_count >= min_sample_det_rate
        if verbose:
            print(feature_col.name, end=', ')
            print(f'nz_df.shape = {nz_df.shape}', end=', ')
            print(f'{ok_sample_cnt}/{unique_sample_id_count}', end=', ')
            print(good_feature)
    return good_feature

def poorly_detected_features(features: list=None, df: DataFrame=None, 
                             verbose=False) -> list:
    feature_detect_df = df.apply(feature_detected, features=features, df=df)
    bad_features = feature_detect_df.loc[~feature_detect_df].index.to_list()
    if verbose:
        print(f'bad features counts is {len(bad_features)}')
    return bad_features

def glm_model(formula: str, df: DataFrame, verbose: bool=False, 
              use_tweedie: bool=True):
    if use_tweedie:
        model = smf.glm(formula=formula, data=df, 
                        family=sm.families.Tweedie(link=None, var_power=1.6, 
                                                   eql=True))
    else:
        model = smf.glm(formula=formula, data=df)
    result = model.fit()
    if verbose:
        print(result.summary())
    return result

@jit(nopython=True)
def compute_fold_change(intercept: float, coef: float) -> float:
    if coef > 0:
        fc = np.log2((intercept + coef)/intercept)
    else:
        fc = -np.log2(intercept/(intercept - abs(coef)))
    return fc

def compute_frmt_pb(df: DataFrame, feature: str) -> DataFrame:
    ret_df = df[[feature, 'Sample_ID']].groupby('Sample_ID').mean()
    ret_df = ret_df.merge(df[['Sample_ID', 'old', 'female']].drop_duplicates(), 
                          how='left', left_index=True, right_on='Sample_ID')
    return ret_df

def glm_diff_expr_age(df: DataFrame, feature: str, verbose: bool=False) -> tuple:
    dep_term = feature
    indep_term = 'old'
    this_formula = f'Q("{dep_term}") ~ {indep_term} + female'
    # just drop zeros 
    try:
        pb_df = compute_frmt_pb(df, feature)
        # run GLM via statsmodel
        result = glm_model(this_formula, pb_df, use_tweedie=False)
        fold_change = compute_fold_change(result.params['Intercept'], 
                                          result.params[indep_term])
        ret_list = [dep_term, result.params['Intercept'], 
                    result.params[indep_term], result.bse[indep_term], 
                    result.tvalues[indep_term], result.pvalues[indep_term], 
                    fold_change]
        if verbose:
            print(f'df shape {df.shape}')
            print(f'non-zero df shape {pb_df.shape}')
            print(result.summary())
            print(['feature', 'intercept', 'coef', 'stderr', 'z', 'p-value', 'log2_fc'])
            print(ret_list)
    except:
#         print(f'Caught Error for {dep_term}')
        ret_list = [dep_term] + [np.nan] * 6
  
    return ret_list

def diff_exp_of_features(df: DataFrame, features: set) -> list:
    results = []
    for feature in features:
        results.append(glm_diff_expr_age(df, feature))         
    return results

def diffexp_group(data: AnnData, cell_name: str, min_cell_count: int=3, 
                  verbose: bool=False) -> DataFrame:
    if verbose:
        print('converting anndata to pandas df')        
    type_df = convert_ad_to_df(data)
    if verbose:
        print(f'finding poorly detected features from cells x features {type_df.shape}')    
    bad_features = poorly_detected_features(data.var.index.values, type_df)
    type_clean_df = type_df.drop(columns=bad_features)
    keep_features = set(data.var.index) & set(type_clean_df.columns)
    type_clean_ad = data[:,list(keep_features)] 
    features_set = set(type_clean_ad.var.index) & set(type_clean_df.columns)    
    type_results = diff_exp_of_features(type_clean_df, features_set)
    results_df = DataFrame(data=type_results, 
                              columns=['feature', 'intercept', 'coef', 
                                       'stderr', 'z', 'p-value', 'log2_fc'])
    results_df['tissue'] = cell_name
    results_df['type'] = 'region_broad_celltype'
    save_results(results_df, cell_name)
    if verbose:
        print(f'done', end='. ')

def diffexp_group_wrapper(data: AnnData, cell_name: str):
    diffexp_group(data, cell_name)
    
def save_results(df: DataFrame, cell_name: str):
    out_file = f'{replication_dir}/{cell_name.replace(" ", "_")}_glm_pb_age_diffs.csv'
    df.to_csv(out_file, index=False)


### load replication data
read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))

#### take a look at the cell counts by cell type

In [None]:
display(adata.obs.Cell_type.value_counts())

In [None]:
display(adata.obs.Sex.value_counts())

In [None]:
# sc.pl.umap(adata, color=[celltype_obs_feature], legend_loc='on data')
with rc_context({'figure.figsize': (9, 9)}):
    sc.pl.umap(adata, color=['Cell_type'], legend_loc='on data', 
               add_outline=True, legend_fontsize=10)

### if testing notebooks for debugging purpose subset the features

In [None]:
if TESTING:
    features = random.sample(list(adata.var.index.values), TEST_FEATURE_SIZE)
    adata = adata[:,features]
    print(adata)

### for each cell-type compute the differential expression 
using pseudobulk and GLM

parallelized by tissue

In [None]:
%%time

cmds = {}
for cell_type in adata.obs.Cell_type.unique():
    print(cell_type)
    adata_sub = subset_anndata(adata, cell_type)
    cell_name = f'Frontal_cortex_{cell_type}'
    # diffexp_group(adata_sub, cell_name)
    p = Process(target=diffexp_group_wrapper,args=(adata_sub, cell_name))
    p.start()
    # Append process and key to keep track
    cmds[cell_type] = p    
    # diffexp_group(adata_sub, cell_name)
# Wait for all processes to finish
for key, p in cmds.items():
    p.join()    

#### save the results

In [None]:
# results_df.to_csv(results_file, index=False)

In [None]:
!date