In [1]:
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit import rdBase
from tqdm.auto import tqdm
import random
import re
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
from scoring_functions import get_scoring_function
print(torch.__version__)

2.0.0


# data process

In [2]:
def replace_halogen(string):
    """Regex to replace Br and Cl with single letters"""
    br = re.compile('Br')
    cl = re.compile('Cl')
    string = br.sub('R', string)
    string = cl.sub('L', string)

    return string

class Vocabulary:
    def __init__(self,file,max_len=140):
        self.file = file
        self.max_len = max_len
        self.vocab_size,self.vocab,self.reversed_vocab = self.get_voc(file)
        self.reversed_vocab = {v: k for k, v in self.vocab.items()}
        self.max_len = max_len
    
    def get_voc(self,file):
        with open(file, 'r') as f:
            chars = f.read().split()
        chars +=  ['EOS', 'GO']
        chars.sort()
        vocab_size = len(chars)
        vocab = dict(zip(chars, range(len(chars))))
        reversed_vocab = {v: k for k, v in vocab.items()}
        return vocab_size,vocab,reversed_vocab
        
    def tokenize(self, smiles):
        """Takes a SMILES and return a list of characters/tokens"""
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = replace_halogen(smiles)
        char_list = re.split(regex, smiles)
        tokenized = []
        for char in char_list:
            if char.startswith('['):
                tokenized.append(char)
            else:
                chars = [unit for unit in char]
                [tokenized.append(unit) for unit in chars]
        tokenized.append('EOS')
        return tokenized
    
    def encode(self,char_list):
        smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
        for i, char in enumerate(char_list):
            smiles_matrix[i] = self.vocab[char]
        return smiles_matrix
    
    def decode(self,matrix):
        chars = []
        for i in matrix:
            if i == self.vocab['EOS']: break
            chars.append(self.reversed_vocab[i.item()])
        smiles = "".join(chars)
        smiles = smiles.replace("L", "Cl").replace("R", "Br")
        return smiles
                          
    def __len__(self):
        return len(self.chars)

    def __str__(self):
        return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)
    
class MolData(Dataset):
    """Custom PyTorch Dataset that takes a file containing SMILES.

        Args:
                fname : path to a file containing \n separated SMILES.
                voc   : a Vocabulary instance

        Returns:
                A custom PyTorch dataset for training the Prior.
    """
    def __init__(self, fname, voc):
        self.voc = voc
        self.smiles = []
        with open(fname, 'r') as f:
            for line in f:
                self.smiles.append(line.split()[0])

    def __getitem__(self, i):
        mol = self.smiles[i]
        tokenized = self.voc.tokenize(mol)
        encoded = self.voc.encode(tokenized)
        return encoded

    def __len__(self):
        return len(self.smiles)

    def __str__(self):
        return "Dataset containing {} structures.".format(len(self))

    @classmethod
    def collate_fn(cls, arr):
        """Function to take a list of encoded sequences and turn them into a batch"""
        max_length = max([seq.size for seq in arr])
        collated_arr = np.zeros((len(arr), max_length))
        for i, seq in enumerate(arr):
            collated_arr[i, :seq.size] = seq
        return torch.from_numpy(collated_arr)

# define model

In [11]:
class GRULayer(nn.Module):
    def __init__(self, voc_size):
        super().__init__()
        self.embedding = nn.Embedding(voc_size,128)
        self.gru1 = nn.GRUCell(128,512)
        self.gru2 = nn.GRUCell(512,512)
        self.gru3 = nn.GRUCell(512,512)
        self.fcn = nn.Linear(512,voc_size)
        
    def forward(self, x, h):
        x = x.to(device)
        h = h.to(device)
        h_out = torch.zeros(h.shape).to(device)
        x = self.embedding(x).to(device)
        x = h_out[0] = self.gru1(x,h[0])
        x = h_out[1] = self.gru2(x,h[1])
        x = h_out[2] = self.gru3(x,h[2])
        x = self.fcn(x)
        return x,h_out

class RNN():
    def __init__(self,voc,batch_size,max_len=140):
        global device
        self.rnn = GRULayer(voc.vocab_size).to(device)
        self.voc = voc
        self.batch_size = batch_size
        self.max_len = max_len
        self.loss_fn = nn.BCELoss(reduction='none')
    
    def bce_loss(self,prob,target):
        onehot_target = torch.zeros(prob.shape).to(device)
        onehot_target.scatter_(1, target.contiguous().view(-1, 1).data, 1.0)
        onehot_target = onehot_target.to(device)
