# Load libraries

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

from joblib import Parallel, delayed
from numpy.random import default_rng

from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import GridSearchCV, KFold, cross_val_score, cross_val_predict
from sklearn.utils import shuffle
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline

from scipy.stats import spearmanr, pearsonr

import seaborn as sns

from utils import loess_ci

import gseapy as gp
lib_names = gp.get_library_name(organism='Human')

# Load data

In [None]:
# load AnnData file
data_raw = sc.read_h5ad('adata.h5ad')

# remove celltypes that are not present in all subjects
data_raw = data_raw[data_raw.obs['celltype']!='End']
data_raw = data_raw[data_raw.obs['celltype']!='Per']

# display number of cells of each celltype
celltypes = data_raw.obs['celltype'].unique().tolist()
print(data_raw.obs['celltype'].value_counts())

# Estimate disease pseudo-progression

In [None]:
# for each celltype, cross-validate to select components and take out-of-sample predictions
# look at distribution of these predictions per individual
# heatmap of median per celltype and per individual, ordered by PC1
subj_col = 'id'
ids = data_raw.obs[subj_col].unique().tolist()
dx = [data_raw[data_raw.obs[subj_col]==id].obs['diagnosis'].unique()[0] for id in ids]
dx = np.array(dx, dtype=int)

celltype_preds = {}

# whether to perform grid search to select optimal number of PLS components
pls_comp_search = False

preds = {}
for cell in celltypes:
    data = data_raw[data_raw.obs['celltype']==cell]
    
    n_cells_max = 35000
    if(data.shape[0]>n_cells_max):
        sc.pp.subsample(data, n_obs=n_cells_max, random_state=0)

    ### balance classes
    data_AD = data[data.obs['diagnosis'] > 0]
    data_ctrl = data[data.obs['diagnosis'] <= 0 ]

    n_AD = data_AD.shape[0]
    n_ctrl = data_ctrl.shape[0]

    if(n_AD > n_ctrl):
        sc.pp.subsample(data_AD, fraction=n_ctrl/n_AD)
    else:
        sc.pp.subsample(data_ctrl, fraction=n_AD/n_ctrl)

    data = data_AD.concatenate(data_ctrl)
    
    # remove genes present in less than 0.1% of cells
    sc.pp.filter_genes(data, min_cells=int(1+data.shape[0]/1000))
    
    X = data.X.toarray()
    y = data.obs['diagnosis'].values

    if(pls_comp_search):
        max_comp=8
        scPLS = make_pipeline(StandardScaler(), PLSRegression(scale=True))
        gridcv = GridSearchCV(estimator=scPLS, param_grid={'plsregression__n_components':[i for i in range(2,max_comp+1)]}, \
            refit=False, scoring='r2', cv=KFold(n_splits=5, shuffle=True, random_state=1), n_jobs=5, return_train_score=False)
        _ = gridcv.fit(X, y)
        print(f"{cell}: {gridcv.best_params_['plsregression__n_components']} selected")
        
        scPLS_optimal = make_pipeline(StandardScaler(), PLSRegression(n_components=gridcv.best_params_['plsregression__n_components'], scale=False))
    else:
        # optimal number of components has been previously computed
        celltype_n_comps = {'Ex':6, 'Oli':3, 'In':4, 'Ast':2, 'Opc':2, 'Mic':2} # pre-calculated for dataset (clinical diagnosis target)
        scPLS_optimal = make_pipeline(StandardScaler(), PLSRegression(n_components=celltype_n_comps[cell], scale=True))
    
    out = cross_val_predict(scPLS_optimal, X, y, cv=KFold(n_splits=10, shuffle=True, random_state=0), n_jobs=5)

    # group predictions by subject id
    preds[cell] = []
    for id in ids:
        id_mask = data.obs[subj_col] == id
        if(id_mask.sum()<1):
            p = [0]
        else:
            p = np.squeeze(out[id_mask])

        preds[cell].append(p)

    celltype_preds[cell] = [np.median(p) for p in preds[cell]]

