In [1]:
#from fcd import get_fcd, load_ref_model, canonical_smiles, get_predictions, calculate_frechet_distance
import warnings
import os
import pandas as pd
import numpy as np
import numpy
import pickle
import rdkit
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split


# Ignore some warnings from RDKIT and keras
from rdkit import RDLogger, Chem
from rdkit.Chem import rdMolDescriptors
RDLogger.DisableLog('rdApp.*')

warnings.filterwarnings("ignore")

# Load methods from the FCD library

np.random.seed(1234)

print("RDKit: ", rdkit.__version__)


RDKit:  2022.09.5


In [3]:
data = pd.read_csv("smiles_train.txt", header=None)[0]


In [3]:
def to_canonical(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol, canonical=True)
    else:
        return None


In [None]:
data_canonical = data.apply(to_canonical).dropna().reset_index(drop=True)


In [22]:
data_canonical.to_csv("canonical_smiles_train.txt", index=False, header=False)


In [4]:
data_canonical = pd.read_csv(
    "canonical_smiles_train.txt", header=None, squeeze=True)


In [4]:
data_canonical

0                COc1ccc(N2CCN(C(=O)c3cc4ccccc4[nH]3)CC2)cc1
1            c1ccc(CCCNC2CCCCN(CCOC(c3ccccc3)c3ccccc3)C2)cc1
2                             Nc1nc(O)c(Br)c(-c2cccc(O)c2)n1
3               CCc1nc2ccc(Br)cc2c(=O)n1-c1nc2c(C)cc(C)cc2s1
4            O=c1cnn(-c2ccc(S(=O)(=O)N3CCCCC3)cc2)c(=O)[nH]1
                                 ...                        
