# A Normal Variational Autoencoder for the DepMap Dataset

## The Model

In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict
import polars as pl
from dataset import DepMap_Data
import numpy as np
from sklearn.metrics import r2_score
import pandas as pd

In [2]:
class MLP(nn.Module):
    """
    A multilayer perceptron with ReLU activations and optional BatchNorm.

    Careful: if activation is set to ReLU, ReLU is only applied to the second half of NN outputs! 
            ReLU is applied to standard deviation not mean
    """

    def __init__(
        self,
        sizes,
        batch_norm=True,
        last_layer_act="linear",
    ):
        super(MLP, self).__init__()
        layers = []
        for s in range(len(sizes) - 1):
            layers += [
                nn.Linear(sizes[s], sizes[s + 1]),
                nn.BatchNorm1d(sizes[s + 1])
                if batch_norm and s < len(sizes) - 2
                else None,
                nn.ReLU(),
            ]

        layers = [l for l in layers if l is not None][:-1]
        
        self.activation = last_layer_act
        if self.activation == "linear":
            pass
        elif self.activation == "ReLU":
            self.relu = nn.ReLU()
        else:
            raise ValueError("last_layer_act must be one of 'linear' or 'ReLU'")

        
        layers_dict = OrderedDict(
                {str(i): module for i, module in enumerate(layers)}
            )

        self.network = nn.Sequential(layers_dict)

    def forward(self, x):
        if self.activation == "ReLU":
            x = self.network(x)
            dim = x.size(1) // 2
            return torch.cat((x[:, :dim], self.relu(x[:, dim:])), dim=1)
        return self.network(x)

