# To do: 
- build Ball Tree for cosine similarity
- implement Bayesian optimisation 

In [None]:
# GOOGLE COLAB: Install RDKit. Takes 2-3 minutes 
# !wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
# !chmod +x Miniconda3-latest-Linux-x86_64.sh
# !time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
# !time conda install -q -y -c conda-forge python=3.7
# !time conda install -q -y -c conda-forge rdkit

In [None]:
# GOOGLE COLAB
# from google.colab import drive
# drive.mount('/content/gdrive', force_remount=True)

In [3]:
# GOOGLE COLAB
# !cp '/content/gdrive/My Drive/rxn_ebm/USPTO_50k_Schneider/clean_rxn_50k_nomap_noreagent.pickle' '/content/'

# import sys
# sys.path.append('/usr/local/lib/python3.7/site-packages/') 

In [1]:
import sys
import os

import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG=True
from rdkit.Chem import Descriptors
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import rdChemReactions
from rdkit.Chem import rdqueries # faster than iterating atoms https://sourceforge.net/p/rdkit/mailman/message/34538007/ 
from rdkit.Chem.rdchem import Atom
from rdkit import DataStructs
import numpy as np

from itertools import chain
import random

from tqdm import tqdm
import csv
import re 
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

### References for papers using fingerprints as inputs
https://pubs.acs.org/doi/pdf/10.1021/acscentsci.6b00219 and https://github.com/jnwei/neural_reaction_fingerprint
http://pubs.acs.org.remotexs.ntu.edu.sg/doi/pdf/10.1021/ci5006614 (IPython notebooks in Supplementary Info)

### MorganFP not fixed values? why results not static? magnitude can be big, even w/ int8. do we need some form of scaling?? need to look into dtype also.  
- int gives scarily big numbers
- need to use int8 with large enough fp_size (e.g. 16384, which was used by <u>Reaction Condition Recommender ACS Cent. Sci. 2018, 4, 1465−1476</u>)

In [None]:
# print(create_rxn_MorganFP(sample_rxns[0], fp_size=16384)[:100])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [None]:
# print(create_rxn_MorganFP(sample_rxns[0], fp_size=16384, dtype='int')[:100])

