## Detect cell states based on marker-set gene sum scoring

Roughly based on method from: <br>
[Gross PS, Durán-Laforet V, Ho LT et al. Senescent-like microglia limit remyelination through the senescence associated secretory phenotype. Nat Commun 2025;16:2283.](https://pubmed.ncbi.nlm.nih.gov/40055369/)

In [None]:
!date

#### import libraries

In [None]:
import scanpy as sc
from anndata import AnnData
import json
import numpy as np
from pandas import Series, DataFrame
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
from kneed import KneeLocator
import statsmodels.api as sm
from seaborn import regplot, barplot

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# variables
project = 'aging_phase2'
DEBUG = True
DPI_VALUE = 100

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'

# in files
anndata_file = f'{quants_dir}/{project}.multivi.curated_final.h5ad'
markers_json = '/home/gibbsr/working/ADRD_Brain_Aging/phase2/development/analyses/gene_sets.json'

if DEBUG:
    print(f'{anndata_file=}')
    print(f'{markers_json=}')

#### functions

In [None]:
def peek_dataframe(df: DataFrame, verbose: bool=False):
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def load_marker_set(marker_genes_json: str, adata: AnnData, 
                    verbose: bool=False) -> (set, dict):
    markers_dict = None
    markers = None
    possible_features = set(adata.var.index.values)
    print(f'{len(possible_features)} features are present')        
    
    with open(marker_genes_json, 'r') as json_file:
        markers_dict = json.load(json_file)
    # get the set of all markers across the the cell-types
    markers = {item for sublist in markers_dict.values() for item in sublist}
    print(f'{len(markers)} marker features loaded')
    # find the marker genes that are present in the current HV features
    missing_markers = markers - possible_features
    print(f'missing {len(missing_markers)} markers: {missing_markers}')
    # drop the markers missing for the current HV features
    markers = markers & possible_features
    print(f'{len(markers)} marker features found')
    if verbose:
        print(f'markers found: {markers}')
    # update cell-type markers dict to drop any of the missing markers
    list_keys_to_delete = []
    for cell_type, marker_list in markers_dict.items():
        new_list = list(set(marker_list) & markers)
        if len(new_list) > 0:
            markers_dict[cell_type] = new_list
        else:
            list_keys_to_delete.append(cell_type)
    for cell_type in list_keys_to_delete:
        markers_dict.pop(cell_type)
    
    return markers, markers_dict

def sum_gene_features(adata: AnnData, gene_set) -> Series:
    gene_indices = [list(adata.var_names).index(gene) for gene in gene_set]
    # Filter the expression data to include only genes in the gene set
    if isinstance(adata.X, np.ndarray):
        filtered_expr_data = adata.X[:, gene_indices]
    else:
        filtered_expr_data = adata.X[:, gene_indices].toarray()
    # Sum the expression values per cell
    gene_set_scores = filtered_expr_data.sum(axis=1)
    
    return Series(gene_set_scores.flatten(), index=adata.obs_names, name='gene_set_score')

def array_summary(arr):
    if not isinstance(arr, np.ndarray):
        raise TypeError("Input must be a NumPy ndarray")
    
    print(f"Shape of the array: {arr.shape}")
    print(f"Data type of elements: {arr.dtype}")
    print(f"Minimum value in the array: {np.min(arr):.3f}")
    print(f"Maximum value in the array: {np.max(arr):.3f}")
    print(f"Mean of the array elements: {np.mean(arr):.3f}")
    print(f"Median of the array elements: {np.median(arr):.3f}")
    print(f"Standard deviation of the array elements: {np.std(arr):.3f}")
    print(f"Variance of the array elements: {np.var(arr):.3f}")

def find_max_curve(scores: Series, show_plots: bool=False) -> np.float64:
    sorted_scores = scores.sort_values(ascending=False)
    
    data_curve = 'convex'
    data_direction = 'decreasing'        
    knee = KneeLocator(np.arange(1, len(sorted_scores)+1), sorted_scores, 
                       S=1.0, curve=data_curve, direction=data_direction)
    print(f'best curve at knee {knee.knee}')
    num_comp = int(knee.knee)
    exp_value = sorted_scores.iloc[num_comp-1]
    print(f'best number of cells is {num_comp} at sum of {exp_value}')
    if show_plots:
        knee.plot_knee()
        plt.show()
        knee.plot_knee_normalized()
        plt.show()  
    return exp_value

## load the raw anndata object

In [None]:
%%time
adata = sc.read_h5ad(anndata_file)
peek_anndata(adata, '## input anndata:', DEBUG)

### subset to just the gene features

In [None]:
adata = adata[:, adata.var.modality == 'Gene Expression']
peek_anndata(adata, '## adata just gene features:', DEBUG)

## load marker sets

In [None]:
markers, markers_dict = load_marker_set(markers_json, adata, DEBUG)

In [None]:
markers_dict.keys()

## normalize and transform the counts

In [None]:
adata.layers["counts"] = adata.X.copy()

# Normalize the data 
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.log1p(adata)

In [None]:
gene_set = markers_dict.get('senescence score')
print(gene_set)

## identify cell state from marker set

In [None]:
for set_name, gene_set in markers_dict.items():
    print(f'### {set_name}: {gene_set}')
    gene_set_scores = sum_gene_features(adata, gene_set)
    # array_summary(gene_set_scores.to_numpy())
    threshold = find_max_curve(gene_set_scores, True)
    found = gene_set_scores[gene_set_scores >= threshold]
    print(f'{found.shape=}')
    print(f'{(found.shape[0]/gene_set_scores.shape[0])*100:.3f}% of cells matched')
    senescent_cell_ids = set(found.index)
    # update the obs for cells found as senescent
    adata.obs['senescent'] = np.where(adata.obs.index.isin(list(senescent_cell_ids)), 1, 0)
    if DEBUG:
        display(adata.obs.senescent.value_counts())
    # is age associated with the number of senescent cells
    counts_table = adata.obs.groupby('sample_id').agg({'senescent': 'sum', 'age': 'first'})
    percentages = []
    for row in counts_table.itertuples():
        percentages.append((row.senescent / adata.obs.loc[adata.obs.sample_id == row.Index].shape[0]) * 100)
    counts_table['percent_senescent'] = percentages
    X_exog = sm.add_constant(counts_table.age)
    model = sm.GLM(counts_table.percent_senescent, X_exog)
    result = model.fit()
    if result.pvalues['age'] <= 0.05:
        print(result.summary())
        with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI_VALUE}):
            plt.style.use('seaborn-v0_8-talk')
            regplot(x='age', y='percent_senescent', data=counts_table, robust=True)

