In [102]:
import pandas as pd
from torch import LongTensor
import torch.nn as nn
from torch.utils.data import Dataset
from typing import Iterable, List, Dict, Tuple
import torch
import numpy as np
from tqdm import tqdm
import io

In [80]:
class SmilesRNNDataSet(Dataset):

    START:str = "G"
    END:str = "E"
    PADDING:str = "A"

    def __init__(self, smiles: Iterable[str])->None:
        self.smiles: Iterable[str] = smiles
        self.max_len: int = self._max_len
        self.alphabet = self._alphabet
        self.inv_alphabet = self._inv_alphabet

    def __len__(self)->int:
        return len(self.smiles)

    @property
    def _max_len(self)->int:
        return max([len(smile) for smile in self.smiles])

    @property
    def _alphabet(self)-> Dict[str, int]:
        alphabet =  list(set.union(*[set(smile)for smile in self.smiles]))+[self.START, self.END, self.PADDING]
        return {element:value for value, element in enumerate(alphabet)}

    @property
    def _inv_alphabet(self):
        return {value:key for key, value in self._alphabet.items()}

    def __getitem__(self, index):
        padding_len:int = self.max_len - len(self.smiles[index])
        preprocessed:str = self.START + self.smiles[index] + self.END + self.PADDING*padding_len
        return LongTensor([self.alphabet[letter] for letter in preprocessed][:-1]), LongTensor([self.alphabet[letter] for letter in preprocessed][1:])

In [81]:
class RNN(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.embed = nn.Embedding(input_size, embed_size)

        self.lstm = nn.LSTMCell(embed_size, hidden_size)

        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, character, hidden, cell):
        embedded = self.embed(character)
        hidden, cell = self.lstm(embedded, (hidden, cell))
        out = self.fc(hidden)
        return out, hidden, cell

    def init_zero_state(self):
        hidden = torch.zeros(1, self.hidden_size)
        cell = torch.zeros(1, self.hidden_size)
        return hidden, cell

In [106]:
def get_random_data(dataset: Dataset)->Tuple[torch.Tensor]:
    return dataset[np.random.randint(0, len(dataset))]

def evaluate(model, temperature):
    hidden, cell = model.init_zero_state()

    inp = torch.tensor([dataset._alphabet["G"]])
    predicted = "G"

    for _ in range(len(dataset[0][0])):
        out, hidden, cell = model(inp, hidden, cell)
        output_dist = out.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        predicted_char = dataset._inv_alphabet[top_i.item()]
        predicted += predicted_char
        inp = torch.tensor([dataset._alphabet[predicted_char]])
    
    return predicted


In [121]:
def trim_smiles(smile):
    trimmed_smile = ""
    for char in smile[1:]:
        if char == "E":
            return trimmed_smile
        trimmed_smile+=char    

In [110]:
data = pd.read_csv("../data/androgen_data.csv")
smiles = data["canonical_smiles"].to_list()
dataset = SmilesRNNDataSet(smiles)

ALPHABET_SIZE = len(dataset._alphabet)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
LEARNING_RATE = 0.005


model = RNN(ALPHABET_SIZE, EMBEDDING_DIM, HIDDEN_DIM, ALPHABET_SIZE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [123]:
for iteration in range(5000):
    hidden, cell = model.init_zero_state()
    optimizer.zero_grad()

    loss = 0
    losses = []

    inputs, targets = get_random_data(dataset)

    for char, target in zip(inputs, targets):
        out, hidden, cell = model(char.unsqueeze(0), hidden, cell)
        loss+= nn.functional.cross_entropy(out, target.view(1))
    
    loss = loss/len(inputs)
    losses.append(loss)
    loss.backward()
    
    optimizer.step()

    with torch.no_grad():
        
        if iteration % 500 == 0:
            print(f"Iteration: {iteration}| Loss: {loss}")
            print(trim_smiles(evaluate(model, 0.8)))
            print("\n")

Iteration: 0| Loss: 0.07067514955997467
[C-]#[N+]c1ccc(N2C(F)(F)F)ccc1O


Iteration: 500| Loss: 0.16816778481006622
Cc1nn(-c2ccccc2Cl)cc1Cl


Iteration: 1000| Loss: 0.15978243947029114
C[C@]12CCC3C(CC=C4C[C@@H]4CC(=O)CC[C@@]43CC[C@@H](c4ccc3)C[C@@]21C


Iteration: 1500| Loss: 0.36353158950805664
C[C@](O)(C#C)[C@H]1CC(C)CCCCCCCC2=Cc3c(-c4ccc(F)cc4)N(C)C)ccc3c3)c2c1-c1ccc(F)c2c1[C@@H](c1ccccc1)C(=O)Nc1ccc(C(=O)O)c(C)c1


Iteration: 2000| Loss: 0.09330683201551437
None


Iteration: 2500| Loss: 0.18971037864685059
CNC(=O)c1cccc(C(=O)O)cc1C(F)(F)F


Iteration: 3000| Loss: 0.1421380490064621
COC(=O)c1c(O)cccc1O


Iteration: 3500| Loss: 0.3314056694507599
C[C@](O)(C(NC(=O)[C@H]1CC[C@H]1[C@H]2CC[C@]2(C)C(CC[C@@]21C)[C@@]1(C)CCC1=CC[C@@]21C=CCO1


Iteration: 4000| Loss: 0.11134600639343262
Nc1ccccc1-c1ccc2[nH]c3c(c21)C(=O)c1cc2ccccc2Cl)C(=O)C1CCCC2)c1ccc(Cl)c(Cl)c1


Iteration: 4500| Loss: 0.0832793191075325
Cc1cc(N2CCC3CC2CC2)cc1


