In [2]:
#from fcd import get_fcd, load_ref_model, canonical_smiles, get_predictions, calculate_frechet_distance
import warnings
import os
import pandas as pd
import numpy as np
import numpy
import pickle
import rdkit
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split


# Ignore some warnings from RDKIT and keras
from rdkit import RDLogger, Chem
from rdkit.Chem import rdMolDescriptors
RDLogger.DisableLog('rdApp.*')

warnings.filterwarnings("ignore")

# Load methods from the FCD library

np.random.seed(1234)

print("RDKit: ", rdkit.__version__)


RDKit:  2022.03.5


In [3]:
data = pd.read_csv("smiles_train.txt", header=None)[0]
data


0                COc1ccc(N2CCN(C(=O)c3cc4ccccc4[nH]3)CC2)cc1
1            c1ccc(CCCNC2CCCCN(CCOC(c3ccccc3)c3ccccc3)C2)cc1
2                             Nc1nc(O)c(Br)c(-c2cccc(O)c2)n1
3               CCc1nc2ccc(Br)cc2c(=O)n1-c1nc2c(C)cc(C)cc2s1
4            O=c1cnn(-c2ccc(S(=O)(=O)N3CCCCC3)cc2)c(=O)[nH]1
                                 ...                        
1036638                  CCOc1ccc(-n2c(SC)nc3c(c2=O)SCC3)cc1
1036639        Nc1ncnc2c1nc(I)n2C1SC(COC(=O)c2ccccc2)C(O)C1O
1036640              O=C(O)CCc1sc(C=C2NC(=O)CS2)nc1-c1ccccn1
1036641    CN(c1ncnc2[nH]ccc12)C1CC(CS(=O)(=O)N2CCC(C#N)C...
1036642    CCc1ccc(S(=O)(=O)NC2c3cc(C(=O)NCCc4c[nH]cn4)cc...
Name: 0, Length: 1036643, dtype: object

In [9]:
train_data, val_data = train_test_split(data, test_size=1/6, random_state=42)
train_data = train_data.reset_index(drop=True)
val_data = val_data.reset_index(drop=True)


In [5]:
charset = sorted(list(set(''.join(data))))
max_length = max([len(smile) for smile in data]) + 1


In [12]:
def vectorize_smiles(smiles, charset, max_length):
    indices = [charset.index(char) for char in smiles]
    padded_indices = indices + [0] * (max_length - len(indices))
    data = torch.tensor(padded_indices[:-1], dtype=torch.long)
    target = torch.tensor(padded_indices[1:], dtype=torch.long)
    return data, target


In [14]:
class SMILESDataset(Dataset):
    def __init__(self, smiles_data, charset, max_length):
        self.smiles_data = smiles_data
        self.charset = charset
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles_data)

    def __getitem__(self, index):
        input_smiles = self.smiles_data[index % len(self.smiles_data)]
        data, target = vectorize_smiles(
            input_smiles, self.charset, self.max_length)
        one_hot_data = torch.zeros(data.size(0), len(self.charset))
        one_hot_data.scatter_(1, data.unsqueeze(1), 1)
        return one_hot_data, target


In [15]:
train_dataset = SMILESDataset(train_data, charset, max_length)
val_dataset = SMILESDataset(val_data, charset, max_length)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)


In [16]:
class SMILESLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(SMILESLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size,
                            num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0),
                         self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0),
                         self.hidden_size).to(device)

        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out)

        return out


In [17]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

input_size = len(charset)
hidden_size = 128
output_size = len(charset)
num_layers = 1
num_epochs = 50
learning_rate = 0.001


cuda