In [3]:
class VAE(nn.Module):
    def __init__(
        self,
        hparams: dict()
    ):
        super(VAE, self).__init__()
        self.hparams = hparams
        self.batch_norm = hparams["batch_norm"]
        self.Variational = hparams["Variational"]

        if self.Variational:
            self.encoder_sizes = [self.hparams["dim"]]+[self.hparams["encoder_width"]]* self.hparams["encoder_depth"]+ [self.hparams["emb_dim"]*2]
            self.decoder_sizes = [self.hparams["emb_dim"]]+[self.hparams["decoder_width"]]* self.hparams["decoder_depth"]+ [self.hparams["dim"]]
            self.encoder = MLP(self.encoder_sizes, batch_norm=self.batch_norm, last_layer_act="ReLU")
            self.decoder = MLP(self.decoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")

        else:
            self.encoder_sizes = [self.hparams["dim"]]+[self.hparams["encoder_width"]]* self.hparams["encoder_depth"]+ [self.hparams["emb_dim"]]
            self.decoder_sizes = [self.hparams["emb_dim"]]+[self.hparams["decoder_width"]]* self.hparams["decoder_depth"]+ [self.hparams["dim"]]
            self.encoder = MLP(self.encoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")
            self.decoder = MLP(self.decoder_sizes, batch_norm=self.batch_norm, last_layer_act="linear")

    def reparametrize(self, mu, sd):
        epsilon = torch.randn_like(sd)    
        z = mu + sd * epsilon 
        return z

    def get_emb(self, x):
        """
        get the embedding of given expression profiles of genes
        @param x: should be the shape [batch_size, hparams["dim]]
        """
        return self.encoder(x)[:, 0:self.hparams["emb_dim"]]
        
    def forward(self, x):
        """
        get the reconstruction of the expression profile of a gene
        @param x: should be the shape [batch_size, hparams["dim]]
        """
        latent = self.encoder(x)
        if self.Variational:
            mu = latent[:, 0:self.hparams["emb_dim"]]
            sd = latent[:, self.hparams["emb_dim"]:]
            assert mu.shape == sd.shape
            latent = self.reparametrize(mu, sd)
        reconstructed = self.decoder(latent)
        return reconstructed

## The Dataset and Hyperparameters 

In [4]:
path = "DepMap18Q3_gene_effect_raw.tsv"
dataset = DepMap_Data(path)

In [5]:
hparams = {
    "dim": dataset.num_celllines_in_assay(),
    "encoder_width": 256,
    "encoder_depth": 4,
    "emb_dim": 256,
    "decoder_width": 256,
    "decoder_depth": 4,
    "batch_norm": True,
    "Variational": True,
    
}
config = {
    "epochs": 1000,
    "batch_size": 1e3,
    "lr": 1e-3
}

device = "cpu"

## The Training and Testing

In [6]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkemingzhang[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
def train_epoch(model, opt, loss, batch_size, dataset, epoch):
    batch_ct = epoch * dataset.get_num_batches_per_epoch(batch_size)
    cumu_loss = 0
    for _, target in dataset.get_batches(batch_size, 'train'):
        model.to(device)
        opt.zero_grad()
        target = target.to(device, non_blocking=True)
        pred = model(target)
        
        mse = loss(pred, target)
        cumu_loss += mse.item()
        mse.backward()
        opt.step()

        batch_ct += 1
        wandb.log({"batch_loss": mse.item(), "batch_ct": batch_ct})


    #torch.mps.empty_cache()
    return cumu_loss / len(dataset.train_table)


In [8]:
def eval_r2(model, dataset):
    model.eval()
    with torch.no_grad():
        target = torch.from_numpy(dataset.test_table[:, 1:].to_numpy().astype('float32')).to(device, non_blocking=True)
        pred = model(target).detach().cpu().numpy()
        target = target.detach().cpu().numpy()
    model.train()
    return r2_score(target, pred)

In [9]:
def train(config=config):
    with wandb.init(project="vae_dm", config = config):
        #this config will be set by Sweep Controller
        config = wandb.config

        model = VAE(hparams)

        loss = nn.MSELoss(reduction="mean")
        opt = torch.optim.Adam(model.parameters(), lr = config.lr)


        wandb.define_metric("batch_loss", step_metric="batch_ct")
        wandb.define_metric("avg_loss", step_metric="epoch")
        wandb.define_metric("test_r2", step_metric="epoch")

        for epoch in range(config.epochs):
            avg_loss = train_epoch(model, opt, loss, config.batch_size, dataset, epoch)
            wandb.log({"avg_loss": avg_loss, "epoch": epoch})
            test_r2 = eval_r2(model, dataset)
            wandb.log({"test_r2": test_r2, "epoch":epoch})


        #save the model in the exchangable ONNX format
        target = torch.from_numpy(dataset.test_table[:, 1:].to_numpy().astype('float32')).to(device, non_blocking=True)
        torch.onnx.export(model, target, "model.onnx")
        wandb.save("model.onnx")

    return model

In [10]:
model = train(config)

  assert mu.shape == sd.shape


VBox(children=(Label(value='2.735 MB of 3.251 MB uploaded\r'), FloatProgress(value=0.8413183093468731, max=1.0…

0,1
avg_loss,█▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
batch_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch_loss,███▆▅▅▄▅▄▄▄▃▃▃▃▃▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_r2,▁▄▅▅▅▆▆▆▇▄▇▆▇▆▇▇▇▇▆▇▇▇▆████████▇▆▇▆█▇█▇█

0,1
avg_loss,9e-05
batch_ct,15000.0
batch_loss,0.08942
epoch,999.0
test_r2,0.88614


## Get the Embeddings

In [11]:
genes = dataset.genes_in_assay()
gene_scores = torch.from_numpy(dataset.dataset[:, 1:].to_numpy().astype('float32'))

In [12]:
model.eval()
with torch.no_grad():
    emb = model.get_emb(gene_scores).detach().cpu().numpy()
model.train()

VAE(
  (encoder): MLP(
    (relu): ReLU()
    (network): Sequential(
      (0): Linear(in_features=485, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=256, out_features=256, bias=True)
      (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Linear(in_features=256, out_features=512, bias=True)
    )
  )
  (decoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0

In [13]:
emb_df = pd.DataFrame(
    data = emb,
    index = genes,
    columns = [f'EMB_{i}' for i in range(hparams["emb_dim"])]
)

In [14]:
emb_df.index.name = "gene_id"
emb_df

Unnamed: 0_level_0,EMB_0,EMB_1,EMB_2,EMB_3,EMB_4,EMB_5,EMB_6,EMB_7,EMB_8,EMB_9,...,EMB_246,EMB_247,EMB_248,EMB_249,EMB_250,EMB_251,EMB_252,EMB_253,EMB_254,EMB_255
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSG00000166136,1.641138,0.113827,3.921674,-1.032297,-1.784951,2.761224,-0.963215,0.054067,0.886144,2.497651,...,-2.551810,-0.874080,0.123632,-1.525758,0.081557,1.207395,0.091882,-3.510774,1.390204,-0.308297
ENSG00000158497,1.227518,-0.201240,2.679383,0.710508,-1.599195,-0.127268,-0.670575,0.623313,0.310372,1.875062,...,-0.717993,0.807587,-0.043969,0.827601,-0.581406,-0.458845,-0.090785,-0.459049,-0.454707,-0.776513
ENSG00000130158,-0.305238,-0.931766,2.670334,-1.156555,-0.207786,0.460169,-0.771380,0.193283,0.701356,0.516930,...,0.092865,0.333991,0.705045,-0.334567,0.143183,0.900249,-2.049591,-1.107022,0.736480,-2.218306
ENSG00000163513,0.555090,-2.963492,1.886296,-0.194856,-1.277622,0.863973,0.157402,0.730744,0.604816,0.873439,...,-1.013703,-0.266838,-0.954578,0.458120,-0.886794,0.048439,-0.940706,-0.370678,-0.082977,-1.468873
ENSG00000196735,-0.011198,-1.249279,2.183905,-0.822078,-1.058541,-1.073211,-0.261656,0.421751,0.181857,0.720012,...,-1.387877,1.664654,0.181335,0.335060,-1.299922,-0.415443,-0.097746,-1.015305,0.270264,-1.403219
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000065911,0.247325,0.267196,1.401120,-0.907286,-0.706039,-0.966901,-0.713497,0.260983,0.510956,1.823110,...,0.185108,1.073838,0.564209,1.178624,0.162501,0.113417,0.518227,0.793102,1.065315,-1.298700
ENSG00000086159,0.639313,-0.212863,2.246478,-0.773693,-1.101509,-0.425882,-0.530464,0.446519,0.553853,2.969281,...,-0.321899,0.013963,0.731371,0.280500,-1.513927,0.211324,-0.295557,-0.082989,0.274435,0.277479
ENSG00000138964,0.172174,-0.531750,3.232540,0.176913,-1.699616,-1.263994,-0.608762,1.327598,0.686856,1.599976,...,0.405329,0.791420,1.229977,-0.765838,-0.179073,0.456911,-0.385508,0.206216,0.132343,-1.259318
ENSG00000163510,0.256619,0.799553,1.589412,4.564667,0.002477,-2.206850,3.108769,1.587506,-1.126365,1.531915,...,-3.390278,-3.083746,-4.117815,1.888183,-1.020239,1.859806,2.011737,-0.827389,1.435091,-0.296291


In [15]:
emb_dim = hparams["emb_dim"]
emb_df.to_csv(f"DepMap_unprocessed_vae_d{emb_dim}.tsv", sep="\t")