In [1]:
import torch 
import torch.nn as nn
import numpy as np
from rdkit import Chem
from rdkit import rdBase
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random
import re
from torch.utils.data import DataLoader,Dataset
from scoring_functions import get_scoring_function
import torch.nn.functional as F
print(torch.__version__)

2.0.0


# data process

In [2]:
def replace_halogen(string):
    """Regex to replace Br and Cl with single letters"""
    br = re.compile('Br')
    cl = re.compile('Cl')
    string = br.sub('R', string)
    string = cl.sub('L', string)

    return string

class Vocabulary:
    def __init__(self,file,max_len=140):
        self.file = file
        self.max_len = max_len
        self.vocab_size,self.vocab,self.reversed_vocab = self.get_voc(file)
        self.reversed_vocab = {v: k for k, v in self.vocab.items()}
        self.max_len = max_len
    
    def get_voc(self,file):
        with open(file, 'r') as f:
            chars = f.read().split()
        chars +=  ['EOS', 'GO']
        chars.sort()
        vocab_size = len(chars)
        vocab = dict(zip(chars, range(len(chars))))
        reversed_vocab = {v: k for k, v in vocab.items()}
        return vocab_size,vocab,reversed_vocab
        
    def tokenize(self, smiles):
        """Takes a SMILES and return a list of characters/tokens"""
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = replace_halogen(smiles)
        char_list = re.split(regex, smiles)
        tokenized = []
        for char in char_list:
            if char.startswith('['):
                tokenized.append(char)
            else:
                chars = [unit for unit in char]
                [tokenized.append(unit) for unit in chars]
        tokenized.append('EOS')
        return tokenized
    
    def encode(self,char_list):
        smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
        for i, char in enumerate(char_list):
            smiles_matrix[i] = self.vocab[char]
        return smiles_matrix
    
    def decode(self,matrix):
        chars = []
        for i in matrix:
            i = i.item()
            if i == self.vocab['EOS']: break
            chars.append(self.reversed_vocab[i])
        smiles = "".join(chars)
        smiles = smiles.replace("L", "Cl").replace("R", "Br")
        return smiles
                          
    def __len__(self):
        return len(self.chars)

    def __str__(self):
        return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)
    
class MolData(Dataset):
    """Custom PyTorch Dataset that takes a file containing SMILES.

        Args:
                fname : path to a file containing \n separated SMILES.
                voc   : a Vocabulary instance

        Returns:
                A custom PyTorch dataset for training the Prior.
    """
    def __init__(self, fname, voc):
        self.voc = voc
        self.smiles = []
        with open(fname, 'r') as f:
            for line in f:
                self.smiles.append(line.split()[0])

    def __getitem__(self, i):
        mol = self.smiles[i]
        tokenized = self.voc.tokenize(mol)
        encoded = self.voc.encode(tokenized)
        return encoded

    def __len__(self):
        return len(self.smiles)

    def __str__(self):
        return "Dataset containing {} structures.".format(len(self))

    @classmethod
    def collate_fn(cls, arr):
        """Function to take a list of encoded sequences and turn them into a batch"""
        max_length = max([seq.size for seq in arr])
        collated_arr = np.zeros((len(arr), max_length))
        for i, seq in enumerate(arr):
            collated_arr[i, :seq.size] = seq
        return torch.from_numpy(collated_arr)

# define model

In [3]:
class GRULayer(nn.Module):
    def __init__(self, voc_size):
        super().__init__()
        self.embedding = nn.Embedding(voc_size,128)
        self.gru1 = nn.GRUCell(128,512)
        self.gru2 = nn.GRUCell(512,512)
        self.gru3 = nn.GRUCell(512,512)
        self.fcn = nn.Linear(512,voc_size)
        
    def forward(self, x, h):
        x = x.to(device)
        h = h.to(device)
        h_out = torch.zeros(h.shape).to(device)
        x = self.embedding(x).to(device)
        x = h_out[0] = self.gru1(x,h[0])
        x = h_out[1] = self.gru2(x,h[1])
        x = h_out[2] = self.gru3(x,h[2])
        x = self.fcn(x)
        return x,h_out

