#  <center> Problem Set 4 <center>
<center> Spring 2024 <center>
<center> 3.C01/3.C51, 10.C01/10.C51 <center>

<b>Name:</b>

<b>Kerberos id:</b>

In [1]:
!pip install wget
!pip install molvs
!pip install rdkit

from torch.utils.data import Dataset, DataLoader
import matplotlib
from torch import nn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestClassifier
import wget
from rdkit.Chem import AllChem
from rdkit import DataStructs

import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw
import pandas as pd
import torch 
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from torch import optim
from molvs import standardize_smiles

matplotlib.rcParams.update({'font.size': 15})
matplotlib.rc('lines', linewidth=3, color='g')
matplotlib.rcParams['axes.linewidth'] = 2.0
matplotlib.rcParams['axes.linewidth'] = 2.0
matplotlib.rcParams["xtick.major.size"] = 6
matplotlib.rcParams["ytick.major.size"] = 6
matplotlib.rcParams["ytick.major.width"] = 2
matplotlib.rcParams["xtick.major.width"] = 2
matplotlib.rcParams['text.usetex'] = False



## Part 1: Dimensionality Reduction for Molecular Representations

In [None]:
wget.download("https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps4/data/nonbio_version/drug.csv", "./")

In [None]:
# Load data 
df = pd.read_csv("drug.csv")

In [None]:
########## Simply run this chunk ##########
# Convert SMILES strings to Morgan fingerprints with rdkit
# Define radius and number of bits for our exercise
radius_pset4 = 3
num_bits_pset4 = 512

class ECFP:
    def __init__(self, smiles):
        self.mols = [Chem.MolFromSmiles(i) for i in smiles]
        self.smiles = smiles

    def mol2fp(self, mol):
        bi = {}
        fp = AllChem.GetMorganFingerprintAsBitVect(mol,
                                                   radius = radius_pset4,
                                                   bitInfo = bi,
                                                   nBits = num_bits_pset4)
        array = np.zeros((1,))
        DataStructs.ConvertToNumpyArray(fp, array)
        return array, bi

    def compute_ECFP(self):
        bit_headers = ['bit' + str(i) for i in range(num_bits_pset4)]
        arr = np.empty((0,num_bits_pset4), int).astype(int)
        bitInfo_all = []
        mol_all = []
        for i in self.mols:
            mol_all.append(i)
            fp, bi = self.mol2fp(i)
            arr = np.vstack((arr, fp))
            bitInfo_all.append(bi)
        df_fp = pd.DataFrame(np.asarray(arr).astype(int),columns=bit_headers)
        df_fp.insert(loc=0, column='smiles', value=self.smiles)
        df_fp.insert(loc=1, column='mol', value=mol_all)
        df_fp.insert(loc=2, column='bitInfo', value=bitInfo_all)
        return df_fp

smiles_standarized = [standardize_smiles(i) for i in df['SMILES'].values]
fp_descriptor = ECFP(smiles_standarized)
fp = fp_descriptor.compute_ECFP()
# Remove first column as we will reference smiles column from df dataframe
# Remove second and third columns because not needed for our exercise here
fp = fp.drop(columns=['smiles', 'mol', 'bitInfo']).values.astype(float)
# This resulting dataframe, fp, contains the 512 bits (columns) making up
# the fingerprints for the 4,629 molecules (rows)

### 1.1 (5 points, Grad students only) Choosing radius and number of bits for Morgan fingerprints


Provide a one-sentence description of what the radius represents and another of what the number of bits represents. How does adjusting the radius parameter affect the granularity of the motifs captured by the fingerprints, and how does this relate to the choice of the number of bits?

*INSERT YOUR ANSWER HERE*

### 1.2 (10 points) Principal Component Analysis on Molecular Fingerprints

Perform PCA to reduce data into vectors of 100 dimensions.

In [None]:
########## Modify this code chunk ##########

# Run PCA

# Plot PCA

fig, ax = plt.subplots(figsize=(5,5))
ax.scatter(, , s=3, label='inactive') 
ax.scatter(, , color='red', s= 3, label='active')
ax.legend()

What is the explained variance ratio of the 100 principal components?

In [None]:
########## Insert your code in this chunk ##########


What patterns do you observe?

*INSERT YOUR ANSWER HERE*

### 1.3 (10 points) t-SNE analysis on Molecular Fingerprints

Perform t-SNE on the obtained principal components, with perplexity value of 2, 30, and 500. Plot the results and label your plots.

In [None]:
########## Insert your code in this chunk ##########

What differences do you see between the 3 t-SNE plots? What patterns do you observe in the perplexity = 30 plot?

In [None]:
########## Insert your answer ##########

### 1.4 Graduate (20 points) Are the low dimensional embeddings meaningful?

Split the data into 10 folds. For each fold, train on the other 9 folds and validate on the last fold. Record your prediction.

In [None]:
########## Insert your code in this chunk ##########

Classify your predictions into True Positives (TP), True Negatives (TN), False Positives (FP) and False Negatives (FN).

In [None]:
########## Insert your code in this chunk ##########

