## Summarize some of the latent factor metrics and compare the latent factors between factor and cell-types

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, concat, DataFrame
from seaborn import barplot, scatterplot
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import statsmodels.api as sm
from pandas import DataFrame
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from itertools import combinations
from statsmodels.stats.multitest import multipletests
import statsmodels.api as sm
from pandas import Series
from sklearn.linear_model import ElasticNetCV
from sklearn.metrics import mean_squared_error, r2_score
from math import sqrt

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'

# in files
assoc_file = f'{results_dir}/{project}.latent.age_glm.csv'
metrics_file = f'{results_dir}/{project}.latent.metrics.csv'

# out files
results_file = f'{results_dir}/{project}.associated_latent_factors.csv'

# variables and constants
categories = {'curated_type': 'broad', 'cluster_name': 'specific'}
modalities = ['GEX', 'ATAC']
model_types = ['PCA', 'NMF', 'ICA']
DEBUG = False
ALPHA = 0.05
DPI = 100

#### fucntions

In [None]:
def model_scores(actual, predicted):
    this_score = r2_score(actual, predicted)
    rmse = sqrt(mean_squared_error(actual, predicted))
    return this_score, rmse

def plot_feature_importance(feature_values: Series, model_name: str):
    these_values = feature_values.copy().sort_values()
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI}):
        plt.style.use('seaborn-v0_8-talk')  
        these_values.plot(kind = 'barh')
        plt.title(f'Feature importance using {model_name}')

### load the input files

#### load the summary metrics for the latent factors

In [None]:
factor_metrics = read_csv(metrics_file, index_col=0)
print(f'shape of factor_metrics is {factor_metrics.shape}')
if DEBUG:
    display(factor_metrics.sample(5))

#### load the latent factor GLM age association results

In [None]:
age_glm = read_csv(assoc_file, index_col=0)
print(f'shape of age_glm is {age_glm.shape}')
if DEBUG:
    display(age_glm.sample(4))

### visualize the reduction accuracy of the latent models

#### by number of components select

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=factor_metrics.sort_values('n_comp', ascending=False),
            x='cell_type', y='n_comp', hue='model_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('Number of components selected for model types based on Reduction accuracy', 
              fontsize='large')
    plt.xlabel('Cell types')
    plt.ylabel('Number of components')
    plt.show()

#### by R-squared

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=factor_metrics.sort_values('R2', ascending=False),
            x='cell_type', y='R2', hue='model_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('Reduction accuracy of model types, R-squared', fontsize='large')
    plt.xlabel('Cell types')
    plt.show()    

#### by RMSE

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=factor_metrics.sort_values('RSME', ascending=True),
            x='cell_type', y='RSME', hue='model_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('Reduction accuracy of model types, RMSE', fontsize='large')  
    plt.xlabel('Cell types')
    plt.show()

### visualize the latent factors associated with age

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    scatterplot(data=age_glm.loc[age_glm.fdr_bh <= ALPHA], 
                x='coef', y='z', hue='model_type', palette='colorblind')
    plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right', borderaxespad=0)
    plt.tight_layout()
    plt.show()

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    scatterplot(data=age_glm.loc[age_glm.fdr_bh <= ALPHA], 
                x='coef', y='z', hue='cell_type', palette='colorblind', style='model_type')
    plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right', borderaxespad=0, ncol=1, fontsize=9)
    plt.tight_layout()
    plt.show()

In [None]:
age_factor_counts = (age_glm.loc[age_glm.fdr_bh <= ALPHA]
                     .groupby(['cell_type', 'model_type'])
                     .count().sort_values('feature', ascending=False))
if DEBUG:
    display(age_factor_counts)

with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=age_factor_counts,
            x='cell_type', y='feature', hue='model_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('Number of components selected for model types that are age associated')
    plt.xlabel('Cell types')
    plt.ylabel('Number of components')
    plt.show()    

### use regularized modeling to determine which latent model type gives most accurate age predictions

#### load the sample information