class RNN():
    def __init__(self,voc,batch_size,max_len=140):
        global device
        self.rnn = GRULayer(voc.vocab_size).to(device)
        self.voc = voc
        self.batch_size = batch_size
        self.max_len = max_len
        self.loss_fn = nn.BCELoss(reduction='none')
    
    def bce_loss(self,prob,target):
        onehot_target = torch.zeros(prob.shape).to(device)
        onehot_target.scatter_(1, target.contiguous().view(-1, 1).data, 1.0)
        onehot_target = onehot_target.to(device)
        return self.loss_fn(prob,onehot_target)
    
    def forward(self,target):
        self.rnn.train()
        outs = []
        loss = 0
        target = target.to(device)
        h = torch.zeros(3, self.batch_size, 512).float()
        start_token = torch.zeros(self.batch_size,1).long().to(device)
        start_token[:] = self.voc.vocab['GO']
        x = torch.cat((start_token, target[:, :-1]), 1)
        for step in range(target.shape[1]):
            logits,h = self.rnn(x[:, step],h)
            prob = F.softmax(logits,dim=1)
            out = torch.argmax(prob,dim=1)
            outs.append(out)
            loss += self.bce_loss(prob, target[:, step])
        outs = torch.transpose(torch.stack(outs),1,0)
        acc = torch.sum(outs==target.to(device))/outs.numel()
        return loss,acc
    
    def sample(self,batch_size):
        output = []
        loss = 0
        h = torch.zeros(3, batch_size, 512).float()
        start_token = torch.zeros(batch_size).long()
        start_token[:] = self.voc.vocab['GO']
        x = start_token
        finished = torch.zeros(batch_size).byte().to(device)
        with torch.no_grad():
            self.rnn.eval()
            for step in range(self.max_len):
                logits,h = self.rnn(x,h)
                prob = F.softmax(logits,dim=1)
                x = torch.multinomial(prob, num_samples=1).view(-1)
                output.append(x.view(-1, 1))
                loss += self.bce_loss(prob, x.to(device))
                EOS_sampled = (x == self.voc.vocab['EOS']).byte()
                finished = torch.ge(finished + EOS_sampled, 1)
                if torch.prod(finished) == 1: break
                
        return torch.cat(output, 1),loss

In [4]:
num_epoch = 5
batch_size = 256
lr = 2e-3
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [5]:
voc = Vocabulary("data/Voc")
moldata = MolData("data/ChEMBL_filtered", voc)
data = DataLoader(moldata, batch_size=batch_size, shuffle=True, drop_last=True,collate_fn=MolData.collate_fn)
model = RNN(voc,batch_size)
optimizer = torch.optim.Adam(model.rnn.parameters(), lr = lr)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lambda x: 1 - x/(len(data)*num_epoch),last_epoch=-1)

In [6]:
for epoch in range(num_epoch):
    bar = tqdm(enumerate(data), total=len(data))
    bar.set_description(f'[ Epoch: {epoch+1}]')
    for step,batch in bar:
        optimizer.zero_grad()
        seqs = batch.long()
        loss,acc = model.forward(seqs)
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        bar.set_postfix({"loss":loss.item(),'acc':acc.item()})
        
        if (step+1) % 200 ==0:
            seqs,_ = model.sample(batch_size)
            tqdm.write(f'-------{step+1}--------')
            valid = 0
            for i,seq in enumerate(seqs):
                smile = voc.decode(seq)
                smile = voc.decode(seq)
                if Chem.MolFromSmiles(smile):
                    valid += 1
                if i < 5:
                    tqdm.write(smile)
            tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs)))
            torch.save(model.rnn.state_dict(), "data/Prior.ckpt")
        torch.save(model.rnn.state_dict(), "data/Prior.ckpt")

  0%|          | 0/4607 [00:00<?, ?it/s]