In [None]:
# create pseudodx
pseudodx_df = pd.DataFrame(celltype_preds, index=ids)

celltype_dx_pc1 = PCA(n_components=1).fit(pseudodx_df)
pseudodx_pc = np.squeeze(celltype_dx_pc1.transform(pseudodx_df))
# ensure positive direction corresponds to increasing disease severity
pseudodx_pc = pseudodx_pc*np.sign(pseudodx_pc[dx>0].mean())
sort_idx = np.argsort(pseudodx_pc, axis=0)
ids_sorted = np.array(ids)[sort_idx].tolist()

pseudodx_df['dx'] = dx
pseudodx_df['dx'] = pseudodx_df['dx'].astype(int)
pseudodx_df['sex'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['sex'].unique()[0]=='female') for id in ids]
pseudodx_df['pseudodx'] = pseudodx_pc

# add ROSMAP metadata from Kellis19 paper
pseudodx_df['disease_progression'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['disease_progression'].unique()[0]) for id in ids]
pseudodx_df['nft'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['nft'].unique()[0]) for id in ids]
pseudodx_df['tangles'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['tangles'].unique()[0]) for id in ids]
pseudodx_df['cogn_global_lv'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['cogn_global_lv'].unique()[0]) for id in ids]
pseudodx_df['gpath'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['gpath'].unique()[0]) for id in ids]
pseudodx_df['amyloid'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['amyloid'].unique()[0]) for id in ids]
pseudodx_df['plaq_n'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['plaq_n'].unique()[0]) for id in ids]
pseudodx_df['ceradsc'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['ceradsc'].unique()[0]) for id in ids]
pseudodx_df['braaksc'] = [float(data_raw[data_raw.obs[subj_col]==id].obs['braaksc'].unique()[0]) for id in ids]

pseudodx_df = pseudodx_df.iloc[sort_idx]

plt.figure(figsize=(24,8))
sns.heatmap(pseudodx_df.loc[:,celltypes[0]:'dx'].T, cmap='coolwarm', vmax=1, vmin=-1, cbar=True)
plt.xlabel('Subject ID');

save_df = pd.DataFrame(index=pseudodx_df.index, columns=pseudodx_df.columns[:6])

In [None]:
f, axs = plt.subplots(ncols=len(celltypes), figsize=(20,20), sharey=True)

for ax, cell in zip(axs, celltypes):
    preds_cell = [preds[cell][i] for i in sort_idx]
    ax.vlines(x=0, ymin=0, ymax=len(ids)+1, color='k');
    bp = ax.boxplot(preds_cell, vert=False, showfliers=False, patch_artist=True, whis=(10,90))
    ax.set_title(f'celltype {cell}')
    ax.set_ylim([0,len(ids)+1])
    ax.set_yticks(np.arange(1,1+len(ids)))
    ax.set_yticklabels(ids)
    ax.set_ylabel('subject id')

    save_df[cell] = [np.median(preds[cell][i]) for i in sort_idx]

    for i, id in enumerate(pseudodx_df.index):
        if(int(pseudodx_df['dx'][i])>0):
            bp['boxes'][i].set_facecolor('red')
        bp['medians'][i].set_color('black')
        bp['medians'][i].set_lw(2)

id_pseudodx_map = dict(zip(pseudodx_df.index, pseudodx_df['pseudodx']))

data_raw.obs['pseudodx'] = data_raw.obs['id'].map(id_pseudodx_map).astype(float)

### Verify correlation of disease pseudo-progression with clinical metrics

In [None]:
x = ['nft', 'tangles', 'gpath', 'amyloid', 'plaq_n', 'cogn_global_lv', 'braaksc', 'ceradsc']

