In [1]:
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random
import re
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
print(torch.__version__)

2.0.0


# data process

In [2]:
def replace_halogen(string):
    """Regex to replace Br and Cl with single letters"""
    br = re.compile('Br')
    cl = re.compile('Cl')
    string = br.sub('R', string)
    string = cl.sub('L', string)

    return string

class Vocabulary:
    def __init__(self,file,max_len=140):
        self.file = file
        self.max_len = max_len
        self.vocab_size,self.vocab,self.reversed_vocab = self.get_voc(file)
        self.reversed_vocab = {v: k for k, v in self.vocab.items()}
        self.max_len = max_len
    
    def get_voc(self,file):
        with open(file, 'r') as f:
            chars = f.read().split()
        chars +=  ['EOS', 'GO']
        chars.sort()
        vocab_size = len(chars)
        vocab = dict(zip(chars, range(len(chars))))
        reversed_vocab = {v: k for k, v in vocab.items()}
        return vocab_size,vocab,reversed_vocab
        
    def tokenize(self, smiles):
        """Takes a SMILES and return a list of characters/tokens"""
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = replace_halogen(smiles)
        char_list = re.split(regex, smiles)
        tokenized = []
        for char in char_list:
            if char.startswith('['):
                tokenized.append(char)
            else:
                chars = [unit for unit in char]
                [tokenized.append(unit) for unit in chars]
        tokenized.append('EOS')
        return tokenized
    
    def encode(self,char_list):
        smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
        for i, char in enumerate(char_list):
            smiles_matrix[i] = self.vocab[char]
        return smiles_matrix
    
    def decode(self,matrix):
        chars = []
        for i in matrix:
            i = i.item()
            if i == self.vocab['EOS']: break
            chars.append(self.reversed_vocab[i])
        smiles = "".join(chars)
        smiles = smiles.replace("L", "Cl").replace("R", "Br")
        return smiles
                          
    def __len__(self):
        return len(self.chars)

    def __str__(self):
        return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)
    
class MolData(Dataset):
    """Custom PyTorch Dataset that takes a file containing SMILES.

        Args:
                fname : path to a file containing \n separated SMILES.
                voc   : a Vocabulary instance

        Returns:
                A custom PyTorch dataset for training the Prior.
    """
    def __init__(self, fname, voc):
        self.voc = voc
        self.smiles = []
        with open(fname, 'r') as f:
            for line in f:
                self.smiles.append(line.split()[0])

    def __getitem__(self, i):
        mol = self.smiles[i]
        tokenized = self.voc.tokenize(mol)
        encoded = self.voc.encode(tokenized)
        return encoded

    def __len__(self):
        return len(self.smiles)

    def __str__(self):
        return "Dataset containing {} structures.".format(len(self))

    @classmethod
    def collate_fn(cls, arr):
        """Function to take a list of encoded sequences and turn them into a batch"""
        max_length = max([seq.size for seq in arr])
        collated_arr = np.zeros((len(arr), max_length))
        for i, seq in enumerate(arr):
            collated_arr[i, :seq.size] = seq
        return torch.from_numpy(collated_arr)

# define model

In [3]:
class GRULayer(nn.Module):
    def __init__(self, voc_size):
        super().__init__()
        self.embedding = nn.Embedding(voc_size,128)
        self.gru1 = nn.GRUCell(128,512)
        self.gru2 = nn.GRUCell(512,512)
        self.gru3 = nn.GRUCell(512,512)
        self.fcn = nn.Linear(512,voc_size)
        
    def forward(self, x, h):
        x = x.to(device)
        h = h.to(device)
        h_out = torch.zeros(h.shape).to(device)
        x = self.embedding(x).to(device)
        x = h_out[0] = self.gru1(x,h[0])
        x = h_out[1] = self.gru2(x,h[1])
        x = h_out[2] = self.gru3(x,h[2])
        x = self.fcn(x)
        return x,h_out