-------200--------
COc1ccc([O-])cc1CNC(=O)C(Cc1ccccc1)c1ccsc1
COc1cccc(C(=O)NC(COc2cncn2)c1CCN(Cc1ccc(F)cc1)NC2=CC(=O)c1ccccc1
COc1cccc(NC(=O)COCC2=Nc3c(O)nc(-n5nc(C(=O)N3CCOCC55)c3)n2C)n1
CCOC(=O)CCOC(=O)C1(N1=CCC(=O)c2cccc(C(=O)NCc3ccccc3)c(=O)c2cc(O)ccc21
Cn1c(NCCNn2cccc2Nc2nc3ccccc3c4c(c2)NCC1CN3

 6.6% valid SMILES
-------400--------
COC(=O)C1=CC(=O)C(=Cc2ccc(O)cc2)N1Cc1cccs1
Cc1ccc(NC(=O)C=Cc2ccc3c(cc2OC)OO4)cc1
Cc1cccnc1Nc1ccc(-n2c(C#N)c3c(N)nccc32)cc1
COC(=O)c1cccc(-c2cc(CS(C)(=O)=O)ccc2N2CCC(O)CC2)c1
CCNC(=O)CSCC(NC(=O)CNC(=O)c1c2ccccc2Cl)cc1C(=O)OCC(Oc1ccccc1)C(C)C

45.7% valid SMILES
-------600--------
Cc1ccc(CNN=c2nn3n(C)c4ccccc4c3nn2)cc1
COC1CC(=O)OC(c2ccc(OC(=O)O)cc2)CC1Cc1ccncc1
CCOc1cc(CNC(=S)NC#N)cc(Br)c1O
O=C1NC(c2ccc(Cl)cc2)N1CCOc1nccnc1Sc1cccc([S+]([O-])(F)F)c1
Cc1cccc(C2(Cn3cnc(NC(C)(C)C)n3)CN(C)CC(C)NC(C)C)N2S(=O)(=O)c2cccc(N)c2)c1

48.8% valid SMILES
-------800--------
c1ccc2nnc(C#CC3=NCCc4c[nH]c(N)n4)cc3c2(O)C1CC1
CC(=O)NCc1ccnc(N=C2SC(CN(CC(=O)c3ccccc3)N3CCC(CO

  0%|          | 0/4607 [00:00<?, ?it/s]

-------200--------
CCOc1ccc(CCNCC(O)COc2ccccc2OC)cc1
Cc1ccc(-c2nnc(O)c3c2C(C)OC3(C)C)cc1
O=C1C2CC(OC3OC(CO)C(O)C(O)C3O)C3OC(C)(C)OC3C2C(C)(O)C12
NC(CCC(=O)NO)C(=O)Nc1cccc(CNS(=O)(=O)c2ccc(C(N)=NO)cc2)c1
CC(C)NC(=O)C(=O)Nc1cccnc1

82.8% valid SMILES
-------400--------
COc1cc(F)c(N(C)CCCN2CCN(S(C)(=O)=O)CC2)cc1C(=O)c1ccccc1
CCN(CC)C(=O)Cn1nc(Oc2ccccc2)c2ccccc2c1=O
O=[N+]([O-])c1ccccc1CNc1ccc(Cn2cccc2)cc1Br
Cc1cccc(-n2cnc(-c3ccccc3)c2C=CC(=O)c2ccccc2[N+](=O)[O-])c1
Cc1ccc(-c2nc3ccccc3[nH]2)cc1

84.4% valid SMILES
-------600--------
CC(N=C(S)Nc1ccc(CCN2CCOCC2)cc1)c1cc(C(F)(F)F)cc(C(F)(F)F)c1
CCOCCCNC(=O)CNC(=O)NCC1CNc2c(O)nc(C)nc2O1
Cc1ccccc1N1CCN(C(=O)CN2c3ccccc3C(=O)N2C)C(c2ccc(OCC(=O)O)cc2)C1
O=C(CC(c1ccccc1)c1ccccc1)Nc1ccc(C(F)(F)F)cc1
COCOc1ccc(-c2cc3cc(O)ccc3[nH]2)cc1

85.2% valid SMILES
-------800--------
COc1cccc(C2C3Cc4ccc(OC)cc4C32CC(C)(O)CC2(O)C#N)c1
CC(Nc1cc(NCCCO)nc(-c2cccnc2)n1)C1CCC(C)CC1
CCCc1cc2nc(N)nc(N)c2nc1C=CC=C(C)C
COc1cc(-c2nc(CSCCc3ccccc3)c(C)o2)cc(OC)c1OC
Cc1cccc(O

  0%|          | 0/4607 [00:00<?, ?it/s]

-------200--------
CCCc1cc(=O)oc2cc(OCC(=O)NC3CC3)ccc12
Cc1cc(C)n(CC2CCC2)c1C
N#Cc1cc(C(=O)c2ccnc3ccccc23)c(Cl)cc1F
CCOC(=O)C(C)(C)NC(=O)c1ccccc1NC(=O)Cc1ccc(OC)cc1
O=C(NC(=S)Nc1ccc(O)nc1)CC(c1ccccc1)c1ccccc1

87.9% valid SMILES
-------400--------
O=C(NN=Cc1ccc([N+](=O)[O-])s1)Nc1ccccc1C(F)(F)F
COc1ccc(N2CCN(C(S)=Nc3ccc(Br)cc3)CC2)c(OC)c1
Cc1cc(C)cc(N2CCN(c3ncnc4c(S(C)(=O)=O)cccc34)CC2)c1C
CCOC(=O)C1(CC)CCN(Cc2ccc(Cl)s2)CC1
O=C(Nc1nccs1)C(=O)c1ccc(Oc2cccc3cccnc23)nc1

84.8% valid SMILES
-------600--------
O=C(O)CCNCc1ccc2[nH]c(O)nc2c1
O=c1[nH]c(=O)n(CCC#Cc2cccc3ccccc23)nc1-c1ccc2oc3ccccc3c2c1
COc1ccc(Cn2ccccc2=NNC(=N)N)cc1
COc1ccc(N2C(=N)c3ccccc3C2N)cc1
C=CCON1C(=O)C(=Cc2ccc(OC(=O)c3cccc(Br)o3)cc2)C=C1O

87.1% valid SMILES
-------800--------
CC(C(=O)NCc1ccco1)=NOC(C)CBr
CCS(=O)(=O)NC1CC2CN(C(=O)c3cccc(F)c3)CC2N(C(=O)NCc2ccccc2)C1
Cn1nnc(CNC(=O)c2ccoc2Cn2cnnn2)n1
CCc1cc(-c2cc(C)c3onc(-c4cnn(CC(=O)NCC)c4)c3c2)cc(N(C)Cc2ccc3ccnn3c2)n1
CC(=O)N1CCCn2nc(CN3CCN(C(C)=O)CC3)nc21

84.4% valid SM

  0%|          | 0/4607 [00:00<?, ?it/s]

-------200--------
O=C(Nc1cccc(OCCn2cccn2)c1)C1CCc2ccccc2C1
COC(=O)c1ccc2c(c1)C(C)=CC(C)(C)N2
O=C1NC(=O)C(=Cc2ccc(OCc3ccccc3Cl)cc2)C(=O)N1
CC(C)C(=O)N1CCc2nc(COc3ccc(CC4COC(c5ccccc5)4CCO)cc3)sc2C1
CC(C)(C)OC(=O)N1CC(CNCc2ccccn2)NCC1Cc1ccccc1

87.9% valid SMILES
-------400--------
N#Cc1c(-c2ccccc2)cn(-c2ccc(N)cc2)c(=O)c1Cl
CSc1nc2ccccc2s1
c1ccc(COC2C3CCC(C3)N2Cc2ccccc2F)cc1
O=C1c2c(ccc(-c3ccccc3)n2)CC(c2ccc(O)cc2)N1c1ccc(-c2cc[n+]([O-])cc2)cc1
C#CCOC(=O)Nc1cccc(-c2ccccn2)c1

86.7% valid SMILES
-------600--------
Oc1nc2ccccc2n1-c1nc(Nc2cc3c(cn2)[nH]c2ccccc23)n[nH]1
CC(=O)Oc1cccc(NC(=O)c2ccc3c(c2)OCO3)c1
c1ccc(Cc2nc3cccnc3nc2Cl)cc1
COc1ccccc1N(CC(O)Cn1ccnc1)c1ccc(Cl)cc1
Cn1nc(C(=O)c2ccco2)c2ccccc2c1=O

88.7% valid SMILES
-------800--------
CCCN(Cc1ccccc1)Cc1csc(Br)c1
CN(C)C1CCc2ccc3[nH]cc(C(=O)c4ccc5ccccc5c4)c3c2C1
CC(C)(C)OC(=O)N1CSCC1C(=O)NC(COCc1ccccc1)C(=O)O
CCN(CC)CCNC(=O)C1CCN(c2cc3c(cc2Cl)NC(=O)C(=Cc2cccn2C)N(C(=O)Nc2cc(C)ccc2Cl)C(C)C1=O)C(C)C
CON=C(C#N)C(=O)NC1C(C)OC(COC(C)=O)OC1C

  0%|          | 0/4607 [00:00<?, ?it/s]

-------200--------
O=C(O)C1=C(C(=O)Nc2ccc(-c3ccccc3S(=O)(=O)NC(=O)c3ccc(Cl)cc3)cc2)C(c2ccc(Cl)cc2Cl)NC(=O)N1
Cc1cc2n[nH]c(=O)n2c2cnn(-c3ccccc3)c12
NC1CCN(Cc2coc(-c3nonc3S(N)(=O)=O)n2)CC1
O=C(Nc1cccc(-c2ccc3nnc(C4CCC4)n3c2)c1)c1ccncc1
CCC1CCc2cc(F)cc(C(=O)NCCC(C)C)c2C1

92.6% valid SMILES
-------400--------
CC[N+](CC)(CC)CC(CN(C)C)[NH+](C)CCCc1ccccc1
Cc1ccc(C(=O)CSc2nnc(NC(=O)Nc3ccc(C)cc3C)s2)cc1
O=C(O)C(OCCO)=C1C(=O)NN=C1SCC(=O)N1c1ccccc1[N+](=O)[O-]
CC(C)(C)c1cc(NC(=O)N2CCOCC2)ccc1-c1nnc(-c2ccccc2-c2ccc(Cl)cc2)o1
CSc1ccc(CC(=O)N2CC(=O)Nc3ccccc32)cc1

91.0% valid SMILES
-------600--------
CCOC(=O)C1CCCN(C(=S)SC)C1
C=CCN1CC2CCC1C(C(=O)OC)(OC(=O)CC(C)O)C2
Cc1nc2ccc(Cl)cc2c(-c2ccccc2)c1CN
COc1ccc(CN2CCN(CCCn3cnc4ccccc43)CC2)cc1
C=CCOc1ccc(C2C(C(=O)c3ccc4c(c3)CC(C)O4)=C(O)C(=O)N2CCN2CCOCC2)cc1Br

93.0% valid SMILES
-------800--------
Cc1[nH]c2ncn(C)c(=O)c2c1C(=O)C(=O)NNC(=O)c1ccccc1
Cc1n[nH]c(C)c1SC1=C(C(=O)O)N2C(=O)C(C(C)O)C2S1
CCC1C(=O)N(Cc2ccc(C)o2)C2CCCN(Cc3ccccc3)C12
COc1ccc(C=C(NC(=O

# Train Agent

In [22]:
n_steps = 3000
scoring_function='tanimoto'
scoring_function_kwargs=None
sigma = 50
num_processes=0

In [23]:
def unique(arr):
    # Finds unique rows in arr and return their indices
    arr = arr.cpu().numpy()
    arr_ = np.ascontiguousarray(arr).view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[1])))
    _, idxs = np.unique(arr_, return_index=True)
    if torch.cuda.is_available():
        return torch.LongTensor(np.sort(idxs)).cuda()
    return torch.LongTensor(np.sort(idxs))

def seq_to_smiles(seqs, voc):
    """Takes an output sequence from the RNN and returns the
       corresponding SMILES."""
    smiles = []
    for seq in seqs.cpu().numpy():
        smiles.append(voc.decode(seq))
    return smiles


In [24]:
Prior = RNN(voc,batch_size)
Agent = RNN(voc,batch_size)

# logger = VizardLog('data/logs')


Prior.rnn.load_state_dict(torch.load(f'data/PriorBest.ckpt'))
Agent.rnn.load_state_dict(torch.load(f'data/PriorBest.ckpt'))
if torch.cuda.is_available():
    Prior.rnn.to(device)
    Agent.rnn.to(device)
    Prior.rnn.eval()

scoring_function = get_scoring_function(scoring_function=scoring_function, num_processes=num_processes)
    
optimizer = torch.optim.Adam(Agent.rnn.parameters(), lr=0.0005)

In [27]:
bar = tqdm(range(n_steps))
bar.set_description(f'[ Training agent ]')
for step in bar:
    seqs,agent_likelihood = Agent.sample(batch_size)
    prior_likelihood,acc = Prior.forward(seqs)
    agent_likelihood,prior_likelihood = agent_likelihood.mean(axis=1),prior_likelihood.mean(axis=1)
    
    smiles = seq_to_smiles(seqs, voc)
    score = scoring_function(smiles)

    augmented_likelihood = prior_likelihood + sigma * torch.tensor(score).to(device)
    loss = torch.pow((augmented_likelihood - agent_likelihood), 2)
    loss = loss.mean()

    loss_p = (1 / agent_likelihood).mean()
    loss += 5 * 1e3 * loss_p

    # Calculate gradients and make an update to the network weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    bar.set_postfix({'loss':loss.item(),'acc':acc.item(),'score':sum(score)/len(score)})
    
    if (step +1) % 100 ==0:
        valid = 0
        for i,seq in enumerate(seqs):
            smile = voc.decode(seq)
            smile = voc.decode(seq)
            if Chem.MolFromSmiles(smile):
                valid += 1
            if i < 5:
                tqdm.write(smile)
        tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs)))
    torch.save(Agent.rnn.state_dict(), "data/Agent.ckpt")

  0%|          | 0/3000 [00:00<?, ?it/s]

COc1ccc(NC(=O)CSc2nc3c(c(=O)n(C)c(=O)n3C)n2C)cc1OC
Cc1ccc2nc3c(nc2c1)C(=O)C1=NNC=C13
C=CCn1c(SCC(=O)OC(C)C)nc2sc3c(c2c1=O)CCC3
Cc1c(NS(=O)(=O)c2ccc(Cl)s2)c(=O)n(-c2ccccc2)n1C
Cc1cc(C(=O)OC(C)C(=O)NC2CCN(C(=N)N)CC2)c(C)o1

93.4% valid SMILES
COc1ccc(C2CC3=Nc4c(c(OC)c(C)c(OC)c4C3=[N+](C)C)NC2)c(O)c1
Cc1cccc(-c2cc(=O)c3ccccc3o2)n1
Nc1ccc(S(=O)(=O)NNS(=O)(=O)c2cccc3ccccc23)cc1
Cc1cc(C)c(OCC(=O)N2CC3(C)CC2CC(C)(C)C3)c(C)c1
CC(OC(=O)CNc1nc2ccc(S(N)(=O)=O)cc2s1)C(=O)Nc1ccc(C(F)(F)F)cc1

91.8% valid SMILES


KeyboardInterrupt: 