In [None]:
info_file = f'{wrk_dir}/sample_info/{project}.sample_info.csv'
info_df = read_csv(info_file, index_col=0)
# fill the missing smoker and bmi value
info_df.loc[info_df.smoker.isna(), 'smoker'] = info_df.smoker.mean().round(1)
info_df.loc[info_df.bmi.isna(), 'bmi'] = info_df.bmi.mean().round(1)
if DEBUG:
    print(info_df.shape)
    display(info_df.head())

In [None]:
age_assoc_glm = age_glm.loc[age_glm.fdr_bh <= ALPHA]
print(f'shape of age_assoc_glm {age_assoc_glm.shape}')
if DEBUG:
    display(age_assoc_glm.sample(5))
    display(age_assoc_glm.groupby('cell_type').model_type.value_counts().sort_values(ascending=False))

In [None]:
%%time
latents_perf = []
for category, cell_types in age_assoc_glm.groupby('type').cell_type.unique().items():
    for cell_type in cell_types:
        for mdl_type in [element.lower() for element in model_types]:
            # print(category, cell_type, mdl_type)
            this_file = f'{results_dir}/latents/{project}.{category}.{cell_type}.{mdl_type}_components.csv'
            this_factors = read_csv(this_file, index_col=0)
            this_target = info_df.loc[this_factors.index, 'age']
            this_factors = DataFrame(data=MinMaxScaler().fit_transform(this_factors), 
                                     columns=this_factors.columns, index=this_factors.index)
            regr = ElasticNetCV(cv=5, random_state=42)
            regr.fit(this_factors, this_target)
            pred_target = regr.predict(this_factors)
            x_r = r2_score(this_target, pred_target)
            x_e = sqrt(mean_squared_error(this_target, pred_target))
            coef = Series(regr.coef_, index=this_factors.columns)
            # plot_feature_importance(coef, 'ElasticNetCV Model')  
            latents_perf.append([category, cell_type, mdl_type, x_r, x_e, regr.alpha_, sum(coef != 0), sum(coef == 0)])
# now convert the list of lactent factor model scores into a dataframe
scores_df = DataFrame(data=latents_perf, columns=['category', 'cell_type', 'mdl_type', 'r2', 'rmse', 'alpha', 'picked', 'dropped'])
print(f'shape of scores_df is {scores_df.shape}')
if DEBUG:
    display(scores_df.sort_values('r2', ascending=False).head())
    display(scores_df.sort_values('rmse').head())    

#### visualize the age prediction accuracy of the latent models

##### by number of components kept

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=scores_df.sort_values('picked', ascending=False),
            x='cell_type', y='picked', hue='mdl_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('ElasticNetCV of age prediction', 
              fontsize='large')
    plt.xlabel('Cell types')
    plt.ylabel('Number of factors selected')
    plt.show()

##### by R-squared accuracy of age prediction

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=scores_df.sort_values('r2', ascending=False),
            x='cell_type', y='r2', hue='mdl_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('ElasticNetCV Age prediction accuracy of model types, R-squared', fontsize='large')
    plt.xlabel('Cell types')
    plt.show()    

##### by RMSE accuracy of age prediction

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=scores_df.sort_values('rmse'),
            x='cell_type', y='rmse', hue='mdl_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('ElasticNetCV Age prediction accuracy of model types, RMSE', fontsize='large')
    plt.xlabel('Cell types')
    plt.show()    

### use GLM modeling to determine which latent model type gives most accurate age predictions

In [None]:
%%time
latents_perf = []
for category, cell_types in age_assoc_glm.groupby('type').cell_type.unique().items():
    for cell_type in cell_types:
        for mdl_type in [element.lower() for element in model_types]:
            # print(category, cell_type, mdl_type)
            this_file = f'{results_dir}/latents/{project}.{category}.{cell_type}.{mdl_type}_components.csv'
            this_factors = read_csv(this_file, index_col=0)
            this_target = info_df.loc[this_factors.index, 'age']
            this_factors = DataFrame(data=MinMaxScaler().fit_transform(this_factors), 
                                     columns=this_factors.columns, index=this_factors.index)
            exog = sm.add_constant(this_factors)
            result = sm.GLM(this_target, exog).fit()
            pred_target = result.predict(exog)
            x_r = r2_score(this_target, pred_target)
            x_e = sqrt(mean_squared_error(this_target, pred_target))
            latents_perf.append([category, cell_type, mdl_type, x_r, x_e])
