#  <center> Problem Set 5 <center>
<center> 3.C01/3.C51, 10.C01/10.C51 <center>

<b>Name:</b>

<b>Kerberos id:</b>

### Download required data & install packages

In [None]:
!wget https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps4/data/nonbio_version/zinc_50k.csv
!wget https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps4/data/nonbio_version/vae-050-0.06.pth
!wget https://raw.githubusercontent.com/coleygroup/ML4MolEng/main/psets/ps2/data/dna_binding.csv
!pip install rdkit

In [None]:
import os
import glob
import math
import random as r
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import Draw
from scipy.stats import norm
from sklearn import preprocessing
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

## Part 1: Variational auto-encoders for SMILES strings

In [None]:
################ Run #################

# character list
moses_charset = ['2', 'o', 'C', 'I', 'O', 'H', 'n', 'N', '=', '+', '#', '-', 'c',
                 'B', 'l', '7', 'r', 'S', 's', '4', '6', '[', '5', ']', 'F', '3',
                 'P', '(', ')', '1', ' ']

# define encoder
enc = preprocessing.LabelEncoder().fit(moses_charset)

# read data
df = pd.read_csv("./zinc_50k.csv")

################ Run #################

### 1.1 (5 points) One-hot encode SMILES strings into padded numerical vectors

Encode SMILES strings into padded categorical vectors.

In [None]:
################ Solution #################

# find out the longest SMILES string, pad, and encode


################ Solution #################

Make train/validation/test Datasets and DataLoaders.

In [None]:
################ Solution #################

X_train, X_test = None
X_train, X_val = None

train_data = None
train_loader = None

val_data = None
val_loader = None

test_data = None
test_loader = None

################ Solution #################

### 1.2 (15 points) Implement the reparametrization trick for VAE

In [None]:
class MolVAE(nn.Module):
    def __init__(self, rnn_enc_hid_dim, enc_nconv, encoder_hid, z_dim,
                 rnn_dec_hid_dim, dec_nconv, smiles_len, nchar):
        super(MolVAE, self).__init__()
        """
            SMILES VAE model

                rnn_enc_hid_dim: hidden dimension for the GRU encoder
                enc_nconv: number of recurrent layers for the GRU decoder
                encoder_hid: dimension of GUR encoder readout
                z_dim: number of latent variable
                rnn_dec_hid_dim: hidden dimension for the GRU decoder
                dec_nconv: number of recurrent layers for the GRU decoder
                smiles_len: total length of padded SMILES string
                nchar: number of possible characters
        """
        self.smiles_len = smiles_len
        self.nchar = nchar

        self.embed = nn.Embedding(self.nchar, rnn_enc_hid_dim)  # embedding layer
        self.rnn_enc = nn.GRU(rnn_enc_hid_dim, rnn_enc_hid_dim,
                              enc_nconv, batch_first=True)  # encoding GRU
        self.mlp0 = nn.Linear(rnn_enc_hid_dim, encoder_hid)  # transfrom hidden from encoding GRU
        self.mu_network = nn.Linear(encoder_hid, z_dim)  # to parametrize mu
        self.logvar_network = nn.Linear(encoder_hid, z_dim)  # to parametrize log variance
        self.rnn_dec = nn.GRU(z_dim, rnn_dec_hid_dim, dec_nconv,
                              batch_first=True)  # decoding GRU
        self.readout = nn.Linear(rnn_dec_hid_dim, self.nchar)  # output characters

    def encode(self, x):
        """ Output mean and log variance of the encoded SMILES
        """
        output, hn = self.rnn_enc(x)
        h = torch.nn.functional.relu(self.mlp0(hn[-1]))

        return self.mu_network(h), self.logvar_network(h)

    def get_std(self, logvar):
        """ Transform log variance to standard deviation
        """
        ################ Solution #################

        std = None

        ################ Solution #################
        return std

    def reparametrize(self, mu, std):
        """ The reparametrization trick
        """
        if self.training:
            ################ Solution #################

            z = None

            ################ Solution #################
            return z
        else:
            return mu

    def decode(self, z):
        """ Decoder to reconstruct latent variable back to SMILES
        """
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, self.smiles_len, 1)
        out, h = self.rnn_dec(z)
        out_reshape = out.contiguous().view(-1, out.size(-1))

        y0 = self.readout(out_reshape)
        y = y0.contiguous().view(out.size(0), -1, y0.size(-1))

        return y

    def forward(self, x):
        x_embed = self.embed(x)  # get SMILES embedding
        mu, logvar = self.encode(x_embed)  # encoding SMILES to latent
        std = self.get_std(logvar)  # transfrom log variance to std

        z = self.reparametrize(mu, std)  # reparametrization trick
        smiles_recon = self.decode(z)  # reconstruct SMILES string

        return smiles_recon, mu, std

