To work as intended:

Data is stored in data/{dataset} folder

Data is prepared from h5ad using approach from the preprocessing.ipynb file

Neptune personal token should be set up to run (or some other logger can be used)

Model config should be prepared carefully

# Run First (imports, the model and functions)

## Imports

In [13]:
# Data manipulation and preparation
import pandas as pd
import numpy as np
from math import isinf
import itertools
import yaml
import csv
import os

# Visualization
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm_notebook as tqdm
from tabulate import tabulate

# Machine learning and statistics
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model import ElasticNet, Lasso, Ridge

# Dimensionality reduction
from umap import UMAP
from MulticoreTSNE import MulticoreTSNE as TSNE

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CyclicLR, LambdaLR

# Miscellaneous
from copy import deepcopy
from itertools import product
import warnings

# Scanpy for single-cell analysis
import scanpy as sc

# Neptune for experiment tracking
import neptune.new as neptune

# Suppress warnings
warnings.filterwarnings('ignore')

## The Model

In [14]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, features_dim):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, features_dim)
    def forward(self, z):
        a = self.linear1(z)  
        b = self.relu(a)   
        x_hat = self.linear2(b) 
        return x_hat
    
    
class Adversary(nn.Module):
    def __init__(self, latent_dim, hidden_dim, confounder_dims):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, sum(confounder_dims))
        )
        self.confounder_dims = confounder_dims

    def forward(self, z):
        outputs = self.layers(z)
        return torch.split(outputs, self.confounder_dims, dim=1)
    
    
class Predictor(nn.Module):
    def __init__(self, latent_dim, hidden_dim, dropout=0.5):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, z):
        return self.layers(z)
    
# Class for evaluating Weighted Loss
class WeightedMAELoss(nn.Module):
    def forward(self, predictions, targets, weights):
        return torch.mean(weights * torch.abs(predictions - targets))


# Main Model 
class ADAE(pl.LightningModule):
    def __init__(self,
                 features_dim, hidden_dim, latent_dim, 
                 confounder_dims, n_cat_conf, 
                 reg_constant,
                 age_coef,
                 cat_weights,
                 age_pred_dropout,
                 encoder_dropout,
                 lr,
                 plot_dir,
                 starting_epoch,
                 plot_every_e,
                 eval_every_e,
                 clip_loss,
                 epoch_for_decay_total,
                 steps_per_epoch,
                 lr_decrease_factor,
                 plot_features,
                 model_mode,
                 model_dir,
                 l1_reg_coef=0,
                 l2_reg_coef=1e-4,
                 model_eval_func=None,
                 save_model_interval=10):

        super(ADAE, self).__init__()
        self.save_hyperparameters()
        self.model_eval_func = model_eval_func
        self.mode = model_mode
        self.weights = [torch.tensor(item) for item in cat_weights]
        self.lamb = reg_constant
        self.confounder_dims = confounder_dims
        self.n_cat_conf = n_cat_conf
        self.output_adv_dim = sum(confounder_dims)
        self.age_coef = age_coef
        self.plot_features = plot_features
        self.plot_dir = plot_dir
        self.starting_epoch = starting_epoch
        self.plot_every_e = plot_every_e
        self.eval_every_e = eval_every_e
        self.clip_loss = clip_loss
        self.epoch_for_decay_total = epoch_for_decay_total
        self.steps_per_epoch = steps_per_epoch
        self.model_dir = model_dir
        self.lr_decrease_factor = lr_decrease_factor
        self.l1_reg_coef = l1_reg_coef
        self.l2_reg_coef = l2_reg_coef
        self.save_model_interval = save_model_interval

        self.lr = {
            'ae': lr['ae'],
            'adv': lr['adv'],
            'age': lr['age'],
        }

        self.lr_decrease = lr_decrease_factor

        # Data to Embedding
        self.encoder = nn.Sequential(
            #nn.BatchNorm1d(features_dim),
            nn.Dropout(p=encoder_dropout),
            nn.Linear(features_dim, hidden_dim),
            nn.ReLU(inplace=False),
            nn.Linear(hidden_dim, latent_dim))

        # Embedding to Data
        self.decoder = Decoder(
            latent_dim=latent_dim,
            hidden_dim=hidden_dim, 
            features_dim = features_dim
        )

        # Embedding to age prediction
        self.predictor = Predictor(
            latent_dim=latent_dim,
            hidden_dim=hidden_dim,
            dropout=age_pred_dropout
        )

        # Emb to confounder prediction
        self.adversary = Adversary(
            latent_dim=latent_dim,
            hidden_dim=hidden_dim,
            confounder_dims = confounder_dims,)

        self.automatic_optimization = False
        self.counter = 0

    class WeightedMAELoss(nn.Module):
        def __init__(self):
            super(WeightedMAELoss, self).__init__()

        def forward(self, predictions, targets, weights):
            return torch.mean(weights * torch.abs(predictions - targets))
                
    def forward(self, x):
        emb = self.encoder(x)
        if self.mode == 'ae':
            x_hat = self.decoder(emb)
            return x_hat
        if self.mode == 'encode':
            return emb
        elif self.mode == 'adv':
            y = self.adversary(emb)
            return y
        elif self.mode == 'age':
            y = self.predictor(emb)
            return y
        elif self.mode == 'full':
            x_hat = self.decoder(emb)
            y = self.adversary(emb)
            return x_hat, y
        elif self.mode == 'aeage':
            x_hat = self.decoder(emb)
            y = self.predictor(emb)
            return x_hat, y
        else:
            raise NotImplementedError()

    def configure_optimizers(self):
        ae_params = list(self.encoder.parameters()) + list(self.decoder.parameters())
        adv_params = list(self.adversary.parameters())
        age_params = list(self.predictor.parameters())

        final_lr = self.lr_decrease_factor

        num_epochs = self.epoch_for_decay_total
        num_steps_per_epoch = self.steps_per_epoch
        num_steps = num_epochs * num_steps_per_epoch

        decay_rate = 1 - (final_lr) ** (1 / num_steps)
        def exp_decay(step):
            return (1 - decay_rate) ** step

        opt_ae = torch.optim.Adam(ae_params, lr=self.lr['ae'], weight_decay=self.l2_reg_coef)
        opt_ad = torch.optim.Adam(adv_params, lr=self.lr['adv'], weight_decay=self.l2_reg_coef)
        opt_age = torch.optim.Adam(age_params, lr=self.lr['age'], weight_decay=self.l2_reg_coef)

        scheduler_ae = torch.optim.lr_scheduler.LambdaLR(opt_ae, lr_lambda=exp_decay)
        scheduler_ad = torch.optim.lr_scheduler.LambdaLR(opt_ad, lr_lambda=exp_decay)
        scheduler_age = torch.optim.lr_scheduler.LambdaLR(opt_age, lr_lambda=exp_decay)

        return [opt_ae, opt_ad, opt_age], [scheduler_ae, scheduler_ad, scheduler_age]
            
    def training_step(self, batch, batch_idx): 

        opt_ae, opt_ad, opt_age = self.optimizers()
        sch_ae, sch_ad, sch_age = self.lr_schedulers()

        current_epoch = self.current_epoch
        
        x, age, age_weights, *y = batch
        emb = self.encoder(x) 

        def l1_penalty(weights, lambda_l1=1e-3):
            return lambda_l1 * weights.abs().sum()
        
        mse_loss = WeightedMAELoss()
        
        if self.mode == 'ae':
            
            x_hat = self.decoder(emb)
            loss = F.mse_loss(x_hat, x) 

            opt_ae.zero_grad()
            self.manual_backward(loss)
            torch.nn.utils.clip_grad_norm_(list(self.encoder.parameters()) + list(self.decoder.parameters()), max_norm=1.0)
            opt_ae.step()
            sch_ae.step()

            self.log(f"Train: {self.mode}_loss", loss)
        
        elif self.mode == 'adv':
            y_hat = self.adversary(emb)
            losses_categorical = [F.cross_entropy(y_hat[i].to('cpu'), y[i].to('cpu'), weight=self.weights[i].to('cpu')).to('mps') for i in range(self.n_cat_conf)]
            losses_numerical = [] #TODO for continuous confounders
            loss = sum(losses_categorical) + sum(losses_numerical)

            opt_ad.zero_grad()
            self.manual_backward(loss)
            torch.nn.utils.clip_grad_norm_(list(self.adversary.parameters()), max_norm=1.0)
            opt_ad.step()
            sch_ad.step()

            self.log(f"Train: {self.mode}_loss", loss)

        elif self.mode == 'age':
            age_hat = self.predictor(emb)
            L1_loss = F.l1_loss(age_hat, age)
            loss = mse_loss(age_hat, age, age_weights)


            opt_age.zero_grad()
            self.manual_backward(loss)
            torch.nn.utils.clip_grad_norm_(list(self.predictor.parameters()), max_norm=1.0)
            opt_age.step()
            sch_age.step()

            self.log(f"Train: {self.mode}_loss", loss)
            self.log(f"Train: {self.mode}_L1_loss", L1_loss)

        elif self.mode == 'full': 

            x_hat = self.decoder(emb)
            y_hat = self.adversary(emb)
            ae_loss = F.mse_loss(F.normalize(x_hat), F.normalize(x))

            ae_loss = F.mse_loss(x_hat, x)  
            adv_losses_categorical = [F.cross_entropy(y_hat[i].to('cpu'), y[i].to('cpu'), weight=self.weights[i].to('cpu')).to('mps') for i in range(self.n_cat_conf)]
            adv_losses_numerical = [] #TODO for continuous confounders
            adv_loss = sum(adv_losses_categorical) + sum(adv_losses_numerical)

            scaled_adv_loss = adv_loss * self.lamb
            if scaled_adv_loss > ae_loss * self.clip_loss:
                scaled_adv_loss = ae_loss * self.clip_loss
            loss = ae_loss - scaled_adv_loss

            if (current_epoch) % 2 == 0:
                opt_ae.zero_grad()
                self.manual_backward(loss, retain_graph=True)
                torch.nn.utils.clip_grad_norm_(list(self.encoder.parameters()) + list(self.decoder.parameters()), max_norm=1.0)
                opt_ae.step()
                sch_ae.step()
            else:
                opt_ad.zero_grad()
                self.manual_backward(adv_loss)
                torch.nn.utils.clip_grad_norm_(list(self.adversary.parameters()), max_norm=1.0)
                opt_ad.step()
                sch_ad.step()


            self.log(f"Train: {self.mode}_total_loss", loss)
            self.log(f"Train: {self.mode}_adv_loss", adv_loss * self.lamb)
            self.log(f"Train: {self.mode}_ae_loss", ae_loss)

        elif self.mode == 'triple': 
            x_hat = self.decoder(emb)
            y_hat = self.adversary(emb)
            age_hat = self.predictor(emb)
            ae_loss = F.mse_loss(x_hat, x)  
            adv_losses_categorical = [F.cross_entropy(y_hat[i].to('cpu'), y[i].to('cpu'), weight=self.weights[i].to('cpu')).to('mps') for i in range(self.n_cat_conf)]
            adv_losses_numerical = [] ##TODO for continuous confounders
            L1_age_loss = F.l1_loss(age_hat, age)
            age_loss = mse_loss(age_hat, age, age_weights)
            adv_loss = sum(adv_losses_categorical) + sum(adv_losses_numerical)

            # Calculate the L1 penalty for each parameter group and add it to the loss
            l1_age_loss = 0
            for param in self.predictor.parameters():
                l1_age_loss += l1_penalty(param, lambda_l1=self.l1_reg_coef)

            scaled_age_loss = age_loss * self.age_coef
            if scaled_age_loss > ae_loss * self.clip_loss:
                scaled_age_loss = ae_loss * self.clip_loss
            scaled_adv_loss = adv_loss * self.lamb
            if scaled_adv_loss > ae_loss * self.clip_loss:
                scaled_adv_loss = ae_loss * self.clip_loss

            age_loss_reg = scaled_age_loss + l1_age_loss

            loss = ae_loss - scaled_adv_loss + age_loss_reg

            if (current_epoch) % 3 == 0:
                opt_ae.zero_grad()
                self.manual_backward(loss, retain_graph=True)
                torch.nn.utils.clip_grad_norm_(list(self.encoder.parameters()) + list(self.decoder.parameters()), max_norm=1.0)
                opt_ae.step()
                sch_ae.step()
            elif (current_epoch) % 3 == 1:
                opt_ad.zero_grad()
                self.manual_backward(adv_loss)
                torch.nn.utils.clip_grad_norm_(list(self.adversary.parameters()), max_norm=1.0)
                opt_ad.step()
                sch_ad.step()
            else:
                opt_age.zero_grad()
                self.manual_backward(age_loss)
                torch.nn.utils.clip_grad_norm_(list(self.predictor.parameters()), max_norm=1.0)
                opt_age.step()
                sch_age.step()
            
            self.log(f"Train: {self.mode}_total_loss", loss)
            self.log(f"Train: {self.mode}_adv_loss", adv_loss * self.lamb)
            self.log(f"Train: {self.mode}_age__weighted_loss", age_loss * self.age_coef)
            self.log(f"Train: {self.mode}_age_l1_weighted", age_loss_reg)
            self.log(f"Train: {self.mode}_age_L1_loss", L1_age_loss)
            self.log(f"Train: {self.mode}_ae_loss", ae_loss)

        elif self.mode == 'aeage': 
            x_hat = self.decoder(emb)
            age_hat = self.predictor(emb)
            ae_loss = F.mse_loss(x_hat, x)  
            L1_age_loss = F.l1_loss(age_hat, age)
            age_loss = mse_loss(age_hat, age, age_weights)
            loss = ae_loss + age_loss * self.age_coef
            
            if (current_epoch) % 2 == 0:
                opt_ae.zero_grad()
                self.manual_backward(loss)
                torch.nn.utils.clip_grad_norm_(list(self.encoder.parameters()) + list(self.decoder.parameters()), max_norm=1.0)
                opt_ae.step()
                sch_ae.step()
            else:
                opt_age.zero_grad()
                self.manual_backward(age_loss)
                torch.nn.utils.clip_grad_norm_(list(self.predictor.parameters()), max_norm=1.0)
                opt_age.step()
                sch_age.step()


            self.log(f"Train: {self.mode}_total_loss", loss)
            self.log(f"Train: {self.mode}_age__weighted_loss", age_loss * self.age_coef)
            self.log(f"Train: {self.mode}_age_L1_loss", L1_age_loss)
            self.log(f"Train: {self.mode}_ae_loss", ae_loss)

        else:
            raise NotImplementedError()

    
    def validation_step(self, batch, batch_idx):    
        x, age, age_weights, *y = batch

        emb = self.encoder(x) 
        if self.mode == 'ae':
            x_hat = self.decoder(emb)
            loss = F.mse_loss(x_hat, x) 
            self.log(f"Val: {self.mode}_loss", loss)
        elif self.mode == 'adv':
            y_hat = self.adversary(emb)
            losses_categorical = [F.cross_entropy(y_hat[0].to('cpu'), y[0].to('cpu'), weight=self.weights[0].to('cpu')).to('mps') for i in range(1)] # should be range(self.n_cat_conf) if categories are the same between train and test; 1 evaluates only first confounder in the list
            losses_numerical = [] #TODO for continuous confounders
            loss = sum(losses_categorical) + sum(losses_numerical)
            self.log(f"Val: {self.mode}_loss", loss) 
        elif self.mode == 'age':
            age_hat = self.predictor(emb)
            loss = F.l1_loss(age_hat, age)
            self.log(f"Val: {self.mode}_loss", loss)
        elif self.mode == 'full':
            x_hat = self.decoder(emb)
            y_hat = self.adversary(emb)
            ae_loss = F.mse_loss(x_hat, x)
            adv_losses_categorical = [F.cross_entropy(y_hat[i].to('cpu'), y[i].to('cpu'), weight=self.weights[i].to('cpu')).to('mps') for i in range(1)] # should be range(self.n_cat_conf) if categories are the same between train and test; 1 evaluates only first confounder in the list
            adv_losses_numerical = [] #TODO for continuous confounders
            adv_loss = sum(adv_losses_categorical) + sum(adv_losses_numerical)
            loss = ae_loss - adv_loss * self.lamb
            self.log(f"Val: {self.mode}_total_loss", loss)
            self.log(f"Val: {self.mode}_adv_loss", adv_loss * self.lamb)
            self.log(f"Val: {self.mode}_ae_loss", ae_loss)
            pass
        elif self.mode == 'triple': 
            x_hat = self.decoder(emb)
            y_hat = self.adversary(emb)
            age_hat = self.predictor(emb)
            ae_loss = F.mse_loss(x_hat, x)  

            adv_losses_categorical = [F.cross_entropy(y_hat[0].to('cpu'), y[0].to('cpu'), weight=self.weights[0].to('cpu')).to('mps') for i in range(1)] # should be range(self.n_cat_conf) if categories are the same between train and test; 1 evaluates only first confounder in the list

            adv_losses_numerical = [] #TODO for continuous confounders
            age_loss = F.l1_loss(age_hat, age)
            adv_loss = sum(adv_losses_categorical) + sum(adv_losses_numerical)
            scaled_age_loss = age_loss * self.age_coef
            if scaled_age_loss > ae_loss * 50:
                scaled_age_loss = ae_loss * 50
            scaled_adv_loss = adv_loss * self.lamb
            if scaled_age_loss > ae_loss * 50:
                scaled_age_loss = ae_loss * 50
            loss = ae_loss - scaled_adv_loss + scaled_age_loss
            self.log(f"Val: {self.mode}_total_loss", loss)
            self.log(f"Val: {self.mode}_adv_loss", adv_loss * self.lamb)
            self.log(f"Val: {self.mode}_age_loss", scaled_age_loss)
            self.log(f"Val: {self.mode}_age_real_loss", age_loss)
            self.log(f"Val: {self.mode}_ae_loss", ae_loss)
        elif self.mode == 'aeage': 
            x_hat = self.decoder(emb)
            age_hat = self.predictor(emb)
            ae_loss = F.mse_loss(x_hat, x)  
            age_loss = F.l1_loss(age_hat, age)
            loss = ae_loss + age_loss * self.age_coef
            self.log(f"Val: {self.mode}_total_loss", loss)
            self.log(f"Val: {self.mode}_age_loss", age_loss * self.age_coef)
            self.log(f"Val: {self.mode}_age_real_loss", age_loss)
            self.log(f"Val: {self.mode}_ae_loss", ae_loss)
        else:
            raise NotImplementedError()
        current_epoch = self.current_epoch

        if batch_idx == 0: # if epoch is started
            if (current_epoch+1) % self.save_model_interval == 0:
                model_path = f"{self.model_dir}/epoch_{current_epoch}.pth"
                torch.save(self.state_dict(), model_path)

            if (current_epoch+1) % self.eval_every_e == 0:
                    # Use outer function that is provded to evaluate
                    self_mae, mae, r2 = self.model_eval_func(self, current_epoch)

                    self.log(f"Lasso_train_MAE", float(self_mae))
                    self.log(f"Lasso_test_MAE", float(mae))
                    self.log(f"Lasso_test_R2", float(r2))

