## Notebook to to convert cell data to pseudobulk based on broad cell-types ('curated_type') and cluster specific cell-types ('cluster_name')

In [None]:
!date

#### import libraries

In [None]:
from anndata import AnnData
from pandas import DataFrame as PandasDF, concat
from polars import read_parquet, DataFrame as PolarsDF, col as pl_col
import scanpy as sc
from matplotlib.pyplot import rc_context
from sklearn.preprocessing import MinMaxScaler
from multiprocessing import Process

import warnings
warnings.simplefilter('ignore')

import random
random.seed(420)

#### set notebook variables

In [None]:
# parameters
modality = 'GEX' # 'GEX' or 'ATAC'

In [None]:
# naming
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'

# in files
anndata_file = f'{quants_dir}/{project}.multivi.curated.h5ad'
if modality == 'GEX':
    quants_file = f'{quants_dir}/{project}.multivi_norm_exp.parquet'
elif modality == 'ATAC':
    quants_file = f'{quants_dir}/{project}.multivi_peak_est.parquet'    

# out files

# variables
DEBUG = False
TESTING = False
TEST_FEATURE_SIZE = 1000
categories = ['curated_type', 'cluster_name'] # 'curated_type' for broad and 'cluster_name' for specific

#### analysis functions

In [None]:
def subset_data(adata: AnnData, quants: PolarsDF, cell_name: str, obs_col: str, 
                reapply_filter: bool=True, min_cell_count: int=3, 
                verbose: bool=False) -> tuple[PandasDF, PolarsDF]:
    sub_adata = adata[(adata.obs[obs_col] == cell_name)].copy()
    if reapply_filter:
        sc.pp.filter_genes(sub_adata, min_counts=min_cell_count)
        sc.pp.filter_cells(sub_adata, min_counts=min_cell_count)
    # now that have subset of anndata, grab the same cells and features from the separate data matrix
    features = list(sub_adata.var.index.values)
    # if no features left nothing to convert
    if len(features) == 0:
        return None, None
    features.append('barcode')
    sub_quants = (quants.select(features)
                  .filter(pl_col('barcode').is_in(sub_adata.obs.index.values)))
    return sub_adata.obs, sub_quants

def scale_dataframe(this_df : PolarsDF) -> PandasDF:
    scaledX = MinMaxScaler().fit_transform(this_df.drop('barcode').to_pandas())
    scaled_df = PandasDF(data=scaledX, columns=this_df.drop('barcode').columns, 
                         index=this_df['barcode'])
    return scaled_df 

def feature_detected(feature_col, features: list=None, df: PandasDF=None,
                     min_cell_count: int=3, min_sample_det_rate: float=0.5,
                     verbose: bool=False):
    if feature_col.name not in features:
        return False  # Early exit for efficiency
    nz_df = feature_col[feature_col > 0]
    if len(nz_df) < min_cell_count:
        return False  # Early exit if insufficient non-zero cells
    ok_sample_cnt = (df.loc[nz_df.index, 'sample_id'].value_counts() >= min_cell_count).sum()
    unique_sample_id_count = df['sample_id'].nunique()
    good_feature = ok_sample_cnt / unique_sample_id_count >= min_sample_det_rate
    if verbose:
        print((f'{feature_col.name}, nz_df.shape={nz_df.shape}, '
               f'{ok_sample_cnt}/{unique_sample_id_count}, {good_feature}'))
    return good_feature

def remove_poorly_detected_features(features: list=None, df: PandasDF=None, 
                                    info_df: PandasDF=None, verbose=False) -> PandasDF:
    if df.index.equals(info_df.index):
        df = concat([df, info_df['sample_id']], axis='columns')
    else:
        print('indices unequal could not add sample_id column')
        return
    feature_detect_df = df.apply(feature_detected, features=features, df=df)
    bad_features = feature_detect_df.loc[~feature_detect_df].index.to_list()
    if verbose:
        print(f'bad features counts is {len(bad_features)}')
    return df.drop(columns=bad_features)