Plot the 2D t-SNE embeddings (perplexity = 30) colored by the four classification classes.

In [None]:
########## Insert your code in this chunk ##########

What pattern do you observe? 

In [None]:
########## Insert your answer ##########

## Part 2: Variational auto-encoders for SMILES strings


In [None]:
# Get data 
wget.download("https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps4/data/nonbio_version/zinc_50k.csv", "./")

wget.download("https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps4/data/nonbio_version/zinc_50k.csv", "./")
    
# Get pretrained model
wget.download("https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps4/data/nonbio_version/vae-050-0.06.pth", "./")

### 2.1 (5 points) One-hot encode SMILES strings into padded numerical vectors

In [None]:
from sklearn import preprocessing

# Character list 
moses_charset = ['2', 'o', 'C', 'I', 'O', 'H', 'n', 'N', '=', '+', '#', '-', 'c',
                 'B', 'l', '7', 'r', 'S', 's', '4', '6', '[', '5', ']', 'F', '3', 
                 'P', '(', ')', '1', ' ']

# Define encoder 
enc = preprocessing.LabelEncoder().fit(moses_charset)

# Read data 
df = pd.read_csv("./zinc_50k.csv")

Encode SMILES strings into padded categorical vectors.

In [None]:
########## Insert your code in this chunk ##########

# Find out the longest SMILES string, pad and encode into categorical vectors

Make train/validation/test Datasets and DataLoaders.

In [None]:
X_train, X_test = None
X_train, X_val = None

train_data = None
train_loader = None

val_data = None
val_loader = None

test_data = None
test_loader = None

### 2.2 (15 points) Implement the Reparametrization Trick for VAE

In [None]:
# Molecular VAE model 