In [18]:
model = SMILESLSTM(input_size, hidden_size, output_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [19]:
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0
    train_count = 0
    for one_hot_data, target in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        one_hot_data, target = one_hot_data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(one_hot_data.float())

        loss = criterion(output.transpose(1, 2), target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_count += 1

    avg_train_loss = train_loss / train_count

    # Validation
    model.eval()
    val_loss = 0
    val_count = 0
    with torch.no_grad():
        for one_hot_data, target in tqdm(val_loader, desc=f"Validating Epoch {epoch + 1}/{num_epochs}", leave=False):
            one_hot_data, target = one_hot_data.to(device), target.to(device)
            output = model(one_hot_data.float())
            loss = criterion(output.transpose(1, 2), target)

            val_loss += loss.item()
            val_count += 1

    avg_val_loss = val_loss / val_count

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")


                                                                          

Epoch [1/50], Train Loss: 0.5234, Validation Loss: 0.3964


                                                                          

Epoch [2/50], Train Loss: 0.3730, Validation Loss: 0.3574


                                                                          

Epoch [3/50], Train Loss: 0.3490, Validation Loss: 0.3417


                                                                          

Epoch [4/50], Train Loss: 0.3378, Validation Loss: 0.3341


                                                                          

Epoch [5/50], Train Loss: 0.3312, Validation Loss: 0.3289


                                                                          

Epoch [6/50], Train Loss: 0.3267, Validation Loss: 0.3255


                                                                          

Epoch [7/50], Train Loss: 0.3235, Validation Loss: 0.3235


                                                                          

Epoch [8/50], Train Loss: 0.3210, Validation Loss: 0.3210


                                                                          

Epoch [9/50], Train Loss: 0.3190, Validation Loss: 0.3197


                                                                           

Epoch [10/50], Train Loss: 0.3175, Validation Loss: 0.3251


                                                                           

Epoch [11/50], Train Loss: 0.3162, Validation Loss: 0.3174


                                                                           

Epoch [12/50], Train Loss: 0.3150, Validation Loss: 0.3152


                                                                           

Epoch [13/50], Train Loss: 0.3141, Validation Loss: 0.3144


                                                                           

Epoch [14/50], Train Loss: 0.3134, Validation Loss: 0.3141


                                                                           

Epoch [15/50], Train Loss: 0.3129, Validation Loss: 0.3129


                                                                           

Epoch [16/50], Train Loss: 0.3121, Validation Loss: 0.3126


                                                                           

Epoch [17/50], Train Loss: 0.3113, Validation Loss: 0.3122


                                                                           

Epoch [18/50], Train Loss: 0.3108, Validation Loss: 0.3114


                                                                           

Epoch [19/50], Train Loss: 0.3104, Validation Loss: 0.3130


                                                                           

Epoch [20/50], Train Loss: 0.3099, Validation Loss: 0.3110


                                                                           

Epoch [21/50], Train Loss: 0.3095, Validation Loss: 0.3106


                                                                           

Epoch [22/50], Train Loss: 0.3091, Validation Loss: 0.3100


                                                                           

Epoch [23/50], Train Loss: 0.3088, Validation Loss: 0.3095


                                                                           

Epoch [24/50], Train Loss: 0.3087, Validation Loss: 0.3095


                                                                           

Epoch [25/50], Train Loss: 0.3081, Validation Loss: 0.3088


                                                                           

Epoch [26/50], Train Loss: 0.3078, Validation Loss: 0.3100


                                                                           

Epoch [27/50], Train Loss: 0.3076, Validation Loss: 0.3087


                                                                           

Epoch [28/50], Train Loss: 0.3073, Validation Loss: 0.3082


                                                                           

Epoch [29/50], Train Loss: 0.3071, Validation Loss: 0.3081


                                                                           

Epoch [30/50], Train Loss: 0.3069, Validation Loss: 0.3079


                                                                           

Epoch [31/50], Train Loss: 0.3066, Validation Loss: 0.3080


                                                                           

Epoch [32/50], Train Loss: 0.3072, Validation Loss: 0.3074


                                                                           

Epoch [33/50], Train Loss: 0.3075, Validation Loss: 0.3074


                                                                           

Epoch [34/50], Train Loss: 0.3062, Validation Loss: 0.3069


                                                                           

Epoch [35/50], Train Loss: 0.3060, Validation Loss: 0.3072


                                                                           

Epoch [36/50], Train Loss: 0.3057, Validation Loss: 0.3070


                                                                           

Epoch [37/50], Train Loss: 0.3058, Validation Loss: 0.3069


                                                                           

Epoch [38/50], Train Loss: 0.3055, Validation Loss: 0.3067


                                                                           

Epoch [39/50], Train Loss: 0.3053, Validation Loss: 0.3068


                                                                           

Epoch [40/50], Train Loss: 0.3051, Validation Loss: 0.3064


                                                                           

Epoch [41/50], Train Loss: 0.3051, Validation Loss: 0.3066


                                                                           

Epoch [42/50], Train Loss: 0.3059, Validation Loss: 0.3076


                                                                           

Epoch [43/50], Train Loss: 0.3047, Validation Loss: 0.3065


                                                                           

Epoch [44/50], Train Loss: 0.3049, Validation Loss: 0.3060


                                                                           

Epoch [45/50], Train Loss: 0.3047, Validation Loss: 0.3061


                                                                           

Epoch [46/50], Train Loss: 0.3046, Validation Loss: 0.3058


                                                                           

Epoch [47/50], Train Loss: 0.3044, Validation Loss: 0.3058


                                                                           

Epoch [48/50], Train Loss: 0.3043, Validation Loss: 0.3055


                                                                           

Epoch [49/50], Train Loss: 0.3041, Validation Loss: 0.3056


                                                                           

Epoch [50/50], Train Loss: 0.3044, Validation Loss: 0.3083




In [20]:
torch.save(model.state_dict(), "trained_model_2.pth")


In [21]:
def unvectorize_smiles(vector, charset):
    smiles = ""
    for index in vector:
        if index == len(charset) - 1:
        smiles += charset[index]
    return smiles


In [38]:

def predict(model, input_smiles, charset, max_length):
    model.eval()
    with torch.no_grad():
        data = vectorize_smiles(input_smiles, charset, max_length)
        one_hot_data = torch.zeros(1, data[0].size(0), len(charset)).to(device)
        one_hot_data.scatter_(2, data[0].to(
            device).unsqueeze(0).unsqueeze(2), 1)
        output = model(one_hot_data.float())
        pred_indices = output.argmax(dim=2).squeeze(0).cpu().tolist()
        predicted_smiles = unvectorize_smiles(pred_indices, charset)

        # Remove "#" characters
        cleaned_smiles = predicted_smiles.replace("#", "")

        # Check for "nan" or "inf" in the SMILES string
        if "nan" not in cleaned_smiles.lower() and "inf" not in cleaned_smiles.lower():
            return cleaned_smiles
        else:
            return None


In [40]:
def is_valid_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return mol is not None


In [41]:
with open("predictions_2.txt", "w") as f:
    for input_smiles in val_data:
        predicted_smiles = predict(model, input_smiles, charset, max_length)
        if predicted_smiles is not None and is_valid_smiles(predicted_smiles):
            f.write(predicted_smiles + "\n")
