# A Normal Variational Autoencoder for the compressed Tabula Sapiens Dataset

## The Model

In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict
import polars as pl
from dataset import TS_Compressed_VAE_Dataset
import numpy as np
from sklearn.metrics import r2_score
import pandas as pd

In [2]:
class MLP(nn.Module):
    """
    A multilayer perceptron with ReLU activations and optional BatchNorm.

    Careful: if activation is set to ReLU, ReLU is only applied to the second half of NN outputs! 
            ReLU is applied to standard deviation not mean
    """

    def __init__(
        self,
        sizes,
        batch_norm=True,
        last_layer_act="linear",
    ):
        super(MLP, self).__init__()
        layers = []
        for s in range(len(sizes) - 1):
            layers += [
                nn.Linear(sizes[s], sizes[s + 1]),
                nn.BatchNorm1d(sizes[s + 1])
                if batch_norm and s < len(sizes) - 2
                else None,
                nn.ReLU(),
            ]

        layers = [l for l in layers if l is not None][:-1]
        
        self.activation = last_layer_act
        if self.activation == "linear":
            pass
        elif self.activation == "ReLU":
            self.relu = nn.ReLU()
        else:
            raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")

        
        layers_dict = OrderedDict(
                {str(i): module for i, module in enumerate(layers)}
            )

        self.network = nn.Sequential(layers_dict)

    def forward(self, x):
        if self.activation == "ReLU":
            x = self.network(x)
            dim = x.size(1) // 2
            return torch.cat((x[:, :dim], self.relu(x[:, dim:])), dim=1)
        return self.network(x)

