In [1]:
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random
import re
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
print(torch.__version__)

2.0.0
2.0.0


# data process

In [2]:
def replace_halogen(string):
    """Regex to replace Br and Cl with single letters"""
    br = re.compile('Br')
    cl = re.compile('Cl')
    string = br.sub('R', string)
    string = cl.sub('L', string)

    return string

class Vocabulary:
    def __init__(self,file,max_len=140):
        self.file = file
        self.max_len = max_len
        self.vocab_size,self.vocab,self.reversed_vocab = self.get_voc(file)
        self.reversed_vocab = {v: k for k, v in self.vocab.items()}
        self.max_len = max_len
    
    def get_voc(self,file):
        with open(file, 'r') as f:
            chars = f.read().split()
        chars +=  ['EOS', 'GO']
        chars.sort()
        vocab_size = len(chars)
        vocab = dict(zip(chars, range(len(chars))))
        reversed_vocab = {v: k for k, v in vocab.items()}
        return vocab_size,vocab,reversed_vocab
        
    def tokenize(self, smiles):
        """Takes a SMILES and return a list of characters/tokens"""
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = replace_halogen(smiles)
        char_list = re.split(regex, smiles)
        tokenized = []
        for char in char_list:
            if char.startswith('['):
                tokenized.append(char)
            else:
                chars = [unit for unit in char]
                [tokenized.append(unit) for unit in chars]
        tokenized.append('EOS')
        return tokenized
    
    def encode(self,char_list):
        smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
        for i, char in enumerate(char_list):
            smiles_matrix[i] = self.vocab[char]
        return smiles_matrix
    
    def decode(self,matrix):
        chars = []
        for i in matrix:
            if i == self.vocab['EOS']: break
            chars.append(self.reversed_vocab[i])
        smiles = "".join(chars)
        smiles = smiles.replace("L", "Cl").replace("R", "Br")
        return smiles
                          
    def __len__(self):
        return len(self.chars)

    def __str__(self):
        return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)
    
class MolData(Dataset):
    """Custom PyTorch Dataset that takes a file containing SMILES.

        Args:
                fname : path to a file containing \n separated SMILES.
                voc   : a Vocabulary instance

        Returns:
                A custom PyTorch dataset for training the Prior.
    """
    def __init__(self, fname, voc):
        self.voc = voc
        self.smiles = []
        with open(fname, 'r') as f:
            for line in f:
                self.smiles.append(line.split()[0])

    def __getitem__(self, i):
        mol = self.smiles[i]
        tokenized = self.voc.tokenize(mol)
        encoded = self.voc.encode(tokenized)
        return encoded

    def __len__(self):
        return len(self.smiles)

    def __str__(self):
        return "Dataset containing {} structures.".format(len(self))

    @classmethod
    def collate_fn(cls, arr):
        """Function to take a list of encoded sequences and turn them into a batch"""
        max_length = max([seq.size for seq in arr])
        collated_arr = np.zeros((len(arr), max_length))
        for i, seq in enumerate(arr):
            collated_arr[i, :seq.size] = seq
        return torch.from_numpy(collated_arr)

# define model

In [3]:
class GRULayer(nn.Module):
    def __init__(self, voc_size):
        super().__init__()
        self.embedding = nn.Embedding(voc_size,128)
        self.gru1 = nn.GRUCell(128,512)
        self.gru2 = nn.GRUCell(512,512)
        self.gru3 = nn.GRUCell(512,512)
        self.fcn = nn.Linear(512,voc_size)
        
    def forward(self, x, h):
        x = x.to(device)
        h = h.to(device)
        h_out = torch.zeros(h.shape).to(device)
        x = self.embedding(x).to(device)
        x = h_out[0] = self.gru1(x,h[0])
        x = h_out[1] = self.gru2(x,h[1])
        x = h_out[2] = self.gru3(x,h[2])
        x = self.fcn(x)
        return x,h_out

