## Notebook to analyze latent factors if present for features, both GEX and ATAC, that are associated with age using multiple dimensionality results and done per cell-type, both broad and cluster specific

- PCA
- NMF
- ICA
- VAE

In [None]:
!date

#### import libraries

In [None]:
from sklearn.decomposition import PCA, NMF, FastICA
from pandas import DataFrame as PandasDF, read_csv, concat, read_parquet, Series
from matplotlib import pyplot as plt
from matplotlib.pyplot import rc_context
from numpy import cumsum, arange, argsort, abs as np_abs
from sklearn.metrics import r2_score, mean_squared_error
from kneed import KneeLocator
from pickle import dump as pkl_dump

#### set notebook variables

In [None]:
# parameters
category = 'curated_type' # 'curated_type' for broad and 'cluster_name' for specific
cell_type = 'ExN'

In [None]:
# parameters
project = 'aging_phase2'
if category == 'curated_type':
    prefix_type = 'broad'
elif category == 'cluster_name':
    prefix_type = 'specific' 

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# in files

# out files

# constants
DEBUG = True
modalities = ['GEX', 'ATAC']

#### functions

In [None]:
def important_loadings(components, feature_names, feature_types, comp_prefix: str) -> dict:
    loadings_dict = {}
    for index, component_loadings in enumerate(components):
        comp_name = f'{comp_prefix}_{index}'
        print(f'#### component {comp_name} ####')
        loadings_abs = np_abs(component_loadings)
        sorted_abs_indices = argsort(loadings_abs)
        sorted_abs_loadings = loadings_abs[sorted_abs_indices]
        print(sorted_abs_loadings.shape)
        # find the knee for loadings
        knee = KneeLocator(arange(1, len(sorted_abs_loadings)+1), cumsum(sorted_abs_loadings), 
                           S=1.0, curve='convex', direction='increasing')
        knee.plot_knee()
        plt.show()
        knee.plot_knee_normalized()
        plt.show()    
        print(f'knee at feature {knee.knee}')
        feature_count = int(knee.knee)
        features_sers = Series(data=component_loadings, index=feature_names)
        selected_features = features_sers.loc[features_sers.sort_values(key=abs, ascending=False).head(feature_count).index]
        print(selected_features.shape)
        display(selected_features.sort_values(key=abs, ascending=False).head())
        print(f'best number of features is {feature_count} with minimum absolute loading of {selected_features.abs().min()}')
        gex_count = len(set(selected_features.index) & set(feature_types.get('GEX')))/len(feature_types.get('GEX'))*100
        atac_count = len(set(selected_features.index) & set(feature_types.get('ATAC')))/len(feature_types.get('ATAC'))*100
        print(f'includes {gex_count:.1f} genes and {atac_count:.1f} peaks')
        loadings_dict[comp_name] = selected_features
    return loadings_dict
        

def iterate_model_component_counts(max_count: int, data_df: PandasDF, 
                                   model_type: str=['PCA', 'NMF', 'ICA']) -> (list, list):
    r2_rets = []
    rmse_rets = []
    for comp_num in arange(1, max_count+1):    
        _,_,r2,rmse = generate_selected_model(comp_num, data_df, model_type)
        r2_rets.append(r2)
        rmse_rets.append(rmse)
    return r2_rets, rmse_rets

def generate_selected_model(n_comps: int, data_df: PandasDF, 
                            model_type: str=['PCA', 'NMF', 'ICA']) -> (object, PandasDF, float, float):
    if model_type == 'PCA':
        model = PCA(n_components=n_comps, random_state=42)
    if model_type == 'NMF':
        model = NMF(n_components=n_comps, init='random', random_state=42, max_iter=500)
    if model_type == 'ICA':
        model = FastICA(n_components=n_comps, random_state=42)        
    components= model.fit_transform(data_df)
    recon_input = model.inverse_transform(components)
    r2 = r2_score(y_true=data_df, y_pred=recon_input)
    rmse = mean_squared_error(data_df, recon_input, squared=False)
    print(f'{model_type} with {n_comps} components accuracy is {r2:.4f}, RMSE is {rmse:.4f}')  
    
    ret_df = PandasDF(data=components, index=data_df.index).round(4)
    ret_df = ret_df.add_prefix(f'{model_type}_')
    return model, ret_df, r2, rmse

def component_from_max_curve(scores, label: str=['R2', 'RMSE']) -> int:
    if label == 'R2':
        data_curve = 'concave'
        data_direction = 'increasing'
    if label == 'RMSE':
        data_curve = 'convex'
        data_direction = 'decreasing'        
    knee = KneeLocator(arange(1, len(scores)+1), scores, 
                       S=1.0, curve=data_curve, direction=data_direction)
    print(f'best curve at knee {knee.knee}')
    num_comp = int(knee.knee)
    exp_value = scores[num_comp-1]
    print(f'best number of components is {num_comp} at {label} of {exp_value}')
    knee.plot_knee()
    plt.show()
    knee.plot_knee_normalized()
    plt.show()
    return num_comp