## Functions

In [15]:
class GeneExpressionData(Dataset):
    def __init__(self, expr, metainfo, age_weight_dict, cat_features=['sex'], num_features=[]):
        self.dataset = expr
        categorical_oh = [torch.tensor(pd.get_dummies(metainfo[c]).to_numpy(), dtype=torch.float32) for c in cat_features]
        numerical = [torch.tensor(metainfo[c].to_numpy()) for c in num_features]
        self.age = torch.tensor(metainfo['age'].to_numpy(), dtype=torch.float32)

        self.expr = []
        self.cat = []
        self.num = []

        agg = torch.tensor([[feat] for feat in self.age])
        age_distributed_weights = torch.tensor([[age_weight_dict[age]] for age in metainfo['age']], dtype=torch.float32)
        self.age = agg
        self.age_weights = age_distributed_weights
        for idx in range(len(self.dataset)):
            exp = torch.tensor(self.dataset[idx], dtype=torch.float)
            cat = [feature[idx].to(torch.float) for feature in categorical_oh]
            num = [feature[idx].to(torch.float) for feature in numerical]
            
            self.expr.append(exp)
            self.cat.append(cat)
            self.num.append(num)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return (self.expr[idx], self.age[idx], self.age_weights[idx], *self.cat[idx], *self.num[idx])
    

# Load pretrained model into the given model
# ae_only if we only need only embedding creation (for cases when, for example, confounder number has changed)
def load_pretrained(model:pl.LightningModule, pretrained_path:str, ae_only=False):
    pretrained_dict = torch.load(pretrained_path)
    if ae_only:
        encoder_dict = {k[8:]: v for k, v in pretrained_dict.items() if 'encoder' in k}
        decoder_dict = {k[8:]: v for k, v in pretrained_dict.items() if 'decoder' in k}
        model.encoder.load_state_dict(encoder_dict)
        model.decoder.load_state_dict(decoder_dict)
    else:
        model.load_state_dict(pretrained_dict)

# Get loaders and datasets for the given feature numbers and batch size
def get_loaders(X_train, X_test, y_train, y_test, age_weights_dict, cat_features=[], num_features=[], batch_size=64):
    # Create datasets
    cat_features = cat_features
    num_features = num_features
    train_dataset = GeneExpressionData(X_train, y_train, age_weights_dict, cat_features=cat_features, num_features=num_features)
    val_dataset = GeneExpressionData(X_test, y_test, age_weights_dict, cat_features=cat_features, num_features=num_features)

    # Create loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                            num_workers=0,
                            drop_last=False,
                            shuffle=True,
                            pin_memory=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                            num_workers=0,
                            drop_last=False, 
                            pin_memory=False)

    return train_loader, val_loader, train_dataset, val_dataset


def tt_split_setup(full_X, full_y, test_size=0.20, random_state=42, donor_test_list=None):
    if donor_test_list is None:
        X_train, X_test, y_train, y_test = train_test_split(full_X, full_y, test_size=0.20, random_state=42)
    else:
        X_train = full_X[~full_y.donor.isin(donor_test_list)]
        y_train = full_y[~full_y.donor.isin(donor_test_list)]
        X_test = full_X[full_y.donor.isin(donor_test_list)]
        y_test = full_y[full_y.donor.isin(donor_test_list)]
    return X_train, X_test, y_train, y_test


def load_data(dataset, tissue, filtered=True, normalize=False, verbose=True):
    H5_FOLDER = f'data/{dataset}'
    tissue = tissue

    if filtered:
        df_type = 'filtered'
    else:
        df_type = 'full'

    gene_list = np.load(f'{H5_FOLDER}/{tissue}_{df_type}_genes.npy')
    X = np.load(f'{H5_FOLDER}/{tissue}_{df_type}.npy')
    metadata = pd.read_csv(f'{H5_FOLDER}/{tissue}_meta.csv', index_col=0)
        
    y = deepcopy(metadata)
    
    if normalize:
        X = np.log1p(X)
        if verbose:
            print('Data normalized.')
    
    if verbose:
        print(f'Data shape: {X.shape}')
    return (X, y, gene_list)


# Load data and match genes (features) with another data (for example, load cross-tissue val data and match with train data to be able to use the model)
def load_data_match_genes(dataset, tissue, match_genes_with_tissue, filtered=True, verbose=True, normalize=False):
    H5_FOLDER = f'data/{dataset}'
    tissue = tissue

    if filtered:
        df_type = 'filtered'
    else:
        df_type = 'full'

    gene_list_to_match = np.load(f'{H5_FOLDER}/{match_genes_with_tissue}_{df_type}_genes.npy')

    gene_list = np.load(f'{H5_FOLDER}/{tissue}_{df_type}_genes.npy')
    X = np.load(f'{H5_FOLDER}/{tissue}_{df_type}.npy')
    metadata = pd.read_csv(f'{H5_FOLDER}/{tissue}_meta.csv', index_col=0)
        
    # Create an empty array with the desired shape and fill it with zeros
    new_data = np.zeros((X.shape[0], len(gene_list_to_match)), dtype=np.float32)

    # Mapping of gene names in the original list to their indices
    gene_index_map = {gene: idx for idx, gene in enumerate(gene_list)}

    # Fill the new data array
    for i, gene in enumerate(gene_list_to_match):
        if gene in gene_index_map:
            # Copy data from the original dataset if the gene exists
            new_data[:, i] = X[:, gene_index_map[gene]]
    
    
    y = deepcopy(metadata)
    n_missing_genes = len(set(gene_list_to_match) - set(gene_list))

    if normalize:
        new_data = np.log1p(new_data)
        if verbose:
            print('Data normalized.')

    if verbose:
        print(f'Data shape: {X.shape}')
        print("New Dataset Shape:", new_data.shape)
        print(f"Number of missing genes: {n_missing_genes} ({n_missing_genes/len(gene_list_to_match)*100:.1f}%)")

    return (new_data, y, gene_list)


### TUNING FUNCS ###

def save_metrics(scenario, hyperparams, epoch, metrics, header=False):
    # Create the directory structure
    metrics_dir = f"experiments/{scenario}/hyperparams_{hyperparams}"
    os.makedirs(metrics_dir, exist_ok=True)

    # Create the metrics.csv file path
    metrics_path = f"{metrics_dir}/metrics.csv"

    # Write the metrics to the CSV file
    with open(metrics_path, "a", newline="") as file:
        writer = csv.writer(file)

        # Write the header if it's the first row and header flag is set
        if header and file.tell() == 0:
            writer.writerow(["Epoch"] + list(metrics.keys()))

        # Write the epoch number and metric values
        writer.writerow([epoch] + list(metrics.values()))



### TRAINING FUNCS ###

def run_model(model, train_loader, val_loader, neptune_logger, callbacks, epochs, mode, full_epochs, scenario):
    
    # Get cpu or gpu device for training
    if torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    model.mode = mode
    trainer = pl.Trainer(accelerator=device, 
                        devices=1,
                        #gpus=1,
                        precision=16, 
                        #limit_train_batches=0.5, 
                        enable_checkpointing=True,
                        max_epochs=epochs, 
                        logger=neptune_logger,
                        #gradient_clip_val=1.0,
                        callbacks=callbacks,
                        log_every_n_steps=10, 
                        num_sanity_val_steps=0
                        )
    trainer.fit(model, train_loader, val_loader)

    torch.save(model.state_dict(), f"models/mice/2024_{scenario}_{mode}_{full_epochs}.joblib")

# Function to save the configuration to a YAML file
def save_config_to_yaml(config, file_path):
    with open(file_path, 'w') as file:
        yaml.dump(config, file, default_flow_style=False)

def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.load(file, Loader=yaml.FullLoader)
    return config


def generate_variations(param_values):
    param_names = list(param_values.keys())
    value_combinations = list(itertools.product(*param_values.values()))
    
    variations = []
    for combination in value_combinations:
        variation = {param_names[i]: value for i, value in enumerate(combination)}
        variations.append(variation)
    
    return variations


def create_hyperparameter_variations(scenario, base_hyperparams, variations, hypername='hyperparameters'):
    base_config_path = f"experiments/{scenario}/{base_hyperparams}/config.yaml"
    
    # Load the base configuration file
    base_config = load_config(base_config_path)
    
    # Create variations of the hyperparameters
    for i, variation in enumerate(variations, start=1):
        # Create a new hyperparameter folder
        if hypername is not None:
            new_hyperparams = f"{hypername}_{i}"
        else:
            new_hyperparams = variation['varname']
        new_hyperparams_path = f"experiments/{scenario}/{new_hyperparams}"
        os.makedirs(new_hyperparams_path, exist_ok=True)
        
        # Create a new configuration file with the variation
        new_config = base_config.copy()
        new_config.update(variation)
        new_config_path = f"{new_hyperparams_path}/config.yaml"

        new_config['model_dir'] = f"{new_hyperparams_path}/models"
        new_config['plot_dir'] = f"{new_hyperparams_path}/plots"

        save_config_to_yaml(new_config, new_config_path)
        
        # Create empty models and plots folders
        os.makedirs(f"{new_hyperparams_path}/models", exist_ok=True)
        os.makedirs(f"{new_hyperparams_path}/plots", exist_ok=True)
        
        # Create an empty metrics.csv file
        with open(f"{new_hyperparams_path}/metrics.csv", 'w') as file:
            pass


def generate_variations(param_values):
    param_names = list(param_values.keys())
    value_combinations = list(itertools.product(*param_values.values()))
    
    variations = []
    for combination in value_combinations:
        variation = {param_names[i]: value for i, value in enumerate(combination)}
        variations.append(variation)
    
    return variations


def get_random_samples(X, y, sample_size):
    num_samples = len(y)
    random_indices = resample(range(num_samples), n_samples=sample_size, random_state=42)
    new_X = X[random_indices]
    new_y = y[random_indices]
    return new_X, new_y