# now convert the list of lactent factor model scores into a dataframe
scores_df = DataFrame(data=latents_perf, columns=['category', 'cell_type', 'mdl_type', 'r2', 'rmse'])
print(f'shape of scores_df is {scores_df.shape}')
if DEBUG:
    display(scores_df.sort_values('r2', ascending=False).head())
    display(scores_df.sort_values('rmse').head())    

#### visualize the age prediction accuracy of the latent models

##### by R-squared accuracy of age prediction

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=scores_df.sort_values('r2', ascending=False),
            x='cell_type', y='r2', hue='mdl_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('GLM Age prediction accuracy of model types, R-squared', fontsize='large')
    plt.xlabel('Cell types')
    plt.show()    

##### by RMSE accuracy of age prediction

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    barplot(data=scores_df.sort_values('rmse'),
            x='cell_type', y='rmse', hue='mdl_type', palette='colorblind')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.title('GLM Age prediction accuracy of model types, RMSE', fontsize='large')
    plt.xlabel('Cell types')
    plt.show()    

In [None]:
display(scores_df.loc[scores_df.mdl_type != 'nmf'].sort_values('r2', ascending=False))

### identify which latent factors that are age associated are well correlated across cell and model types

#### load, label, and combine the latent factors into a dataframe

In [None]:
factors = []
for category, cell_types in age_assoc_glm.groupby('type').cell_type.unique().items():
    for cell_type in cell_types:
        for mdl_type in [element.lower() for element in model_types]:
            # print(category, cell_type, mdl_type)
            this_file = f'{results_dir}/latents/{project}.{category}.{cell_type}.{mdl_type}_components.csv'
            this_factors = read_csv(this_file, index_col=0)
            this_factors = this_factors.add_prefix(f'{cell_type}:')
            factors.append(this_factors)
# now convert the list of factor dataframes into single dataframe
factors_df = concat(factors, axis='columns')
# scale for interpretability
factors_df = DataFrame(data=MinMaxScaler().fit_transform(factors_df), 
                       columns=factors_df.columns, index=factors_df.index)
print(f'shape of factors_df is {factors_df.shape}')
if DEBUG:
    display(factors_df.sample(5))

#### create list of pairings to run regressions for

In [None]:
pairings = list(combinations(factors_df.columns, 2))

In [None]:
len(pairings)

#### regress the pairings

In [None]:
%%time
def regress_pair(endog_name: str, exog_name: str, data: DataFrame) -> tuple:
    ret_list = None
    if not endog_name == exog_name:
        endog = data[endog_name].values
        exog = sm.add_constant(data[exog_name].values)
        try:
            result = sm.GLM(endog, exog).fit()
            ret_list = [endog_name, exog_name, 
                        result.params[1], result.bse[1], 
                        result.tvalues[1], result.pvalues[1]]
        except:
                print(f'Caught Error for {endog_name} ~ {exog_name}')
                ret_list = [endog_name] + [exog_name] + [np.nan] * 4        
        return ret_list

results = [regress_pair(endog_name, exog_name, 
                        factors_df[[endog_name, exog_name]].dropna()) 
           for endog_name, exog_name in pairings]

#### convert regression results into a dataframe

In [None]:
results_df = DataFrame(data=results, 
                       columns=['endog', 'exog', 'coef', 'stderr', 
                                'z', 'p-value'])
print(f'shape of results_df is {results_df.shape}')
if DEBUG:
    display(results_df.sample(5))

#### compute the FDR values