def compute_frmt_pb(df: PandasDF=None, 
                    info_df: PandasDF=None) -> tuple[PandasDF, PandasDF]:
    if df.index.equals(info_df.index):
        df = concat([df, info_df['sample_id']], axis='columns')
    else:
        print('indices unequal could not add sample_id column')
        return    
    ret_df = df.groupby('sample_id').mean()
    ret_df['cell_count'] = df.sample_id.value_counts()
    return ret_df

def pseudobulk_conversion(quants: PolarsDF, info_df: PolarsDF, cell_name: str, 
                          features: list, category: str):
    # minmax scale the subset
    quants = scale_dataframe(quants)
    # drop any poorly detected features across samples
    quants = remove_poorly_detected_features(features, quants, info_df)
    pseudo_bulk = compute_frmt_pb(quants, info_df)
    # save the pseudo bulk data for the cell-type
    save_results(pseudo_bulk, cell_name, category)
    
def save_results(df: PandasDF, cell_name: str, category: str):
    if category == 'curated_type':
        prefix_type = 'broad'
    elif category == 'cluster_name':
        prefix_type = 'specific'    
    out_file = f'{quants_dir}/{project}.{modality}.{prefix_type}.{cell_name}.pb.parquet'
    df.to_parquet(out_file)


### load data
read the anndata (h5ad) file

In [None]:
%%time
adata = sc.read(anndata_file, cache=True)
print(adata)
if DEBUG:
    display(adata.obs.sample(5))

### subset feature set by modality (GEX or ATAc)

In [None]:
if modality == 'GEX':
    features = list(adata.var.loc[adata.var.modality == 'Gene Expression'].index.values)
elif modality == 'ATAC':
    features = list(adata.var.loc[adata.var.modality == 'Peaks'].index.values)
adata = adata[:,features]
print(adata)
display(adata.var.modality.value_counts())

#### take a look at the cell counts by cell type

In [None]:
for category in categories:
    display(adata.obs[category].value_counts())

In [None]:
for category in categories:
    with rc_context({'figure.figsize': (9, 9)}):
        sc.pl.umap(adata, color=[category], legend_loc='on data', 
                   add_outline=True, legend_fontsize=10)

### load the quantfied features
either expression (GEX) and chromatin accessibiltiy (ATAT) base on the modality that was specified

In [None]:
%%time
quants_df = read_parquet(quants_file)
# for a polars dataframe read from parquent need to fix index column name
quants_df = quants_df.rename({'__index_level_0__': 'barcode'})
print(f'shape of quantified {modality} features: {quants_df.shape}')
if DEBUG:
    display(quants_df.sample(5))

### if testing notebooks for debugging purpose subset the features

In [None]:
if TESTING:
    if modality == 'GEX':
        features = random.sample(list(adata.var.loc[adata.var.modality == 'Gene Expression'].index.values),
                                 TEST_FEATURE_SIZE)
    elif modality == 'ATAC':
        features = random.sample(list(adata.var.loc[adata.var.modality == 'Peaks'].index.values),
                                 TEST_FEATURE_SIZE)
    adata = adata[:,features]
    # need to keep the barcode as well
    features.append('barcode')    
    quants_df = quants_df.select(features)
    print(adata)
    display(adata.var.modality.value_counts())
    print(f'shape of testing quants dataframe is: {quants_df.shape}')
    display(quants_df.sample(5))

### for each cell-type pseudobulk (mean) convert the the single-cell quantifications

parallelized by cell-type

In [None]:
%%time

for category in categories:
    print(f'#### processing {category}')
    cmds = {}
    for cell_type in adata.obs[category].unique():
        print(f'--- {cell_type}')
        # subset by cell-type or cluster
        obs_sub, quants_sub = subset_data(adata, quants_df, cell_type, category)
        if quants_sub is None:
            print(f'{cell_type} is empty, skipping')
            continue
        # pseudobulk_conversion(quants_sub, obs_sub, cell_type, list(adata.var.index.values), category)
        p = Process(target=pseudobulk_conversion,args=(quants_sub, obs_sub, 
                                                       cell_type, list(adata.var.index.values),
                                                       category))
        p.start()
        # Append process and key to keep track
        cmds[cell_type] = p    
        # diffexp_group(adata_sub, cell_name)
    # Wait for all processes to finish
    for key, p in cmds.items():
        p.join()

In [None]:
!date