f, axs = plt.subplots(nrows=len(x), figsize=(15,3*len(x)))
for ax, metric in zip(axs, x):
    dx = pseudodx_df['dx']
    ax.plot(pseudodx_df['pseudodx'][dx>0], pseudodx_df[metric][dx>0], 'ro')
    ax.plot(pseudodx_df['pseudodx'][dx<=0], pseudodx_df[metric][dx<=0], 'ko')
    
    percentile = [1,99]
    y_smoothed = loess_ci(pseudodx_df['pseudodx'], pseudodx_df[metric], percentile=percentile, frac=2/3)
    # plot confidence intervals
    for i_ci in range(len(percentile)):
        ax.fill_between(pseudodx_df['pseudodx'], y_smoothed[:,1+2*i_ci], y_smoothed[:,2+2*i_ci], alpha=0.5/(i_ci+1), color='g')

    ax.set_title(metric)
    ax.set_ylabel('principal component 1')
pseudodx_df.corr('spearman')['pseudodx']

In [None]:
# find optimal ordering of subjects given disease severity metrics
clinical_metric_pc = PCA(n_components=1).fit_transform(pseudodx_df.loc[:,['nft', 'gpath', 'braaksc', 'ceradsc']])
pseudodx_df['clinical_pc'] = -clinical_metric_pc

x = ['nft', 'gpath', 'braaksc', 'ceradsc']

f, axs = plt.subplots(nrows=len(x), figsize=(15,3*len(x)))
for ax, metric in zip(axs, x):
    dx = pseudodx_df['dx']
    ax.plot(pseudodx_df['clinical_pc'][dx>0], pseudodx_df[metric][dx>0], 'ro')
    ax.plot(pseudodx_df['clinical_pc'][dx<=0], pseudodx_df[metric][dx<=0], 'ko')
    ax.set_title(metric)
pseudodx_df.corr('spearman')['clinical_pc']

id_pseudodx_map = dict(zip(pseudodx_df.index, pseudodx_df['clinical_pc']))

data_raw.obs['clinical_pc'] = data_raw.obs['id'].map(id_pseudodx_map).astype(float)

# Fit PLS-DA model

In [None]:
def filter_shuffle_data(adata_raw, genes, group, subgroup):
    data = adata_raw[data_raw.obs[group]==subgroup]

    sc.pp.filter_genes(data, min_cells=int(1+data.shape[0]/1000))

    genes[subgroup] = np.array(data.var.index.tolist())

    # balance classes
    binary_target = 'diagnosis'
    data_AD = data[data.obs[binary_target] > 0]
    data_ctrl = data[data.obs[binary_target] <= 0 ]

    n_AD = data_AD.shape[0]
    n_ctrl = data_ctrl.shape[0]

    if(n_AD > n_ctrl):
        sc.pp.subsample(data_AD, fraction=n_ctrl/n_AD, random_state=0)
    else:
        sc.pp.subsample(data_ctrl, fraction=n_AD/n_ctrl, random_state=0)

    data_sub = data_AD.concatenate(data_ctrl)

    # shuffle data
    data_shuffled = shuffle(data_sub, random_state=0)

    return data_shuffled

In [None]:
# define model parameters
group = 'celltype'
target = 'diagnosis'
scorer = 'roc_auc'
celltype_n_comps = {'Mic':2, 'Ast':2, 'Ex':4, 'Oli':3, 'In':3, 'Opc':2} # precomputed

loadings_dict = {}
gene_symbols = {}
scores_df = pd.DataFrame(index=celltypes, columns=[f'r2_fold{i+1}' for i in range(5)]+[f'roc_auc_fold{i+1}' for i in range(5)])

