In [2]:
import os

import warnings
import os
import pandas as pd
import numpy as np
import numpy
import pickle
import rdkit
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import functools
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from rdkit.Chem import Draw
# from rdkit.Chem.Draw import IPythonConsole


# Ignore some warnings from RDKIT and keras
from rdkit import RDLogger, Chem
from torch.nn.functional import one_hot
import itertools
from functools import reduce
from rdkit.Chem import rdMolDescriptors
import torch.utils.data as torch_data

RDLogger.DisableLog('rdApp.*')

warnings.filterwarnings("ignore")

# Load methods from the FCD library

np.random.seed(1234)

print("RDKit: ", rdkit.__version__)


RDKit:  2022.03.5


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [None]:
with open("smiles_train.txt", "r") as f:
    smiles = []
    for line in f:
        mol = Chem.MolFromSmiles(line.strip())
        if mol is not None:
            smiles.append(Chem.MolToSmiles(mol))


with open("smiles_train.smi", "w") as f:
    for s in smiles:
        f.write(s + "\n")


In [24]:
__encoders__ = {
    0: "<PAD>",
    1: "<EOS>",
    2: "<BOS>",
}
max_length = 130

In [6]:
class SMILESDATA(torch_data.DataLoader):
    def __init__(self, smiles, max_length):
        self.smiles = open(smiles, 'r').read().split("\n")[:-1]
        self.max_length = max_length
        
        tokens = functools.reduce(
            lambda acc, s: acc.union(set(s)), self.smiles, set())
        self.vocsize = len(tokens) + len(__encoders__)
        self.index2token = dict(enumerate(tokens, start=3))
        self.index2token.update(__encoders__)
        self.token2index = {v: k for k, v in self.index2token.items()}
        self.ints = [torch.LongTensor([self.token2index[s] for s in line]) for line in 
            self.smiles]

    def __len__(self):
        return len(self.smiles)
    
    def decode(self, indexes):
        return "".join([self.index2token[index] for index in indexes if index not in __encoders__])

    def __getitem__(self, i):
        special_added = torch.cat((torch.LongTensor([self.token2index['<BOS>']]), self.ints[i], torch.LongTensor([self.token2index['<EOS>']]),
                                   torch.LongTensor([self.token2index["<PAD>"]]*(self.max_length-len(self.ints[i])-2))), dim=0)
        return one_hot(special_added, self.vocsize).float(), special_added


In [37]:

hidden_size = 512
num_layers = 3
num_epochs = 2
learning_rate = 0.001
batch_size = 256
dropout = 0.2


In [38]:
dataset = SMILESDATA('smiles_train.smi', max_length)
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True)


In [48]:
class SimplifiedSMILESGRU(nn.Module):
    def __init__(self, vocsize, hidden_size=512, num_layers=3, dropout=0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocsize = vocsize

        self.gru = nn.GRU(vocsize, hidden_size, num_layers,
                          batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_size, vocsize)

    def forward(self, x):
        out = self.gru(x)[0]
        out = self.dropout(out)
        out = self.fc(out)
        return out
    
    def sample(self, batch_size=128, max_len=130):
        bos_token = [k for k, v in __encoders__.items() if v == "<BOS>"][0]
        x = torch.LongTensor([bos_token]*batch_size)
        h = torch.zeros((self.num_layers, batch_size,
                        self.hidden_size)).to(device)
        accumulator = torch.zeros(batch_size, max_len)
        for i in range(max_len):
            x = one_hot(x, self.vocsize).float().unsqueeze(1).to(device)
            output, h = self.gru(x, h)
            next = F.softmax(self.linear(output).squeeze(1), dim=1)
            x = torch.multinomial(next, num_samples=1,
                                  replacement=True).squeeze(1)
            accumulator[:, i] = x
        return accumulator


In [41]:
model = SimplifiedSMILESGRU(dataset.vocsize, hidden_size, num_layers, dropout).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

scheduler = ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=2, verbose=True)


In [43]:
for epoch in range(1, num_epochs+1):
    model.train()
    train_loss = 0
    train_count = 0
    for i, (batch, target) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch}/{num_epochs}", leave=False)):

        batch, target = batch.to(device), target.to(device)
        output = model(batch)
 
        #print("Output shape:", output.shape)
        #print("Target shape:", target.shape)
        output = output.transpose(2, 1)
        loss = criterion(output[:, :, :-1], target[:, 1:])
        optimizer.zero_grad()
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()

        train_loss += loss.item()
        train_count += 1

    avg_train_loss = train_loss / train_count

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}")

    scheduler.step(avg_train_loss)


                                                                       

Epoch [2/2], Train Loss: 0.3421


                                                                       

Epoch [3/2], Train Loss: 0.3193




In [44]:
torch.save({'tokenizer': dataset.index2token,
            'model': model.cpu()}, "gru_model_1.pt")


In [45]:
trained_model = torch.load('gru_model_1.pt')


In [46]:
def is_valid_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return mol is not None


In [49]:

model, tokenizer = trained_model['model'], trained_model['tokenizer']
model = model.to(device)
model.eval()
res = model.sample(64)
valid_count = 0
with open("predictions_new_1.txt", "w") as f:
    while valid_count < 10001:
        print(smiles, is_valid_smiles(smiles))
        for i in range(res.size(0)):
            smiles = "".join([tokenizer.decode([index]).strip()
                            for index in res[i].tolist() if index not in __encoders__])
            if is_valid_smiles(smiles):
                f.write(smiles + os.linesep)
                valid_count += 1


AttributeError: 'SimplifiedSMILESGRU' object has no attribute 'device'