#  <center> Problem Set 5 <center>

<center> 3.C01/3.C51, 7.C01/7.C51, 10.C01/10.C51, 20.C01/20.C51 <center>

## Part 1: Predicting molecular properties with Graph Convolutional Nets 

### 1.1 (5 points) Install and try out RDkit

This is a hack to install RDKit, without needing to install conda (which might take minutes). If you have anaconda installed, you can install RDKit from anaconda.

In [None]:
url = 'https://anaconda.org/rdkit/rdkit/2018.09.1.0/download/linux-64/rdkit-2018.09.1.0-py36h71b666b_1.tar.bz2'
!curl -L $url | tar xj lib
!mv lib/python3.6/site-packages/rdkit /usr/local/lib/python3.7/dist-packages/

x86 = '/usr/lib/x86_64-linux-gnu'
!mv lib/*.so.* $x86/
!ln -s $x86/libboost_python3-py36.so.1.65.1 $x86/libboost_python3.so.1.65.1

In [None]:
import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors,Crippen
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import pandas as pd
import sys
import torch 
from tqdm import tqdm
import itertools
import matplotlib.pyplot as plt

from rdkit import RDLogger   
RDLogger.DisableLog('rdApp.*') # turn off RDKit warning message 

Optional: mount your Google Drive to save your model and files.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
mydrive = '/content/drive/MyDrive'

Example: make a Mol object.

In [None]:
dopamine_mol = Chem.MolFromSmiles("C1=CC(=C(C=C1CCN)O)O") # Dopamine 
caffeine_mol = Chem.MolFromSmiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") # Caffeine 

Arrange molecules in a grid image.

In [None]:
# Arrange molecules in a grid image
Draw.MolsToGridImage([dopamine_mol, caffeine_mol])

Use RDKit to visualize molecule line drawings

In [None]:
################ Code #################




### 1.2 (10 points) Construct Molecular Graph Datasets and DataLoaders

In [None]:
! wget https://raw.githubusercontent.com/vikram-sundar/ML4MolEng_Spring2022/master/psets/ps4/data/qm9.csv

A SMILES to graph conversion function.

In [None]:
def smiles2graph(smiles):
    '''
    Transform smiles into a list of atomic numbers and an edge array
    
    Args: 
        smiles (str): SMILES strings
    
    Returns: 
        z(np.array), A (np.array): list of atomic numbers, edge array
    '''
    
    mol = Chem.MolFromSmiles( smiles ) # no hydrogen 
    z = np.array( [atom.GetAtomicNum() for atom in mol.GetAtoms()] )
    A = np.stack(Chem.GetAdjacencyMatrix(mol)).nonzero()
    
    return z, A

Read in the DataFrame, shuffle its rows, and store its properties as lists.

In [None]:
import torch
from sklearn.utils import shuffle

df = pd.read_csv("qm9.csv")
df = shuffle(df).reset_index()

################ Code #################

AtomicNum_list = []
Edge_list = []
y_list = []
Natom_list = []



A GraphDataset class for you to store graphs in PyTorch.

In [None]:
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self,
                 AtomicNum_list, 
                 Edge_list, 
                 Natom_list, 
                 y_list):
        
        '''
        GraphDataset object
        
        Args: 
            z_list (list of torch.LongTensor)
            a_list (list of torch.LongTensor)
            N_list (list of int)
            y_list (list of torch.FloatTensor)

        '''
        self.AtomicNum_list = AtomicNum_list # atomic number
        self.Edge_list = Edge_list # edge list 
        self.Natom_list = Natom_list # Number of atoms 
        self.y_list = y_list # properties to predict 

    def __len__(self):
        return len(self.Natom_list)

    def __getitem__(self, idx):
        
        AtomicNum = torch.LongTensor(self.AtomicNum_list[idx])
        Edge = torch.LongTensor(self.Edge_list[idx])
        Natom = self.Natom_list[idx]
        y = torch.Tensor(self.y_list[idx])
        
        return AtomicNum, Edge, Natom, y

Split your dataset into train, validation, and test and define the GraphDataset class for each.

In [None]:
################ Code #################
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

train_dataset = # fill this in
val_dataset = # fill this in
test_dataset = # fill this in

A graph collation function to batch multiple graphs into one batch.

In [None]:
def collate_graphs(batch):
    '''Batch multiple graphs into one batched graph
    
    Args:
    
        batch (tuple): tuples of AtomicNum, Edge, Natom and y obtained from GraphDataset.__getitem__() 
        
    Return 
        (tuple): Batched AtomicNum, Edge, Natom, y
    
    '''
    
    AtomicNum_batch = []
    Edge_batch = []
    Natom_batch = []
    y_batch = []

    cumulative_atoms = np.cumsum([0] + [b[2] for b in batch])[:-1]
    
    for i in range(len(batch)):
        z, a, N, y = batch[i]
        index_shift = cumulative_atoms[i]
        a = a + index_shift
        AtomicNum_batch.append(z) 
        Edge_batch.append(a)
        Natom_batch.append(N)
        y_batch.append(y)
        
    AtomicNum_batch = torch.cat(AtomicNum_batch)
    Edge_batch = torch.cat(Edge_batch, dim=1)
    Natom_batch = Natom_batch
    y_batch = torch.cat(y_batch)
    
    return AtomicNum_batch, Edge_batch, Natom_batch, y_batch 

An example use of collate_graph.

In [None]:
# Define graph 1 
AtomicNum1 = torch.LongTensor([6, 6, 7])
Edge1 = torch.LongTensor([[0, 2, 2, 1], 
                       [2, 0, 1, 2]])
Natom1 = 3
y1 =  torch.Tensor([74.18])

# Define graph 2 
AtomicNum2 = torch.LongTensor([6, 6, 8])
Edge2 = torch.LongTensor([[0, 2, 2, 1], 
                       [2, 0, 1, 2]])
Natom2 = 3
y2 = torch.Tensor([64.32])

graph1 = (AtomicNum1, Edge1, Natom1, y1)
graph2 = (AtomicNum2, Edge2, Natom2, y2)

collate_graphs((graph1, graph2))  

Defining the train and test DataLoaders with the above functions.

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset,
                          batch_size=512, 
                          collate_fn=collate_graphs,shuffle=True)

val_loader = DataLoader(val_dataset,
                          batch_size=512, 
                          collate_fn=collate_graphs,shuffle=True)

test_loader = DataLoader(test_dataset,
                          batch_size=512, 
                          collate_fn=collate_graphs,shuffle=True)

### 1.3 (20 points) Complete the definition of a GNN

The scatter_add function for use in your node updates.

In [None]:
from itertools import repeat
def scatter_add(src, index, dim_size, dim=-1, fill_value=0):
    
    '''
    Sums all values from the src tensor into out at the indices specified in the index 
    tensor along a given axis dim. 
    '''
    
    index_size = list(repeat(1, src.dim()))
    index_size[dim] = src.size(dim)
    index = index.view(index_size).expand_as(src)
    
    dim = range(src.dim())[dim]
    out_size = list(src.size())
    out_size[dim] = dim_size

    out = src.new_full(out_size, fill_value)

    return out.scatter_add_(dim, index, src)

Example usage of scatter_add().

In [None]:
# Say you have a graph with 4 nodes, and there are an edge list that describes their connectivities.

Edge = torch.LongTensor([[0, 0, 1, 3], # index for i 
                         [1, 2, 2, 0]]) # index for j 

# It means that the 0th node is connected to 1st node and the 2nd node; the 1st node is connected to the 2nd node. 
# For now, let us assume the connections are directed, i.e. 0th node is connected the 1st node, but the 1st node is not connected to the 0th node. 
# We want pass connection messages from the nodes in the first row to the nodes in the second row in Edge.

# And for each edge, we have an message we wanto broadcast from i to j.
message_i2j = torch.Tensor([1000., 100., 10., 1.])

# We can use scatter_add() function to aggregate these pairwise messages onto each node. 

node_message = scatter_add(src=message_i2j, # message array for all the directed edge 
            index=Edge[1], # index to all the jth node to which you want to pass your message 
            dim=0,         # feature dimension you want to sum over 
            dim_size=4     # there are 4 nodes 
            ) 

print(node_message)

# Now you can look at your results, you can see the messages are assigned from message_i2j to all the jth nodes you specified

# see the graphical representation here: "https://github.com/vikram-sundar/ML4MolEng_Spring2022/blob/master/psets/ps4/scatter_add_demo.png"

In [None]:
# If you want your graph to be undirected, i.e. the ith node is connected to the jth node and vice versa, you can perfrom the summation in both direction like this: 
node_message = scatter_add(src=message_i2j, index=Edge[1], dim=0, dim_size=4) +  scatter_add(src=message_i2j, index=Edge[0], dim=0, dim_size=4)

print(node_message)

Example usage of torch.split().

In [None]:
tensor = torch.ones((5, 2))
splits_idx = [2, 3] # list of integers 
print( torch.split(tensor, splits_idx) ) 

# you have two tensors with size (2,2) and (3,2) respectively 
for split in torch.split(tensor, splits_idx):
    print(split.shape)
    
# And you can sum the spllited array separately and stack them together 
print( torch.stack([split.sum(0) for split in torch.split(tensor, splits_idx)], dim=0) )

Your GNN class.


In [None]:
from torch import nn
from torch.nn import ModuleDict

class GNN(torch.nn.Module):
    '''
        A GNN model 
    '''
    def __init__(self, n_convs=3, n_embed=64):
        super(GNN, self).__init__()
        
        self.atom_embed = nn.Embedding(100, n_embed)
        # Declare MLPs in a ModuleList
        self.convolutions = nn.ModuleList(
            [ 
                ModuleDict({
                    'update_mlp': nn.Sequential(nn.Linear(n_embed, n_embed), 
                                                nn.ReLU(), 
                                                nn.Linear(n_embed, n_embed)),
                    'message_mlp': nn.Sequential(nn.Linear(n_embed, n_embed), 
                                                 nn.ReLU(), 
                                                 nn.Linear(n_embed, n_embed)) 
                })
                for _ in range(n_convs)
            ]
            )
        # Declare readout layers
        self.readout = nn.Sequential(nn.Linear(n_embed, n_embed), nn.ReLU(), nn.Linear(n_embed, 1))
        
    def forward(self, AtomicNum, Edge, Natom):
        ################ Code #################
        
        # Parametrize embedding 
        h = self.atom_embed(AtomicNum) #eqn. 1
        
        
    
        
        ################ Code #################
        return output

### 1.4 (5 points) Verify that your GNN preserves permutational invariance

Run this cell as is to show that your GNN respects permutational invariance.

In [None]:
def permute_graph(z, a, perm):
    '''
        permute the order of nodes in a molecular graph 
        
        Args: 
            z(np.array): atomic number array
            a(np.array): edge index pairs 
            
        Return: 
            (np.array, np.array): permuted atomic number, and edge list 
    '''
    
    z = np.array(z)
    perm = np.array(perm)
    assert len(perm) == len(z)
    
    z_perm = z[perm]
    a_perm = np.zeros(a.shape).astype(int)
    
    for i, edge in enumerate(a):
        for j in range(len(edge)):
            a_perm[i, j] = np.where(perm==edge[j])[0]
    return z_perm, a_perm

# node input
z_orig = np.array([6, 6, 8, 7])
# edge input 
a_orig = np.array([[0, 0, 1, 2, 3, 0], [1, 2, 0, 0, 0, 3]] )

permutation = itertools.permutations([0, 1 ,2, 3])
device = 'cuda:0'
model = GNN(n_convs=4, n_embed=128).to(device)
model.eval()

for perm in permutation:
    z_perm, a_perm = permute_graph(z_orig, a_orig, perm)
    
    z = torch.LongTensor(z_perm).to(device)
    a = torch.LongTensor(a_perm).to(device)
    N = [z.shape[0]]

    output = model(z, a, N).item()
    
    print("model output: {:.5f} for perumutation: {}".format(output, perm)) 

### 1.5  (10 points) Train and test your GNN

The optimizer and scheduler setup.

In [None]:
from torch import optim

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=50, verbose=True)

A combined train/validation loop, with progress bar.

In [None]:
def loop(model, loader, epoch, evaluation=False):
    
    if evaluation:
        model.eval()
        mode = "eval"
    else:
        model.train()
        mode = 'train'
    batch_losses = []
    
    # Define tqdm progress bar 
    tqdm_data = tqdm(loader, position=0, leave=True, desc='{} (epoch #{})'.format(mode, epoch))
    
    for data in tqdm_data:
        
        AtomicNumber, Edge, Natom, y = data 
        AtomicNumber = AtomicNumber.to(device)
        Edge = Edge.to(device)
        y = y.to(device)
        pred = model(AtomicNumber, Edge, Natom)
        loss = (pred-y).pow(2).mean() # MSE loss
        
        if not evaluation:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        batch_losses.append(loss.item())

        postfix = ['batch loss={:.3f}'.format(loss.item()) , 
                   'avg. loss={:.3f}'.format(np.array(batch_losses).mean())]
        
        tqdm_data.set_postfix_str(' '.join(postfix))
    
    return np.array(batch_losses).mean()

Run this cell to train your model for 500 epochs.

In [None]:
for epoch in range(500):    
    train_loss = loop(model, train_loader, epoch)
    val_loss = loop(model, val_loader, epoch, evaluation=True)
    
    # save model 
    if epoch % 20 == 0:
        torch.save(model.state_dict(), "{}/gcn_model_{}.pt".format(mydrive, epoch))

Show us a scatter plot of the training and test data, with MSEs labeled.

In [None]:
################ Code #################

## Part 2: Variational auto-encoders for SMILES strings

In [None]:
# Get data 
! wget https://raw.githubusercontent.com/YitongTseo/ML4MolEng_Spring2023/master/psets/ps5/data/zinc_50k.csv
    
# Get pretrained model
! wget -O vae_checkpoint.pth https://raw.githubusercontent.com/YitongTseo/ML4MolEng_Spring2023/master/psets/ps5/pretrained_checkpoints/vae-050-0.06.pth

### 2.1 (5 points) One-hot encode SMILES strings into padded numerical vectors

In [None]:
from sklearn import preprocessing

# Character list for SMILES string
moses_charset = ['2', 'o', 'C', 'I', 'O', 'H', 'n', 'N', '=', '+', '#', '-', 'c',
                 'B', 'l', '7', 'r', 'S', 's', '4', '6', '[', '5', ']', 'F', '3', 
                 'P', '(', ')', '1', ' ']

# Define encoder 
enc = preprocessing.LabelEncoder().fit(moses_charset)

# Read data 
df = pd.read_csv("./zinc_50k.csv")


Encode SMILES strings into padded categorical vectors.

In [None]:
################ Code #################




Make train/validation/test Datasets and DataLoaders.

In [None]:
################ Code #################





### 2.2 (14 points) Implement the Reparametrization Trick for VAE

In [None]:
# Molecular VAE model 

class MolVAE(nn.Module):
    def __init__(self,  rnn_enc_hid_dim, enc_nconv,
                         encoder_hid, z_dim, 
                         rnn_dec_hid_dim, dec_nconv, smiles_len, nchar
                         ):
        '''
            SMILES VAE model 
            
                rnn_enc_hid_dim: hidden dimension for the GRU encoder 
                enc_nconv: number of recurrent layers for the GRU decoder
                encoder_hid: dimension of GUR encoder readout
                z_dim: number of latent variable 
                rnn_dec_hid_dim: hidden dimension for the GRU decoder 
                dec_nconv: number of recurrent layers for the GRU decoder
                smiles_len: total length of padded SMILES string 
                nchar: number of possible characters 
                
        '''
        
        super(MolVAE, self).__init__()
        
        self.smiles_len = smiles_len
        self.nchar = nchar
        # Embedding layer
        self.embed = nn.Embedding(self.nchar, rnn_enc_hid_dim)
        # Encoding GRU
        self.rnn_enc = nn.GRU(rnn_enc_hid_dim, rnn_enc_hid_dim, enc_nconv, batch_first=True)
        # MLP to transfrom hidden output from Encoding GRU
        self.mlp0 = nn.Linear(rnn_enc_hid_dim, encoder_hid)
        # Network to parametrize mu
        self.mu_network = nn.Linear(encoder_hid, z_dim)
        # Network to parametrize log variance
        self.logvar_network = nn.Linear(encoder_hid, z_dim)
        # Decoding GRU
        self.rnn_dec = nn.GRU(z_dim, rnn_dec_hid_dim, dec_nconv, batch_first=True)
        # Output SMILES characters
        self.readout = nn.Linear(rnn_dec_hid_dim, self.nchar)

    def encode(self, x):
        '''Output mean and log variance of the encoded SMILES'''
        output, hn = self.rnn_enc(x)
        h = torch.nn.functional.relu(self.mlp0(hn[-1]))
        return self.mu_network(h), self.logvar_network(h)
    
    def get_std(self, logvar):
        '''Transform log variance to standard deviation'''
        ################ Code #################
        
        return logvar
        ################ Code #################

    def reparametrize(self, mu, std):
        '''The reparametrization trick'''
        if self.training:
            ################ Code #################
            
            
            return z
           ################ Code ################# 
        else:
            return mu

    def decode(self, z):
        '''Decoder to reconstruct latent variable back to SMILES'''
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, self.smiles_len, 1)
        out, h = self.rnn_dec(z)
        out_reshape = out.contiguous().view(-1, out.size(-1))
        
        y0 = self.readout(out_reshape)
        y = y0.contiguous().view(out.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        x_embed = self.embed(x) # Get SMILES embedding 
        mu, logvar = self.encode(x_embed) # Encoding SMILES to latent representations 
        std = self.get_std(logvar) # Transform log variance to std.
        z = self.reparametrize(mu, std) # Reparametrization trick 
        smiles_recon = self.decode(z)  # Reconstruct SMILES string 
        return smiles_recon, mu, std

Test your model by comparing your sampling with N(0, 1).

In [None]:
# Define your model 
model = MolVAE(rnn_enc_hid_dim=256, enc_nconv=1, 
                     encoder_hid=256, z_dim=128, rnn_dec_hid_dim=512,
                    dec_nconv=3, nchar=31, smiles_len=max_len)

# Compare your sampling with N(0, 1)

import matplotlib.pyplot as plt
from scipy.stats import norm

sample = model.reparametrize(torch.zeros(1000), torch.ones(1000))
plt.hist(sample.detach().cpu().numpy(), density=True)

# Plot between -10 and 10 with .001 steps.
x_axis = np.arange(-7, 7, 0.001)
plt.plot(x_axis, norm.pdf(x_axis,0,1)) # Mean = 0, SD = 1.
plt.show()


### 2.3 (14 points) Implement the SMILES VAE loss function

Implement your loss function here.

In [None]:
def loss_function(recon_x, x, mu, std):
    ################ Code #################
    L_recon = 0.0  # reconstruction loss 
    L_kl = 0.0 # KL loss
    return L_recon, L_kl

### 2.4 (2 points) Train your model

Run the following cells to train your model.

In [None]:


def loop(model, loader, epoch, beta=0.05, evaluation=False):
    '''
        Train/test your VAE model
    '''
    
    if evaluation:
        model.eval()
        mode = "eval"
    else:
        model.train()
        mode = 'train'
    batch_losses = []
        
    tqdm_data = tqdm(loader, position=0, leave=True, desc='{} (epoch #{})'.format(mode, epoch))
    for data in tqdm_data:
        
        x = data[0].to(device)
        recon_batch, mu, std = model(x)
        loss_recon, loss_kl = loss_function(recon_batch, x, mu, std)
        loss = loss_recon + beta * loss_kl     
        
        if not evaluation:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        batch_losses.append(loss.item())

        postfix = ['recon loss={:.3f}'.format(loss_recon.item()) ,
                   'KL loss={:.3f}'.format(loss_kl.item()) ,
                   'total loss={:.3f}'.format(loss.item()) , 
                   'avg. loss={:.3f}'.format(np.array(batch_losses).mean())]
        
        tqdm_data.set_postfix_str(' '.join(postfix))
    
    return np.array(batch_losses).mean()

In [None]:
device = 0

model = MolVAE(rnn_enc_hid_dim=367, enc_nconv=2, 
                     encoder_hid=512, z_dim=171, rnn_dec_hid_dim=512,
                    dec_nconv=1, nchar=31, smiles_len=max_len)
                      
model = model.to(device)

# load pretrained model 
model.load_state_dict(torch.load("./vae_checkpoint.pth"))

In [None]:
optimizer = optim.Adam(model.parameters(),lr=5e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)

In [None]:
epochs = 50

for epoch in range(0, epochs):
    
    train_loss = loop(model, train_loader, epoch, 0.001)
    val_loss = loop(model, val_loader, epoch, 0.001,  evaluation=True)
    scheduler.step(val_loss)
    
    # optional: save model 
#     if epoch % 15 == 0:
#         torch.save(model.state_dict(),
#                 './{}/vae-{:03d}-{:.2f}.pth'.format(mydrive, epoch, train_loss))
        
#         torch.save(optimizer.state_dict(),
#             './{}/optim-{:03d}-{:.2f}.pth'.format(mydrive, epoch, train_loss))

    if epoch == 0:
        best_loss = train_loss.item()
    else:
        if train_loss.item() < best_loss:
            best_loss = train_loss.item()


### 2.5 (10 points) Sample new molecules

Some helper functions for you.

In [None]:
def index2smiles(mol_index, enc):
    '''Transform your array of character indices back to SMILES'''
    smiles_charlist = enc.inverse_transform(np.array(mol_index))
    smiles = ''.join(smiles_charlist).strip(" ")
    
    return smiles

def check_smiles_valid(smiles):
    '''Check if SMILES string is valid'''
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        valid = True 
    else:
        valid = False 
    return valid

Randomly select two SMILES in your test data, interpolate 10 points between them, and decode those points. Test them for accuracy and draw the scatter plot of the lower 2 dimensions. Then visualize any molecules that worked.

In [None]:
################ Code #################
# select a starting and ending molecule

start = index2smiles(test_loader.dataset.__getitem__(random.choices(range(len(test_loader.dataset)), k=1))[0].numpy().reshape(-1), enc)
end = index2smiles(test_loader.dataset.__getitem__(random.choices(range(len(test_loader.dataset)), k=1))[0].numpy().reshape(-1), enc)

model.eval()



Why does the VAE sometimes fail to generate valid SMILES strings?