for subgroup in ['Mic']:#celltypes:
    print(subgroup)
    # get data for target celltype
    data_shuffled = filter_shuffle_data(data_raw, gene_symbols, group, subgroup)
    X = data_shuffled.X.toarray()
    y = (data_shuffled.obs[target] > 0).astype(int) # binary

    # estimate out-of-sample classification performance
    # nested CV
    inner_loop = GridSearchCV(PLSRegression(scale=True), param_grid={'n_components':[1,2,3,4,5]})
    for metric in ['r2','roc_auc']:
        print(f'Metric {metric}')
        scores = cross_val_score(inner_loop, X, y, cv=KFold(n_splits=5, shuffle=True), scoring=metric, n_jobs=-1)
        print(f'\t{np.mean(scores):.3f} ({np.std(scores):.3f})')
        scores_df.loc[subgroup,f'{metric}_fold1':f'{metric}_fold5'] = scores

    # obtain component loadings
    optimal_n_comp = celltype_n_comps[subgroup]
    scPLS_optimal = PLSRegression(n_components=optimal_n_comp, scale=True)
    true_pls = scPLS_optimal.fit(X,y)
    true_loadings = true_pls.x_loadings_

    loadings_dict[subgroup] = true_loadings

    # investigate inhibitory neuron markers
    if(subgroup == 'In'):
        marker_genes = ['SST','PVALB','KIT','VIP']
        x = pd.DataFrame(r, columns=['Module 1', 'Module 2', 'Module 3'], index=marker_genes)
        for ic in range(optimal_n_comp):
            # generate null for association strength
            rs = []
            for g in np.random.choice(len(gene_symbols[subgroup]), replace=False, size=1000):
                r = spearmanr(true_pls.x_scores_[:,ic], X[:,g])[0]
                rs.append(np.abs(r))
            print(f'Module {ic+1} null correlation: {np.mean(rs):.2f} ({np.std(rs):.2f} std)')

            for g in marker_genes:
                r = spearmanr(true_pls.x_scores_[:,ic], X[:,gene_symbols[subgroup]==g])[0]
                print(f'\t{g}: {r:.2f}')

                x.loc[g,f'Module {ic+1}'] = r

        # create heatmap
        plt.figure()
        sns.heatmap(x, cmap='bwr', vmin=-0.35, vmax=0.35, annot=True)
        plt.show()
    
    # PHATE visualization
    pcs = {'Mic':5, 'Ast':5, 'Opc':5, 'Oli':10, 'In':15, 'Ex':15}
    sc.external.tl.phate(data_shuffled, k=15, t='auto', n_pca=pcs[subgroup], a=100, n_jobs=6)

    pd.DataFrame(data_shuffled.obsm['X_phate']).to_csv(f'Fig2_{subgroup}_PHATE_data.csv')

    for i_comp in range(optimal_n_comp):
        data_shuffled.obs['comp_score'] = true_pls.x_scores_[:,i_comp]

        lim = max(abs(true_pls.x_scores_[:,i_comp]))
        dot_size = 20

        # PHATE visualization, colored by module score
        with plt.rc_context({'figure.figsize':[6,6]}):
            plt.figure()
            sc.external.pl.phate(data_shuffled, color='comp_score', cmap='coolwarm', norm=None, \
                                size=dot_size, title=f'Cell scores component {i_comp+1}', show=False)
        
        # PHATE visualization, colored by diagnosis
        with plt.rc_context({'figure.figsize':[6,6]}):
            dot_size = 30
            plt.figure()
            sc.external.pl.phate(data_shuffled, color='diagnosis', cmap='coolwarm', size=dot_size, alpha=0.5, show=False)


In [None]:
f, ax = plt.subplots()
for i1,c in enumerate(celltypes):
    for i2,metrics in enumerate(['r2','roc_auc']):
        ax.plot(i2+2*np.array(5*[i1]),scores_df.loc[c,f'{metric}_fold1':f'{metric}_fold5'], '.')
plt.show()

In [None]:
# plot of out-of-sample classification performance (Figure 1b)
f, ax = plt.subplots()
# ax.boxplot(scores_df.loc[:,'r2_fold1':'r2_fold5'].T, positions=np.arange(0,15,2.5), showfliers=False, showbox=False, widths=0.8, whis=(1,99));
for i1,c in enumerate(celltypes):
    ax.plot(2.5*np.array(5*[i1]),scores_df.loc[c,'r2_fold1':'r2_fold5'], 'ok', ms=10)
