# Introduction to Variational AutoEncoders (VAE)

The goal of this notebook is to demonstrate how VAEs work in practice through a coding example.

In [None]:
# download a utils.py file containing some utility functions we will need
!curl -O https://raw.githubusercontent.com/chimie-paristech-CTM/PSL_notebooks/main/generative_models/utils.py
!curl -O https://raw.githubusercontent.com/chimie-paristech-CTM/PSL_notebooks/main/generative_models/pretrained.zinc.rnn.pth
# download the pre-trained VAE model
!curl -O https://raw.githubusercontent.com/chimie-paristech-CTM/PSL_notebooks/main/generative_models/pretrained.vae.pt
# clone repository to extract the compressed molecular data
!git clone https://github.com/aksub99/molecular-vae.git

In [None]:
# install other packages required
!pip install rdkit
!pip install molplotly
!pip install torch==2.1
!pip install numpy==1.26
!pip install scikit-learn
!pip install h5py
!pip install dash==2.9.2

We will start by presenting a high-level overview of VAEs (figure below was taken from [here](https://towardsdatascience.com/vae-variational-autoencoders-how-to-employ-neural-networks-to-generate-new-images-bdeb216ed2c0)).



<div align="middle">
<img src="https://towardsdatascience.com/wp-content/uploads/2022/04/1qtXrzMLorYDl4SzKqoZxBg-1536x1110.png" width="900"/>
</div>

As already explained during the lecture, the `Encoder` takes molecules and converts it into a low-dimensional vector. The job of the `Decoder` is to take this `Latent Vector` and `reconstruct` the input.

In practice, to enable robust introduction of noise (which is an essential element of the VARIATIONAL autoencoder), the latent vectors are typically mapped onto a `Gaussian Distribution`, i.e., we don't actually assign the molecule to just a single point, but we assign a probabilistic distribution to them. 

`Gaussian Distributions` or `Normal Distributions`  are completely defined by their `mean` and `variance`, which means that if you know both the `mean` and `variance`, you can construct the full `Gaussian Distribution`. 

When every molecule is assigned a `mean` and `variance`, then the noise can be added according to the formula provided in the figure. The noise parameter `epsilon` is typically drawn from a `Gaussian Distribution` itself. 

In the [original molecular `VAE` paper](https://pubs.acs.org/doi/10.1021/acscentsci.7b00572), a neural network model was trained to predict properties in the `Latent Space`. Additionally, it was demonstrated how you can move in the `Latent Space` to go from some starting molecule to another molecules with desired properties. Here, we omit further details and instead try to visually demonstrate what the `Latent Space` is.

Note, the `VAE` code below is taken from [here](https://aksub99.github.io/).

In [None]:
# imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
import numpy as np
from __future__ import print_function
import argparse
import os
import h5py
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection

In [None]:
# these are utility functions
def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    return b"".join(map(lambda x: charset[x], vec)).strip()

def load_dataset(filename, split = True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset =  h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)

In [None]:
# the main code for the VAE is found below
class MolecularVAE(nn.Module):
    def __init__(self):
        super().__init__()

        # encoder related blocks
        # 120 corresponds to the default size selected for the SMILES string (if the string contains less tokens, the remaining entries are left "empty")
        # for every token, there are 33 options, i.e., every SMILES string is turned into a matrix/tensor of shape [120, 33]
        # for the first index of the tensor, we go from 120 to 9 dimensions in the first convolutional layer. For the second index, 
        # we gradually reduce the size based on the kernel size. 
        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9) 
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(70, 435)
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)

        # decoder related blocks
        # now we go back from 292 dimensions to the original size of the molecule
        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 33)

        # activation function 
        self.relu = nn.ReLU()
        
    def encode(self, x):
        # forward pass through encoder
        x = self.relu(self.conv_1(x)) # input shape [120, 33], output shape [9, 25] 
        x = self.relu(self.conv_2(x)) # input shape [9, 25], output shape [9, 17]
        x = self.relu(self.conv_3(x)) # input shape [9, 17], output shape [10, 7]
        x = x.view(x.size(0), -1) # turns the [10, 7] tensor into a [70] one
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        # recall in the VAE figure, noise is added
        # epsilon is the noise
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        # return the latent vector (this is what the decoder will use to reconstruct the input)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        # forward pass through decoder to go from latent vector back to a molecule
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        # the overall forward pass takes the input, passes it to the encoder and then decoder
        # first encode your input to get the mean and variance of the Gaussian distribution it is mapped to
        z_mean, z_logvar = self.encode(x)
        # get the latent vector taking the mean and variance above and adding noise t it
        z = self.sampling(z_mean, z_logvar)
        # decode the latent vector, z, to reconstruct a molecule
        return self.decode(z), z_mean, z_logvar
    
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    # the loss function is a combination of 2 quantities:
    #     1. "reconstruction loss" which measures how different the reconstructed molecule 
    #        is to the original. We would want them to be similar
    
    #     2. "Kullback–Leibler (KL) divergence". We are trying to approximate the distribution
    #         of the latent vector with a Gaussian distribution. The KL divergence measure how "off" we are
    reconstruction_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return reconstruction_loss + kl_loss

In [None]:
!unzip molecular-vae/data/processed.zip  -d molecular-vae/data/

In [None]:
# this was used when we pre-trained the VAE
# it initializes a PyTorch DataLoader so we can read batches of molecules at a time during training
data_train, data_test, charset = load_dataset('molecular-vae/data/processed.h5')
data_train = torch.utils.data.TensorDataset(torch.from_numpy(data_train))
train_loader = torch.utils.data.DataLoader(data_train, batch_size=500, shuffle=True)

In [None]:
# initiate an instance of the MolecularVAE 
pretrained_vae = MolecularVAE()
# load the pre-trained model (we provide this)
pretrained_vae.load_state_dict(torch.load('pretrained.vae.pt'))

*Starting* below, we will visualize the `Latent Space`.

In [None]:
# RERUN HERE FOR NEW MOLECULES!
# this bit of code randomly takes 500 molecules from the training data
for batch in train_loader:
    training_data_molecules = batch

In [None]:
# manually add some noise to the training data molecules --> we will see 
# what these "noised" molecules look like in the latent space later
num_noised = 10
molecules_to_noise = training_data_molecules[0][:num_noised]

# you can check that molecules_to_noise is a tensor with a shape of [10, 120, 33], i.e., number of molecules, default token length, types of tokens
print(molecules_to_noise.shape)

In [None]:
# now, we add noise to every element of this tensor
noised_molecules = molecules_to_noise + torch.normal(0, 0.0001, (num_noised, 120, len(charset)))

In [None]:
# this bit of code gets the SMILES back from the 500 training data molecules we got above
smiles_list = []
for idx in range(training_data_molecules[0].shape[0]):
    vector = training_data_molecules[0][idx].reshape(1, 120, len(charset)).argmax(axis=2)[0]
    smiles = decode_smiles_from_indexes(vector, charset)
    smiles = str(smiles).replace("'", '').replace('b', '')
    smiles_list.append(smiles)

In [None]:
# this bit of code gets the SMILES from the "noised" training data molecules
noised_smiles_list = []
for idx in range(noised_molecules.shape[0]):
    vector = noised_molecules[idx].reshape(1, 120, len(charset)).argmax(axis=2)[0]
    smiles = decode_smiles_from_indexes(vector, charset)
    smiles = str(smiles).replace("'", '').replace('b', '')
    noised_smiles_list.append(smiles)

In [None]:
# encode the training data SMILES
z_mean, z_logvar = pretrained_vae.encode(training_data_molecules[0])
# get the latent space
latent_space = pretrained_vae.sampling(z_mean, z_logvar)

# encode the noised data
noised_z_mean, noised_z_logvar = pretrained_vae.encode(noised_molecules)
# get the latent space of the "noised" molecules
noised_latent_space = pretrained_vae.sampling(noised_z_mean, noised_z_logvar)

In [None]:
# the code here plots an interative (cross-section of the) latent space - hover around the space and explore the molecules!
import plotly
import plotly.express as px
import molplotly
import pandas as pd 

all_smiles = smiles_list + noised_smiles_list
full_latent_space = torch.vstack([latent_space, noised_latent_space])

print(full_latent_space)

plotting_df = pd.DataFrame({'smiles': all_smiles,
                           'group': ['Training Data']*500 + ['Sampled from Latent Space']*num_noised,
                           'latent_space_x': full_latent_space[:, 0].detach(),
                           'latent_space_y': full_latent_space[:, 1].detach()})

fig_scatter = px.scatter(plotting_df,
                         x='latent_space_x',
                         y='latent_space_y',
                         color='group')

app_scatter = molplotly.add_molecules(fig=fig_scatter,
                                      df=plotting_df,
                                      smiles_col='smiles',
                                      title_col='group',
                                      color_col='group')

app_scatter.run_server(mode='inline', height=400)

# the red points are the "sampled" molecules created from adding "noise" to the latent vectors of the
# training data molecules.

# Locate a red point and look at the training data points around it. You should be able to see some 
# structural similarities. One can think of the red point as a "hybrid" between its surrounding neighbours
# of blue points

# Note: it could be that sometimes "close" points are not that similar - this has to do with the "smoothness"
#       of the latent space such that there are abrupt changes

# Finally, if you want to see new molecules, re-run the cell above marked with "RERUN HERE FOR NEW MOLECULES!"