def save_important_loadings(comp_features: dict, file_name: str):
    with open(file_name, 'wb') as pkl_file:
        pkl_dump(comp_features, pkl_file)    

#### load age associated feature results
get the age associated GEX and ATA features need per cell-type

In [None]:
%%time
age_results = []
for modality in modalities:
    print(modality)
    in_file = f'{results_dir}/{project}.{modality}.{prefix_type}.glm_tweedie_fdr_filtered.age.csv'
    this_df = read_csv(in_file)
    this_df['modality'] = modality
    age_results.append(this_df)
age_results_df = concat(age_results)
print(f'shape of the age results is {age_results_df.shape}')
if DEBUG:
    display(age_results_df.sample(5))
    display(age_results_df.modality.value_counts())
    display(age_results_df.tissue.value_counts())

### load the feature quantifications

In [None]:
%%time
cell_type_quants = []
modality_features = {}
for modality in modalities:
    features_to_keep = age_results_df.loc[(age_results_df.tissue == cell_type) & 
                                          (age_results_df.modality == modality)].feature.to_list()
    print(modality, len(features_to_keep))
    modality_features[modality] = features_to_keep
    in_file = f'{quants_dir}/{project}.{modality}.{prefix_type}.{cell_type}.pb.parquet'
    df = read_parquet(in_file)
    df = df[features_to_keep]    
    # df = df[features_to_keep + ['cell_count']]
    # df = df.rename(columns={'cell_count': f'{modality}_cell_count'})
    print(modality, df.shape)
    cell_type_quants.append(df)
quants_df = concat(cell_type_quants, axis='columns', join='inner')
print(f'shape of feature quantifications for {cell_type} is {quants_df.shape}')
if DEBUG:
    display(quants_df.sample(5))

### using PCA

#### find number of components to use

In [None]:
%%time
max_count = int(min(quants_df.shape[0], quants_df.shape[1])/2)
print(f'max count is {max_count}')

r2_values, rmse_values = iterate_model_component_counts(max_count, quants_df, 'PCA')

In [None]:
knee_rmse = component_from_max_curve(rmse_values, 'RMSE')
knee_r2 = component_from_max_curve(r2_values, 'R2')
num_comp = max(knee_rmse, knee_r2)
print(num_comp)

#### regenerate the PCA model at the selected component size

In [None]:
pca_mdl,pca_df,a,b = generate_selected_model(num_comp, quants_df, 'PCA')
if DEBUG:
    print(pca_df.shape)
    display(pca_df.head())

#### what are the 'important features' based on their loadings

In [None]:
pca_features = important_loadings(pca_mdl.components_, pca_mdl.feature_names_in_,
                                  modality_features, 'PCA')

In [None]:
out_file = f'{results_dir}/{project}.{prefix_type}.{cell_type}.pca_loadings.pkl'
save_important_loadings(pca_features, out_file)

### using NMF

#### find number of components to use

In [None]:
r2_values, rmse_values = iterate_model_component_counts(max_count, quants_df, 'NMF')

In [None]:
knee_rmse = component_from_max_curve(rmse_values, 'RMSE')
knee_r2 = component_from_max_curve(r2_values, 'R2')
num_comp = max(knee_rmse, knee_r2)
print(num_comp)

#### regenerate the NMF model at the selected component size

In [None]:
nmf_mdl,nmf_df,_,_ = generate_selected_model(num_comp, quants_df, 'NMF')
if DEBUG:
    print(nmf_df.shape)
    display(nmf_df.head())

#### what are the 'important features' based on their loadings

In [None]:
nmf_features = important_loadings(nmf_mdl.components_, nmf_mdl.feature_names_in_,
                                  modality_features, 'NMF')

In [None]:
out_file = f'{results_dir}/{project}.{prefix_type}.{cell_type}.nmf_loadings.pkl'
save_important_loadings(nmf_features, out_file)

### using ICA

#### find number of components to use

In [None]:
r2_values, rmse_values = iterate_model_component_counts(max_count, quants_df, 'ICA')

In [None]:
knee_rmse = component_from_max_curve(rmse_values, 'RMSE')
knee_r2 = component_from_max_curve(r2_values, 'R2')
num_comp = max(knee_rmse, knee_r2)
print(num_comp)

#### regenerate ICA model at the selected component size

In [None]:
ica_mdl,ica_df,_,_ = generate_selected_model(num_comp, quants_df, 'ICA')
if DEBUG:
    print(ica_df.shape)
    display(ica_df.head())

#### what are the 'important features' based on their loadings

In [None]:
ica_features = important_loadings(ica_mdl.components_, ica_mdl.feature_names_in_,
                                  modality_features, 'ICA')

