# the VAE model suggested by the paper on the Tabula Sapiens Pseudobulks Dataset

## importing and general setup

In [21]:
import torch
import matplotlib.pyplot as plt
from plotnine import *
from sklearn.metrics import r2_score
import torch.nn as nn
import numpy as np
import pandas as pd
import scanpy as sc

import optuna
import polars as pl
from dataset import TS_Compressed_Dataset

## the model

In [22]:
class GeneEmbedding(nn.Module):
    def __init__(self, n_genes, gene_emb_dim, gene_list, gene_emb_init = None):
        super(GeneEmbedding, self).__init__()

        self.n_genes = n_genes
        self.gene_emb_dim = gene_emb_dim
        self.genes = gene_list

        """
        model the gene embedding as gaussian distributions computed from the mean embeddings and sd embeddings
        """
        if gene_emb_init is not None:
            self.emb_mu = nn.Embedding.from_pretrained(gene_emb_init, freeze = False)
        else:
            # init as [0, 1], as the dirichlet prior is in [0, 1]
            self.emb_mu = nn.Embedding.from_pretrained(
                torch.rand((n_genes, gene_emb_dim)), freeze = False
            )

        self.emb_log_sigma = nn.Embedding.from_pretrained(
            torch.full((n_genes, gene_emb_dim), np.log(0.5)), freeze = False
        )


    def get_emb_table(self):
        emb_df = pd.DataFrame(self.emb_mu.weight.detach().cpu(),
                              index = self.genes,
                              columns = [f'FACT_EMB_{i}' for i in range(self.gene_emb_dim)]
                              )
        emb_df.index.name = 'gene_id'
        return emb_df


    def get_log_sigma_table(self):
        emb_df = pd.DataFrame(self.emb_log_sigma.weight.detach().cpu(),
                              index = self.genes,
                              columns = [f'FACT_EMB_{i}' for i in range(self.gene_emb_dim)]
                              )
        emb_df.index.name = 'gene_id'
        return emb_df


    def get_shape(self):
        return (self.n_genes, self.gene_emb_dim)



In [23]:
class MODEL(nn.Module):
    def __init__(self, gene_emb, n_samples, sample_emb_dim, NUM_LAYERS, sample_emb_init = None):
        super().__init__()
        #initialize the gene embedding
        self.gene_emb = gene_emb
        #create sample embedding
        if sample_emb_init is not None:
            self.sample_emb = nn.Embedding.from_pretrained(sample_emb_init, freeze=False)
        else:
            self.sample_emb = nn.Embedding(n_samples, sample_emb_dim)

        joint_emb_dim = sample_emb_dim + self.gene_emb.gene_emb_dim

        #the decoder
        self.model = nn.Sequential()
        for i in range(NUM_LAYERS - 1):
            self.model.add_module(f"layer_{i}", nn.Linear(joint_emb_dim, joint_emb_dim))
            self.model.add_module(f"relu_{i}", nn.ReLU())
        self.model.add_module(f"layer_{NUM_LAYERS - 1}", nn.Linear(joint_emb_dim, 1))


    def reparameterization(self, mean, sd):
        epsilon = torch.randn_like(sd)    # sampling epsilon
        z = mean + sd * epsilon           # reparameterization trick
        return z


    def get_emb(self, idx):
        emb_mu = self.gene_emb.emb_mu.weight
        emb_sigma = self.gene_emb.emb_log_sigma.weight.exp()
        emb_z = self.reparameterization(emb_mu[idx, :], emb_sigma[idx, :])
        return emb_z


    def forward(self, gene_index, sample_index):
        #get the gene embeddings and sample embeddings from indices
        gene_emb_batch = self.get_emb(gene_index)
        sample_emb_batch = self.sample_emb(sample_index)
        joint_emb = torch.cat((gene_emb_batch, sample_emb_batch), dim = 1)

        #use the joint embeddings to predict the score
        pred = self.model(joint_emb)
        return pred.squeeze()



In [24]:
#load the dataframe that contains one to one corresponding of gene id and gene index
gene_index = pd.read_csv("gene_index.csv")
path = "TabulaSapiens_pb_normalized.h5ad"

dataset = TS_Compressed_Dataset(gene_index, path)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "mps"