In [None]:
results = []
for set_name, gene_set in markers_dict.items():
    print(f'\n\n### {set_name}: {gene_set}')
    senescent_cell_ids = set()
    for cell_type in adata.obs.curated_type.unique():
        print(f'--- {cell_type}')
        adata_sub = adata[adata.obs.curated_type == cell_type].copy()
        gene_set_scores = sum_gene_features(adata_sub, gene_set)    
        # array_summary(gene_set_scores.to_numpy())
        threshold = find_max_curve(gene_set_scores)
        found = gene_set_scores[gene_set_scores >= threshold]
        print(f'{found.shape=}')
        percent_matched = round((found.shape[0]/gene_set_scores.shape[0])*100, 3)
        print(f'{percent_matched:.3f}% of cells in {cell_type} matched')
        results.append([set_name, cell_type, 'broad', percent_matched])
        senescent_cell_ids = senescent_cell_ids | set(found.index)
        # update the obs for cells found as senescent
        adata_sub.obs['senescent'] = np.where(adata_sub.obs.index.isin(list(senescent_cell_ids)), 1, 0)    
        # see in cell-type has cluster that are enriched for senescents
        if adata_sub.obs.cluster_name.nunique() > 1:
            for cluster in adata_sub.obs.cluster_name.unique():
                this_obs = adata_sub.obs.loc[adata_sub.obs.cluster_name == cluster]
                positive_cnt = this_obs.loc[adata_sub.obs.senescent == 1].shape[0]
                cluster_matched = round((positive_cnt/this_obs.shape[0])*100, 3)
                print(f'\t{cluster_matched:.3f}% of cells in {cluster} matched')
                results.append([set_name, cluster, 'cluster', cluster_matched])
        # is age associated with the number of senescent cells
        counts_table = adata_sub.obs.groupby('sample_id').agg({'senescent': 'sum', 'age': 'first'})
        percentages = []
        for row in counts_table.itertuples():
            percentages.append((row.senescent / adata_sub.obs.loc[adata_sub.obs.sample_id == row.Index].shape[0]) * 100)
        counts_table['percent_senescent'] = percentages    
        X_exog = sm.add_constant(counts_table.age)
        # model = sm.GLM(counts_table.senescent, X_exog)
        model = sm.GLM(counts_table.percent_senescent, X_exog)
        result = model.fit()
        if result.pvalues['age'] <= 0.05:
            print(result.summary())
            with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI_VALUE}):
                plt.style.use('seaborn-v0_8-talk')
                regplot(x='age', y='percent_senescent', data=counts_table, robust=True)
                plt.title(f'{cell_type} ({set_name})')
                plt.show()
    
    print(f'In total {(len(senescent_cell_ids)/adata.n_obs)*100:.3f}% of cells matched')    