# class ValidationProgressBarDisabler(Callback):
#     def on_validation_start(self, trainer, pl_module):
#         trainer.progress_bar_callback.disable()

#     def on_validation_end(self, trainer, pl_module):
#         trainer.progress_bar_callback.enable()


def run_trainings(config_paths, X_test_dataset=None, ae_only=False, keyword=None, from_full=False):
    for config_path in tqdm(config_paths):
        if keyword is not None:
            if keyword not in config_path:
                continue
        print(config_path)
        with open(config_path, "r") as file:
            config = yaml.load(file, Loader=yaml.FullLoader)

        cat_features = config['cat_features']
        num_features = config['num_features']
        batch_size = config['batch_size']

        # Get cpu or gpu device for training
        if torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {device} device")


        # Data preparation
        if from_full:
            full_X, full_y, _= load_data(
                                    dataset=config['input_dataset'], 
                                    tissue='ALL_METHODS', 
                                    filtered=config['filtered'],
                                    normalize=False)
            full_X = full_X[full_y['tissue'] == config['train_tissue']]
            full_y = full_y[full_y['tissue'] == config['train_tissue']]
        else:
            full_X, full_y, _= load_data(
                                    dataset=config['input_dataset'], 
                                    tissue=config['train_tissue'], 
                                    filtered=config['filtered'],
                                    normalize=False)

        if X_test_dataset is None:
            X_train, X_test, y_train, y_test = tt_split_setup(full_X, full_y, donor_test_list=config['test_data_donors'])
        else:
            X_train, y_train = full_X, full_y
            X_test, y_test, _ = load_data_match_genes(config['input_dataset'], X_test_dataset, match_genes_with_tissue=config['train_tissue'], normalize=False)
            X_test, y_test = resample(X_test, y_test, n_samples=3200, random_state=42)
        
        del full_X, full_y
        if X_train.shape[0] > 128000:
            X_train, y_train = resample(X_train, y_train, n_samples=128000, random_state=42)
        X_eval_train, y_eval_train = resample(X_train, y_train, n_samples=4096, random_state=42)
        X_test, y_test = resample(X_test, y_test, n_samples=4096, random_state=42)
            

        config['steps_per_epoch'] = X_train.shape[0] // batch_size

        age_weights = 1 / pd.get_dummies(y_train['age']).sum(0)
        age_weights = age_weights / age_weights.sum() #weights for imbalance
        age_weights_dict = dict(age_weights)
        for age in y_test['age'].unique():
            if age not in y_train['age']:
                age_weights_dict[age] = 0

        # Compute category weights for imbalanced data
        cat_weights = []
        for category in cat_features:
            weights = 1 / pd.get_dummies(y_train[category]).sum(0)
            weights = torch.tensor(weights / weights.sum(), dtype=torch.float).to(device) #weights for imbalance
            cat_weights.append(weights)


        
        config['cat_weights'] = [t.tolist() for t in cat_weights]

        config['features_dim'] = X_train.shape[1]
        config['confounder_dims'] = [len(y_train[cat_features[i]].unique()) for i in range(len(cat_features))]

        train_loader, val_loader, _, _ = get_loaders(X_train, X_test, y_train, y_test, 
                                                    age_weights_dict=age_weights_dict,
                                                    cat_features=cat_features, 
                                                    num_features=[], 
                                                    batch_size=batch_size)
        del X_train, y_train
        
        
        num_models = 3
        if config['model_mode'] == 'ae' or 'age' or 'adv':
            num_models = 1
        elif config['model_mode'] == 'aeage' or 'full':
            num_models = 2

        config['epoch_for_decay_total'] = config['epochs_total'] // num_models

        # Parameters to exclude from the new configuration
        setup_only_params = ['input_dataset', 
                          'train_tissue', 
                          'filtered', 
                          'test_data_donors', 
                          'cat_features', 
                          'num_features', 
                          'scenario', 
                          'basic_model',
                          'batch_size',
                          'epochs_total',
                          'varname'
                          ]

        model_config = {key: value for key, value in config.items() if key not in setup_only_params}
        model_config_path = f"{os.path.dirname(config['model_dir'])}/model_config.yaml"
        save_config_to_yaml(model_config, model_config_path)

        def model_eval(model, current_epoch):
            
            tranformed_X_train = model.encoder(torch.tensor(X_eval_train).to(device)).to('cpu').detach().numpy()
            tranformed_X_test = model.encoder(torch.tensor(X_test).to(device)).to('cpu').detach().numpy()

            eval_model = Lasso()
            eval_model.fit(tranformed_X_train, y_eval_train.age)

            r2 = eval_model.score(tranformed_X_test, y_test.age)
            self_mae = mean_absolute_error(eval_model.predict(tranformed_X_train), y_eval_train.age)
            mae = mean_absolute_error(eval_model.predict(tranformed_X_test), y_test.age)

            return self_mae, mae, r2
        
            
        # Model creation
        model = ADAE(model_eval_func=model_eval, **model_config)
        path_pretrained = config['basic_model']
        if path_pretrained is not None:
            load_pretrained(model, path_pretrained, ae_only=ae_only)

        # Create a list of tags that includes the scenario and hyperparameter values
        tags = [
            f"input_dataset={config['input_dataset']}",
            f"train_tissue={config['train_tissue']}",
            f"filtered={config['filtered']}",
            f"test_data_donors={config['test_data_donors']}",
            f"cat_features={config['cat_features']}",
            f"num_features={config['num_features']}",
            f"scenario={config['scenario']}",
            f"basic_model={config['basic_model']}",
            f"model_mode={config['model_mode']}",
            f"model_dir={config['model_dir']}",
            f"plot_dir={config['plot_dir']}",
            f"features_dim={config['features_dim']}",
            f"hidden_dim={config['hidden_dim']}",
            f"latent_dim={config['latent_dim']}",
            f"confounder_dims={config['confounder_dims']}",
            f"n_cat_conf={config['n_cat_conf']}",
            f"reg_constant={config['reg_constant']}",
            f"age_coef={config['age_coef']}",
            f"cat_weights={config['cat_weights']}",
            f"encoder_dropout={config['encoder_dropout']}",
            f"age_pred_dropout={config['age_pred_dropout']}",
            f"batch_size={config['batch_size']}",
            f"lr={config['lr']}",
            f"starting_epoch={config['starting_epoch']}",
            f"plot_every_e={config['plot_every_e']}",
            f"eval_every_e={config['eval_every_e']}",
            f"plot_features={config['plot_features']}",
            f"clip_loss={config['clip_loss']}",
            f"epochs_total={config['epochs_total']}",
            f"epoch_for_decay_total={config['epoch_for_decay_total']}",
            f"steps_per_epoch={config['steps_per_epoch']}",
        ]

        neptune_logger = NeptuneLogger(
            project="hermanmoiseev/nnclocks",
            api_token="#TOKEN",
            tags=tags
        )

        callbacks = [
            pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        ]

        trainer = pl.Trainer(accelerator=device, 
                                devices=1,
                                precision=16, 
                                enable_checkpointing=False,
                                max_epochs=config['epochs_total'], 
                                logger=neptune_logger,
                                callbacks=callbacks,
                                log_every_n_steps=10, 
                                num_sanity_val_steps=0
                                )
        print('No problem :/')
        trainer.fit(model, train_loader, val_loader)
        model_path = f"{config['model_dir']}/epoch_{config['epochs_total']}.pth"
        torch.save(model.state_dict(), model_path)
        neptune_logger.experiment.stop()


# Get all config paths from the given scenario folder
def get_config_paths(experiments_dir, scenarios, model=False):
    config_paths = []

    search_pattern = 'config.yaml'
    if model:
        search_pattern = 'model_config.yaml'
    # Traverse through the experiments directory and its subdirectories
    for root, dirs, files in os.walk(experiments_dir):
        #if any(scenario in root for scenario in scenarios):

        # current_dir = os.path.basename(root)
        # if current_dir in scenarios:
        if any(scenario in root.split(os.path.sep) for scenario in scenarios):
            # Check if the current directory contains a config.yaml file
            if search_pattern in files:
                config_path = os.path.join(root, search_pattern)
                config_paths.append(config_path)

    return config_paths

In [16]:
def create_model(model_config):
    model = ADAE(**model_config)
    return model


def evaluate_model(X_train, y_train, X_test, y_test, epochs_range, scenario, configuration):

    # Load and evaluate ADAE model
    config = load_config(f'/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/{configuration}/model_config.yaml')
    model = ADAE(**config)
    model.to('cpu')
    model.eval()
    model.mode = 'age'

    train_age_mean = np.mean(y_train.age)
    train_age_median = np.median(y_train.age)

    print(f"Train Age: Mean: {train_age_mean}, Median: {train_age_median}")

    epochs = []
    self_maes = []
    maes = []
    r2s = []
    mses = []

    best_mae = float('inf')
    best_mae_epoch = None
    best_self_mae = float('inf')
    best_self_mae_epoch = None
    best_r2 = float('-inf')
    best_r2_epoch = None
    best_mse = float('inf')
    best_mse_epoch = None

    age_groups = y_test['age'].unique()
    num_groups = len(age_groups)
    group_maes = {age_group: [] for age_group in age_groups}
    best_group_maes = {age_group: float('inf') for age_group in age_groups}
    best_group_mae_epochs = {age_group: None for age_group in age_groups}

    for epoch in tqdm(range(*epochs_range)):
        load_pretrained(model, f"/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/{configuration}/models/epoch_{epoch}.pth")

        predicted_test_age = model(torch.tensor(X_test)).detach().numpy().tolist()
        predicted_train_age = model(torch.tensor(X_train)).detach().numpy().tolist()

        self_mae = mean_absolute_error(predicted_train_age, y_train.age)
        mae = mean_absolute_error(predicted_test_age, y_test.age)
        r2 = r2_score(y_true=y_test.age, y_pred=predicted_test_age)
        mse = mean_squared_error(predicted_test_age, y_test.age)

        epochs.append(epoch)
        self_maes.append(self_mae)
        maes.append(mae)
        r2s.append(r2)
        mses.append(mse)

        if mae < best_mae:
            best_mae = mae
            best_mae_epoch = epoch
        if self_mae < best_self_mae:
            best_self_mae = self_mae
            best_self_mae_epoch = epoch
        if r2 > best_r2:
            best_r2 = r2
            best_r2_epoch = epoch
        if mse < best_mse:
            best_mse = mse
            best_mse_epoch = epoch

        for age_group in age_groups:
            group_mask = y_test['age'] == age_group
            group_test_age = y_test.loc[group_mask, 'age']
            group_predicted_age = [predicted_test_age[j] for j in range(len(predicted_test_age)) if group_mask.iloc[j]]
            group_mae = mean_absolute_error(group_predicted_age, group_test_age)
            group_maes[age_group].append(group_mae)

            if group_mae < best_group_maes[age_group]:
                best_group_maes[age_group] = group_mae
                best_group_mae_epochs[age_group] = epoch

    print(f"Best MAE: {best_mae} at Epoch: {best_mae_epoch}")
    print(f"Best Self MAE: {best_self_mae} at Epoch: {best_self_mae_epoch}")
    print(f"Best R2: {best_r2} at Epoch: {best_r2_epoch}")
    print(f"Best MSE: {best_mse} at Epoch: {best_mse_epoch}")

    for age_group in age_groups:
        print(f"Best MAE for Age Group {age_group}: {best_group_maes[age_group]} at Epoch: {best_group_mae_epochs[age_group]}")

    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, self_maes)
    plt.xlabel('Epochs')
    plt.ylabel('Self MAE')
    plt.title('Self MAE vs. Epochs')

    plt.subplot(2, 2, 2)
    plt.plot(epochs, maes)
    plt.xlabel('Epochs')
    plt.ylabel('MAE')
    plt.title('MAE vs. Epochs')

    plt.subplot(2, 2, 3)
    plt.plot(epochs, r2s)
    plt.xlabel('Epochs')
    plt.ylabel('R2')
    plt.title('R2 vs. Epochs')

    plt.subplot(2, 2, 4)
    plt.plot(epochs, mses)
    plt.xlabel('Epochs')
    plt.ylabel('MSE')
    plt.title('MSE vs. Epochs')

    plt.tight_layout()
    plt.show()

    # Plot MAEs for each age group
    fig, axs = plt.subplots(num_groups, 1, figsize=(8, 4 * num_groups))

    for i, age_group in enumerate(age_groups):
        axs[i].plot(epochs, group_maes[age_group])
        axs[i].set_xlabel('Epochs')
        axs[i].set_ylabel('MAE')
        axs[i].set_title(f'MAE vs. Epochs (Age Group: {age_group})')

    plt.tight_layout()
    plt.show()


def get_best_results(config, results_df, test_tissues_list):
    best_results = []
    test_tissues = test_tissues_list.copy()
    test_tissues.remove(config['train_tissue'])

    for data_type in test_tissues:
        raw_data = results_df[(results_df['Tissue'] == data_type) & (results_df['Input Type'].str.contains('Plain Data'))]
        epoch_data = results_df[(results_df['Tissue'] == data_type) & (results_df['Input Type'].str.contains('Epoch'))]

        if len(raw_data) > 0:
            best_raw_test_mae = raw_data['MAE Test'].min()
            best_raw_test_mae_model = raw_data.loc[raw_data['MAE Test'].idxmin(), 'Model']
            best_raw_test_mae_input = raw_data.loc[raw_data['MAE Test'].idxmin(), 'Input Type']

            best_raw_r2 = raw_data['R2'].max()
            best_raw_r2_model = raw_data.loc[raw_data['R2'].idxmax(), 'Model']
            best_raw_r2_input = raw_data.loc[raw_data['R2'].idxmax(), 'Input Type']

            best_results.append({
                'Tissue': f"{data_type} (Raw)",
                'Best Test MAE': best_raw_test_mae,
                'Best Test MAE Model': best_raw_test_mae_model,
                'Best Test MAE Input': best_raw_test_mae_input,
                'Best R2': best_raw_r2,
                'Best R2 Model': best_raw_r2_model,
                'Best R2 Input': best_raw_r2_input
            })

        if len(epoch_data) > 0:
            best_epoch_test_mae = epoch_data['MAE Test'].min()
            best_epoch_test_mae_model = epoch_data.loc[epoch_data['MAE Test'].idxmin(), 'Model']
            best_epoch_test_mae_input = epoch_data.loc[epoch_data['MAE Test'].idxmin(), 'Input Type']
            best_epoch_test_mae_epoch = int(best_epoch_test_mae_input.split('Epoch ')[1].split(' ')[0])

            best_epoch_r2 = epoch_data['R2'].max()
            best_epoch_r2_model = epoch_data.loc[epoch_data['R2'].idxmax(), 'Model']
            best_epoch_r2_input = epoch_data.loc[epoch_data['R2'].idxmax(), 'Input Type']
            best_epoch_r2_epoch = int(best_epoch_r2_input.split('Epoch ')[1].split(' ')[0])

            best_results.append({
                'Tissue': data_type,
                'Best Test MAE': best_epoch_test_mae,
                'Best Test MAE Model': best_epoch_test_mae_model,
                'Best Test MAE Input': best_epoch_test_mae_input,
                'Best Test MAE Epoch': best_epoch_test_mae_epoch,
                'Best R2': best_epoch_r2,
                'Best R2 Model': best_epoch_r2_model,
                'Best R2 Input': best_epoch_r2_input,
                'Best R2 Epoch': best_epoch_r2_epoch
            })

    best_results_df = pd.DataFrame(best_results)
    return best_results_df