gene_list = gene_index.gene_id.to_list()
n_genes = len(gene_list) #the number of total genes, for TS is around 6000
n_samples = dataset.dataset["sample_idx"].unique().shape[0] #the number of samples

In [25]:
def define_model(trial):
    gene_emb_dim = trial.suggest_int("gene_emb_dim", 128, 256)
    sample_emb_dim = trial.suggest_int("sample_emb_dim", 128, 164)
    num_layers = trial.suggest_int("num_layers", 3, 6)
    sample_emb_init = torch.from_numpy(dataset.compute_sample_init_pca(dim=sample_emb_dim).astype('float32'))

    gene_emb = GeneEmbedding(n_genes,
                            gene_emb_dim,
                            gene_list,
                            gene_emb_init = torch.rand(n_genes, gene_emb_dim) - 0.5,
                            )


    model= MODEL(gene_emb,
                n_samples,
                sample_emb_dim,
                num_layers,
                sample_emb_init = sample_emb_init
                )

    return model


In [26]:
def train_epoch(model, opt, loss, BATCH_SIZE, dataset):
    for g_idx, s_idx, target in dataset.get_batches(BATCH_SIZE, 'train'):
        model.to(device)
        opt.zero_grad()
        mu = model(g_idx.to(device, non_blocking=True), s_idx.to(device, non_blocking=True))
        target = target.to(device, non_blocking=True)

        mse = loss(mu, target)
        mse.backward()
        opt.step()




In [27]:
def eval_r2(model, dataset):
    with torch.no_grad():
        model.eval()
        test_gene_idx = torch.from_numpy(np.array(dataset.test_table["gene_idx"]))
        test_sample_idx = torch.from_numpy(np.array(dataset.test_table["sample_idx"]))
        test_target = torch.from_numpy(np.array(dataset.test_table["score"]))
        mu = model(test_gene_idx.to(device, non_blocking=True), test_sample_idx.to(device, non_blocking=True))
        val_pred = mu.detach().cpu().numpy()
        val_target = test_target.detach().cpu().numpy()
        model.train()
    return r2_score(val_target, val_pred)

In [29]:
def objective(trial):
    model = define_model(trial)

    epochs = trial.suggest_int("epochs", 10, 40, step=2)
    batch_size = trial.suggest_int("batch_size", 500, 3000)
    lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)

    loss = nn.MSELoss(reduction="mean")
    opt = torch.optim.Adam(model.parameters(), lr = lr)

    for epoch in range(epochs):
        train_epoch(model, opt, loss, batch_size, dataset)
        test_r2 = eval_r2(model, dataset)
        trial.report(test_r2, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return test_r2

In [None]:
study = optuna.create_study(direction="maximize",
                            sampler=optuna.samplers.TPESampler(),
                            pruner=optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=10))
study.optimize(objective, n_trials=80)

In [12]:
trial = study.best_trial

In [13]:
trial

FrozenTrial(number=5, state=TrialState.COMPLETE, values=[0.2634236128235621], datetime_start=datetime.datetime(2024, 1, 19, 23, 13, 1, 544116), datetime_complete=datetime.datetime(2024, 1, 19, 23, 20, 17, 593133), params={'gene_emb_dim': 256, 'sample_emb_dim': 164, 'num_layers': 3, 'epochs': 26, 'batch_size': 1000, 'lr': 0.001500448427487926}, user_attrs={}, system_attrs={}, intermediate_values={0: 0.12378107219811518, 1: 0.16998195973091423, 2: 0.20459919216446731, 3: 0.20182239226753584, 4: 0.22310728990017947, 5: 0.19709450815781115, 6: 0.23848472583931135, 7: 0.19549991225558838, 8: 0.2217207901312387, 9: 0.24973191560933583, 10: 0.22235081518008548, 11: 0.2611747011608512, 12: 0.22903808949077953, 13: 0.2364694721583065, 14: 0.26619711387640843, 15: 0.26871686721773813, 16: 0.2913349473263752, 17: 0.27012272445508345, 18: 0.28642319677329886, 19: 0.3082091615504078, 20: 0.24299475203081178, 21: 0.1619381531810068, 22: 0.23979524310584532, 23: 0.23179623319745912, 24: -0.3997103720

## The visualization

In [15]:
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice


In [None]:
plot_optimization_history(study)

In [None]:
plot_param_importances(study)

In [None]:
plot_slice(study)