[196411728       370 325594256       370         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
         0         0         0         0         0         0         0
      

### utils

In [2]:
import shutil
import torch
import torch.nn as nn

def get_activation_function(activation: str) -> nn.Module:
    """
    Gets an activation function module given the name of the activation.
    Supports:
    * :code:`ReLU`
    * :code:`LeakyReLU`
    * :code:`PReLU`
    * :code:`tanh`
    * :code:`SELU`
    * :code:`ELU`
    :param activation: The name of the activation function.
    :return: The activation function module.
    """
    if activation == 'ReLU':
        return nn.ReLU()
    elif activation == 'LeakyReLU':
        return nn.LeakyReLU(0.1)
    elif activation == 'PReLU':
        return nn.PReLU()
    elif activation == 'tanh':
        return nn.Tanh()
    elif activation == 'SELU':
        return nn.SELU()
    elif activation == 'ELU':
        return nn.ELU()
    else:
        raise ValueError(f'Activation "{activation}" not supported.')
    
def initialize_weights(model: nn.Module) -> None:
    """
    Initializes the weights of a model in place.
    :param model: An PyTorch model.
    """
    for param in model.parameters():
        if param.dim() == 1:
            nn.init.constant_(param, 0)
        else:
            nn.init.xavier_normal_(param)
            
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

### model

In [3]:
import torch
import torch.nn as nn

class FF_ebm(nn.Module):
    '''
    trainargs: dictionary containing hyperparameters to be optimised, 
    hidden_sizes must be a list e.g. [1024, 512, 256]
    
    To do: bayesian optimisation
    '''
    def __init__(self, trainargs):
        super(FF_ebm, self).__init__()
        self.output_size = trainargs['output_size']
        self.num_layers = len(trainargs['hidden_sizes']) + 1

        if trainargs['model'] == 'FF_sep':
          self.input_dim = trainargs['rctfp_size'] + trainargs['prodfp_size']
        elif trainargs['model'] == 'FF_diff':
          self.input_dim = trainargs['fp_size']

        self.create_ffn(trainargs)
        initialize_weights(self)
    
    def create_ffn(self, trainargs):
        '''
        Creates feed-forward network using trainargs dict
        '''
        dropout = nn.Dropout(trainargs['dropout'])
        activation = get_activation_function(trainargs['activation'])

        if self.num_layers == 1:
            ffn = [
                dropout,
                nn.Linear(self.input_dim, self.output_size)
            ]
        else:
            ffn = [
                dropout,
                nn.Linear(self.input_dim, trainargs['hidden_sizes'][0])
            ]
            
            # intermediate hidden layers 
            for i, layer in enumerate(range(self.num_layers - 2)):
                ffn.extend([
                    activation,
                    dropout,
                    nn.Linear(trainargs['hidden_sizes'][i], trainargs['hidden_sizes'][i+1]),
                ])
                
            # last hidden layer
            ffn.extend([
                activation,
                dropout,
                nn.Linear(trainargs['hidden_sizes'][-1], self.output_size),
            ])

        self.ffn = nn.Sequential(*ffn)
        
    def forward(self, batch):
        '''
        Runs FF_ebm on input
        
        batch: a N x K x 1 tensor of N training samples, where each sample contains 
        a positive rxn on the first column, and K-1 negative rxn on subsequent columns 
        supplied by DataLoader on custom ReactionDataset 
        '''
        energy_scores = self.ffn(batch) # tensor of size N x K x 1
        return energy_scores 

### train

In [4]:
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import time

def train_one(model, batch, optimizer, val=False):
    '''
    Trains model for 1 epoch 
    
    TO DO: learning rate scheduler + logger 
    '''
    model.zero_grad()
    scores = model.forward(batch).squeeze(dim=-1) # scores: size N x K x 1 --> N x K after squeezing
    
    softmax = nn.Softmax(dim=1)
    probs = softmax(scores) # size N x K
    
    # positives are the 0-th index of each sample, add a small epsilon 1e-9 to stabilise log 
    loss = -torch.log(probs[:, 0]+1e-9).mean() # probs[:, 0] is size N x 1 --> sum/mean to 1 value
    
    if not val:
        optimizer.zero_grad()
        loss.backward()
    #     if args.grad_clip: # gradient clipping if needed 
    #         nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

    return loss.data.cpu()
    
def train(model, trainargs):
    '''
    Trains model for num_epochs provided in trainargs
    Currently supports feed-forward networks: 
        FF_diff: takes as input a difference FP of fp_size & fp_radius
        FF_sep: takes as input a concatenation of [reactants FP, product FP] 
    
    trainargs: dict of params 
    '''
    start = time.time()
    stats = {'trainargs': trainargs} # to store training statistics 
    torch.manual_seed(trainargs['model_seed'])
    random.seed(trainargs['random_seed'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    optimizer = trainargs['optimizer'](model.parameters(), lr=trainargs['learning_rate'])
    
    train_dataset = ReactionDataset(trainargs['path_to_pickle'], 'train', trainargs)
    train_loader = DataLoader(train_dataset, trainargs['batch_size'], shuffle=True)
    mean_train_loss = []
    
    val_dataset = ReactionDataset(trainargs['path_to_pickle'], 'valid', trainargs)
    val_loader = DataLoader(val_dataset, 2 * trainargs['batch_size'], shuffle=False)
    min_val_loss = 1e9
    mean_val_loss = []
    
    for epoch in np.arange(trainargs['epochs']):
        model.train() # set model to training mode
        train_loss = []
        for batch in tqdm(train_loader): 
            batch = batch.to(device)
            train_loss.append(train_one(model, batch, optimizer, val=False))
            mean_train_loss.append(np.mean(train_loss)) 
            # print('train_loss: {}'.format(train_loss))
        
        model.eval() # validation mode
        val_loss = []
        with torch.no_grad():
            for batch in tqdm(val_loader):
                batch = batch.to(device)
                val_loss.append(train_one(model, batch, optimizer, val=True))
        
            if trainargs['early_stop'] and min_val_loss - np.mean(val_loss) < trainargs['min_delta']:
                if trainargs['patience'] <= wait:
                    print('Early stopped at the end of epoch: ', epoch)
                    stats['early_stop_epoch'] = epoch 
                    break 
                else:
                    wait += 1
                    print('Decrease in val loss < min_delta, patience count: ', wait)
            else:
                wait = 0
                min_val_loss = min(min_val_loss, np.mean(val_loss))
            mean_val_loss.append(np.mean(val_loss))
        
        if trainargs['checkpoint']: # adapted from moco: main_moco.py
            save_checkpoint({
                    'epoch': epoch + 1,
                    'model': trainargs['model'],
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'stats' : stats,
                }, is_best=False, filename=trainargs['checkpoint_path']+'{}_{}_checkpoint_{:04d}.pth.tar'.format(trainargs['model'], trainargs['expt_name'], epoch))
        
        print('Epoch: {}, train_loss: {}, val_loss: {}'.format(epoch, 
                                         np.around(np.mean(train_loss), decimals=4), 
                                         np.around(np.mean(val_loss), decimals=4)))
      
    stats['mean_train_loss'] = mean_train_loss
    stats['mean_val_loss'] = mean_val_loss
    stats['min_val_loss'] = min_val_loss
    stats['train_time'] = time.time() - start 
    # save training stats
    torch.save(stats, trainargs['checkpoint_path']+'{}_{}_stats.pkl'.format(trainargs['model'], trainargs['expt_name']))
    return stats 
              
def test(model, stats, trainargs):
    '''
    Evaluates the model on the test set 
    '''
    test_dataset = ReactionDataset(trainargs['path_to_pickle'], 'test', trainargs)
    test_loader = DataLoader(test_dataset, 2 * trainargs['batch_size'], shuffle=False)

    test_loss = []
    with torch.no_grad():
        for batch in tqdm(test_loader):
            batch = batch.to(device)
            test_loss.append(train_one(model, batch, optimizer, val=True))
              
    stats['test_loss'] = test_loss 
    stats['mean_test_loss'] = np.mean(test_loss)
    print('train_time: {}'.format(stats['train_time']))
    print('test_loss: {:.4f}'.format(stats['test_loss']))
    # overrides training stats w/ training + test stats
    torch.save(stats, trainargs['checkpoint_path']+'{}_{}_stats.pkl'.format(trainargs['model'], trainargs['expt_name'])) 
    return stats 

### data

In [5]:
# https://github.com/pytorch/tutorials/blob/master/beginner_source/data_loading_tutorial.py
import torch
from torch.utils.data import Dataset
import random
import pickle

import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdChemReactions
from rdkit import DataStructs
import numpy as np

def create_rxn_MorganFP(rxn_smi, fp_type='diff', radius=2, rctfp_size=16384, prodfp_size=16384, useChirality=True, dtype='int8'):
    '''
    fp_type: 'diff' or 'sep', 
    'diff' (difference):
    Creates reaction MorganFP following Schneider et al in J. Chem. Inf. Model. 2015, 55, 1, 39–53
    reactionFP = productFP - sum(reactantFPs)
    
    'sep' (separate):
    Creates separate reactantsFP and productFP following Gao et al in ACS Cent. Sci. 2018, 4, 11, 1465–1476
    '''
    # initialise empty fp numpy arrays
    if fp_type == 'diff':
        diff_fp = np.empty(fp_size, dtype = dtype)
    elif fp_type == 'sep':
        rcts_fp = np.empty(rctfp_size, dtype = dtype)
        prod_fp = np.empty(prodfp_size, dtype = dtype)
    else:
        print('ERROR: fp_type not recognised!')
        return
    
    # create product FP
    prod_mol = Chem.MolFromSmiles(rxn_smi.split('>')[-1])
    try:
        prod_fp_bit = AllChem.GetMorganFingerprintAsBitVect(
                        mol=prod_mol, radius=radius, nBits=prodfp_size, useChirality=useChirality)

        fp = np.empty(prodfp_size, dtype = dtype)   # temporarily store numpy array as fp 
        DataStructs.ConvertToNumpyArray(prod_fp_bit, fp)
        if fp_type == 'diff':
            diff_fp += fp
        elif fp_type == 'sep':
            prod_fp = fp
    except Exception as e:
        print("Cannot build product fp due to {}".format(e))
        return
                                  
    # create reactant FPs, subtracting each from product FP
    rcts_smi = rxn_smi.split('>')[0].split('.')
    for rct_smi in rcts_smi:
        rct_mol = Chem.MolFromSmiles(rct_smi)
        try:
            rct_fp_bit = AllChem.GetMorganFingerprintAsBitVect(
                            mol=rct_mol, radius=radius, nBits=rctfp_size, useChirality=useChirality)
            fp = np.empty(rctfp_size, dtype = dtype)
            DataStructs.ConvertToNumpyArray(rct_fp_bit, fp)
            if fp_type == 'diff':
                diff_fp -= fp
            elif fp_type == 'sep':
                rcts_fp += fp
        except Exception as e:
            print("Cannot build reactant fp due to {}".format(e))
            return
    
    if fp_type == 'diff':
        return diff_fp
    elif fp_type == 'sep':
        return np.concatenate([rcts_fp, prod_fp])

    
class ReactionDataset(Dataset):
    '''
    The Dataset class ReactionDataset prepares training samples of length K: 
    [pos_rxn, neg_rxn_1, ..., neg_rxn_K-1], ... where K-1 = num_neg 

    TO DO: can this be further optimised? Augmentation is the clear bottleneck during training
    '''
    def __init__(self, path_to_pickle, key, trainargs):
        '''
        pickle is dict w/ keys 'train', 'valid', 'test' each storing a list of rxn_smiles (str)
        IMPORTANT: molAtomMapNumbers have been cleared during data pre-processing 
        ''' 
        # feels like loading the entire pickle is not feasible when the dataset gets larger 
        # is there a more memory-efficient way to do this? 
        with open(path_to_pickle, 'rb') as handle: 
            self.rxn_smiles = pickle.load(handle)[key] 
        self.fp_radius = trainargs['fp_radius']
        self.rctfp_size = trainargs['rctfp_size']
        self.prodfp_size = trainargs['prodfp_size']
        self.fp_type = trainargs['fp_type']
        self.num_neg = trainargs['num_neg']
    
    def random_sample_negative(self, pos_rxn_smi, pos_rxn_idx):
        '''
        Generates 1 negative reaction given a positive reaction SMILES
        Returns neg_rxn_smi (str)
        '''
        rcts_smi = pos_rxn_smi.split('>')[0].split('.')
        prod_smi = pos_rxn_smi.split('>')[-1]       
            
        rct_or_prod = random.choice([0, 1])
        if rct_or_prod == 0: # randomly change one of the reactant(s)
            orig_idx = random.choice(np.arange(len(rcts_smi))) # randomly choose 1 reactant to be replaced
            
            found = False
            while not found: # searches randomly to find a different rct molecule to swap with 
                rdm_rxn_idx = random.choice(np.arange(len(self.rxn_smiles))) # randomly choose 1 rxn
                if rdm_rxn_idx == pos_rxn_idx: continue # don't choose the original rxn
                        
                new_rxn_smi = self.rxn_smiles[rdm_rxn_idx]
                new_rcts_smi = new_rxn_smi.split('>')[0].split('.')

                rdm_rcts_idx = random.choice(np.arange(len(new_rcts_smi)))
                if new_rcts_smi[rdm_rcts_idx] != rcts_smi[orig_idx]:
                    found = True
                    rcts_smi[orig_idx] = new_rcts_smi[rdm_rcts_idx]
            
        else: # randomly change the product            
            found = False
            while not found:  # searches randomly to find a different prod molecule to swap with 
                rdm_rxn_idx = random.choice(np.arange(len(self.rxn_smiles)))
                if rdm_rxn_idx == pos_rxn_idx: continue # don't choose the original rxn
                        
                new_rxn_smi = self.rxn_smiles[rdm_rxn_idx]      
                new_prod_smi = new_rxn_smi.split('>')[-1]
                if new_prod_smi != prod_smi:
                    found = True
                    prod_smi = new_prod_smi
        
        return '{}>>{}'.format('.'.join(rcts_smi), prod_smi)
    
    def __getitem__(self, idx):
        ''' 
        Returns 1 training sample in the form [pos_rxn, neg_rxn_1, ..., neg_rxn_K-1]
        num_neg: a hyperparameter to be tuned
        
        MAY DO: use while loops to retrieve MorganFP in case some of them fail (but this dataset is cleaned already)
        '''
        if torch.is_tensor(idx): # may not be needed, taken from data loading tutorial
            idx = idx.tolist() 
        
        pos_rxn_smi = self.rxn_smiles[idx]
        pos_rxn_fp = create_rxn_MorganFP(pos_rxn_smi, radius=self.fp_radius, 
                                         rctfp_size=self.rctfp_size, prodfp_size=self.prodfp_size, fp_type=self.fp_type)
        
        assert self.num_neg > 0, 'num_neg cannot be negative!'
        neg_rxn_smis = [self.random_sample_negative(pos_rxn_smi, idx) for i in range(self.num_neg)]
        neg_rxn_fps = [create_rxn_MorganFP(neg_rxn_smi, radius=self.fp_radius,  
                                           rctfp_size=self.rctfp_size, prodfp_size=self.prodfp_size, fp_type=self.fp_type)
                      for neg_rxn_smi in neg_rxn_smis]
        
        return torch.Tensor([pos_rxn_fp, *neg_rxn_fps])

    def __len__(self):
        return len(self.rxn_smiles)

### Preliminary checks

In [11]:
trainargs = {
    'model': 'FF_sep',
    'hidden_sizes': [512],  
    'output_size': 1,
    'dropout': 0.5, # adapted from Reaction Condition Recommender   
    
    'batch_size': 256,
    'activation': 'ELU', # trying ELU for its differentiability everywhere (vs ReLU which is not differentiable at x=0)
    'optimizer': torch.optim.Adam,
    'learning_rate': 1e-6, # to try: integrate w/ fast.ai lr_finder & lr_schedulers 
    'epochs': 50,
    'early_stop': True,
    'min_delta': 1e-5, 
    'patience': 5,

    'checkpoint': True,
    'model_seed': 1337,
    'random_seed': 0, # affects neg rxn sampling since it is random
    
    'rctfp_size': 16384,
    'prodfp_size': 16384,
    'fp_radius': 3,
    'fp_type': 'sep',
    
    'num_neg': 9, # to be tuned, 9 seems to be superior to 5 (overfitting occured quickly)
    
    'path_to_pickle': os.getcwd()+'/clean_rxn_50k_nomap_noreagent.pickle', 
    'checkpoint_path': os.getcwd()+'/checkpoints/',
    'expt_name': '1layer_rad3_ELU'
}

train_dataset = ReactionDataset(os.getcwd()+'/clean_rxn_50k_nomap_noreagent.pickle',
                               'train',
                               trainargs)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = FF_ebm(trainargs)
model.to(device)
train_loader = DataLoader(train_dataset, trainargs['batch_size'])

In [21]:
for i_batch, sample_batched in enumerate(train_loader):
    sample_batched = sample_batched.to(device)
    i_scores = model.forward(sample_batched)
    print(i_batch, sample_batched.shape, i_scores.shape)

    if i_batch == 2:
        break

0 torch.Size([256, 6, 32768]) torch.Size([256, 6, 1])
1 torch.Size([256, 6, 32768]) torch.Size([256, 6, 1])
2 torch.Size([256, 6, 32768]) torch.Size([256, 6, 1])


In [23]:
scores = i_scores.squeeze(dim=-1)
softmax = nn.Softmax(dim=1)
probs = softmax(scores) # size N x K

# positives are the 0-th index of each sample 
loss = -torch.log(probs[:, 0]).mean() # probs[:, 0] is size N x 1 --> sum to 1 value
loss, probs[:, 0].shape

(tensor(3.3954, device='cuda:0', grad_fn=<NegBackward>), torch.Size([256]))

In [24]:
loss.data.cpu()

tensor(3.3954)

### Train and Test

In [13]:
trainargs = {
    'model': 'FF_sep',
    'hidden_sizes': [512],  
    'output_size': 1,
    'dropout': 0.5, # adapted from Reaction Condition Recommender   
    
    'batch_size': 256,
    'activation': 'ELU', # trying ELU for its differentiability everywhere (vs ReLU which is not differentiable at x=0)
    'optimizer': torch.optim.Adam,
    'learning_rate': 1e-6, # to try: integrate w/ fast.ai lr_finder & lr_schedulers 
    'epochs': 50,
    'early_stop': True,
    'min_delta': 1e-5, 
    'patience': 5,

    'checkpoint': False,
    'model_seed': 1337, # affects pytorch random generator
    'random_seed': 0, # affects neg rxn sampling since it is random
    
    'rctfp_size': 16384,
    'prodfp_size': 16384,
    'fp_radius': 3,
    'fp_type': 'sep',
    
    'num_neg': 9, # to be tuned, 9 seems to be superior to 5 (overfitting occured quickly)
    
    'path_to_pickle': os.getcwd()+'/clean_rxn_50k_nomap_noreagent.pickle', 
    'checkpoint_path': os.getcwd()+'/checkpoints/',
    'expt_name': '1layer_rad3_ELU'
}

In [14]:
# init fingerprint-based feedforward EBM model 
model = FF_ebm(trainargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

FF_ebm(
  (ffn): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=32768, out_features=512, bias=True)
    (2): ELU(alpha=1.0)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=512, out_features=1, bias=True)
  )
)

In [None]:
stats = train(model, trainargs)

In [None]:
stats = test(model, stats, trainargs)