def iterate_scenarios_and_evaluate(experiments_dir, scenarios, test_tissues_list, eval_models_list, epoch_interval=10, max_epoch=None, sample_size=1000, X_test_input=None, y_test_input=None, allowed_config=None, deep_model=None):
    model_config_paths = get_config_paths(experiments_dir, scenarios, model=True)
    if allowed_config is not None:
        model_config_paths = [line for line in model_config_paths if allowed_config in line]
    evaluation_results = {}
    evaluation_best = {}

    for model_config_path in tqdm(model_config_paths, desc='Configurations'):
        model_config = load_config(model_config_path)
        model_dir = os.path.dirname(model_config_path)
        config = load_config(os.path.join(model_dir, 'config.yaml'))

        

        # Get the list of available epochs
        epoch_files = [f for f in os.listdir(os.path.join(model_dir, 'models')) if f.startswith('epoch_') and f.endswith('.pth')]
        epochs = sorted([int(f.split('_')[1].split('.')[0]) for f in epoch_files])

        # Filter epochs based on the specified interval
        if max_epoch is None:
            max_epoch = max(epochs)
        filtered_epochs = [epoch for epoch in epochs if epoch % epoch_interval == 0 and epoch <= max_epoch]

        # Evaluate the model
        evaluation_result = evaluate_model(config, model_config, filtered_epochs, test_tissues_list, eval_models_list, sample_size=sample_size, X_test_input=X_test_input, y_test_input=y_test_input, deep_model=deep_model)
        evaluation_results[os.path.basename(model_dir)] = evaluation_result
        evaluation_best[os.path.basename(model_dir)] = get_best_results(config, evaluation_result, test_tissues_list)

    return evaluation_results, evaluation_best

In [None]:
# Functions to plot numerical evaluations result

def plot_evaluations(df):
    # Get unique tissues, models, and input types
    tissues = df['Tissue'].unique()
    models = df['Model'].unique()
    input_types = ['Encoded', 'Transformed', 'Transformed PCA']

    # Iterate over each tissue
    for tissue in tissues:
        # Filter the DataFrame for the current tissue
        tissue_df = df[df['Tissue'] == tissue]

        # Iterate over each model
        for model in models:
            # Filter the DataFrame for the current model
            model_df = tissue_df[tissue_df['Model'] == model]

            # Create a figure and axis for the current model
            fig, ax = plt.subplots(figsize=(8, 3))

            # Iterate over each input type
            for input_type in input_types:
                # Filter the DataFrame for the current input type
                input_type_df = model_df[model_df['Input Type'].str.endswith(input_type)]

                # Exclude rows with "Plain Data" in the Input Type
                input_type_df = input_type_df[~input_type_df['Input Type'].str.contains('Plain Data')]

                # Extract the epoch numbers from the Input Type column
                epochs = input_type_df['Input Type'].str.extract('Epoch (\d+)', expand=False).astype(int)

                # Check if there are any missing epochs
                if epochs.isnull().any():
                    print(f"Warning: Missing epochs for {tissue} - {model} - {input_type}")
                    continue

                epochs = epochs.astype(int)

                # Get the corresponding MAE Test values
                mae_test = input_type_df['MAE Test']

                # Plot the MAE Test values against the epochs for the current input type
                ax.plot(epochs, mae_test, marker='o', label=input_type)

            # Set the title and labels for the current plot
            ax.set_title(f'{tissue} - {model}')
            ax.set_xlabel('Epoch')
            ax.set_ylabel('MAE Test')
            ax.legend()

            # Adjust the layout and display the plot
            plt.tight_layout()
            plt.show()

def plot_best_predictions(df):
    # Get unique tissues and input types
    tissues = df['Tissue'].unique()
    input_types = ['Encoded', 'Transformed', 'Transformed PCA']

    # Iterate over each tissue
    for tissue in tissues:
        # Filter the DataFrame for the current tissue
        tissue_df = df[df['Tissue'] == tissue]

        # Create a figure and axis for the current tissue
        fig, ax = plt.subplots(figsize=(8, 3))

        # Iterate over each input type
        for input_type in input_types:
            # Filter the DataFrame for the current input type
            input_type_df = tissue_df[tissue_df['Input Type'].str.endswith(input_type)]

            # Exclude rows with "Plain Data" in the Input Type
            input_type_df = input_type_df[~input_type_df['Input Type'].str.contains('Plain Data')]

            # Find the minimum MAE Test value for the current input type
            min_mae_test = input_type_df['MAE Test'].min()

            # Find the corresponding epoch for the minimum MAE Test value
            best_epoch = input_type_df.loc[input_type_df['MAE Test'] == min_mae_test, 'Input Type'].str.extract('Epoch (\d+)', expand=False).astype(int).iloc[0]

            # Plot the best prediction point for the current input type
            ax.plot(best_epoch, min_mae_test, marker='o', label=input_type)

            # Annotate the best prediction point with the MAE Test value
            ax.annotate(f'{min_mae_test:.2f}', (best_epoch, min_mae_test), textcoords="offset points", xytext=(0, 10), ha='center')

        # Set the title and labels for the current plot
        ax.set_title(f'Best Predictions - {tissue}')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('MAE Test')
        ax.legend()

        # Adjust the layout and display the plot
        plt.tight_layout()
        plt.show()
        
def plot_best_of_best_predictions(df):
    # Get unique tissues and epochs
    tissues = df['Tissue'].unique()
    epochs = df['Input Type'].str.extract('Epoch (\d+)', expand=False).astype(int).unique()

    # Iterate over each tissue
    for tissue in tissues:
        # Filter the DataFrame for the current tissue
        tissue_df = df[df['Tissue'] == tissue]

        # Create a figure and axis for the current tissue
        fig, ax = plt.subplots(figsize=(8, 3))

        # Iterate over each epoch
        for epoch in epochs:
            # Filter the DataFrame for the current epoch
            epoch_df = tissue_df[tissue_df['Input Type'].str.contains(f'Epoch {epoch}')]

            # Exclude rows with "Plain Data" in the Input Type
            epoch_df = epoch_df[~epoch_df['Input Type'].str.contains('Plain Data')]

            # Find the minimum MAE Test value for the current epoch
            min_mae_test = epoch_df['MAE Test'].min()

            # Plot the best prediction point for the current epoch
            ax.plot(epoch, min_mae_test, marker='o', label=f'Epoch {epoch}')

            # Annotate the best prediction point with the MAE Test value
            ax.annotate(f'{min_mae_test:.2f}', (epoch, min_mae_test), textcoords="offset points", xytext=(0, 10), ha='center')

        # Set the title and labels for the current plot
        ax.set_title(f'Best of Best Predictions - {tissue}')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('MAE Test')
        ax.legend()

        # Adjust the layout and display the plot
        plt.tight_layout()
        plt.show()


def plot_best_predictions_per_model(model_dfs):
    # Get unique tissues and epochs from the first DataFrame
    tissues = list(model_dfs.values())[0]['Tissue'].unique()
    
    # Get unique epochs from the first DataFrame, excluding NaN values
    epochs = list(model_dfs.values())[0]['Input Type'].str.extract('Epoch (\d+)', expand=False)
    epochs = epochs[~epochs.isnull()].astype(int).unique()

    # Iterate over each tissue
    for tissue in tissues:
        # Create a figure and axis for the current tissue
        fig, ax = plt.subplots(figsize=(8, 6))

        # Iterate over each model
        for model_name, df in model_dfs.items():
            # Filter the DataFrame for the current tissue
            tissue_df = df[df['Tissue'] == tissue]

            # Initialize lists to store the best MAE Test values and corresponding epochs for the current model
            best_mae_test = []
            best_epochs = []

            # Iterate over each epoch
            for epoch in epochs:
                # Filter the DataFrame for the current epoch
                epoch_df = tissue_df[tissue_df['Input Type'].str.contains(f'Epoch {epoch}')]

                # Exclude rows with "Plain Data" in the Input Type
                epoch_df = epoch_df[~epoch_df['Input Type'].str.contains('Plain Data')]

                # Find the minimum MAE Test value for the current epoch
                min_mae_test = epoch_df['MAE Test'].min()

                # Append the minimum MAE Test value and corresponding epoch to the lists
                best_mae_test.append(min_mae_test)
                best_epochs.append(epoch)

            # Plot the best prediction trend for the current model
            ax.plot(best_epochs, best_mae_test, marker='o', label=model_name)

        # Set the title and labels for the current plot
        ax.set_title(f'Best Predictions per Model - {tissue}')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('MAE Test')
        ax.legend()

        # Adjust the layout and display the plot
        plt.tight_layout()
        plt.show()

In [18]:
def generate_simple_variations(tissue_list, epochs_per_mode, scenario):
    modes = ['ae', 'adv', 'age', 'triple', 'aeage', 'full']
    variations = []

    for tissue in tissue_list:
        for mode in modes:
            variation = {
                'model_mode': mode,
                'train_tissue': tissue,
                'epochs_total': epochs_per_mode[mode],
                'varname': f'{mode}_{tissue}'
            }

            if mode in ['adv', 'age', 'triple', 'aeage', 'full']:
                if mode == 'adv':
                    base_model = f'experiments/{scenario}/ae_{tissue}/models/epoch_{epochs_per_mode["ae"]}.pth'
                elif mode == 'age':
                    base_model = f'experiments/{scenario}/adv_{tissue}/models/epoch_{epochs_per_mode["adv"]}.pth'
                else:
                    base_model = f'experiments/{scenario}/age_{tissue}/models/epoch_{epochs_per_mode["age"]}.pth'

                variation['basic_model'] = base_model

            variations.append(variation)

    return variations

In [19]:
# FUnctions to help with SHAP preparation

def evaluate_category_counts(y_train, categories):
    # Create a DataFrame from the y_train data
    df = pd.DataFrame(y_train, columns=categories)

    # Print the number of samples per each category
    print("Number of samples per category:")
    for category in categories:
        category_counts = df[category].value_counts()
        print(f"\n{category}:")
        print(category_counts)

    # Print the number of samples per each category combination
    print("\nNumber of samples per category combination:")
    category_combinations = list(product(*[df[category].unique() for category in categories]))
    combination_counts = {}
    for combination in category_combinations:
        mask = True
        for category, value in zip(categories, combination):
            mask &= (df[category] == value)
        count = mask.sum()
        combination_counts[combination] = count

    # Create a DataFrame from the combination counts
    combination_df = pd.DataFrame.from_dict(combination_counts, orient='index', columns=['Count'])
    combination_df.index = pd.MultiIndex.from_tuples(combination_df.index, names=categories)

    # # Print the combination counts DataFrame
    # print(combination_df)

    # Print the combination counts DataFrame
    print(tabulate(combination_df, headers='keys', tablefmt='grid'))


import pandas as pd

import numpy as np
import pandas as pd

def create_balanced_sample(X, y, categories, samples_per_category):
    # Convert X to a NumPy array
    X = np.asarray(X)

    # Remove rows with missing values in any of the category columns
    y = y.dropna(subset=categories)

    # Get the unique category combinations
    unique_combinations = y[categories].drop_duplicates().values

    # Create a list to store the balanced samples
    balanced_samples = []

    # Iterate over each unique category combination
    for combination in tqdm(unique_combinations):
        # Find the indices of samples matching the current combination
        mask = np.all(y[categories].values == combination, axis=1)
        indices = np.where(mask)[0]

        # Determine the number of samples to select for the current combination
        num_samples = min(len(indices), samples_per_category)

        # Randomly select the samples from the current combination
        selected_indices = np.random.choice(indices, size=num_samples, replace=False)

        # Append the selected samples to the balanced_samples list
        balanced_samples.append(X[selected_indices])

    # Concatenate the balanced samples into a single array
    balanced_X = np.concatenate(balanced_samples, axis=0)

    # Create the corresponding balanced y DataFrame
    balanced_y = pd.DataFrame(np.repeat(unique_combinations, samples_per_category, axis=0)[:len(balanced_X)], columns=categories)

    return balanced_X, balanced_y

import numpy as np

def get_worst_predictions_subset(model, X, y_age, percentile=80):
    # Make predictions using the model
    y_pred = model(X).squeeze().detach().numpy()

    # Calculate the absolute errors between the predictions and the true values
    errors = np.abs(y_pred - y_age)

    # Calculate the percentile threshold value
    threshold = np.percentile(errors, percentile)

    # Find the indices of the samples whose errors are above the threshold
    worst_indices = np.where(errors >= threshold)[0]

    # Get the subset of X corresponding to the worst predictions
    worst_subset = X[worst_indices]

    return worst_subset