ax2 = ax.twinx()
# ax2.boxplot(scores_df.loc[:,'roc_auc_fold1':'roc_auc_fold5'].T, positions=1+np.arange(0,15,2.5), showfliers=False, showbox=False, widths=0.8, whis=(1,99));
for i1,c in enumerate(celltypes):
    ax2.plot(1+2.5*np.array(5*[i1]),scores_df.loc[c,'roc_auc_fold1':'roc_auc_fold5'], 'ok', ms=10)

ax.set_xticks(0.5+np.arange(0,15,2.5), labels=celltypes);
plt.show()

# Create heatmap of top modules genes

In [None]:
group = 'celltype'
target = 'diagnosis'
scorer = 'roc_auc'

celltype_n_comps = {'Mic':2, 'Ast':2, 'Ex':4, 'Oli':3, 'In':3, 'Opc':2}
module_correlations = {k:[] for k in celltype_n_comps.keys()}
gene_symbols = {}
loadings = {}

for subgroup in celltype_n_comps.keys():
    data_shuffled = filter_shuffle_data(data_raw, gene_symbols, group, subgroup)
    X = data_shuffled.X.toarray()
    y = (data_shuffled.obs[target] > 0).astype(int) # binary

    # fit model
    optimal_n_comp = celltype_n_comps[subgroup]
    scPLS_optimal = PLSRegression(n_components=optimal_n_comp, scale=True)
    true_pls = scPLS_optimal.fit(X,y)
    true_scores = true_pls.x_scores_

    # compute correlations
    for i_comp in range(optimal_n_comp):
        module_correlations[subgroup].append(np.corrcoef(true_scores[:,i_comp], y)[0,1])

In [None]:
# heatmap of correlations between scores and diagnosis
fill = np.empty((6,4))
fill[:] = np.nan
corrs = pd.DataFrame(fill, index=module_correlations.keys(), columns=[f'Module {i+1}' for i in range(4)])
for subgroup in module_correlations.keys():
    for i_comp in range(len(module_correlations[subgroup])):
        corrs.loc[subgroup, f'Module {i_comp+1}'] = np.abs(module_correlations[subgroup][i_comp])

plt.figure()
sns.heatmap(corrs, cmap='vlag', annot=True, vmin=-0.70, vmax=0.70)
plt.xticks(rotation=45, ha='right');
plt.yticks(rotation=0);
plt.show()

In [None]:
# heatmap of top genes
top_n = 3
modules = [f'{k}{i+1}' for (k,v) in celltype_n_comps.items() for i in range(v)]
n_modules = len(modules)

# generate index
index_tmp = []
for c in gene_symbols.keys():
    for m in range(celltype_n_comps[c]):
        index_tmp.extend(np.array(gene_symbols[c])[np.argsort(loadings[c][:,m])[-top_n:]][::-1])

top_gene_loadings = pd.DataFrame(np.zeros((len(index_tmp), n_modules)), columns=modules, index=index_tmp)

# fill in loadings
for c in gene_symbols.keys(): # celltypes
    for m in range(celltype_n_comps[c]): # modules
        for g in top_gene_loadings.index:
            if g in gene_symbols[c]:
                top_gene_loadings.loc[g,f'{c}{m+1}'] = loadings[c][gene_symbols[c]==g,m]

top_gene_loadings.drop_duplicates(inplace=True)

plt.figure(figsize=(10,top_gene_loadings.shape[0]*0.2))
sns.heatmap(top_gene_loadings, cmap='vlag', vmin=-0.07, vmax=0.07)
plt.show()

# Run bootstrap analysis (per celltype)

In [None]:
subgroup = 'Mic'
true_loadings = loadings_dict[subgroup]

### perturbation (get bootstrap distribution of loadings)
data_shuffled = filter_shuffle_data(data_raw, gene_symbols, group, subgroup)