#         print(prob,onehot_target)
        return self.loss_fn(prob,onehot_target)
    
    def forward(self,target):
#         self.rnn.train()
#         target = target.to(device)
#         batch_size, seq_length = target.size()
#         start_token = Variable(torch.zeros(batch_size, 1).long())
#         start_token[:] = self.voc.vocab['GO']
#         x = torch.cat((start_token, target[:, :-1]), 1)
#         h = torch.zeros(3, self.batch_size, 512).float()

#         log_probs = Variable(torch.zeros(batch_size))
#         loss = 0
#         outs = []
#         for step in range(seq_length):
#             logits, h = self.rnn(x[:, step], h)
#             log_prob = F.log_softmax(logits)
            
#             prob = F.softmax(logits,dim=1)
#             out = torch.argmax(prob,dim=1)
#             outs.append(out)
#             log_probs += NLLLoss(log_prob, target[:, step])
#         outs = torch.transpose(torch.stack(outs),1,0)
#         acc = torch.sum(outs==target.to(device))/outs.numel()
#         return log_probs,acc
        target = target.to(device)
        batch_size, seq_length = target.size()
        start_token = Variable(torch.zeros(batch_size, 1).long())
        start_token[:] = self.voc.vocab['GO']
        x = torch.cat((start_token, target[:, :-1]), 1)
        h = torch.zeros(3, batch_size, 512).float()
        log_probs = Variable(torch.zeros(batch_size))
        entropy = Variable(torch.zeros(batch_size))
        outs = []
        for step in range(seq_length):
            logits, h = self.rnn(x[:, step], h)
            log_prob = F.log_softmax(logits)
            prob = F.softmax(logits)
            out = torch.argmax(prob,dim=1)
            outs.append(out)
            log_probs += NLLLoss(log_prob, target[:, step])
            entropy += -torch.sum((log_prob * prob), 1)
        outs = torch.transpose(torch.stack(outs),1,0)
        acc = torch.sum(outs==target.to(device))/outs.numel()
        return log_probs, acc
    
    def sample(self,batch_size):
        output = []
        loss = 0
        log_probs = Variable(torch.zeros(batch_size))
        finished = torch.zeros(batch_size).byte().to(device)
        h = torch.zeros(3, batch_size, 512).float()
        start_token = torch.zeros(batch_size).long()
        start_token[:] = self.voc.vocab['GO']
        x = start_token
        with torch.no_grad():
            self.rnn.eval()
            for step in range(self.max_len):
                logits,h = self.rnn(x,h)
                prob = F.softmax(logits,dim=1)
                log_prob = F.log_softmax(logits)
                x = torch.multinomial(prob, num_samples=1).view(-1)
                output.append(x.view(-1, 1))
                log_probs += NLLLoss(log_prob, x)

                EOS_sampled = (x == self.voc.vocab['EOS']).byte()
                finished = torch.ge(finished + EOS_sampled, 1)
                if torch.prod(finished) == 1: break
                
        return log_probs.data,torch.cat(output, 1)
    
def NLLLoss(inputs, targets):
    """
        Custom Negative Log Likelihood loss that returns loss per example,
        rather than for the entire batch.

        Args:
            inputs : (batch_size, num_classes) *Log probabilities of each class*
            targets: (batch_size) *Target class index*

        Outputs:
            loss : (batch_size) *Loss for each example*
    """

    if torch.cuda.is_available():
        target_expanded = torch.zeros(inputs.size()).cuda()
    else:
        target_expanded = torch.zeros(inputs.size())

    target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0)
    loss = Variable(target_expanded) * inputs
    loss = torch.sum(loss, 1)
    return loss

    
def Variable(tensor):
    """Wrapper for torch.autograd.Variable that also accepts
       numpy arrays directly and automatically assigns it to
       the GPU. Be aware in case some operations are better
       left to the CPU."""
    if isinstance(tensor, np.ndarray):
        tensor = torch.from_numpy(tensor)
    if torch.cuda.is_available():
        return torch.autograd.Variable(tensor).cuda()
    return torch.autograd.Variable(tensor)