In [21]:
def get_predictions(X_train, y_train, X_test, y_test, configuration, epoch, name):
    scenario = 'WHOLE_CROSS'
    config = load_config(f'/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/{configuration}/model_config.yaml')
    model = ADAE(**config)
    model.to('cpu')
    model.eval()
    model.mode = 'age'
    if epoch == 5:
        load_pretrained(model, f"/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/age_BOTH_NOFILTER/models/epoch_2.pth")
    else:
        load_pretrained(model, f"/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/{configuration}/models/epoch_{epoch}.pth")

    X_train_encoded = model.encoder(torch.tensor(X_train)).detach().numpy()
    X_test_encoded = model.encoder(torch.tensor(X_test)).detach().numpy()

    ridge = Ridge()
    lasso = Lasso()

    ridge.fit(X_train_encoded, y_train.age)
    lasso.fit(X_train_encoded, y_train.age)

    ridge_pred = ridge.predict(X_test_encoded)
    lasso_pred = lasso.predict(X_test_encoded)
    agenn_pred = model.predictor(torch.tensor(X_test_encoded)).detach().numpy().flatten()

    ridge_abs_diff = np.abs(ridge_pred - y_test.age)
    lasso_abs_diff = np.abs(lasso_pred - y_test.age)
    agenn_abs_diff = np.abs(agenn_pred - y_test.age)

    ridge_mae = mean_absolute_error(y_test.age, ridge_pred)
    lasso_mae = mean_absolute_error(y_test.age, lasso_pred)
    agenn_mae = mean_absolute_error(y_test.age, agenn_pred)

    ridge_r2 = r2_score(y_test.age, ridge_pred)
    lasso_r2 = r2_score(y_test.age, lasso_pred)
    agenn_r2 = r2_score(y_test.age, agenn_pred)

    return {
        "name": name,
        "ridge_pred": ridge_pred,
        "lasso_pred": lasso_pred,
        "agenn_pred": agenn_pred,
        "ridge_abs_diff": ridge_abs_diff,
        "lasso_abs_diff": lasso_abs_diff,
        "agenn_abs_diff": agenn_abs_diff,
        "ridge_mae": ridge_mae,
        "lasso_mae": lasso_mae,
        "agenn_mae": agenn_mae,
        "ridge_r2": ridge_r2,
        "lasso_r2": lasso_r2,
        "agenn_r2": agenn_r2
    }

# Model Config setup and Training

## Generate Configs

In [None]:
# If using standard apporach, this config does not need to specify epoch number, basic model, model_mode etc; it is don in the next cell

scenario = f'WHOLE_CROSS' # Any name
hyperparams_name = 'default' # Any name
cat_features = ('method',) # List of confounders

# List of donors used for testing
test_donors = ['1-M-62', '3_10_M', '3_11_M', '3_39_F', '3-M-5/6', '18_53_M', '18-M-53', '21_48_F', '24_61_M', '24-M-61', '30-M-4']


config = {
    'input_dataset': "mouse", # Data folder name
    'train_tissue': 'BOTH_NOFILTER', # Data file name used in preprocessing.ipynb
    'filtered': True, 
    'test_data_donors': test_donors, 
    'cat_features': cat_features,
    'num_features': tuple(), # Not used for now
    'scenario': scenario,
    'basic_model': None, # If None, the model is trained with the basic architecture
    'model_mode': 'no_run_default_config', # Any for standard approach; specified in the next cell
    'model_dir': f"experiments/{scenario}/{hyperparams_name}/models", # Folder to save models into
    'plot_dir': f"experiments/{scenario}/{hyperparams_name}t/plots", # Folder to save plots into (not used in the final version)
    'hidden_dim': 300, # Size of hidden layer (in encoder, deconder, age predictor, confounder predictor)
    'latent_dim': 100, # Size of the embedding
    'n_cat_conf': len(cat_features),
    'reg_constant': 0.25, # D/R ratio (degree of deconfoundment for ADAE and AgeADAE)
    'age_coef': 8, # E/R ratio (degree of age-enhancement for AgeAE and AgeADAE)
    'encoder_dropout': 0.2, # Dropout in encoder
    'age_pred_dropout': 0.2, # Dropout in age predictor
    'batch_size': 128, # Batch size
    'lr': {
        'ae': 1e-3, # LR of autoencoder; main LR, utilized also for deconfoundment and age-enhancement
        'adv': 1e-2, # LR of confounder predictor; should be set so that counfounders are predicted with high accuracy
        'age': 3e-3 # LR of age predictor; same as for adv LR
    },
    'lr_decrease_factor': 3e-2, # LR decrease factor
    'starting_epoch': 0, # Zero if training is started from scratch
    'plot_every_e': 99999, # Plot interval (not used in the final version)
    'eval_every_e': 1, # Evaluation interval
    'plot_features': ('age', 'method'), # Features to plot (not used in the final version)
    'clip_loss': 10000, # Clip loss to avoid exploding gradients (set to a very high value as is not used in the final version)
    'epochs_total': -999, # Not set in the default config
    'epoch_for_decay_total': -999, # Not set in the default config
    'steps_per_epoch': -999, # Not set in the default config
    'l1_reg_coef': 1e-2, # L1 Regularization
    'l2_reg_coef': 1e-4, # L2 Regularization
}

# Create dirs if not exist
model_dir = f"experiments/{scenario}/{hyperparams_name}/models" 
plot_dir = f"experiments/{scenario}/{hyperparams_name}/plots"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)

# Save config to yaml
config_path = f"experiments/{scenario}/{hyperparams_name}/config.yaml"
save_config_to_yaml(config, config_path)

print(config_path)

In [None]:
# Generate basic iterative setup (For each tissue provided in tissue_list)
# First is AE training, then ADV trains on pretrained AE, AGE trains on pretrained ADV+AE
# Then either Triple, AeAge or Full is trained on pretrained AE+ADV+AGE

scenario = "WHOLE_CROSS"
base_hyperparams = "default"

# Form such setup for each of these tissues
tissue_list = ['BOTH_NOFILTER']

epochs_per_mode = {
    'ae': 5,
    'adv': 2,
    'age': 2,
    'triple': 240,
    'aeage': 150,
    'full': 150
}

variations = generate_simple_variations(tissue_list, epochs_per_mode, scenario)
create_hyperparameter_variations(scenario, base_hyperparams, variations, hypername=None)

## Run Training

In [None]:
experiments_dir = 'experiments'
scenarios = ['WHOLE_CROSS']  # Specify the desired scenarios
all_configs = get_config_paths(experiments_dir, scenarios)

# Pretraining
run_trainings(all_configs, keyword='/ae_') # AE
run_trainings(all_configs, keyword='/adv_') # Counfounder predictor
run_trainings(all_configs, keyword='/age_') # AgeNN

# Training (one of those or all three to compare; they all start with the pretrained AE and separately pretrained adv and age based on the embedding)
run_trainings(all_configs, keyword='/triple') # AgeADAE
#run_trainings(all_configs, keyword='/aeage') # AgeAE
#run_trainings(all_configs, keyword='/full') # ADAE

# Evaluations

### MAE dynamics with epochs (for one specific model configuration)

In [None]:
full_X, full_y, _= load_data(
    dataset='mouse', 
    tissue='BOTH_NOFILTER')

X_test, y_test, _ = load_data_match_genes('mouse',
                    'ALL_FACS',
                    match_genes_with_tissue='BOTH_NOFILTER')

donor_tests = ['1-M-62', '3_10_M', '3_10_M/3_11_M', '3_11_M', '3_39_F', '3-M-5/6', '18_53_M', '18-M-53', '21_48_F', '24_61_M', '24-M-61', '30-M-4']

X_train, _, y_train, _ = tt_split_setup(full_X, full_y, donor_test_list=donor_tests)
_, X_test, _, y_test = tt_split_setup(X_test, y_test, donor_test_list=donor_tests)
del full_X, full_y
X_train, y_train = resample(X_train, y_train, n_samples=128000, random_state=42)


# Train baseline model
baseline_model = Lasso()
baseline_model.fit(X_train, y_train.age)

train_baseline_predictions = baseline_model.predict(X_train)
test_baseline_predictions = baseline_model.predict(X_test)

train_r2 = r2_score(y_true=y_train.age, y_pred=train_baseline_predictions)
train_mse = mean_squared_error(train_baseline_predictions, y_train.age)
train_mae = mean_absolute_error(train_baseline_predictions, y_train.age)

test_r2 = r2_score(y_true=y_test.age, y_pred=test_baseline_predictions)
test_mse = mean_squared_error(test_baseline_predictions, y_test.age)
test_mae = mean_absolute_error(test_baseline_predictions, y_test.age)

print(f"Train: R2: {train_r2}, MSE: {train_mse}, MAE: {train_mae}")
print(f"Test: R2: {test_r2}, MSE: {test_mse}, MAE: {test_mae}")

scenario = 'WHOLE_CROSS'
configuration = 'triple_BOTH_NOFILTER'

evaluate_model(X_train, y_train, X_test, y_test, epochs_range=(0, 151, 3), scenario=scenario, configuration=configuration)

### MAE and R2 compared across train, test and cross-tissue data for AE, AgeAE and AgeADAE

1. Evaluate Predictions

In [None]:
# Evaluate Predictions

import numpy as np
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.linear_model import Ridge, Lasso

# General loading
full_X, full_y, _= load_data(
    dataset='mouse', 
    tissue='BOTH_NOFILTER', 
    filtered=True,
    normalize=False)

donor_tests = ['1-M-62', '3_10_M', '3_10_M/3_11_M', '3_11_M', '3_39_F', '3-M-5/6', '18_53_M', '18-M-53', '21_48_F', '24_61_M', '24-M-61', '30-M-4']
# Test Data
X_train, X_test, y_train, y_test = tt_split_setup(full_X, full_y, donor_test_list=donor_tests)
del full_X, full_y
# Train Data
X_train, y_train = resample(X_train, y_train, n_samples=128000, random_state=42)

# Cross-tissue Data
X_test_cross, y_test_cross, _ = load_data_match_genes('mouse',
                    'ALL_FACS',
                    'BOTH_NOFILTER',
                    True)
_, X_test_cross, _, y_test_cross = tt_split_setup(X_test_cross, y_test_cross, donor_test_list=donor_tests)

# Initialize lists to store results
results = []
# AgeADAE configuration
configuration = 'triple_BOTH_NOFILTER'
epoch = 168
# Train
results.append(get_predictions(X_train, y_train, X_train, y_train, configuration, epoch, "AgeADAE Train"))
# Test
results.append(get_predictions(X_train, y_train, X_test, y_test, configuration, epoch, "AgeADAE Test"))
# Cross
results.append(get_predictions(X_train, y_train, X_test_cross, y_test_cross, configuration, epoch, "AgeADAE Cross"))

# AgeAE configuration
configuration = 'aeage_BOTH_NOFILTER'
epoch = 123
# Train
results.append(get_predictions(X_train, y_train, X_train, y_train, configuration, epoch, "AgeAE Train"))
# Test
results.append(get_predictions(X_train, y_train, X_test, y_test, configuration, epoch, "AgeAE Test"))
# Cross
results.append(get_predictions(X_train, y_train, X_test_cross, y_test_cross, configuration, epoch, "AgeAE Cross"))

# AgeAE configuration
configuration = 'ae_BOTH_NOFILTER'
epoch = 5
# Train
results.append(get_predictions(X_train, y_train, X_train, y_train, configuration, epoch, "AE Train"))
# Test
results.append(get_predictions(X_train, y_train, X_test, y_test, configuration, epoch, "AE Test"))
# Cross
results.append(get_predictions(X_train, y_train, X_test_cross, y_test_cross, configuration, epoch, "AE Cross"))


2. Plot Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Assuming the results list is available and structured as in the provided code
# Let's create a DataFrame for better visualization and plotting

results_data = []
for result in results:
    result_data = {
        "name": result["name"],
        "ridge_mae": result["ridge_mae"],
        "lasso_mae": result["lasso_mae"],
        "agenn_mae": result["agenn_mae"],
        "ridge_r2": result["ridge_r2"],
        "lasso_r2": result["lasso_r2"],
        "agenn_r2": result["agenn_r2"],
        "ridge_abs_diff": result["ridge_abs_diff"],
        "lasso_abs_diff": result["lasso_abs_diff"],
        "agenn_abs_diff": result["agenn_abs_diff"]
    }
    results_data.append(result_data)

df_results = pd.DataFrame(results_data)

# Sort the DataFrame for consistent order: AE, AgeAE, AgeADAE
df_results['name_order'] = df_results['name'].apply(lambda x: 0 if 'AE ' in x else 1 if 'AgeAE ' in x else 2)
df_results = df_results.sort_values(by=['name_order', 'name'])

# Remove the additional column used for sorting
df_results.drop(columns='name_order', inplace=True)

# Plotting MAE comparison
fig, ax = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
datasets = ['Train', 'Test', 'Cross']

for i, dataset in enumerate(datasets):
    dataset_results = df_results[df_results['name'].str.contains(dataset)]

    # Ensure the correct order: AE, AgeAE, AgeADAE
    dataset_results = dataset_results.set_index('name').reindex([f'AE {dataset}', f'AgeAE {dataset}', f'AgeADAE {dataset}']).reset_index()

    width = 0.25  # width of each bar
    x = np.arange(len(dataset_results))

    ax[i].bar(x - width, dataset_results['ridge_mae'], width, label='Ridge', alpha=0.7, color='b')
    ax[i].bar(x, dataset_results['lasso_mae'], width, label='Lasso', alpha=0.7, color='orange')
    if dataset_results['agenn_mae'].notna().any():
        valid_agenn_indices = dataset_results['agenn_mae'].notna()
        ax[i].bar(x[valid_agenn_indices] + width, dataset_results.loc[valid_agenn_indices, 'agenn_mae'], width, label='AgeNN', alpha=0.7, color='g')

    ax[i].set_xticks(x)
    ax[i].set_xticklabels([name.replace(dataset, '').strip() for name in dataset_results['name']], fontsize=14)
    if dataset == 'Cross':
        ax[i].set_title(f'Cross-Tissue Data', fontsize=16, fontweight='bold')
    else:
        ax[i].set_title(f'{dataset} Data', fontsize=16, fontweight='bold')
    ax[i].set_ylabel('MAE', fontsize=14)
    ax[i].legend(fontsize=14)

    # Adding numerical annotations on top of bars
    for j in range(len(dataset_results)):
        ax[i].text(x[j] - width, dataset_results['ridge_mae'].iloc[j] + 0.1, round(dataset_results['ridge_mae'].iloc[j], 2), ha='center', fontsize=12)
        ax[i].text(x[j], dataset_results['lasso_mae'].iloc[j] + 0.1, round(dataset_results['lasso_mae'].iloc[j], 2), ha='center', fontsize=12)
        if pd.notna(dataset_results['agenn_mae'].iloc[j]):
            ax[i].text(x[j] + width, dataset_results['agenn_mae'].iloc[j] + 0.1, round(dataset_results['agenn_mae'].iloc[j], 2), ha='center', fontsize=12)