1036638                  CCOc1ccc(-n2c(SC)nc3c(c2=O)SCC3)cc1
1036639        Nc1ncnc2c1nc(I)n2C1SC(COC(=O)c2ccccc2)C(O)C1O
1036640              O=C(O)CCc1sc(C=C2NC(=O)CS2)nc1-c1ccccn1
1036641    CN(c1ncnc2[nH]ccc12)C1CC(CS(=O)(=O)N2CCC(C#N)C...
1036642    CCc1ccc(S(=O)(=O)NC2c3cc(C(=O)NCCc4c[nH]cn4)cc...
Name: 0, Length: 1036643, dtype: object

In [5]:
#train_data, val_data = train_test_split(data, test_size=1/6, random_state=42)
train_data, val_data = train_test_split(
    data, test_size=0.2, random_state=42)
train_data = train_data.reset_index(drop=True)
val_data = val_data.reset_index(drop=True)


In [6]:
charset = sorted(list(set(''.join(data) + '^$')))
pad_char = '_'
charset.append(pad_char)
charset = sorted(list(set(charset)))

max_length = max([len(smile) for smile in data]) + 1


In [7]:
# def vectorize_smiles(smiles, charset, max_length):
#     indices = [charset.index(char) for char in smiles]
#     padded_indices = indices + [0] * (max_length - len(indices))
#     data = torch.tensor(padded_indices[:-1], dtype=torch.long)
#     target = torch.tensor(padded_indices[1:], dtype=torch.long)
#     return data, target

def vectorize_smiles(smiles, charset, max_length):
    smiles = '^' + smiles + '$'
    indices = [charset.index(char) for char in smiles]
    padded_indices = indices + [0] * (max_length - len(indices))
    data = torch.tensor(padded_indices[:-1], dtype=torch.long)
    target = torch.tensor(padded_indices[1:], dtype=torch.long)
    return data, target


In [8]:
class SMILESData(Dataset):
    def __init__(self, smiles_list, charset, max_length, pad_char='_'):
        self.smiles_list = smiles_list
        self.max_length = max_length
        self.pad_char = pad_char

        # Use the provided charset
        self.charset = charset
        self.char_to_int = {char: i for i, char in enumerate(self.charset)}
        self.int_to_char = {i: char for i, char in enumerate(self.charset)}

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        one_hot = self.smiles_to_one_hot(smiles)
        target = torch.tensor([self.char_to_int[char] for char in smiles] + [
                              self.char_to_int[self.pad_char]] * (self.max_length - len(smiles)), dtype=torch.long)


        return one_hot, target

    def smiles_to_one_hot(self, smiles):
        one_hot = np.zeros((self.max_length, len(self.charset)), dtype=np.float32)
        for i, char in enumerate(smiles):
            one_hot[i, self.char_to_int[char]] = 1.0
        return torch.tensor(one_hot)


In [9]:
train_dataset = SMILESData(train_data, charset, max_length, pad_char=pad_char)
val_dataset = SMILESData(val_data, charset, max_length, pad_char=pad_char)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


In [10]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

input_size = len(charset)
hidden_size = 512
output_size = len(charset)
num_layers = 3
num_epochs = 5
learning_rate = 0.001


cuda


In [11]:
class SimplifiedSMILESGRU(nn.Module):
    def __init__(self, input_size=input_size, hidden_size=512, output_size=output_size, num_layers=3, dropout=0.2):
        super(SimplifiedSMILESGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.gru = nn.GRU(input_size, hidden_size, num_layers,
                          batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0),
                         self.hidden_size).to(device)
        out, _ = self.gru(x, h0)
        out = self.dropout(out)
        out = self.fc(out)
        return out


In [12]:
model = SimplifiedSMILESGRU(input_size, hidden_size,
                            output_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-4)

scheduler = ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=2, verbose=True)


In [106]:
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0
    train_count = 0
    for one_hot_data, target in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        one_hot_data, target = one_hot_data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(one_hot_data.float())

        loss = criterion(output.transpose(1, 2), target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()

        train_loss += loss.item()
        train_count += 1

    avg_train_loss = train_loss / train_count

    # Validation
    model.eval()
    val_loss = 0
    val_count = 0
    with torch.no_grad():
        for one_hot_data, target in tqdm(val_loader, desc=f"Validating Epoch {epoch + 1}/{num_epochs}", leave=False):
            one_hot_data, target = one_hot_data.to(device), target.to(device)
            output = model(one_hot_data.float())
            loss = criterion(output.transpose(1, 2), target)

            val_loss += loss.item()
            val_count += 1

    avg_val_loss = val_loss / val_count

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
    
    scheduler.step(avg_val_loss)


                                                                       

Epoch [1/5], Train Loss: 0.0547, Validation Loss: 0.0014


                                                                       

Epoch [2/5], Train Loss: 0.0164, Validation Loss: 0.0012


                                                                       

Epoch [3/5], Train Loss: 0.0178, Validation Loss: 0.0012


                                                                     

KeyboardInterrupt: 

In [107]:
torch.save(model.state_dict(), "trained_model_lstm_8.pth")


In [13]:
model.load_state_dict(torch.load("trained_model_lstm_8.pth"))
model.to(device)


SimplifiedSMILESGRU(
  (gru): GRU(40, 512, num_layers=3, batch_first=True, dropout=0.2)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=512, out_features=40, bias=True)
)

In [14]:
def vectorize_smiles(smiles, charset, max_length):
    vector = [charset.index(char) for char in smiles]
    padding = [len(charset) - 1] * (max_length - len(vector))
    vector.extend(padding)
    return torch.tensor(vector, dtype=torch.long)


In [15]:
def unvectorize_smiles(one_hot_array, charset):
    smiles = ''
    for i in one_hot_array:
        if charset[i] == 's':
            break  # Stop decoding when the first padding character is encountered
        smiles += charset[i]
    return smiles


In [16]:
def is_valid_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return mol is not None


In [17]:
def predict(model, input_smiles, charset, max_length):
    model.eval()
    with torch.no_grad():
        data = vectorize_smiles(input_smiles, charset, max_length)
        one_hot_data = torch.zeros(1, data.size(0), len(charset)).to(device)
        one_hot_data.scatter_(2, data.to(device).unsqueeze(0).unsqueeze(2), 1)
        output = model(one_hot_data.float())
        output_smiles = unvectorize_smiles(
            output[0].argmax(dim=1).tolist(), charset)
        return output_smiles


In [108]:
valid_count = 0
smiles_gen = set()
with open("predictions_lstm_8.txt", "w") as f:
    for input_smiles in val_data:
        predicted_smiles = predict(model, input_smiles, charset, max_length)
        #print(f"{predicted_smiles}")
        if predicted_smiles is not None and is_valid_smiles(predicted_smiles):
            if predicted_smiles not in data_canonical:
                smiles_gen.add(predicted_smiles)
                f.write(predicted_smiles + "\n")
                valid_count += 1
            if valid_count >= 10001:
                break


In [21]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs


In [24]:
def get_fingerprint(smiles):
    mol = Chem.MolFromSmiles(smiles)
    fp = AllChem.GetMorganFingerprint(mol, 2)
    return fp


threshold = 0.8

valid_count = 0
smiles_gen = set()
with open("predictions_lstm_test.txt", "w") as f:
    for input_smiles in val_data:
        predicted_smiles = predict(model, input_smiles, charset, max_length)
        # print(f"{predicted_smiles}")
        if predicted_smiles is not None and is_valid_smiles(predicted_smiles):
            if predicted_smiles not in data_canonical:
                # Compute the fingerprint of the predicted SMILES
                fp_pred = get_fingerprint(predicted_smiles)
                # Check if the predicted SMILES is similar to any of the previously generated SMILES
                is_similar = False
                for smiles in smiles_gen:
                    fp = get_fingerprint(smiles)
                    similarity = DataStructs.TanimotoSimilarity(fp, fp_pred)
                    if similarity > threshold:
                        is_similar = True
                        break
                if not is_similar:
                    smiles_gen.add(predicted_smiles)
                    f.write(predicted_smiles + "\n")
                    valid_count += 1
                if valid_count >= 10001:
                    break