class RNN():
    def __init__(self,voc,batch_size,max_len=140):
        global device
        self.rnn = GRULayer(voc.vocab_size).to(device)
        self.voc = voc
        self.batch_size = batch_size
        self.max_len = max_len
        self.loss_fn = nn.BCELoss()
    
    def bce_loss(self,prob,target):
        onehot_target = torch.zeros(prob.shape)
        onehot_target.scatter_(1, target.contiguous().view(-1, 1).data, 1.0)
        onehot_target = onehot_target.to(device)
        return self.loss_fn(prob,onehot_target)
    
    def forward(self,target):
        self.rnn.train()
        loss = 0
        h = torch.zeros(3, self.batch_size, 512).float()
        start_token = torch.zeros(self.batch_size,1).long()
        start_token[:] = self.voc.vocab['GO']
        x = torch.cat((start_token, target[:, :-1]), 1)
        outs = []
        for step in range(target.shape[1]):
            logits,h = self.rnn(x[:, step],h)
            prob = F.softmax(logits,dim=1)
            out = torch.argmax(prob,dim=1)
            outs.append(out)
            loss += self.bce_loss(prob, target[:, step])
        return loss
    
    def sample(self,batch_size):
        output = []
        finished = torch.zeros(batch_size).byte().to(device)
        h = torch.zeros(3, self.batch_size, 512).float()
        start_token = torch.zeros(self.batch_size).long()
        start_token[:] = self.voc.vocab['GO']
        x = start_token
        with torch.no_grad():
            self.rnn.eval()
            for step in range(self.max_len):
                logits,_ = self.rnn(x,h)
                prob = F.softmax(logits,dim=1)
                x = torch.multinomial(prob, num_samples=1).view(-1)
                output.append(x.view(-1, 1))

                EOS_sampled = (x == self.voc.vocab['EOS']).byte()
                finished = torch.ge(finished + EOS_sampled, 1)
                if torch.prod(finished) == 1: break
                
        return torch.cat(output, 1)

In [4]:
num_epoch = 3
batch_size = 128
lr = 1e-4
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [5]:
voc = Vocabulary("data/Voc")
moldata = MolData("data/ChEMBL_filtered", voc)
data = DataLoader(moldata, batch_size=batch_size, shuffle=True, drop_last=True,collate_fn=MolData.collate_fn)
model = RNN(voc,batch_size)
optimizer = torch.optim.Adam(model.rnn.parameters(), lr = lr)

In [6]:
for epoch in range(num_epoch):
    bar = tqdm(enumerate(data), total=len(data))
    bar.set_description(f'[ Epoch: {epoch}]')
    for step,batch in bar:
        optimizer.zero_grad()
        seqs = batch.long()
        loss = model.forward(seqs)
        loss.backward()
        optimizer.step()
        bar.set_postfix({"loss":loss.item()})
        
        if (step+1) % 100 ==0:
            seqs = model.sample(128)
            tqdm.write(f'-------{step}--------')
#             for i,seq in enumerate(seqs):
#                 smile = voc.decode(seq)
            print(seqs)

  0%|          | 0/9214 [00:00<?, ?it/s]

  0%|          | 0/9214 [00:00<?, ?it/s]

-------99--------
tensor([[10,  1, 33,  ..., 41, 26, 11],
        [10, 17, 38,  ..., 43, 20, 38],
        [34, 36, 11,  ..., 37, 30, 16],
        ...,
        [ 7, 21, 13,  ..., 32, 23, 29],
        [12, 17, 49,  ..., 30,  4, 32],
        [21, 39, 12,  ..., 15, 13, 39]], device='cuda:0')
-------99--------
tensor([[10,  1, 33,  ..., 41, 26, 11],
        [10, 17, 38,  ..., 43, 20, 38],
        [34, 36, 11,  ..., 37, 30, 16],
        ...,
        [ 7, 21, 13,  ..., 32, 23, 29],
        [12, 17, 49,  ..., 30,  4, 32],
        [21, 39, 12,  ..., 15, 13, 39]], device='cuda:0')
-------199--------
tensor([[ 5, 29, 24,  ..., 36,  8,  5],
        [16, 16, 15,  ..., 48, 42, 35],
        [ 2, 36, 39,  ..., 18,  4, 18],
        ...,
        [ 2, 33, 30,  ..., 21, 41, 10],
        [ 0, 13, 25,  ..., 33, 14, 43],
        [21, 16, 22,  ..., 38, 10, 46]], device='cuda:0')
-------199--------
tensor([[ 5, 29, 24,  ..., 36,  8,  5],
        [16, 16, 15,  ..., 48, 42, 35],
        [ 2, 36, 39,  ..., 18,  4

KeyboardInterrupt: 

KeyboardInterrupt: 

In [None]:
target = torch.ones(128,1).long()
onehot_target = torch.zeros((128,50)).to(device)
onehot_target.scatter_(1, target.contiguous().view(-1, 1).data, 1.0)