# A Normal (Variational) Autoencoder for the Tabula Sapiens Pseudobulks Dataset

## The Model

In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict
import polars as pl
from dataset import TS_Compressed_VAE_Dataset
import numpy as np
from sklearn.metrics import r2_score
import pandas as pd
import copy

In [2]:
class MLP(nn.Module):
    """
    A multilayer perceptron with ReLU activations and optional BatchNorm.

    Careful: if activation is set to ReLU, ReLU is only applied to the second half of NN outputs! 
            ReLU is applied to standard deviation not mean
    """

    def __init__(
        self,
        sizes,
        batch_norm=True,
        last_layer_act="linear",
    ):
        super(MLP, self).__init__()
        layers = []
        for s in range(len(sizes) - 1):
            layers += [
                nn.Linear(sizes[s], sizes[s + 1]),
                nn.BatchNorm1d(sizes[s + 1])
                if batch_norm and s < len(sizes) - 2
                else None,
                nn.ReLU(),
            ]

        layers = [l for l in layers if l is not None][:-1]
        
        self.activation = last_layer_act
        if self.activation == "linear":
            pass
        elif self.activation == "ReLU":
            self.relu = nn.ReLU()
        else:
            raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")

        
        layers_dict = OrderedDict(
                {str(i): module for i, module in enumerate(layers)}
            )

        self.network = nn.Sequential(layers_dict)

    def forward(self, x):
        if self.activation == "ReLU":
            x = self.network(x)
            dim = x.size(1) // 2
            return torch.cat((x[:, :dim], self.relu(x[:, dim:])), dim=1)
        return self.network(x)