Test your model by comparing your sampling with N(0, 1).

In [None]:
################ Run #################

# define your model
model = MolVAE(rnn_enc_hid_dim=256, enc_nconv=1, encoder_hid=256, z_dim=128,
               rnn_dec_hid_dim=512, dec_nconv=3, nchar=31, smiles_len=max_len)

# compare your sampling with N(0, 1)
sample = model.reparametrize(torch.zeros(1000), torch.ones(1000))
plt.hist(sample.detach().cpu().numpy(), density=True)

# plot between -10 and 10 with .001 steps.
x_axis = np.arange(-7, 7, 0.001)
plt.plot(x_axis, norm.pdf(x_axis,0,1))  # mean = 0, std = 1
plt.show()

################ Run #################

### 1.3 (10 points) Implement the SMILES VAE loss function

Implement your loss function here.

In [None]:
def loss_function(recon_x, x, mu, std):
    ################ Solution #################

    BCE = None
    KLD = None

    ################ Solution #################
    return BCE, KLD

### 1.4 (5 points) Train your model

Simply run the following code chunks to train your model.

In [None]:
################ Run #################

def loop(model, loader, epoch, beta=0.05, evaluation=False):
    """ Train/test your VAE model
    """
    if evaluation:
        model.eval()
        mode = "eval"
    else:
        model.train()
        mode = "train"
    batch_losses = []

    tqdm_data = tqdm(loader, position=0, leave=True, desc=f"{mode} (epoch #{epoch})")
    for data in tqdm_data:
        x = data[0].to(device)
        recon_batch, mu, std = model(x)
        loss_recon, loss_kl = loss_function(recon_batch, x, mu, std)
        loss = loss_recon + beta * loss_kl

        if not evaluation:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        batch_losses.append(loss.item())
        postfix = [f"recon loss={loss_recon.item():.3f}",
                   f"KL loss={loss_kl.item():.3f}",
                   f"total loss={loss.item():.3f}",
                   f"avg. loss={np.array(batch_losses).mean():.3f}"]

        tqdm_data.set_postfix_str(" ".join(postfix))

    return np.array(batch_losses).mean()

################ Run #################

In [None]:
################ Run #################

device = 0
model = MolVAE(rnn_enc_hid_dim=367, enc_nconv=2, encoder_hid=512, z_dim=171,
               rnn_dec_hid_dim=512, dec_nconv=1, nchar=31, smiles_len=max_len)
model = model.to(device)

# load pretrained model
model.load_state_dict(torch.load("./vae-050-0.06.pth"))

################ Run #################

In [None]:
################ Run #################

optimizer = optim.Adam(model.parameters(), lr=5e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.5, patience=5)

################ Run #################

Mount your Google Drive to save your model and files (optional).

In [None]:
################# Run (optional) #################

from google.colab import drive
drive.mount("/content/drive")
mydrive = "/content/drive/MyDrive"

################ Run (optional) #################

In [None]:
################ Run #################

epochs = 50
for epoch in range(0, epochs):

    train_loss = loop(model, train_loader, epoch, 0.001)
    val_loss = loop(model, val_loader, epoch, 0.001,  evaluation=True)
    scheduler.step(val_loss)

    # uncomment to save model (optional)
    # if epoch % 15 == 0:
    #     torch.save(model.state_dict(), f"{mydrive}/vae-{epoch:03d}-{train_loss:.2f}.pth")
    #     torch.save(optimizer.state_dict(), f"{mydrive}/optim-{epoch:03d}-{train_loss:.2f}.pth")

    if epoch == 0:
        best_loss = train_loss.item()
    else:
        if train_loss.item() < best_loss:
            best_loss = train_loss.item()