class MolVAE(nn.Module):
    def __init__(self,  rnn_enc_hid_dim, enc_nconv,
                         encoder_hid, z_dim, 
                         rnn_dec_hid_dim, dec_nconv, smiles_len, nchar
                         ):
        '''
            SMILES VAE model 
            
                rnn_enc_hid_dim: hidden dimension for the GRU encoder 
                enc_nconv: number of recurrent layers for the GRU decoder
                encoder_hid: dimension of GUR encoder readout
                z_dim: number of latent variable 
                rnn_dec_hid_dim: hidden dimension for the GRU decoder 
                dec_nconv: number of recurrent layers for the GRU decoder
                smiles_len: total length of padded SMILES string 
                nchar: number of possible characters 
                
        '''
        
        super(MolVAE, self).__init__()
        
        self.smiles_len = smiles_len
        self.nchar = nchar
        # Embedding layer
        self.embed = nn.Embedding(self.nchar, rnn_enc_hid_dim)
        # Encoding GRU
        self.rnn_enc = nn.GRU(rnn_enc_hid_dim, rnn_enc_hid_dim, enc_nconv, batch_first=True)
        # MLP to transfrom hidden output from Encoding GRU
        self.mlp0 = nn.Linear(rnn_enc_hid_dim, encoder_hid)
        # Network to parametrize mu
        self.mu_network = nn.Linear(encoder_hid, z_dim)
        # Network to parametrize log variance
        self.logvar_network = nn.Linear(encoder_hid, z_dim)
        # Decoding GRU
        self.rnn_dec = nn.GRU(z_dim, rnn_dec_hid_dim, dec_nconv, batch_first=True)
        # Output SMILES characters
        self.readout = nn.Linear(rnn_dec_hid_dim, self.nchar)

    def encode(self, x):
        '''Output mean and log variance of the encoded SMILES'''
        output, hn = self.rnn_enc(x)
        h = torch.nn.functional.relu(self.mlp0(hn[-1]))
        return self.mu_network(h), self.logvar_network(h)
    
    def get_std(self, logvar):
        '''Transform log variance to standard deviation'''
        ################ Your code #################

    def reparametrize(self, mu, std):
        '''The reparametrization trick'''
        if self.training:
            ################ Your code #################
        else:
            return mu

    def decode(self, z):
        '''Decoder to reconstruct latent variable back to SMILES'''
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, self.smiles_len, 1)
        out, h = self.rnn_dec(z)
        out_reshape = out.contiguous().view(-1, out.size(-1))
        
        y0 = self.readout(out_reshape)
        y = y0.contiguous().view(out.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        x_embed = self.embed(x) # Get SMILES embedding 
        mu, logvar = self.encode(x_embed) # Encoding SMILES to latent representations 
        std = self.get_std(logvar) # Transfrom log variance to std.
        z = self.reparametrize(mu, std) # Reparametrization trick 
        smiles_recon = self.decode(z)  # Reconstruct SMILES string 
        return smiles_recon, mu, std

Test your model by comparing your sampling with N(0, 1).

In [None]:
########## Simply run this chunk ##########

# Define your model 
model = MolVAE(rnn_enc_hid_dim=256, enc_nconv=1, 
                     encoder_hid=256, z_dim=128, rnn_dec_hid_dim=512,
                    dec_nconv=3, nchar=31, smiles_len=max_len)

# Compare your sampling with N(0, 1)

import matplotlib.pyplot as plt
from scipy.stats import norm

sample = model.reparametrize(torch.zeros(1000), torch.ones(1000))
plt.hist(sample.detach().cpu().numpy(), density=True)

# Plot between -10 and 10 with .001 steps.
x_axis = np.arange(-7, 7, 0.001)
plt.plot(x_axis, norm.pdf(x_axis,0,1)) # Mean = 0, SD = 1.
plt.show()


### 2.3 (10 points) Implement the SMILES VAE loss function

Implement your loss function here.

In [None]:
def loss_function(recon_x, x, mu, std):
    ########## Modify this code chunk ##########
    BCE = 
    KLD = 
    return BCE, KLD

### 2.4 (5 points) Train your model

Run the following cells to train your model.

In [None]:
########## Simply run this chunk ##########
def loop(model, loader, epoch, beta=0.05, evaluation=False):
    '''
        Train/test your VAE model
    '''
    
    if evaluation:
        model.eval()
        mode = "eval"
    else:
        model.train()
        mode = 'train'
    batch_losses = []
        
    tqdm_data = tqdm(loader, position=0, leave=True, desc='{} (epoch #{})'.format(mode, epoch))
    for data in tqdm_data:
        
        x = data[0].to(device)
        recon_batch, mu, std = model(x)
        loss_recon, loss_kl = loss_function(recon_batch, x, mu, std)
        loss = loss_recon + beta * loss_kl     
        
        if not evaluation:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        batch_losses.append(loss.item())

        postfix = ['recon loss={:.3f}'.format(loss_recon.item()) ,
                   'KL loss={:.3f}'.format(loss_kl.item()) ,
                   'total loss={:.3f}'.format(loss.item()) , 
                   'avg. loss={:.3f}'.format(np.array(batch_losses).mean())]
        
        tqdm_data.set_postfix_str(' '.join(postfix))
    
    return np.array(batch_losses).mean()

In [None]:
########## Simply run this chunk ##########
device = 0

model = MolVAE(rnn_enc_hid_dim=367, enc_nconv=2, 
                     encoder_hid=512, z_dim=171, rnn_dec_hid_dim=512,
                    dec_nconv=1, nchar=31, smiles_len=max_len)
                      
model = model.to(device)

#load pretrained model 
model.load_state_dict(torch.load("./vae-050-0.06.pth"))

In [None]:
########## Simply run this chunk ##########
optimizer = optim.Adam(model.parameters(),lr=5e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)

Optional: mount your Google Drive to save your model and files.

In [None]:
########## Optional: simply run this chunk ##########
# Optional: mount your google drive to save model and files
# Uncomment below lines if of interest

#from google.colab import drive
#drive.mount('/content/drive')
#mydrive = '/content/drive/MyDrive'

In [None]:
########## Simply run this chunk ##########

epochs = 50

for epoch in range(0, epochs):
    
    # sample_recon(model, epoch, val_loader, enc)
    train_loss = loop(model, train_loader, epoch, 0.001)
    val_loss = loop(model, val_loader, epoch, 0.001,  evaluation=True)
    scheduler.step(val_loss)
    
    # optional: save model 
#     if epoch % 15 == 0:
#         torch.save(model.state_dict(),
#                 './{}/vae-{:03d}-{:.2f}.pth'.format(mydrive, epoch, train_loss))
        
#         torch.save(optimizer.state_dict(),
#             './{}/optim-{:03d}-{:.2f}.pth'.format(mydrive, epoch, train_loss))

    if epoch == 0:
        best_loss = train_loss.item()
    else:
        if train_loss.item() < best_loss:
            best_loss = train_loss.item()


### 2.5 (20 points) Sample new molecules

Some helper functions for you.

In [None]:
# Helper functions 
def index2smiles(mol_index, enc):
    '''Transform your array of character indices back to SMILES'''
    smiles_charlist = enc.inverse_transform(np.array(mol_index))
    smiles = ''.join(smiles_charlist).strip(" ")
    
    return smiles

def check_smiles_valid(smiles):
    '''Check if SMILES string is valid'''
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        valid = True 
    else:
        valid = False 
    return valid

Randomly select two SMILES in your test data, interpolate 10 points between them, and decode those points. Test them for accuracy and draw the scatter plot of the lower 2 dimensions. Then visualize any molecules that worked.

In [ ]:
########## Modify this code chunk ##########

# select a starting and ending molecule

start = index2smiles(test_loader.dataset.__getitem__(random.choices(range(len(test_loader.dataset)), k=1))[0].numpy().reshape(-1), enc)
end = index2smiles(test_loader.dataset.__getitem__(random.choices(range(len(test_loader.dataset)), k=1))[0].numpy().reshape(-1), enc)

model.eval()

################ Your code #################


Produce a scatter plot with the first two dimensions of $z$ of your test molecules and newly sampled molecules in the same figure. Color differently the test points and generated points.

In [ ]:
########## Insert your code in this chunk ##########


Draw different molecules you generated.

In [ ]:
########## Insert your code in this chunk ##########


Why does the VAE sometimes fail to generate valid SMILES strings?

*INSERT YOUR ANSWER HERE*