In [3]:
class VAE(nn.Module):
    def __init__(
        self,
        hparams: dict()
    ):
        super(VAE, self).__init__()
        self.hparams = hparams
        self.batch_norm = hparams["batch_norm"]
        self.Variational = hparams["Variational"]

        if self.Variational:
            self.encoder_sizes = [self.hparams["dim"]]+[self.hparams["encoder_width"]]* self.hparams["encoder_depth"]+ [self.hparams["emb_dim"]*2]
            self.decoder_sizes = [self.hparams["emb_dim"]]+[self.hparams["decoder_width"]]* self.hparams["decoder_depth"]+ [self.hparams["dim"]]
            self.encoder = MLP(self.encoder_sizes, batch_norm=self.batch_norm, last_layer_act="ReLU")
            self.decoder = MLP(self.decoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")

        else:
            self.encoder_sizes = [self.hparams["dim"]]+[self.hparams["encoder_width"]]* self.hparams["encoder_depth"]+ [self.hparams["emb_dim"]]
            self.decoder_sizes = [self.hparams["emb_dim"]]+[self.hparams["decoder_width"]]* self.hparams["decoder_depth"]+ [self.hparams["dim"]]
            self.encoder = MLP(self.encoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")
            self.decoder = MLP(self.decoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")

    def reparametrize(self, mu, sd):
        epsilon = torch.randn_like(sd)    
        z = mu + sd * epsilon 
        return z

    def get_emb(self, x):
        """
        get the embedding of given expression profiles of genes
        @param x: should be the shape [batch_size, hparams["dim]]
        """
        return self.encoder(x)[:, 0:self.hparams["emb_dim"]]
        
    def forward(self, x):
        """
        get the reconstruction of the expression profile of a gene
        @param x: should be the shape [batch_size, hparams["dim]]
        """
        latent = self.encoder(x)
        if self.Variational:
            mu = latent[:, 0:self.hparams["emb_dim"]]
            sd = latent[:, self.hparams["emb_dim"]:]
            assert mu.shape == sd.shape
            latent = self.reparametrize(mu, sd)
        reconstructed = self.decoder(latent)
        return reconstructed

## The Dataset and Hyperparameters 

In [4]:
path = "TabulaSapiens_pb_normalized.h5ad"
dataset = TS_Compressed_VAE_Dataset(path)

In [5]:
hparams = {
    "dim": dataset.num_celllines_in_assay(),
    "encoder_width": 128,
    "encoder_depth": 5,
    "emb_dim": 128,
    "decoder_width": 128,
    "decoder_depth": 6,
    "batch_norm": True,
    "Variational": False,
    
}
config = {
    "epochs": 500,
    "batch_size": 1500,
    "lr": 1e-3
}

device = "cpu"

## The Training and Testing

In [6]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkemingzhang[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
def train_epoch(model, opt, loss, batch_size, dataset, epoch):
    batch_ct = epoch * dataset.get_num_batches_per_epoch(batch_size)
    cumu_loss = 0
    for _, target in dataset.get_batches(batch_size, 'train'):
        model.to(device)
        opt.zero_grad()
        target = target.to(device, non_blocking=True)
        pred = model(target)
        
        mse = loss(pred, target)
        cumu_loss += mse.item()
        mse.backward()
        opt.step()

        batch_ct += 1
        wandb.log({"batch_loss": mse.item(), "batch_ct": batch_ct})


    #torch.mps.empty_cache()
    return cumu_loss / dataset.get_num_batches_per_epoch(batch_size)


In [8]:
def eval_r2(model, dataset):
    model.eval()
    with torch.no_grad():
        target = torch.from_numpy(dataset.test_table[:, 1:].to_numpy().astype('float32')).to(device, non_blocking=True)
        pred = model(target).detach().cpu().numpy()
        target = target.detach().cpu().numpy()
    model.train()
    return r2_score(target, pred)

In [9]:
class EarlyStopping():
  def __init__(self, patience=8, min_delta=-0.1, restore_best_weights=False, restore_app_weights=True):
    self.patience = patience
    self.min_delta = min_delta
    self.restore_best_weights = restore_best_weights
    self.restore_app_weights = restore_app_weights
    self.best_model = None
    self.app_model = None #the model that may perfrom a bit worse than the best on the test data but better on train data
    self.best_r2 = None
    self.app_r2 = None
    self.counter = 0
    self.status = ""

  def __call__(self, model, test_r2):
    if self.best_r2 == None:
      self.best_r2 = test_r2
      self.app_r2 = test_r2
      self.best_model = copy.deepcopy(model)
      self.app_model = copy.deepcopy(model)

    elif test_r2 - self.best_r2 >= 0:
      self.best_r2 = test_r2
      self.app_r2 = test_r2
      self.counter = 0
      self.best_model.load_state_dict(model.state_dict())
      self.app_model.load_state_dict(model.state_dict())

    elif test_r2 - self.best_r2 >= self.min_delta:
      self.counter = 0
      self.app_r2 = test_r2
      self.app_model.load_state_dict(model.state_dict())

    elif test_r2 - self.best_r2 < self.min_delta:
      self.counter += 1
      if self.counter >= self.patience:
        self.status = f"Stopped on {self.counter}"
        if self.restore_app_weights:
          model.load_state_dict(self.app_model.state_dict())
        elif self.restore_best_weights:
          model.load_state_dict(self.best_model.state_dict())
        return True
    self.status = f"{self.counter}/{self.patience}"
    return False

In [10]:
def train(config=config):
    with wandb.init(project="vae_ts", config = config):
        #this config will be set by Sweep Controller
        config = wandb.config

        model = VAE(hparams)
        #es = EarlyStopping()
        loss = nn.MSELoss(reduction="mean")
        opt = torch.optim.Adam(model.parameters(), lr = config.lr)


        wandb.define_metric("batch_loss", step_metric="batch_ct")
        wandb.define_metric("avg_loss", step_metric="epoch")
        wandb.define_metric("test_r2", step_metric="epoch")

        for epoch in range(config.epochs):
            avg_loss = train_epoch(model, opt, loss, config.batch_size, dataset, epoch)
            wandb.log({"avg_loss": avg_loss, "epoch": epoch})
            test_r2 = eval_r2(model, dataset)
            wandb.log({"test_r2": test_r2, "epoch":epoch})
            #if es(model, test_r2): break

    return model

In [11]:
model = train(config)



VBox(children=(Label(value='0.001 MB of 0.007 MB uploaded\r'), FloatProgress(value=0.13352272727272727, max=1.…

0,1
avg_loss,█▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▁▄▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁
batch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
batch_loss,▁▁▁▁▂▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_r2,█▇████████████▇██▆▇██▇█▃█▅█▇██▇█▇██▇███▁

0,1
avg_loss,0.03093
batch_ct,18000.0
batch_loss,0.02017
epoch,499.0
test_r2,-4035919.22502


In [12]:
model

VAE(
  (encoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=177, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=128, out_features=128, bias=True)
      (10): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Linear(in_features=128, out_features=128, bias=True)
      (13): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Linear(in_features=128, out_features=128, bias=True)
    )
  )


## Get the Embeddings

In [13]:
genes = dataset.genes_in_assay()
gene_scores = torch.from_numpy(dataset.dataset[:, 1:].to_numpy().astype('float32'))

In [14]:
model.eval()
with torch.no_grad():
    emb = model.get_emb(gene_scores).detach().cpu().numpy()
model.train()

VAE(
  (encoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=177, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=128, out_features=128, bias=True)
      (10): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Linear(in_features=128, out_features=128, bias=True)
      (13): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Linear(in_features=128, out_features=128, bias=True)
    )
  )


In [15]:
emb_df = pd.DataFrame(
    data = emb,
    index = genes,
    columns = [f'EMB_{i}' for i in range(hparams["emb_dim"])]
)

In [16]:
emb_df.index.name = "gene_id"
emb_df

Unnamed: 0_level_0,EMB_0,EMB_1,EMB_2,EMB_3,EMB_4,EMB_5,EMB_6,EMB_7,EMB_8,EMB_9,...,EMB_118,EMB_119,EMB_120,EMB_121,EMB_122,EMB_123,EMB_124,EMB_125,EMB_126,EMB_127
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSG00000265452,-6.919820,5.343982,10.204840,-5.058686,-5.970248,-13.626997,-8.005278,2.334924,8.291081,-11.489544,...,-7.788873,-2.794478,-6.156457,-8.815792,-6.537822,8.457936,-10.383499,11.630306,9.759427,-1.255642
ENSG00000260063,-5.339058,4.377153,8.083728,-4.016703,-4.803891,-10.884915,-6.794176,1.987034,6.935829,-8.553007,...,-6.121444,-2.508773,-4.429782,-7.239748,-5.213462,7.236993,-8.096043,9.682898,8.137159,-1.237945
ENSG00000260254,-5.315782,4.363211,8.042725,-4.002351,-4.769751,-10.819091,-6.750141,1.965974,6.901060,-8.503800,...,-6.094487,-2.506976,-4.401450,-7.199893,-5.171790,7.184188,-8.056158,9.621420,8.093013,-1.235950
ENSG00000228950,-5.323012,4.367267,8.056360,-4.006558,-4.781978,-10.842258,-6.766303,1.974133,6.913198,-8.520170,...,-6.103116,-2.507045,-4.410837,-7.213929,-5.187171,7.203930,-8.069219,9.643794,8.108402,-1.236462
ENSG00000110536,-5.561466,4.483913,8.529350,-4.139972,-5.232438,-11.683893,-7.368824,2.290424,7.345352,-9.103452,...,-6.394947,-2.488848,-4.749409,-7.718553,-5.768536,7.946247,-8.517799,10.473782,8.655459,-1.241451
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000161270,-5.320959,4.365307,8.055573,-4.004800,-4.784202,-10.844629,-6.770431,1.977348,6.914439,-8.518546,...,-6.101426,-2.505679,-4.409843,-7.215448,-5.190890,7.209930,-8.067607,9.648590,8.110155,-1.236302
ENSG00000225076,-5.327791,4.369851,8.065884,-4.009334,-4.791033,-10.858951,-6.778502,1.980442,6.922038,-8.531300,...,-6.108912,-2.506972,-4.417233,-7.224072,-5.198606,7.218955,-8.078146,9.660419,8.119690,-1.236883
ENSG00000142794,-5.368519,4.391788,8.145722,-4.032509,-4.865459,-10.998361,-6.878222,2.031758,6.994969,-8.626948,...,-6.158308,-2.505996,-4.472102,-7.308576,-5.293208,7.341596,-8.153842,9.797400,8.212143,-1.239571
ENSG00000261938,-5.313808,4.362024,8.039355,-4.001160,-4.767064,-10.813755,-6.746733,1.964369,6.898285,-8.499634,...,-6.092216,-2.506829,-4.399058,-7.196678,-5.168484,7.180127,-8.052823,9.616570,8.089529,-1.235829


In [18]:
emb_dim = hparams["emb_dim"]
emb_df.to_csv(f"TS_pb_ae_d{emb_dim}.tsv", sep="\t")