In [None]:
results_df = DataFrame(results, columns=['cell_state', 'cell_type', 'annot_type', 'percentage'])
peek_dataframe(results_df, DEBUG)

In [None]:
broad_df = results_df.loc[results_df.annot_type == 'broad']
peek_dataframe(broad_df, DEBUG)

In [None]:
broad_df.sort_values('percentage', ascending=False)

## visualize DAM percentages

In [None]:
gene_sets = ['disease associated microglia']
cell_type = 'Micro'
dam_df = results_df.loc[(results_df.cell_state.isin(gene_sets)) & (results_df.cell_type.str.startswith(cell_type))]
peek_dataframe(dam_df, DEBUG)
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI_VALUE}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=dam_df, x='cell_type', y='percentage', hue='annot_type', palette='colorblind')
    plt.title('disease associated microglia')
    plt.xlabel('Cell Type')
    plt.ylabel('Percentage')
    plt.show()

## visualize OPC state percentages
included OD's with OPC's 

In [None]:
gene_sets = ['Cycling OPC', 'Differentiating OPC', 'Quiescent OPC', 'Transitioning OPC']
opc_df = results_df.loc[(results_df.cell_state.isin(gene_sets)) & (results_df.cell_type.str.startswith('OPC') | results_df.cell_type.str.startswith('OD'))]
peek_dataframe(opc_df, DEBUG)
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI_VALUE}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=opc_df, x='cell_type', y='percentage', hue='cell_state', palette='colorblind')
    plt.title('OPC states')
    plt.xlabel('Cell Type')
    plt.ylabel('Percentage')
    plt.show()

## visualize senescence state percentages for the various gene sets

In [None]:
gene_sets = ['Canonical Senescence Pathway', 'Senescence Response Pathway', 
             'Senescence Initiating Pathway', 'senescence signature', 
             'SenMayo', 'senescence score']
max_y = round(sen_df.loc[sen_df.cell_state.isin(gene_sets)].percentage.max()+1,0)
print(max_y)
for gene_set in gene_sets:
    sen_df = results_df.loc[results_df.cell_state == gene_set]
    peek_dataframe(sen_df, DEBUG)
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI_VALUE}):
        plt.style.use('seaborn-v0_8-talk')
        barplot(data=sen_df, x='cell_type', y='percentage', hue='annot_type', palette='colorblind')
        plt.title(gene_set)
        plt.xlabel('Cell Type')
        plt.ylabel('Percentage')
        plt.ylim([None, max_y])
        plt.xticks(rotation=90)
        plt.show()

In [None]:
display(sen_df.head())
broad_order = sen_df.loc[sen_df.annot_type == 'broad'].sort_values('percentage', ascending=False).cell_type.to_list()
display(broad_order)

In [None]:
!date