In [None]:
def compute_bh_fdr(df: DataFrame, alpha: float=0.05, p_col: str='p-value',
                   method: str='fdr_bh', verbose: bool=True) -> DataFrame:
    ret_df = df.copy()
    test_adjust = multipletests(np.array(ret_df[p_col]), alpha=alpha, 
                                method=method)
    ret_df[method] = test_adjust[1]
    if verbose:
        print(f'total significant after correction: {ret_df.loc[ret_df[method] < alpha].shape}')
    return ret_df

In [None]:
results_df['p-value'] = results_df['p-value'].fillna(1)
results_df = compute_bh_fdr(results_df)
print(f'shape of results_df is {results_df.shape}')
if DEBUG:
    display(results_df.sort_values(['fdr_bh']).head())

#### save the results

In [None]:
results_df.to_csv(results_file)

#### visualize a random result

In [None]:
random_result = results_df.loc[results_df.fdr_bh <= ALPHA].sample(n=1).iloc[0]
print(random_result)

with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    scatterplot(data=factors_df, x=random_result.exog , y=random_result.endog)

#### which ones are shared across different cell-types

In [None]:
diff_celltypes = results_df.loc[(results_df.fdr_bh <= ALPHA) & 
                                 results_df.endog.str.startswith('ExN') & 
                                 ~results_df.exog.str.startswith('ExN')].sort_values('z')
if DEBUG:
    display(diff_celltypes.head())

In [None]:
random_result = diff_celltypes.loc[diff_celltypes.fdr_bh <= ALPHA].sample(n=1).iloc[0]
print(random_result)

with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    scatterplot(data=factors_df, x=random_result.exog , y=random_result.endog)

### cluster based on the latent factors

In [None]:
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score

def find_optimal_clusters(data, max_k):
    iters = range(2, max_k+1)
    gmm_models = [GaussianMixture(n_components=n, covariance_type='full').fit(data) for n in iters]
    silhouette_scores = [silhouette_score(data, model.predict(data)) for model in gmm_models]
    
    best_k = iters[np.argmax(silhouette_scores)]
    best_gmm = gmm_models[np.argmax(silhouette_scores)]
    
    return best_k, best_gmm, silhouette_scores

In [None]:
# fill any missing with the mean from that column
tfactors_df = factors_df.fillna(factors_df.mean()).transpose()
print(f'shape of tfactors_df is {tfactors_df.shape}')
if DEBUG:
    display(tfactors_df.head())

In [None]:
%%time
max_clusters = tfactors_df.shape[1] -1
best_k, best_gmm, silhouette_scores = find_optimal_clusters(tfactors_df, max_clusters)

print(f'Optimal number of clusters: {best_k}')

#### plot silhouette scores

In [None]:
from seaborn import lineplot
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    scatterplot(data=silhouette_scores)
    lineplot(data=silhouette_scores)
    plt.xlabel('Number of clusters')
    plt.ylabel('Silhouette Score')
    plt.title('Silhouette Scores for Different Numbers of Clusters')
    plt.show()

#### visualize the clusters

In [None]:
labels = best_gmm.predict(tfactors_df)
labels_df = DataFrame(data=labels, columns=['clust_num'])
labels_df['factor'] = tfactors_df.index
temp = labels_df.factor.str.split(':', expand=True)
labels_df['cell_type'] = temp[0]
labels_df['model_factor'] = temp[1]
print(f'shape of labels_df is {labels_df.shape}')
if DEBUG:
    display(labels_df.head())

In [None]:
%%time
import torch
import pymde
pymde.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

mde = pymde.preserve_neighbors(tfactors_df.to_numpy(), device=device, verbose=True)
embedding = mde.embed(verbose=True)

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    pymde.plot(embedding, color_by=labels_df.cell_type, marker_size=50)
    plt.show()

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': DPI}):
    plt.style.use('seaborn-v0_8-talk')
    pymde.plot(embedding, color_by=labels_df.clust_num, marker_size=50)
    plt.show()

In [None]:
for cluster in labels_df.clust_num.unique():
    print(cluster)
    this_df = labels_df.loc[labels_df.clust_num == cluster]
    print(this_df.factor.values)

In [None]:
!date