for a in ax:
    a.grid(True, which='both', axis='y', linestyle='--', linewidth=0.5)

plt.suptitle('MAE Comparison across Models and Datasets', fontsize=20, fontweight='bold')
plt.tight_layout()
plt.show()

# Plotting R2 comparison
fig, ax = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
for i, dataset in enumerate(datasets):
    dataset_results = df_results[df_results['name'].str.contains(dataset)]
    dataset_results = dataset_results.set_index('name').reindex([f'AE {dataset}', f'AgeAE {dataset}', f'AgeADAE {dataset}']).reset_index()
    width = 0.25  # width of each bar
    x = np.arange(len(dataset_results))

    ax[i].bar(x - width, dataset_results['ridge_r2'], width, label='Ridge', alpha=0.7, color='b')
    ax[i].bar(x, dataset_results['lasso_r2'], width, label='Lasso', alpha=0.7, color='orange')
    if dataset_results['agenn_r2'].notna().any():
        valid_agenn_indices = dataset_results['agenn_r2'].notna()
        ax[i].bar(x[valid_agenn_indices] + width, dataset_results.loc[valid_agenn_indices, 'agenn_r2'], width, label='AgeNN', alpha=0.7, color='g')

    ax[i].set_xticks(x)
    ax[i].set_xticklabels([name.replace(dataset, '').strip() for name in dataset_results['name']], fontsize=14)
    if dataset == 'Cross':
        ax[i].set_title(f'Cross-Tissue Data', fontsize=16, fontweight='bold')
    else:
        ax[i].set_title(f'{dataset} Data', fontsize=16, fontweight='bold')
    ax[i].set_ylabel('R2 Score', fontsize=14)
    ax[i].legend(fontsize=14)

    # Adding numerical annotations on top of bars
    for j in range(len(dataset_results)):
        ax[i].text(x[j] - width, dataset_results['ridge_r2'].iloc[j] + 0.01, round(dataset_results['ridge_r2'].iloc[j], 2), ha='center', fontsize=12)
        ax[i].text(x[j], dataset_results['lasso_r2'].iloc[j] + 0.01, round(dataset_results['lasso_r2'].iloc[j], 2), ha='center', fontsize=12)
        if pd.notna(dataset_results['agenn_r2'].iloc[j]):
            ax[i].text(x[j] + width, dataset_results['agenn_r2'].iloc[j] + 0.01, round(dataset_results['agenn_r2'].iloc[j], 2), ha='center', fontsize=12)

for a in ax:
    a.grid(True, which='both', axis='y', linestyle='--', linewidth=0.5)

plt.suptitle('R2 Score Comparison across Models and Datasets', fontsize=20, fontweight='bold')
plt.tight_layout()
plt.show()

# Additional insights using absolute differences
fig, ax = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
for i, dataset in enumerate(datasets):
    dataset_results = [result for result in results if dataset in result["name"]]
    dataset_results = sorted(dataset_results, key=lambda x: ['AE', 'AgeAE', 'AgeADAE'].index(x['name'].split()[0]))
    width = 0.2  # width of each boxplot group
    x = np.arange(len(dataset_results))

    box1 = ax[i].boxplot([result["ridge_abs_diff"] for result in dataset_results],
                         positions=x - width, widths=width, patch_artist=True, boxprops=dict(facecolor='b', color='b'), 
                         medianprops=dict(color='red'))
    box2 = ax[i].boxplot([result["lasso_abs_diff"] for result in dataset_results],
                         positions=x, widths=width, patch_artist=True, boxprops=dict(facecolor='orange', color='orange'), 
                         medianprops=dict(color='red'))
    box3 = ax[i].boxplot([result["agenn_abs_diff"] for result in dataset_results if result["agenn_abs_diff"] is not None],
                         positions=x + width, widths=width, patch_artist=True, boxprops=dict(facecolor='g', color='g'), 
                         medianprops=dict(color='red'))

    ax[i].set_xticks(x)
    ax[i].set_xticklabels([result["name"].replace(dataset, '').strip() for result in dataset_results], fontsize=14)
    if dataset == 'Cross':
        ax[i].set_title(f'Cross-Tissue Data', fontsize=16, fontweight='bold')
    else:
        ax[i].set_title(f'{dataset} Data', fontsize=16, fontweight='bold')
    ax[i].set_ylabel('Absolute Differences', fontsize=14)
    ax[i].legend([box1["boxes"][0], box2["boxes"][0], box3["boxes"][0]], ['Ridge', 'Lasso', 'AgeNN'], fontsize=14)

# Make dots grey
for boxes in [box1, box2, box3]:
    if boxes:
        for flier in boxes['fliers']:
            flier.set(marker='o', color='grey', alpha=0.5)

for a in ax:
    a.grid(True, which='both', axis='y', linestyle='--', linewidth=0.5)
plt.suptitle('Absolute Prediction Errors Distribution', fontsize=20, fontweight='bold')
plt.tight_layout()
plt.show()


3. AgeNN visual evaluation

In [None]:
scenario = 'WHOLE_CROSS'
configuration = 'triple_BOTH_NOFILTER'
epoch = 168

config = load_config(f'/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/{configuration}/model_config.yaml')
model = ADAE(**config)
model.to('cpu')
model.eval()
model.mode = 'age'
load_pretrained(model, f"/Users/mindblaze/Desktop/Thesis/clocks/experiments/{scenario}/{configuration}/models/epoch_{epoch}.pth")

real_age = y_train.age
predicted_age = results[0]['agenn_pred']
real_test_age = y_test.age
predicted_test_age = results[1]['agenn_pred']

X_emb = model.encoder(torch.tensor(X_train)).detach().numpy()
X_test_emb = model.encoder(torch.tensor(X_test)).detach().numpy()

pca = PCA(n_components=2)
X_train_pca = pca.fit_transform(X_emb)
test_pca = PCA(n_components=2)
X_test_pca = test_pca.fit_transform(X_test_emb)

In [None]:
fig, axes = plt.subplots(1,2,figsize=(14,6))

sns.scatterplot(x=X_train_pca[:, 0], y=X_train_pca[:, 1], hue=y_train.method, ax=axes[0], palette='viridis_r').set(title=f'Real Age')
sns.scatterplot(x=X_train_pca[:, 0], y=X_train_pca[:, 1], hue=y_train.sex, ax=axes[1], palette='viridis_r', legend=False).set(title=f'Predicted Age')

norm = plt.Normalize(0, 30)
sm = plt.cm.ScalarMappable(cmap="viridis_r", norm=norm)
sm.set_array([])
# fig.colorbar(sm, ax=axes[0])
# fig.colorbar(sm, ax=axes[1])
plt.subplots_adjust(wspace=0.01, hspace=0.1)

for ax in axes.flat:
    ax.set_xticklabels([])  # Remove x-axis tick labels
    ax.set_yticklabels([])  # Remove y-axis tick labels

plt.show()

fig, axes = plt.subplots(1,2,figsize=(14,6))

sns.scatterplot(x=X_test_pca[:, 0], y=X_test_pca[:, 1], hue=real_test_age, ax=axes[0], palette='viridis_r').set(title=f'Real Age')
sns.scatterplot(x=X_test_pca[:, 0], y=X_test_pca[:, 1], hue=predicted_test_age, ax=axes[1], palette='viridis_r', legend=False).set(title=f'Predicted Age')

norm = plt.Normalize(0, 30)
sm = plt.cm.ScalarMappable(cmap="viridis_r", norm=norm)
sm.set_array([])
fig.colorbar(sm, ax=axes[0])
fig.colorbar(sm, ax=axes[1])
plt.subplots_adjust(wspace=0.01, hspace=0.1)

for ax in axes.flat:
    ax.set_xticklabels([])  # Remove x-axis tick labels
    ax.set_yticklabels([])  # Remove y-axis tick labels

plt.show()


### Visual for a specific model

In [None]:
def eval_model(model_path, features=('method', 'age',), cross_tissue=False, tsne=False):
    model_config_path = os.path.join(os.path.dirname(os.path.dirname(model_path)), 'model_config.yaml')
    model_config = load_config(model_config_path)
    model = ADAE(**model_config)
    load_pretrained(model, model_path)

    pca_orig = PCA(n_components=2)
    pca_enc = PCA(n_components=2)
    pca_res = PCA(n_components=2)

    tissue = 'Train_LSK'

    tissue_X, tissue_y, _ = load_data_match_genes(
                            dataset='mouse', 
                            tissue=tissue, 
                            match_genes_with_tissue=tissue,
                            filtered=True,
                            verbose=False)


    data, metadata = tissue_X, tissue_y
    del tissue_X, tissue_y

    p_orig = pca_orig.fit_transform(data)
    p_enc = pca_enc.fit_transform(model.encoder(torch.tensor(data)).detach().numpy())
    p_res = pca_res.fit_transform(model.decoder(model.encoder(torch.tensor(data))).detach().numpy())

    explained_variance_ratio_orig = pca_orig.explained_variance_ratio_
    explained_variance_ratio_enc = pca_enc.explained_variance_ratio_
    explained_variance_ratio_res = pca_res.explained_variance_ratio_

    for feature in features:

        fig, axes = plt.subplots(1,3,figsize=(16,6))

        sns.scatterplot(x=p_orig[:, 0], y=p_orig[:, 1], hue=metadata[feature], ax=axes[0]).set(title=f'Original | C1: {explained_variance_ratio_orig[0]*100:.1f}% C2: {explained_variance_ratio_orig[1]*100:.1f}%')
        sns.scatterplot(x=p_enc[:, 0], y=p_enc[:, 1], hue=metadata[feature], ax=axes[1]).set(title=f'Encoded | C1: {explained_variance_ratio_enc[0]*100:.1f}% C2: {explained_variance_ratio_enc[1]*100:.1f}%')
        sns.scatterplot(x=p_res[:, 0], y=p_res[:, 1], hue=metadata[feature], ax=axes[2]).set(title=f'Restored | C1: {explained_variance_ratio_res[0]*100:.1f}% C2: {explained_variance_ratio_res[1]*100:.1f}%')
        fig.suptitle(f'PCA comparison: train')

        fig.show()

In [None]:
eval_model(f'experiments/TRIPLE_COMBO/new_best_normalized/models/epoch_150.pth', features=('method','age', 'tissue', 'sex'))

### AgeNN Visuals Train Test

In [None]:
full_X, full_y, feature_names= load_data(
    dataset='mouse', 
    tissue='BOTH_NOFILTER', 
    filtered=True,
    normalize=False)

donor_tests = ['1-M-62', '3_57_F', '3-F-57', '18_45_M', '21_48_F', '24_60_M', '24-M-60', '30-M-4']

X_train, X_test, y_train, y_test = tt_split_setup(full_X, full_y, donor_test_list=donor_tests)
del full_X, full_y

X_train, y_train = resample(X_train, y_train, n_samples=128000, random_state=42)

In [None]:
config = load_config('experiments/WHOLE_IDEAL/aeage_BOTH_NOFILTER/model_config.yaml')
model = ADAE(**config)
load_pretrained(model, 'experiments/WHOLE_IDEAL/aeage_BOTH_NOFILTER_cont/models/epoch_8.pth')
model.to('cpu')
model.eval()
model.mode = 'age'

pca_train = PCA(n_components=2)
pca_test = PCA(n_components=2)

model.eval()
model.to('cpu')

p_train = pca_train.fit_transform(model.encoder(torch.tensor(X_train)).detach().numpy())
p_test = pca_test.fit_transform(model.encoder(torch.tensor(X_test)).detach().numpy())

# p_train = pca_train.fit_transform(X_train)
# p_test = pca_test.fit_transform(X_test)

age_train = model.predictor(model.encoder(torch.tensor(X_train))).detach().numpy().flatten()
age_test = model.predictor(model.encoder(torch.tensor(X_test))).detach().numpy().flatten()

explained_variance_ratio_train = 1
explained_variance_ratio_test = 1

fig, axes = plt.subplots(2,2,figsize=(14,12))

sns.scatterplot(x=p_train[:, 0], y=p_train[:, 1], hue=y_train['method'], ax=axes[0][0], palette='viridis_r').set(title=f'Train Age')
sns.scatterplot(x=p_train[:, 0], y=p_train[:, 1], hue=age_train, ax=axes[0][1], palette='viridis_r', legend=False).set(title=f'Train Predicted Age')
sns.scatterplot(x=p_test[:, 0], y=p_test[:, 1], hue=y_test['method'], ax=axes[1][0], palette='viridis_r').set(title=f'Test Age')
sns.scatterplot(x=p_test[:, 0], y=p_test[:, 1], hue=age_test, ax=axes[1][1], palette='viridis_r', legend=False).set(title=f'Test Predicted Age')

norm = plt.Normalize(0, 40)
sm = plt.cm.ScalarMappable(cmap="viridis_r", norm=norm)
sm.set_array([])
fig.colorbar(sm, ax=axes[0][0])
fig.colorbar(sm, ax=axes[0][1])
fig.colorbar(sm, ax=axes[1][0])
fig.colorbar(sm, ax=axes[1][1])
plt.subplots_adjust(wspace=0.01, hspace=0.1)

for ax in axes.flat:
    #ax.set_xlabel('')  # Remove x-axis label
    #ax.set_ylabel('')  # Remove y-axis label
    ax.set_xticklabels([])  # Remove x-axis tick labels
    ax.set_yticklabels([])  # Remove y-axis tick labels
    #ax.set_title('')  # Optionally remove the title if needed

plt.show()

### PCAs for Original, encoded and restored data

In [None]:
config = load_config('experiments/WHOLE_IDEAL/aeage_BOTH_NOFILTER_cont/model_config.yaml')
model = ADAE(**config)
load_pretrained(model, 'experiments/WHOLE_IDEAL/aeage_BOTH_NOFILTER_cont/models/epoch_8.pth')
model.to('cpu')
model.eval()
model.mode = 'ae'

pca_orig = PCA(n_components=2)
pca_enc = PCA(n_components=2)
pca_res = PCA(n_components=2)

p_orig = pca_orig.fit_transform(X_train)
p_enc = pca_enc.fit_transform(model.encoder(torch.tensor(X_train)).detach().numpy())
p_res = pca_res.fit_transform(model.decoder(model.encoder(torch.tensor(X_train))).detach().numpy())

explained_variance_ratio_orig = pca_orig.explained_variance_ratio_
explained_variance_ratio_enc = pca_enc.explained_variance_ratio_
explained_variance_ratio_res = pca_res.explained_variance_ratio_