class RNN():
    def __init__(self,voc,batch_size,max_len=140):
        global device
        self.rnn = GRULayer(voc.vocab_size).to(device)
        self.voc = voc
        self.batch_size = batch_size
        self.max_len = max_len
        self.loss_fn = nn.BCELoss()
    
    def bce_loss(self,prob,target):
        onehot_target = torch.zeros(prob.shape)
        onehot_target.scatter_(1, target.contiguous().view(-1, 1).data, 1.0)
        onehot_target = onehot_target.to(device)
        return self.loss_fn(prob,onehot_target)
    
    def forward(self,target):
        self.rnn.train()
        loss = 0
        h = torch.zeros(3, self.batch_size, 512).float()
        start_token = torch.zeros(self.batch_size,1).long()
        start_token[:] = self.voc.vocab['GO']
        x = torch.cat((start_token, target[:, :-1]), 1)
        outs = []
        for step in range(target.shape[1]):
            logits,h = self.rnn(x[:, step],h)
            prob = F.softmax(logits,dim=1)
            out = torch.argmax(prob,dim=1)
            outs.append(out)
            loss += self.bce_loss(prob, target[:, step])
        outs = torch.transpose(torch.stack(outs),1,0)
        acc = torch.sum(outs==target.to(device))/outs.numel()
        return loss,acc
    
    def sample(self,batch_size):
        output = []
        finished = torch.zeros(batch_size).byte().to(device)
        h = torch.zeros(3, batch_size, 512).float()
        start_token = torch.zeros(batch_size).long()
        start_token[:] = self.voc.vocab['GO']
        x = start_token
        with torch.no_grad():
            self.rnn.eval()
            for step in range(self.max_len):
                logits,_ = self.rnn(x,h)
                prob = F.softmax(logits,dim=1)
                x = torch.multinomial(prob, num_samples=1).view(-1)
                output.append(x.view(-1, 1))

                EOS_sampled = (x == self.voc.vocab['EOS']).byte()
                finished = torch.ge(finished + EOS_sampled, 1)
                if torch.prod(finished) == 1: break
                
        return torch.cat(output, 1)

In [4]:
num_epoch = 3
batch_size = 512
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [5]:
voc = Vocabulary("data/Voc")
moldata = MolData("data/ChEMBL_filtered", voc)
data = DataLoader(moldata, batch_size=batch_size, shuffle=True, drop_last=True,collate_fn=MolData.collate_fn)
model = RNN(voc,batch_size)
optimizer = torch.optim.Adam(model.rnn.parameters(), lr = lr)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lambda x: 1 - x/(len(data)*num_epoch),last_epoch=-1)

In [6]:
for epoch in range(num_epoch):
    bar = tqdm(enumerate(data), total=len(data))
    bar.set_description(f'[ Epoch: {epoch+1}]')
    for step,batch in bar:
        optimizer.zero_grad()
        seqs = batch.long()
        loss,acc = model.forward(seqs)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        bar.set_postfix({"loss":loss.item(),'acc':acc.item()})
        
        if (step+1) % 100 ==0:
            seqs = model.sample(5)
            tqdm.write(f'-------{step+1}--------')
            for i,seq in enumerate(seqs):
                smile = voc.decode(seq)
                print(smile)

  0%|          | 0/2303 [00:00<?, ?it/s]