n_bootstrap = 1000
def parallel_bootstrap(j, data_bs):
    if((100*j/n_bootstrap)%10==0):
        print(f'{100*j/n_bootstrap:.0f}% complete')
    rng = default_rng(j)

    # select bootstrap data
    idx_bootstrap = []
    for i in data_bs.obs[target].unique():
        idx_class = np.where(data_bs.obs[target] == i)[0]
        n_class = idx_class.shape[0]
        idx_class_bootstrap = rng.integers(0, n_class, n_class)
        idx_bootstrap.extend(idx_class[idx_class_bootstrap])
    
    data_bootstrap = data_bs[idx_bootstrap, :]
    data_bootstrap.obs_names_make_unique()
    
    X_bootstrap = data_bootstrap.X.toarray()
    y_bootstrap = (data_bootstrap.obs[target] > 0).astype(int)

    # calculate model loadings of null model
    bootstrap_loadings = scPLS_optimal.fit(X_bootstrap, y_bootstrap).x_loadings_

    return bootstrap_loadings

bootstrap_loadings = Parallel(n_jobs=4)(delayed(parallel_bootstrap)(j, data_shuffled) for j in range(n_bootstrap))
bootstrap_loadings = np.array(bootstrap_loadings)

# bootstrap_loadings shape: [n_bootstrap, n_genes, n_components]

In [None]:
# some bootstrap loadings are mirrored, which causes problems with distribution
# for each bootstrap loading, determine whether flipping it results in a lower distance between it and the true loading
mirror = True
if(mirror):
    mirror_mask = np.mean(np.abs(bootstrap_loadings - true_loadings), axis=1) > np.mean(np.abs(-bootstrap_loadings - true_loadings), axis=1)
    #mirror_mask = correlation
    mirror_factor = np.where(mirror_mask, -1, 1)

    mirror_factor = np.expand_dims(mirror_factor, axis=1)

    bootstrap_loadings = bootstrap_loadings * mirror_factor

# calculate mean of bootstrap distribution for each feature
bootstrap_mean = np.median(bootstrap_loadings, axis=0)

# create copies to modify
bootstrap_mean_zeroed = bootstrap_mean.copy()
bootstrap_loadings_zeroed = bootstrap_loadings.copy()

# zero distributions that significantly cross zero
zero_threshold = 5 # fraction of bootstrap loadings that need to cross zero to zero out feature
limits = np.percentile(bootstrap_loadings_zeroed, q=[zero_threshold,100-zero_threshold], axis=0)

In [None]:
gene_symbols

In [None]:
# boolean mask for features where the sign of one of the limits is of opposite sign from the median
# True indicates feature should be dropped
zero_mask = np.logical_or(np.logical_xor(limits[0,:,:]>0, bootstrap_mean>0), np.logical_xor(limits[1,:,:]>0, bootstrap_mean>0))
zero_mask = np.logical_or(zero_mask, (bootstrap_loadings_zeroed==0).sum(axis=0)>0)

bootstrap_mean_zeroed[zero_mask] = 0
bootstrap_loadings_zeroed[:,zero_mask] = 0
# bootstrap_mean = np.where(zero_mask, 0, bootstrap_mean)

for i in range(zero_mask.shape[1]):
    print(f"component {i}: {zero_mask[:,i].sum()} features zeroed")

# save component loadings for later use
df = pd.DataFrame(bootstrap_mean_zeroed, index=gene_symbols[subgroup])
df.to_csv(f'data/{subgroup}_significant_loadings.csv')

# Compare PLS modules to DEGs

In [None]:
# correlated PLS module gene weights with DEG results
degs = {}
loadings = {}