for feature in ('method','age', 'tissue', 'sex'):

    fig, axes = plt.subplots(1,3,figsize=(18,6))

    sns.scatterplot(x=p_orig[:, 0], y=p_orig[:, 1], hue=y_train[feature], ax=axes[0]).set(title=f'Original | C1: {explained_variance_ratio_orig[0]*100:.1f}% C2: {explained_variance_ratio_orig[1]*100:.1f}%')
    sns.scatterplot(x=p_enc[:, 0], y=p_enc[:, 1], hue=y_train[feature], ax=axes[1]).set(title=f'Encoded | C1: {explained_variance_ratio_enc[0]*100:.1f}% C2: {explained_variance_ratio_enc[1]*100:.1f}%')
    sns.scatterplot(x=p_res[:, 0], y=p_res[:, 1], hue=y_train[feature], ax=axes[2]).set(title=f'Restored | C1: {explained_variance_ratio_res[0]*100:.1f}% C2: {explained_variance_ratio_res[1]*100:.1f}%')
    fig.suptitle(f'PCA comparison: train')

    fig.show()

# Further analyses

### Categorial distribution analysis

1. Prep

In [None]:
import pandas as pd
from itertools import product
from tabulate import tabulate

def evaluate_category_counts(y_train, categories):
    # Create a DataFrame from the y_train data
    df = pd.DataFrame(y_train, columns=categories)

    # Print the number of samples per each category
    print("Number of samples per category:")
    for category in categories:
        category_counts = df[category].value_counts()
        print(f"\n{category}:")
        print(category_counts)

    # Print the number of samples per each category combination
    print("\nNumber of samples per category combination:")
    category_combinations = list(product(*[df[category].unique() for category in categories]))
    combination_counts = {}
    for combination in category_combinations:
        mask = True
        for category, value in zip(categories, combination):
            mask &= (df[category] == value)
        count = mask.sum()
        combination_counts[combination] = count

    # Create a DataFrame from the combination counts
    combination_df = pd.DataFrame.from_dict(combination_counts, orient='index', columns=['Count'])
    combination_df.index = pd.MultiIndex.from_tuples(combination_df.index, names=categories)

    # # Print the combination counts DataFrame
    # print(combination_df)

    # Print the combination counts DataFrame
    print(tabulate(combination_df, headers='keys', tablefmt='grid'))


import pandas as pd

import numpy as np
import pandas as pd

def create_balanced_sample(X, y, categories, samples_per_category):
    # Convert X to a NumPy array
    X = np.asarray(X)

    # Remove rows with missing values in any of the category columns
    y = y.dropna(subset=categories)

    # Get the unique category combinations
    unique_combinations = y[categories].drop_duplicates().values

    # Create a list to store the balanced samples
    balanced_samples = []

    # Iterate over each unique category combination
    for combination in tqdm(unique_combinations):
        # Find the indices of samples matching the current combination
        mask = np.all(y[categories].values == combination, axis=1)
        indices = np.where(mask)[0]

        # Determine the number of samples to select for the current combination
        num_samples = min(len(indices), samples_per_category)

        # Randomly select the samples from the current combination
        selected_indices = np.random.choice(indices, size=num_samples, replace=False)

        # Append the selected samples to the balanced_samples list
        balanced_samples.append(X[selected_indices])

    # Concatenate the balanced samples into a single array
    balanced_X = np.concatenate(balanced_samples, axis=0)

    # Create the corresponding balanced y DataFrame
    balanced_y = pd.DataFrame(np.repeat(unique_combinations, samples_per_category, axis=0)[:len(balanced_X)], columns=categories)

    return balanced_X, balanced_y

import numpy as np

def get_worst_predictions_subset(model, X, y_age, percentile=80):
    # Make predictions using the model
    y_pred = model(X).squeeze().detach().numpy()

    # Calculate the absolute errors between the predictions and the true values
    errors = np.abs(y_pred - y_age)

    # Calculate the percentile threshold value
    threshold = np.percentile(errors, percentile)

    # Find the indices of the samples whose errors are above the threshold
    worst_indices = np.where(errors >= threshold)[0]

    # Get the subset of X corresponding to the worst predictions
    worst_subset = X[worst_indices]

    return worst_subset


2. Test

In [None]:
X, y = load_data('mouse', 'WHOLE_CROSS')

In [None]:
categories = ['age', 'sex', 'method', 'tissue']
evaluate_category_counts(y, categories)

In [None]:
categories = ['age', 'sex', 'method', 'tissue']
X_train_balanced, y = create_balanced_sample(X_train, y_train, categories, 150)

### SHAP for Human (basic summary)

In [None]:
donor_tests = ['TSP3', 'TSP6', 'TSP9']

full_X, full_y, feature_names = load_data(
    dataset='sapiens', 
    tissue='TS_EES_hv', 
    filtered=True)

X_train, X_test, y_train, y_test = tt_split_setup(full_X, full_y, donor_test_list=donor_tests)
del full_X, full_y

import shap

categories = ['age', 'method', 'cell_ontology_class']
train_background, _ = create_balanced_sample(X_train, y_train, categories, 30)
print(train_background.shape)

test_data_split = resample(X_test, n_samples=2000, random_state=42)

train_background = torch.from_numpy(train_background).float()
test_data_split = torch.from_numpy(test_data_split).float()

del y_train, X_train, X_test, y_test

In [None]:
config = load_config(f'experiments/TS_EES_hv2/triple_TS_EES_hv/model_config.yaml')
model = ADAE(**config)
load_pretrained(model, f'experiments/TS_EES_hv2/triple_TS_EES_hv/models/epoch_138.pth')
model.to('cpu')
model.eval()
model.mode = 'age'

# Assuming model, train_background, and test_data_split are already defined
explainer = shap.DeepExplainer(model, train_background)
del train_background

In [None]:
# Wrap the shap_values computation with tqdm
shap_values = []
for i in tqdm(range(len(test_data_split)), desc="Computing SHAP values"):
    try:
        shap_value = explainer.shap_values(test_data_split[i:i+1])
    except:
        print(f'error {input}')
        try:
            shap_value = explainer.shap_values(test_data_split[i:i+1])
        except:
            print(f'-> error {input}')
            try:
                shap_value = explainer.shap_values(test_data_split[i:i+1])
            except:
                print(f'->-> error {input}')
                shap_value = explainer.shap_values(test_data_split[i:i+1])
    shap_values.append(shap_value)

shap_values = np.array(shap_values).reshape(2000, 56126)

np.save(f'HUMAN_SHAP_EES_hv.npy', shap_values)

shap.summary_plot(shap_values, test_data_split, feature_names=feature_names, show=False)
plt.title(f'SHAP Summary for Human')
plt.show()

## SHAP & Correlations for Cell Types: evaluation and analysis

### Evaluate

In [None]:
import shap
def shap_cell_type(train_background, test_data_split, tissue):   

    #X_train, y_train = resample(X_train, y_train, n_samples=128000, random_state=42)

    train_background = torch.from_numpy(train_background).float()
    test_data_split = torch.from_numpy(test_data_split).float()

    config = load_config('/Users/mindblaze/Desktop/Thesis/clocks/experiments/WHOLE_CROSS/triple_BOTH_NOFILTER/model_config.yaml')
    model = ADAE(**config)
    load_pretrained(model, f'/Users/mindblaze/Desktop/Thesis/clocks/experiments/WHOLE_CROSS/triple_BOTH_NOFILTER/models/epoch_235.pth')
    model.to('cpu')
    model.eval()
    model.mode = 'age'

    explainer = shap.DeepExplainer(model, train_background)

    shap_values_train = explainer.shap_values(test_data_split)

    np.save(f'SHAP_test_cell_type/whole_{tissue}.npy', shap_values_train)

    shap.summary_plot(shap_values_train, test_data_split, feature_names=feature_names, show=False)
    plt.title(f'SHAP Summary for {tissue}')
    plt.show()

    return shap_values_train

In [None]:
completed = ('B cell', 'T cell', 'macrophage', 'granulocyte')

categories = ['age', 'sex', 'method', 'tissue']
train_background, _ = create_balanced_sample(X_train, y_train, categories, 100)

# Dictionary to store aggregated SHAP values for each cell type
shap_values_dict = {}

# Perform SHAP analysis for each cell type
for cell_type in y_test['cell_ontology_class'].unique():
    if y_test[y_test['cell_ontology_class'] == cell_type].shape[0] < 300:
        print(f"Excluded: {cell_type} (N = {y_test[y_test['cell_ontology_class'] == cell_type].shape[0]})")
        continue
    if cell_type in completed:
        print(f"Excluded: {cell_type} (already completed)")
        continue

    print(X_test[y_test.cell_ontology_class == cell_type].shape[0], cell_type)
    test_data_split = resample(X_test[y_test.cell_ontology_class == cell_type], n_samples=300, random_state=42)

    # Compute SHAP values for the specific cell type
    shap_values = shap_cell_type(train_background, test_data_split, cell_type)
    shap_values_dict[cell_type] = shap_values

### Analyse

#### Basic Plots

In [None]:
_, _, feature_names = load_data(
        dataset='mouse', 
        tissue='BOTH_NOFILTER', 
        filtered=True,
        normalize=False)

In [None]:
import os
import numpy as np

# Directory containing the .npy files
directory = 'SHAP_test_cell_type/'

# Initialize an empty dictionary to hold the SHAP values
shap_values_dict = {}

# List all .npy files in the directory
for filename in os.listdir(directory):
    if filename.endswith(".npy"):
        # Extract the cell type from the filename
        cell_type = filename[len("whole_"):-len(".npy")]
        # Construct the full file path
        file_path = os.path.join(directory, filename)
        # Load the .npy file and store it in the dictionary
        shap_values_dict[cell_type] = np.load(file_path)

In [None]:
shap_dfs = {k: pd.DataFrame(v, columns=feature_names) for k, v in shap_values_dict.items()}

# Calculate mean absolute SHAP values for each feature and cell type
mean_shap_values = {k: df.abs().mean() for k, df in shap_dfs.items()}

# Convert to a single DataFrame
mean_shap_df = pd.DataFrame(mean_shap_values)

In [None]:
# Convert mean SHAP values DataFrame for heatmap plotting
mean_shap_matrix = mean_shap_df.T

# Plot the heatmap
plt.figure(figsize=(15, 10))
sns.heatmap(mean_shap_matrix, cmap='viridis')
plt.title('Heatmap of Mean Absolute SHAP Values Across Cell Types')
plt.xlabel('Features')
plt.ylabel('Cell Type')
plt.show()

#### Normalized Plot

In [None]:
N = 20000  # Number of top features to visualize

# Convert SHAP values to DataFrame for easier processing
shap_dfs = {k: pd.DataFrame(v, columns=feature_names) for k, v in shap_values_dict.items()}

# Calculate mean absolute SHAP values for each feature and cell type
mean_shap_values = {k: df.abs().mean() for k, df in shap_dfs.items()}

# Find the top N features by mean absolute SHAP value across all cell types
top_features = pd.concat(mean_shap_values, axis=1).mean(axis=1).nlargest(N).index

# Filter SHAP values to include only the top N features
filtered_shap_values_dict = {k: df[top_features] for k, df in shap_dfs.items()}

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Combine SHAP values for all cell types into a single DataFrame
combined_shap_df = pd.concat(filtered_shap_values_dict, axis=0)
combined_shap_df.index = pd.MultiIndex.from_tuples(combined_shap_df.index, names=['Cell Type', 'Sample'])

# Calculate the mean SHAP value for each feature within each cell type
mean_shap_matrix = combined_shap_df.groupby(level=0).mean()

# Plot the heatmap with clustering
g = sns.clustermap(
    mean_shap_matrix, 
    cmap='coolwarm', 
    figsize=(30, 20), 
    standard_scale=1, 
    method='average', 
    metric='euclidean'
)

# Hide the top dendrogram
g.ax_col_dendrogram.set_visible(False)

# Hide the left dendrogram
g.ax_row_dendrogram.set_visible(False)

# Move y labels to the left
g.ax_heatmap.yaxis.set_label_position("left")
g.ax_heatmap.yaxis.tick_left()

# Adjust the position of the colorbar to the right of the plot
g.cax.set_position([.88, .2, .02, .45])

# Remove the title
g.ax_heatmap.set_title("")

# Remove x labels
g.ax_heatmap.set_xticklabels([])

plt.show()


#### Significantly high/low SHAP genes

1. Find cell type-specific genes with extreme values (specifically small and large)

In [None]:
from scipy.stats import zscore

# Assuming mean_shap_matrix is a DataFrame with genes as columns and cell types as rows

#normalized_shap_matrix = mean_shap_matrix.apply(normalize_min_max, axis=1)
normalized_shap_matrix = mean_shap_matrix.apply(zscore, axis=0)

# Step 2: Define a threshold for low SHAP values (e.g., z-score < -1.96 for significance at 5% level)
low_shap_threshold = -7
high_shap_threshold = 7

# Function to identify low SHAP values compared to other cell types
def find_significant_low_shap_genes(normalized_shap_matrix, low_threshold, high_threshold):
    significant_low_shap_genes = {}
    significant_high_shap_genes = {}
    for cell_type in normalized_shap_matrix.index:
        low_genes = normalized_shap_matrix.loc[cell_type][normalized_shap_matrix.loc[cell_type] < low_threshold].index.tolist()
        high_genes = normalized_shap_matrix.loc[cell_type][normalized_shap_matrix.loc[cell_type] > high_threshold].index.tolist()
        significant_low_shap_genes[cell_type] = low_genes
        significant_high_shap_genes[cell_type] = high_genes
    return significant_low_shap_genes, significant_high_shap_genes

# Get the genes with significantly low SHAP values
low_genes, high_genes = find_significant_low_shap_genes(normalized_shap_matrix, low_shap_threshold, high_shap_threshold)

# Print the genes with significantly low SHAP values for each cell type
for cell_type in low_genes:
    print(f"Cell Type: {cell_type}, Genes with significantly low SHAP values: {low_genes[cell_type]}")
    print(f"Cell Type: {cell_type}, Genes with significantly high SHAP values: {high_genes[cell_type]}")

In [None]:
all_low_genes = set(gene for genes in low_genes.values() for gene in genes)
all_high_genes = set(gene for genes in high_genes.values() for gene in genes)
extreme_genes = all_low_genes.union(all_high_genes)