-------100--------
C1[CH]#8#####################1NOO)c3)c([NH+][NH-][NH3+]F)c1CCl
O)n1Oc(8[nH][N+]SCNCC1###########[nH+]OCN2
NCNOC6)c2[NH-]S##########[N-][N+]O(=F)CC(s5GOCO)1####################
C==[NH-]c4####
CCN[nH]o[n+]Br[n+]o9[s+][SH]F)nc3[O+][SH+][o+]2[o+]OCc2ClNn3
-------200--------
C(Cl[nH+]Cl=S
CNCC(F[O-]CCCNc(ON[n+]7[nH+]CCNCCCNCCCC1###%[S+][s+]6[n-][SH]S8CCCCNCCNCCCOCCCOC6[s+]1
C1
CO=F-(OCc3
CC(Cc1[N+](=)o(Cl)c2Cl
-------300--------
CCC1####
CC(F7GOC(F)o[n+][NH-]-c2#####[SH]3[NH-]=SC1
CCc([N-]SCO)S3############
O
C1OCC(F)c4#
-------400--------
C([N-]CC1-4#########[NH+][NH+]n1#######[N-]CCC(F)n[N-]NCCC(CCC(SCNNCCCCCc32#OCCC(5###########################)c2[nH+](N([s+]n2#######[nH+][n+]4[SH][SH+]2[CH]c4-2GOO
C([O-][O]c[N]C(OCCCC([N+]([O-][S-]F)c1[S-]O[O]3##GOO####
CCC1)85)c[NH+][S-][CH]Br[n-]CC0)c3[O+]c2
C([N-]
C(O)o6)c1###n[s+]s(F3
-------500--------
CCCC(=[O][n+]3F##=[O-])n1
CCCC(F)c4)N[S-]CCC[NH-]n[nH+](F)o1##
N(F)c[nH]NCCCCC=OBr)n%####
CCCCCC(F)c1
CCCCCC(=ONC(Clo[n-]CC(O)n3#

  0%|          | 0/2303 [00:00<?, ?it/s]

-------100--------
OCN=NC[S+](OCCCCC############
CCC(F)s[NH2+]C(Cl)n1##########
CCCCC(F)c1###############################################################
C(N1######
CCNCCC##########################
-------200--------
C([O-]###
C(F)c1#
OC(F)c3)s2
CN=[N-]C(N1
CCC(Cl)c1#########
-------300--------
CO1
CCCCOCCC([O-])n1#####C(C(F)n6#################[N]C1
C=O)c3[O-])c3
C(=[N-]CC([N+](=OCC1#
CNCCC[O+][N-]CCCC(F)s3#######################CC(F)n3)c[SH]2
-------400--------
CF)s3)n1
CCC###############
OCCCC(CCC(F)o3########[O+]o[n+]1#########[NH3+]#############
C([O-])s3###########
CCC(F#######################################CCCCC(C(=CCCCC(C([O+]1
-------500--------
CC###############################################################################
CCCC(N(F)c7[O]######################################################################)c1#######################################################
C(Cl9####################################GOCC([NH3+]##################
CCCCCCCC([O-])[nH]7######################

  0%|          | 0/2303 [00:00<?, ?it/s]

-------100--------
CC(F)c1
CCC(Br)o1################
O
CC#########
CCCCCCCC(Cl)[nH]4N[NH3+]############################################################################################
-------200--------
CCCC(C([O-]
CCCC##############
S(F
C=[N-]CC(N=O
CC(F)n1########
-------300--------
CC(OC(F)n4######
CCCCCCCC(F)s3
OCCCCCC([NH3+]
OCCOCCCC(F)[nH]4)s2
CCC###########################################################################
-------400--------
CCC([NH2+]CC(F)s8######################################################################
CCCCCC(F)s[SH+][nH]GOCCCCCCC(F)[nH]1
C([O-])c1
CNCCC(Cl)s8
CCC1####################=OCC([O-])n1#############
-------500--------
N###########################################################################################################################################
CCCCC##############################
C(F)[nH]3######################
C(=[NH+]###################CCC(OC(F#########
OCC(F)[nH]GOCCC(F
-------600--------
C(F)[nH][S-]##
OCCC########################

In [7]:
target = torch.ones(128,1).long()
onehot_target = torch.zeros((128,50)).to(device)
onehot_target.scatter_(1, target.contiguous().view(-1, 1).data, 1.0)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_scatter__value)

In [None]:
torch.tensor([[1,2,3],[1,1,1]]).numel()