fill = np.empty((4,6))
fill[:] = np.nan
heatmap = pd.DataFrame(fill, index=[f'Module {i}' for i in range(1,5)], columns=celltypes)
for c in ['Ex','In','Oli','Mic','Ast','Opc']:
    degs[c] = pd.read_excel('data/41586_2019_1195_MOESM4_ESM.xlsx', sheet_name=c, index_col=0)
    loadings[c] = pd.read_csv(f'data/{c}_significant_loadings.csv', index_col=0)

    # get genes present in both sets
    shared_genes = list(set(degs[c].index).intersection(set(loadings[c].index)))

    deg_feature = 'IndModel.FC' # which metric to use for comparison
    for comp in range(loadings[c].shape[1]):
        comp_loading = loadings[c].loc[shared_genes].iloc[:,comp].values # sorted by shared genes
        deg_metric = degs[c].loc[shared_genes, deg_feature].values
        deg_metric[np.isnan(deg_metric)] = 0
        r,p = pearsonr(comp_loading, deg_metric)
        
        heatmap.loc[f'Module {comp+1}', c] = r

plt.figure(constrained_layout=True)
sns.heatmap(heatmap.T, cmap='vlag', annot=True, vmin=-1, vmax=1)
plt.xticks(rotation=45, ha='right');
plt.yticks(rotation=0);
plt.ylabel('Celltype')
plt.show()

In [None]:
# look at top DEGs by p-value and FC
for c in ['Ex','In','Oli','Mic','Ast','Opc']:
    DEG = pd.read_excel('data/41586_2019_1195_MOESM4_ESM.xlsx', sheet_name=c, index_col=0)
    PLS = pd.read_csv(f'data/{c}_significant_loadings.csv', index_col=0)

    deg_feature = 'IndModel.adj.pvals' #'IndModel.FC' # which metric to use for comparison
    for comp in range(PLS.shape[1]):
        pls_genes_sorted = PLS.abs().sort_values(by=f'{comp}', ascending=False).index.tolist()[:10]
        deg_genes_sorted = DEG.abs().sort_values(by=deg_feature, ascending=True).index.tolist()[:10]
        
        gene_overlap = list(set(pls_genes_sorted).intersection(set(deg_genes_sorted)))
        print(f'{c} module {comp+1}: {gene_overlap}')

# Plot top genes per module

In [None]:
# plot significant features
f, axs = plt.subplots(nrows=optimal_n_comp, figsize=(25,optimal_n_comp*15))
if(optimal_n_comp==1):
    axs = [axs]

N_TOP = 15

# plot the significant genes for each component
for comp, ax in enumerate(axs):
    sort_idx = np.argsort(np.abs(bootstrap_mean_zeroed[:,comp]))[::-1]

    # plot top genes
    NN = 0
    ax.violinplot(bootstrap_loadings_zeroed[:,:,comp][:,sort_idx][:,NN:(NN+N_TOP)], showextrema=False, widths=0.8, showmedians=True, points=200);
    ax.set_xticks(np.arange(1,N_TOP+1))
    ax.set_xticklabels(labels=gene_symbols[subgroup][sort_idx][NN:(NN+N_TOP)], rotation=65, fontsize=30)
    ax.set_xlim([0,N_TOP+1])
plt.show()

# Obtain GSEA enriched pathways per module

In [None]:
gene_set_list = ['GO_Biological_Process_2021', 'WikiPathway_2021_Human', 'Panther_2016']
gsea_terms_tmp = []
# perform GSEA for each component in the model
for comp in range(optimal_n_comp):
    gsea_terms_tmp_comp = pd.DataFrame()
    # perform GSEA using multiple gene set databases
    for gene_set in gene_set_list:
        ranked_genes = pd.DataFrame(data={'genes':gene_symbols[subgroup][~zero_mask[:,comp]].tolist(), 'PLS_weights':bootstrap_mean[:,comp][~zero_mask[:,comp]]})
        try:
            pre_res = gp.prerank(rnk=ranked_genes,
                            gene_sets=gene_set,
                            processes=4,
                            min_size=5,
                            no_plot=True,
                            permutation_num=1000,
                            outdir=None,
                            seed=1)
        except Exception:
            print(f'GSEA Error: No enriched gene modules found in component {comp}')
        
        df = pre_res.res2d
        df['component'] = comp
        df['gene_set_source'] = gene_set
        df = df[df['fdr']<0.05]
        df = df[~df['pval'].isna()]
        df = df.sort_values('nes', ascending=False)
        display(df.head(10))
        gsea_terms_tmp_comp = pd.concat((gsea_terms_tmp_comp, df), axis=0)
    gsea_terms_tmp.append(gsea_terms_tmp_comp)