In [None]:
out_file = f'{results_dir}/{project}.{prefix_type}.{cell_type}.ica_loadings.pkl'
save_important_loadings(ica_features, out_file)

### compare the different latent embedding
approximate if they are finding latent space by samples

In [None]:
latent_df = concat([pca_df, nmf_df, ica_df], axis='columns')
print(f'shape of combined latent space df {latent_df.shape}')
if DEBUG:
    display(latent_df.head())

In [None]:
from seaborn import heatmap
min_pearson = 0.22
cor = latent_df.corr(method='pearson')
cor.dropna(how='all', inplace=True)
modified_title = ''
print(cor.shape)
fig_width = cor.shape[1] if cor.shape[1] > 12 else 12
fig_height = cor.shape[0] if cor.shape[1] > 12 else 12
with rc_context({'figure.figsize': (fig_width, fig_height), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')       
    ax = heatmap(cor[(cor > min_pearson) | (cor < -min_pearson)], annot=True, 
            annot_kws={"fontsize":10}, linewidths=0.05, cmap='Purples')
    plt.title(f'Pearson heatmap of PPScore covariates {modified_title}')
    plt.yticks(rotation=90)
    plt.show()

In [None]:
%%time
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

info_file = f'{wrk_dir}/sample_info/aging_phase2.sample_info.csv'
info_df = read_csv(info_file, index_col=0)
info_df = info_df.reindex(ica_df.index)
print(info_df.shape)
display(info_df.head())

import pymde
pymde.seed(42)
mde = pymde.preserve_neighbors(ica_df.to_numpy(), device=device, verbose=True)
embedding = mde.embed(verbose=True)
pymde.plot(embedding, color_by=info_df.age, marker_size=50)
plt.show()

### using a VAE

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'
print(device)

In [None]:
# from pythae.models import VAE, VAEConfig
# from pythae.trainers import BaseTrainerConfig
# from pythae.pipelines.training import TrainingPipeline
# from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST

In [None]:
# import torchvision.datasets as datasets
# mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

# train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
# eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.

In [None]:
# config = BaseTrainerConfig(
#     output_dir='my_model',
#     learning_rate=1e-4,
#     per_device_train_batch_size=64,
#     per_device_eval_batch_size=64,
#     num_epochs=10, # Change this to train the model a bit more
#     optimizer_cls="AdamW",
#     optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.99)}
# )


# model_config = VAEConfig(
#     input_dim=(1, 28, 28),
#     latent_dim=16
# )

# model = VAE(
#     model_config=model_config,
#     encoder=Encoder_ResNet_VAE_MNIST(model_config), 
#     decoder=Decoder_ResNet_AE_MNIST(model_config)     
# )

In [None]:
# pipeline = TrainingPipeline(
#     training_config=config,
#     model=model
# )

In [None]:
# pipeline(
#     train_data=train_dataset,
#     eval_data=eval_dataset
# )

In [None]:
# import os
# from pythae.models import AutoModel

In [None]:
# last_training = sorted(os.listdir('my_model'))[-1]
# trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))

In [None]:
# from pythae.samplers import NormalSampler

In [None]:
# # create normal sampler
# normal_samper = NormalSampler(
#     model=trained_model
# )

In [None]:
# # sample
# gen_data = normal_samper.sample(
#     num_samples=25
# )

In [None]:
# # show results with normal sampler
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

# for i in range(5):
#     for j in range(5):
#         axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
#         axes[i][j].axis('off')
# plt.tight_layout(pad=0.)

In [None]:
# from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig

In [None]:
# # set up GMM sampler config
# gmm_sampler_config = GaussianMixtureSamplerConfig(
#     n_components=10
# )

# # create gmm sampler
# gmm_sampler = GaussianMixtureSampler(
#     sampler_config=gmm_sampler_config,
#     model=trained_model
# )

# # fit the sampler
# gmm_sampler.fit(train_dataset)

In [None]:
# # sample
# gen_data = gmm_sampler.sample(
#     num_samples=25
# )

In [None]:
# # show results with gmm sampler
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

# for i in range(5):
#     for j in range(5):
#         axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
#         axes[i][j].axis('off')
# plt.tight_layout(pad=0.)

In [None]:
# reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()

In [None]:
# # show reconstructions
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

# for i in range(5):
#     for j in range(5):
#         axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')
#         axes[i][j].axis('off')
# plt.tight_layout(pad=0.)

In [None]:
# # show the true data
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

# for i in range(5):
#     for j in range(5):
#         axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')
#         axes[i][j].axis('off')
# plt.tight_layout(pad=0.)

In [None]:
# interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()

In [None]:
# # show interpolations
# fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))

# for i in range(5):
#     for j in range(10):
#         axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')
#         axes[i][j].axis('off')
# plt.tight_layout(pad=0.)