In [12]:
num_epoch = 4
batch_size = 256
lr = 2e-3
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [13]:
voc = Vocabulary("data/Voc")
moldata = MolData("data/ChEMBL_filtered", voc)
data = DataLoader(moldata, batch_size=batch_size, shuffle=True, drop_last=True,collate_fn=MolData.collate_fn)
model = RNN(voc,batch_size)
optimizer = torch.optim.Adam(model.rnn.parameters(), lr = lr)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lambda x: 1 - x/(len(data)*num_epoch))

In [14]:
for epoch in range(num_epoch):
    bar = tqdm(enumerate(data), total=len(data))
    bar.set_description(f'[ Epoch: {epoch+1}]')
    for step,batch in bar:
        optimizer.zero_grad()
        seqs = batch.long()
#         log_p, _ = model.forward(seqs)
#         loss = - log_p.mean()
        log_p,acc = model.forward(seqs)
        loss = - log_p.mean()
        optimizer.step()
        lr_scheduler.step()
        bar.set_postfix({"loss":loss.item(),'acc':acc.item()})
        
        if (step +1) % 200 ==0:
#             seqs, likelihood, _ = model.sample(batch_size)
            _,seqs = model.sample(batch_size)
            tqdm.write(f'-------{step+1}--------')
            valid = 0
            for i,seq in enumerate(seqs):
                smile = voc.decode(seq)
                if Chem.MolFromSmiles(smile):
                    valid += 1
                if i < 5:
                    tqdm.write(smile)
            tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs)))
            tqdm.write("*" * 50 + "\n")
            torch.save(model.rnn.state_dict(), "data/Prior.ckpt")
        torch.save(model.rnn.state_dict(), "data/Prior.ckpt")

  0%|          | 0/4607 [00:00<?, ?it/s]

  log_prob = F.log_softmax(logits)
  prob = F.softmax(logits)


KeyboardInterrupt: 

In [None]:
n_steps = 3000
scoring_function='tanimoto'
scoring_function_kwargs=None
sigma = 60
num_processes=0

In [None]:
def unique(arr):
    # Finds unique rows in arr and return their indices
    arr = arr.cpu().numpy()
    arr_ = np.ascontiguousarray(arr).view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[1])))
    _, idxs = np.unique(arr_, return_index=True)
    if torch.cuda.is_available():
        return torch.LongTensor(np.sort(idxs)).cuda()
    return torch.LongTensor(np.sort(idxs))

def seq_to_smiles(seqs, voc):
    """Takes an output sequence from the RNN and returns the
       corresponding SMILES."""
    smiles = []
    for seq in seqs.cpu().numpy():
        smiles.append(voc.decode(seq))
    return smiles

In [None]:
Prior = RNN(voc,batch_size)
Agent = RNN(voc,batch_size)

# logger = VizardLog('data/logs')


Prior.rnn.load_state_dict(torch.load(f'data/Prior.ckpt'))
Agent.rnn.load_state_dict(torch.load(f'data/Prior.ckpt'))
if torch.cuda.is_available():
    Prior.rnn.to(device)
    Agent.rnn.to(device)
    Prior.rnn.eval()

scoring_function = get_scoring_function(scoring_function=scoring_function, num_processes=num_processes)
    
optimizer = torch.optim.Adam(Agent.rnn.parameters(), lr=0.0005)
bar = tqdm(range(n_steps))
bar.set_description(f'[ Training agent ]')
for step in bar:

    # Sample from Agent
    agent_likelihood,seqs = Agent.sample(batch_size)
    

    # Remove duplicates, ie only consider unique seqs
#     print(seqs)
#     unique_idxs = unique(seqs)
#     seqs = seqs[unique_idxs]
#     agent_likelihood = agent_likelihood[unique_idxs]

    # Get prior likelihood and score
    prior_likelihood = Prior.forward(seqs)
    smiles = seq_to_smiles(seqs, voc)
    
    score = scoring_function(smiles)

    # Calculate augmented likelihood
    augmented_likelihood = prior_likelihood + sigma * Variable(score)
    loss = torch.pow((augmented_likelihood - agent_likelihood), 2)
    loss = loss.mean()

    # Add regularizer that penalizes high likelihood for the entire sequence
    loss_p = - (1 / agent_likelihood).mean()
    loss += 5 * 1e3 * loss_p

    # Calculate gradients and make an update to the network weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    bar.set_postfix({'loss':f'{loss.item():.5f}'})
    
    if (step +1) % 100 ==0:
        for i,smile in enumerate(smiles):
            print(smile)
            if i>5: break
    torch.save(Agent.rnn.state_dict(), "data/Agent.ckpt")