# Create DataFrames for the low and high SHAP genes
low_shap_df = normalized_shap_matrix[list(all_low_genes)]
high_shap_df = normalized_shap_matrix[list(all_high_genes)]
extreme_shap_df = normalized_shap_matrix[list(extreme_genes)]

In [None]:
cell_types = [
    "hepatocyte",
    "keratinocyte",
    "pancreatic A cell",
    "enterocyte of epithelium of large intestine",
    "precursor B cell",
    "alveolar macrophage",
    "large intestine goblet cell",
    "chondrocyte",
    "pancreatic acinar cell",
    "kidney collecting duct principal cell"
]

filt_low_genes = set(gene for cell_type, genes in low_genes.items() for gene in genes if cell_type in cell_types)
filt_high_genes = set(gene for cell_type, genes in high_genes.items() for gene in genes if cell_type in cell_types)
filt_extreme_genes = filt_low_genes.union(filt_high_genes)

# Create DataFrames for the low and high SHAP genes
filt_low_shap_df = normalized_shap_matrix[list(filt_low_genes)]
filt_high_shap_df = normalized_shap_matrix[list(filt_high_genes)]
filt_extreme_shap_df = normalized_shap_matrix[list(filt_extreme_genes)]

2. GO analysis (enrichr)

In [None]:
import gseapy
from gseapy.plot import barplot, dotplot
def plot_go_aging(cell_type):

    positive = high_genes[cell_type]
    negative = low_genes[cell_type]
    try:
        enr_AGE_up = gseapy.enrichr(gene_list=positive,
            gene_sets=['Aging_Perturbations_from_GEO_up'],
            organism='Mouse', 
            cutoff=0.05 
        )
        barplot(enr_AGE_up.res2d, title=f'GEO Aging Perturbations UP: {cell_type} (positive)', color='r',)
        plt.show()
    except:
        print(f"LOL no {cell_type}")
        pass
    
    try:
        enr_AGE_down = gseapy.enrichr(gene_list=positive,
            gene_sets=['Aging_Perturbations_from_GEO_down'],
            organism='Mouse',
            cutoff=0.05
        )
        barplot(enr_AGE_down.res2d, title=f'GEO Aging Perturbations DOWN: {cell_type} (positive)', color='b',)
        plt.show()
    except:
        print(f"LOL no {cell_type}")
        pass
    
    

    try:
        enr_AGE_up = gseapy.enrichr(gene_list=negative,
            gene_sets=['Aging_Perturbations_from_GEO_up'],
            organism='Mouse',
            cutoff=0.05
        )
        barplot(enr_AGE_up.res2d, title=f'GEO Aging Perturbations UP: {cell_type} (negative)', color='r',)
        plt.show()
    except:
        print(f"LOL no {cell_type}")
        pass

    try:
        enr_AGE_down = gseapy.enrichr(gene_list=positive,
            gene_sets=['Aging_Perturbations_from_GEO_down'],
            organism='Mouse', 
            cutoff=0.05 
        )
        barplot(enr_AGE_down.res2d, title=f'GEO Aging Perturbations DOWN: {cell_type} (negative)', color='b',)
        plt.show()
    except:
        print(f"LOL no {cell_type}")
        pass

for cell_type in cell_types:
    plot_go_aging(cell_type)

3. Plot with only extreme genes

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Combine SHAP values for all cell types into a single DataFrame
combined_shap_df = pd.concat(filtered_shap_values_dict, axis=0)
combined_shap_df.index = pd.MultiIndex.from_tuples(combined_shap_df.index, names=['Cell Type', 'Sample'])

# Calculate the mean SHAP value for each feature within each cell type
mean_shap_matrix = combined_shap_df.groupby(level=0).mean()

# Plot the heatmap with clustering
g = sns.clustermap(
    extreme_shap_df, 
    cmap='coolwarm', 
    figsize=(30, 20), 
    standard_scale=None, 
    method='average', 
    metric='euclidean'
)

# Hide the top dendrogram
g.ax_col_dendrogram.set_visible(False)

# Hide the left dendrogram
g.ax_row_dendrogram.set_visible(False)

# Move y labels to the left
g.ax_heatmap.yaxis.set_label_position("left")
g.ax_heatmap.yaxis.tick_left()

# Adjust the position of the colorbar to the right of the plot
g.cax.set_position([.88, .2, .02, .45])

# Remove the title
g.ax_heatmap.set_title("")

# Remove x labels
g.ax_heatmap.set_xticklabels([])

plt.show()

4. GO analysis (GProfiler)

In [None]:
from gprofiler import GProfiler
import pandas as pd

# Initialize gprofiler
gp = GProfiler(return_dataframe=True)

# Function to perform GO analysis on gene sets
def perform_go_analysis_for_genes(gene_set, cell_type, category):
    if gene_set:  # Ensure there are genes to query
        results = gp.profile(organism='mmusculus', query=gene_set)
        results['Cell Type'] = cell_type
        results['Category'] = category
        return results
    else:
        return pd.DataFrame()

# Prepare lists to collect all results
all_low_results = []
all_high_results = []

# Perform GO analysis for low and high SHAP genes for each cell type
for cell_type in cell_types:
    # Low SHAP genes
    low_shap_genes = low_genes[cell_type]
    low_results = perform_go_analysis_for_genes(low_shap_genes, cell_type, 'Low SHAP')
    all_low_results.append(low_results)

    # High SHAP genes
    high_shap_genes = high_genes[cell_type]
    high_results = perform_go_analysis_for_genes(high_shap_genes, cell_type, 'High SHAP')
    all_high_results.append(high_results)

# Combine all results
combined_low_results = pd.concat(all_low_results, ignore_index=True)
combined_high_results = pd.concat(all_high_results, ignore_index=True)

# Save the results to CSV files
combined_low_results.to_csv('low_shap_go_results_by_cell_type.csv', index=False)
combined_high_results.to_csv('high_shap_go_results_by_cell_type.csv', index=False)


## GO ontology analysis based on significant SHAP values (for whole data)

In [None]:
import numpy as np
import pandas as pd

# Assuming shap_values_train and test_data_split are already loaded

def get_gene_lists(shap_values_train, feature_names, test_data_split):
    # Convert SHAP values and feature values to numpy arrays for easier manipulation
    shap_values_array = np.array(shap_values_train)
    feature_values_array = np.array(test_data_split)

    # Initialize a list to store the names of upregulated genes
    upregulated_genes = []
    downregulated_genes = []
    uncertain_genes = []

    # Define the threshold for high expression (e.g., 75th percentile)
    threshold_percentile = 75

    # Loop over each feature (gene)
    for i in range(feature_values_array.shape[1]):
        # Extract the SHAP values and feature values for the current gene
        shap_values_gene = shap_values_array[:, i]
        feature_values_gene = feature_values_array[:, i]
        
        # Identify the cells where the gene expression is high
        high_feature_threshold = np.percentile(feature_values_gene, threshold_percentile)
        high_feature_indices = feature_values_gene > high_feature_threshold
        
        # Extract the SHAP values for these high expression cells
        shap_values_high_feature = shap_values_gene[high_feature_indices]
        
        # Check if the majority of SHAP values are positive in high expression cells
        if np.mean(shap_values_high_feature > 0) > 0.5:
            upregulated_genes.append(feature_names[i])
        elif np.mean(shap_values_high_feature < 0) > 0.5:
            downregulated_genes.append(feature_names[i])
        else:
            uncertain_genes.append(feature_names[i])


    # Calculate the sum of absolute SHAP values for each feature
    feature_sums = np.sum(np.abs(shap_values_train), axis=0)

    # Get the indices that would sort the feature sums in descending order
    feature_order = np.argsort(feature_sums)[::-1]

    # Get the feature names in the right order
    ordered_feature_names = [feature_names[i] for i in feature_order]

    # Get the feature sums in the right order
    ordered_feature_sums = feature_sums[feature_order]

    # Create lists to store the ordered upregulated and downregulated genes
    ordered_upregulated_genes = []
    ordered_downregulated_genes = []

    # Create lists to store the ordered feature sums for upregulated and downregulated genes
    ordered_upregulated_sums = []
    ordered_downregulated_sums = []

    # Iterate over the ordered feature names and sums
    for name, sum_value in zip(ordered_feature_names, ordered_feature_sums):
        if name in upregulated_genes:
            ordered_upregulated_genes.append(name)
            ordered_upregulated_sums.append(sum_value)
        elif name in downregulated_genes:
            ordered_downregulated_genes.append(name)
            ordered_downregulated_sums.append(sum_value)

    return ordered_feature_names, ordered_upregulated_genes, ordered_downregulated_genes, ordered_feature_sums, ordered_upregulated_sums, ordered_downregulated_sums


In [None]:
def get_gene_lists_for_data(dataset, data_name, donor_tests, shap_path, balanced_num=300, gene_number_threshold=300, parameter_name=None, parameter_value=None, y_parameter_name=None, y_parameter_value=None):
    full_X, full_y, feature_names = load_data(
        dataset=dataset, 
        tissue=data_name, 
        filtered=True,
        normalize=False)

    if parameter_name is not None:
        full_X = full_X[full_y[parameter_name] == parameter_value]
        full_y = full_y[full_y[parameter_name] == parameter_value]

    _, X_test, _, y_test = tt_split_setup(full_X, full_y, donor_test_list=donor_tests)
    del full_X, full_y

    import shap

    if y_parameter_name is not None:
        X_test = X_test[y_test[y_parameter_name] == y_parameter_value]

    shap_values = np.load(shap_path)
    
    test_data_split = resample(X_test, n_samples=shap_values.shape[0], random_state=42)
    del X_test, y_test

    # return get_gene_lists(shap_values_deconf, feature_names)

    feature_names, upregulated_genes, downregulated_genes, feature_sums, upregulated_sums, downregulated_sums = get_gene_lists(shap_values, feature_names, test_data_split)

    upregulated_sums = upregulated_sums / upregulated_sums[0]
    downregulated_sums = downregulated_sums / downregulated_sums[0]

    upregulated_sums = upregulated_sums[:gene_number_threshold]
    downregulated_sums = downregulated_sums[:gene_number_threshold]
    upregulated_genes = upregulated_genes[:gene_number_threshold]
    downregulated_genes = downregulated_genes[:gene_number_threshold]

    return upregulated_genes, downregulated_genes
    

In [None]:
donor_tests_chosen_tissues = ['1-M-62', '3_9_M', '3-M-9', '3_38_F', '3_38_F/3_39_F', '3_39_F', '18_47_F', '18_53_M', '18-M-53', '21_55_F', '21-F-55', '24_61_M', '24-M-61', '30-M-5']
donor_tests_whole = ['1-M-62', '3_56_F', '3-M-5/6', '18_53_M', '21_55_F', '24_60_M', '30-M-4']

upregulated_whole, downregulated_whole = get_gene_lists_for_data('mouse', 'BOTH_NOFILTER', donor_tests_whole,
                                                        shap_path = 'whole_Lung.npy',
                                                        y_parameter_name = 'tissue',
                                                        y_parameter_value = 'Lung')

upregulated, downregulated = get_gene_lists_for_data('mouse', 'ALL_METHODS', donor_tests_chosen_tissues,
                                                        shap_path = 'SHAP_X_Lung.npy',
                                                        balanced_num = 325,
                                                        parameter_name = 'tissue',
                                                        parameter_value = 'Lung')

upregulated_fr, downregulated_fr = get_gene_lists_for_data('mouse', 'ALL_METHODS', donor_tests_chosen_tissues,
                                                        shap_path = 'SHAP_X_Lung_fr.npy',
                                                        balanced_num = 325,
                                                        parameter_name = 'tissue',
                                                        parameter_value = 'Lung')

In [None]:
import gseapy
from gseapy.plot import barplot, dotplot

def all_enr_analysis(organism, outdir_parent_name, upregulated_age_markers, downregulated_age_markers):  

    results = dict()

    analyses_up = {
        'GO_Biological_Process_2023': 'enr_GOBP_up',
        'GO_Molecular_Function_2023': 'enr_GOMF_up',
        'GO_Cellular_Component_2023': 'enr_GOCC_up',
        'Aging_Perturbations_from_GEO_up': 'enr_Aging_GEO_up',
        'RNAseq_Automatic_GEO_Signatures_Mouse_Up': 'enr_GEO_up'
    }  
    analyses_down = {
        'GO_Biological_Process_2023': 'enr_GOBP_down',
        'GO_Molecular_Function_2023': 'enr_GOMF_down',
        'GO_Cellular_Component_2023': 'enr_GOCC_down',
        'Aging_Perturbations_from_GEO_down': 'enr_Aging_GEO_down',
        'RNAseq_Automatic_GEO_Signatures_Mouse_Down': 'enr_GEO_down'
    }  
    for key, value in analyses_up.items():
        enr_res = gseapy.enrichr(gene_list=upregulated_age_markers,
            gene_sets=[key],
            organism=organism, 
            outdir=f'SHAPS/{outdir_parent_name}/{value}',
            cutoff=0.05 
        )
        results[f"{key} UP"] = enr_res
        barplot(enr_res.res2d, title=f"{key} UP", color='r', ofname=f'SHAPS/plots/{outdir_parent_name}_{value}.png')

    for key, value in analyses_down.items():
        enr_res = gseapy.enrichr(gene_list=downregulated_age_markers,
            gene_sets=[key],
            organism=organism, 
            outdir=f'SHAPS/{outdir_parent_name}/{value}',
            cutoff=0.05 
        )
        results[f"{key} DOWN"] = enr_res
        barplot(enr_res.res2d, title=f"{key} DOWN", color='b', ofname=f'SHAPS/plots/{outdir_parent_name}_{value}.png')

    return results

In [None]:
all_enr_analysis('Mouse', 'lung', upregulated, downregulated)

Network plot sample

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

# Function to create the network plot
def create_network_plot(df, top_terms=10):
    G = nx.Graph()
    
    # Select top GO terms based on adjusted p-value
    top_df = df.head(top_terms)
    
    for _, row in top_df.iterrows():
        term = row['Term']
        genes = row['Genes'].split(';')
        
        for gene in genes:
            G.add_edge(term, gene)
            
    pos = nx.spring_layout(G)
    
    plt.figure(figsize=(14, 10))
    nx.draw(G, pos, with_labels=True, node_size=1000, node_color="skyblue", font_size=10, font_weight="bold", edge_color="gray")
    
    plt.title("Network of Top Enriched GO Terms")
    plt.show()

# Create the network plot
create_network_plot(results['GO_Biological_Process_2023 UP'].results, top_terms=10)