In [3]:
class VAE(nn.Module):
    def __init__(
        self,
        hparams: dict()
    ):
        super(VAE, self).__init__()
        self.hparams = hparams
        self.batch_norm = hparams["batch_norm"]
        self.Variational = hparams["Variational"]

        if self.Variational:
            self.encoder_sizes = [self.hparams["dim"]]+[self.hparams["encoder_width"]]* self.hparams["encoder_depth"]+ [self.hparams["emb_dim"]*2]
            self.decoder_sizes = [self.hparams["emb_dim"]]+[self.hparams["decoder_width"]]* self.hparams["decoder_depth"]+ [self.hparams["dim"]]
            self.encoder = MLP(self.encoder_sizes, batch_norm=self.batch_norm, last_layer_act="ReLU")
            self.decoder = MLP(self.decoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")

        else:
            self.encoder_sizes = [self.hparams["dim"]]+[self.hparams["encoder_width"]]* self.hparams["encoder_depth"]+ [self.hparams["emb_dim"]]
            self.decoder_sizes = [self.hparams["emb_dim"]]+[self.hparams["decoder_width"]]* self.hparams["decoder_depth"]+ [self.hparams["dim"]]
            self.encoder = MLP(self.encoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")
            self.decoder = MLP(self.decoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")

    def reparametrize(self, mu, sd):
        epsilon = torch.randn_like(sd)    
        z = mu + sd * epsilon 
        return z

    def get_emb(self, x):
        """
        get the embedding of given expression profiles of genes
        @param x: should be the shape [batch_size, hparams["dim]]
        """
        return self.encoder(x)[:, 0:self.hparams["emb_dim"]]
        
    def forward(self, x):
        """
        get the reconstruction of the expression profile of a gene
        @param x: should be the shape [batch_size, hparams["dim]]
        """
        latent = self.encoder(x)
        if self.Variational:
            mu = latent[:, 0:self.hparams["emb_dim"]]
            sd = latent[:, self.hparams["emb_dim"]:]
            assert mu.shape == sd.shape
            latent = self.reparametrize(mu, sd)
        reconstructed = self.decoder(latent)
        return reconstructed

## The Dataset and Hyperparameters 

In [4]:
path = "TabulaSapiens_CO_compressed.h5ad"
dataset = TS_Compressed_VAE_Dataset(path)

In [5]:
hparams = {
    "dim": dataset.num_celllines_in_assay(),
    "encoder_width": 128,
    "encoder_depth": 4,
    "emb_dim": 128,
    "decoder_width": 128,
    "decoder_depth": 4,
    "batch_norm": True,
    "Variational": True,
    
}
config = {
    "epochs": 100,
    "batch_size": 1e3,
    "lr": 1e-3
}

device = "cpu"

## The Training and Testing

In [6]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkemingzhang[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
def train_epoch(model, opt, loss, batch_size, dataset, epoch):
    batch_ct = epoch * dataset.get_num_batches_per_epoch(batch_size)
    cumu_loss = 0
    for _, target in dataset.get_batches(batch_size, 'train'):
        model.to(device)
        opt.zero_grad()
        target = target.to(device, non_blocking=True)
        pred = model(target)
        
        mse = loss(pred, target)
        cumu_loss += mse.item()
        mse.backward()
        opt.step()

        batch_ct += 1
        wandb.log({"batch_loss": mse.item(), "batch_ct": batch_ct})


    #torch.mps.empty_cache()
    return cumu_loss / dataset.get_num_batches_per_epoch(batch_size)


In [8]:
def eval_r2(model, dataset):
    model.eval()
    with torch.no_grad():
        target = torch.from_numpy(dataset.test_table[:, 1:].to_numpy().astype('float32')).to(device, non_blocking=True)
        pred = model(target).detach().cpu().numpy()
        target = target.detach().cpu().numpy()
    model.train()
    return r2_score(target, pred)

In [9]:
def train(config=config):
    with wandb.init(project="vae_ts", config = config):
        #this config will be set by Sweep Controller
        config = wandb.config

        model = VAE(hparams)

        loss = nn.MSELoss(reduction="mean")
        opt = torch.optim.Adam(model.parameters(), lr = config.lr)


        wandb.define_metric("batch_loss", step_metric="batch_ct")
        wandb.define_metric("avg_loss", step_metric="epoch")
        wandb.define_metric("test_r2", step_metric="epoch")

        for epoch in range(config.epochs):
            avg_loss = train_epoch(model, opt, loss, config.batch_size, dataset, epoch)
            wandb.log({"avg_loss": avg_loss, "epoch": epoch})
            test_r2 = eval_r2(model, dataset)
            wandb.log({"test_r2": test_r2, "epoch":epoch})


        #save the model in the exchangable ONNX format
        target = torch.from_numpy(dataset.test_table[:, 1:].to_numpy().astype('float32')).to(device, non_blocking=True)
        torch.onnx.export(model, target, "model.onnx")
        wandb.save("model.onnx")

    return model

In [10]:
model = train(config)

  assert mu.shape == sd.shape


VBox(children=(Label(value='0.768 MB of 0.768 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
avg_loss,█▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▄▃▃▃▂▂▃▄▂▂▂▂▁▂▂▂▁▁▂▂
batch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch_loss,▃▂▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_r2,▆▃▆▁▇▆▇█▇█▇▁▇▇█▇▇█████▇██████████▇███▇██

0,1
avg_loss,0.37985
batch_ct,5300.0
batch_loss,0.02249
epoch,99.0
test_r2,-2000.27312


## Get the Embeddings

In [11]:
genes = dataset.genes_in_assay()
gene_scores = torch.from_numpy(dataset.dataset[:, 1:].to_numpy().astype('float32'))

In [12]:
model.eval()
with torch.no_grad():
    emb = model.get_emb(gene_scores).detach().cpu().numpy()
model.train()

VAE(
  (encoder): MLP(
    (relu): ReLU()
    (network): Sequential(
      (0): Linear(in_features=177, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=128, out_features=128, bias=True)
      (10): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Linear(in_features=128, out_features=256, bias=True)
    )
  )
  (decoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0

In [13]:
emb_df = pd.DataFrame(
    data = emb,
    index = genes,
    columns = [f'EMB_{i}' for i in range(hparams["emb_dim"])]
)

In [14]:
emb_df.index.name = "gene_id"
emb_df

Unnamed: 0_level_0,EMB_0,EMB_1,EMB_2,EMB_3,EMB_4,EMB_5,EMB_6,EMB_7,EMB_8,EMB_9,...,EMB_118,EMB_119,EMB_120,EMB_121,EMB_122,EMB_123,EMB_124,EMB_125,EMB_126,EMB_127
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSG00000268723,-2.447452,-2.578989,1.441438,-1.551750,2.606443,1.790830,-3.063021,1.277864,0.830067,-2.299615,...,0.042474,1.454578,0.510296,-2.255242,3.855302,-1.307074,2.449574,-2.159902,1.629684,-2.412930
ENSG00000272472,-2.445033,-2.576698,1.440027,-1.550353,2.604102,1.789134,-3.060102,1.276379,0.829069,-2.297524,...,0.042579,1.453420,0.509974,-2.252960,3.851781,-1.305450,2.447245,-2.158240,1.628374,-2.410498
ENSG00000252002,-2.474335,-2.604754,1.456906,-1.567647,2.633294,1.809640,-3.095583,1.294581,0.841056,-2.323334,...,0.040572,1.467866,0.513787,-2.280678,3.894751,-1.324960,2.475613,-2.178764,1.644406,-2.440076
ENSG00000180806,-2.601565,-2.740588,1.539280,-1.651261,2.761152,1.904528,-3.251044,1.383297,0.887083,-2.428858,...,0.035545,1.525795,0.536068,-2.408858,4.079382,-1.411488,2.598793,-2.269048,1.715945,-2.569330
ENSG00000165525,-1.986893,-0.233340,0.212840,-0.884871,3.655952,0.128181,-2.359196,0.913187,1.621459,-3.916359,...,-2.070255,2.006452,-0.443681,-1.327221,4.845542,-0.971142,3.540757,-3.027183,1.400047,-3.174893
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000273513,-2.443378,-2.575095,1.439057,-1.549378,2.602487,1.787957,-3.058109,1.275344,0.828405,-2.296116,...,0.042658,1.452626,0.509730,-2.251409,3.849388,-1.304365,2.445667,-2.157098,1.627464,-2.408860
ENSG00000223313,-2.448873,-2.580260,1.442225,-1.552513,2.607687,1.791833,-3.064644,1.278627,0.830607,-2.300694,...,0.042431,1.455190,0.510514,-2.256455,3.857208,-1.307922,2.450789,-2.160752,1.630385,-2.414178
ENSG00000158623,-2.646688,-2.789868,1.570103,-1.681018,2.802729,1.938740,-3.305467,1.414327,0.901650,-2.462649,...,0.036495,1.542633,0.544556,-2.455058,4.141959,-1.442622,2.640727,-2.298578,1.739371,-2.613292
ENSG00000284922,-2.443314,-2.575032,1.439018,-1.549340,2.602424,1.787911,-3.058031,1.275303,0.828379,-2.296062,...,0.042660,1.452595,0.509720,-2.251348,3.849296,-1.304322,2.445606,-2.157054,1.627427,-2.408796


In [15]:
emb_dim = hparams["emb_dim"]
#emb_df.to_csv(f"TS_compressed_vae_d{emb_dim}.tsv", sep="\t")