################ Run #################

### 1.5 (20 points) Sample new molecules

Some helper functions for you.

In [None]:
################ Run #################

def index2smiles(mol_index, enc):
    """ Transform your array of character indices back to SMILES
    """
    smiles_charlist = enc.inverse_transform(np.array(mol_index))
    smiles = "".join(smiles_charlist).strip(" ")

    return smiles

def check_smiles_valid(smiles):
    """ Check if SMILES string is valid
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        valid = True
    else:
        valid = False
    return valid

################ Run #################

Randomly select two SMILES in your test data, interpolate 10 points between them, and decode those points. Test them for accuracy and draw the scatter plot of the lower 2 dimensions. Then visualize any molecules that worked.

In [None]:
################ Solution #################

# select a starting and ending molecule
start = index2smiles(test_loader.dataset.__getitem__(r.choices(range(len(test_loader.dataset)), k=1))[0].numpy().reshape(-1), enc)
end = index2smiles(test_loader.dataset.__getitem__(r.choices(range(len(test_loader.dataset)), k=1))[0].numpy().reshape(-1), enc)
model.eval()

################ Solution #################

Produce a scatter plot with the first two dimensions of $z$ of your test molecules and newly sampled molecules in the same figure. Color differently the test points and generated points.

In [None]:
################ Solution #################



################ Solution #################

Draw different molecules you generated.

In [None]:
################ Solution #################



################ Solution #################

Why does the VAE sometimes fail to generate valid SMILES strings?

In [None]:
################ Solution #################



################ Solution #################

## Part 2: Predicting DNA binding sites with transformers

Load the ChIP-seq dataset.

In [None]:
################ Run #################

df = pd.read_csv("./dna_binding.csv")

sequences = df.seq.values
y = df.bind.values

################ Run #################

Build Datasets and DataLoaders in PyTorch (from Problem Set 2)

In [None]:
################ Run #################

def SeqEnc(sequences):
    '''
    A function to one-hot encode DNA sequences

    Args:
        sequences (list): list of DNA sequences

    Returns:
        np.array: array with shape (N,C,4) where N is the number of sequences
        and C is the sequence length
    '''

    X = []
    base_dict = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

    for seq in sequences:
        onehot = []
        for base in seq:
            vec = np.zeros(4)
            vec[base_dict[base]] = 1
            onehot.append(vec)
        X.append(np.array(onehot))

    return np.array( X )

X = SeqEnc(sequences)
print("Shape of X is {}.".format(X.shape))

################ Run #################

In [None]:
################ Run #################

# generate dataset
class SequenceDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.Tensor(np.array(X))  # store X as a pytorch Tensor
        self.y = torch.Tensor(np.array(y))  # store y as a pytorch Tensor
        self.len=len(self.X)                # number of samples in the data

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return self.len

################ Run #################

In [None]:
################ Run #################

X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.1, random_state=42)

# define dataset
train_data = SequenceDataset(X_train, y_train)
val_data = SequenceDataset(X_val, y_val)
test_data = SequenceDataset(X_test, y_test)

# train/test split
batch_size = 256
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

################ Run #################

Implement functions for training and testing (from Problem Set 2)

In [None]:
################ Run #################

def train(model, dataloader, optimizer, device):

    '''
    A function to train on the entire dataset for one epoch.

    Args:
        model (torch.nn.Module): Your sequence classifier
        dataloader (torch.utils.data.Dataloader): DataLoader object for the train data
        optimizer (torch.optim.Optimizer): Optimizer object to interface gradient calculation and optimization
        device (str): Your device

    Returns:
        float: loss averaged over all the batches

    '''

    epoch_loss = []
    model.train()

    for batch in dataloader:
        seq, label  = batch
        seq = seq.to(device)
        label = label.to(device)

        proba =  model(seq)

        loss = F.binary_cross_entropy(proba.squeeze(),label)
        epoch_loss.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return np.array(epoch_loss).mean()


def validate(model, dataloader, device):

    '''
    A function to validate on the validation dataset for one epoch.

    Args:
        model (torch.nn.Module): Your sequence classifier
        dataloader (torch.utils.data.Dataloader): DataLoader object for the validation data
        device (str): Your device

    Returns:
        float: loss averaged over all the batches

    '''

    val_loss = []
    model.eval()
    with torch.no_grad():
        for batch in dataloader:

            seq, label  = batch
            seq = seq.to(device)
            label = label.to(device)

            proba = model(seq)
            loss = F.binary_cross_entropy(proba.squeeze(),label)

            val_loss.append(loss.item())

        return np.array(val_loss).mean()

def evaluate(model, dataloader, device):

    '''
    A function to return the classification probabilities and true labels (for evaluation).

    Args:
        model (torch.nn.Module): your sequence classifier
        dataloader (torch.utils.data.Dataloader): DataLoader object for the train data
        device (str): Your device

    Returns:
        (np.array, np.array): true labels, predicted probabilities
    '''

    pred_prob = []
    labels = []

    with torch.no_grad():
        model.eval()
        for batch in dataloader:
            epoch_loss = []
            seq, label = batch

            seq = seq.to(device)
            label = label.to(device)

            # Forward pass
            proba = model(seq)
            batch_pred=proba.squeeze().cpu().detach().numpy().tolist()
            batch_labels=label.cpu().numpy().squeeze().tolist()

            labels += batch_labels
            pred_prob += batch_pred

    return labels, pred_prob

################ Run #################

### 2.1 (15 points) Implement a transformer encoder

In [None]:
################ Run #################

class PositionalEncoding(nn.Module):
    """ Defines positional encoding (adapted from PyTorch's
        documentation)
    """
    def __init__(self, d_model):
        super().__init__()

        position = torch.arange(101).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
        pe = torch.zeros(1, 101, d_model)
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)

        self.dropout = nn.Dropout(p=0.2)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x += self.pe
        return self.dropout(x)

################ Run #################

In [None]:
class TransformerSeq(nn.Module):
    """ Defines DNA sequence transformer
    """
    def __init__(self, d_model, nhead, num_layers, positional=True):
        super(TransformerSeq, self).__init__()
        self.positional = positional

        # to prep transformer input
        self.in_full = nn.Linear(4, d_model)
        self.pe = PositionalEncoding(d_model)

        ################ Solution #################

        # transformer encoder
        self.layer =
        self.encoder =

        ################ Solution #################

        # to output probability
        self.out_full = nn.Linear(d_model, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # apply embedding and positional encoding
        if self.positional:
            x_in = self.pe(self.in_full(x))
        else:
            x_in = self.in_full(x)

        ################ Solution #################

        # apply transformer encoder and pool output
        x_out =
        pooled =

        # get probability
        prob =

        ################ Solution #################

        return prob

### 2.2 (15 points) Explore how positional encodings improve classification

Try training the transformer with/without positional encodings! Run both models for 100 epochs.

In [None]:
device = 0

# model with positional encodings off
model_off = TransformerSeq(16, 4, 2, positional=False).to(device)
optimizer_off = torch.optim.Adam(model_off.parameters(), lr=0.001)

################ Solution #################

# model with positional encodings on
model_on =
optimizer_on =

################ Solution #################

# use tqdm for progress bar
val_loss_off, val_loss_on = [], []
train_loss_off, train_loss_on = [], []
for epoch in tqdm(range(100), desc="Progress"):
    # compute training/validation loss for off model
    train_loss_off.append(train(model_off, train_loader, optimizer_off, device=device))
    val_loss_off.append(validate(model_off, val_loader, device=device))

    ################ Solution #################

    # compute training/validation loss for on model
    train_loss_on.append()
    val_loss_on.append()

    ################ Solution #################

Plot train and validation loss functions with and without the positional encodings. The first plot has been made for you.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].plot(val_loss_off, label="Validation Loss")
axes[0].plot(train_loss_off, label="Training Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend(loc="upper right")
axes[0].set_title("Positional encoding OFF")

################ Solution #################



################ Solution #################

fig.tight_layout()

From what you know about transformers, why are positional encodings here necessary? If the two plots were identical, what would that tell us?

In [None]:
################ Solution #################



################ Solution #################