In [None]:
# process terms for each component, then combine terms from all components into one dataframe
gsea_terms = pd.DataFrame()
for t in gsea_terms_tmp:
    # remove duplicated terms
    t = t[~t.index.duplicated()]
    t.index = t.index.values + str(t['component'][0])

    # sort by absolute value of gsea score
    t = t.reindex(t['nes'].abs().sort_values(ascending=False).index)

    gsea_term_sets = [set(x.split(';')) for x in t['genes'].values]

    # list to track number of times a term is a superset of another
    superset_count_list = [0]*len(gsea_term_sets)

    # drop terms that are superset of others
    set_overlap = []
    drop_idx = []
    drop_bool = False
    overlap_threshold = 0.9

    # select a gene set
    for i in range(len(gsea_term_sets)):
        # skip terms that will be dropped
        if((i in drop_idx) and (drop_bool)):
            continue
        
        # check if selected gene set overlaps with any other remaining gene set
        # drop less enriched (lower enrichment score) gene set if overlap is above threshold
        tmp_overlap = []
        for j in range(i+1, len(gsea_term_sets)):
            # size of smaller set
            min_size = min(len(gsea_term_sets[i]), len(gsea_term_sets[j]))
            # number of genes common between sets
            overlap = len(gsea_term_sets[i].intersection(gsea_term_sets[j]))
            # save index if overlap is too high
            
            if(overlap/min_size >= overlap_threshold):
                drop_idx.append(j)
                superset_count_list[j] += 1

    # drop terms with high overlap
    if(drop_bool):
        keep_idx = [i for i in range(len(gsea_term_sets)) if i not in drop_idx]
        t = t.iloc[keep_idx]

        print(f'dropping {len(drop_idx)} terms, {len(keep_idx)} remain')

    gsea_terms = pd.concat((gsea_terms, pd.concat((t['nes'], t['fdr'], t['component'], t['gene_set_source'], t['genes'].str.split(';'), t['ledge_genes'].str.split(';'), pd.Series(data=superset_count_list, index=t.index)), axis=1)), axis=0)

gsea_terms.rename(columns={0:'superset_count'}, inplace=True)

print(gsea_terms.shape)

In [None]:
# identify AD GWAS genes in leading edge genes
# gene list from: Wightman, D. P. et al. A genome-wide association study 
# with 1,126,563 individuals identifies new risk loci for Alzheimer’s disease. Nat. Genet. 53, 1276–1282 (2021).
goi = {
    'AD GWAS genes':set(['AGRN','CR1','NCK2','BIN1','INPPD5','CLNK','TNIP1','HAVCR2','HLA-DRB1','TREM2','CD2AP',
    'TMEM106B','ZCWPW1', 'NYAP1','EPHA1-AS1','CLU','SHARPIN','USP6NL', 'ECHDC3','CCDC6','MADD', 'SPI1','MS4A4A','PICALM',
    'SORL1','FERMT2','RIN3','ADAM10','APH1B','SCIMP', 'RABEP1','GRN','ABI3','TSPOAP1-AS1','ACE','ABCA7','APOE','NTN5',
    'CD33','LILRB2','CASS4','APP', 'IFNB1']),
    }

def identify_genes(row, gene_list):
    hits = []
    for g in gene_list:
        if(g in row['genes']):
            hits.append(g)
    
    return hits

for k,v in goi.items():
    gsea_terms[k] = gsea_terms.apply(identify_genes, args=(v,), axis=1)