In [36]:
#Get SMILES data
data = open("SMILES.txt").read()

#List of all characters in dataset
chars = list(set(data))

#String of all characters in dataset
all_characters = "".join(str(x) for x in chars)

#Number of unique characters in dataset
n_characters = len(all_characters)

#Size of dataset, in number of characters
data_len = len(data)

#Size of dataset, in number of molecules
data_size = len(open("SMILES.txt").readlines())

In [37]:
import random

#Gets random molecule from dataset
def random_mol():
    start_index = int(random.randint(0, data_size) * (data_len / data_size))
    end_index = int(start_index + (data_len / data_size))
    
    return data[start_index:end_index]

In [92]:
import torch
import torch.nn as nn
from torch.autograd import Variable

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
    
    def forward(self, input, hidden):
        input = self.encoder(input.view(1, -1))
        output, hidden = self.lstm(input.view(1, 1, -1), hidden)
        output = self.decoder(output.view(1, -1))
        return output, hidden

    def init_hidden(self):
        return Variable(torch.zeros(self.n_layers, 1, self.hidden_size))

In [93]:
#Converts molecule in SMILES format to tensor 
def mol_tensor(mol):
    tensor = torch.zeros(len(mol)).long()

    for i in range(len(mol)):
        tensor[i] = all_characters.index(mol[i])
        
    return Variable(tensor)

In [94]:
def random_training_set():    
    mol = random_mol()
    inp = mol_tensor(mol[:-1])
    target = mol_tensor(mol[1:])
    return inp, target

In [95]:
def evaluate(prime_str="{", predict_len=(data_len / data_size), temperature=1.0):
    hidden = decoder.init_hidden()
    prime_input = mol_tensor(prime_str)
    predicted = prime_str

    # Use priming string to "build up" hidden state
    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[p], hidden)
    inp = prime_input[-1]
    
    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)
        
        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]
        
        # Add predicted character to string and use as next input
        predicted_char = all_characters[top_i]
        predicted += predicted_char
        inp = mol_tensor(predicted_char)

    return predicted

In [96]:
import time, math

def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [97]:
def train(inp, target):
    hidden = decoder.init_hidden()
    decoder.zero_grad()
    loss = 0

    for i in range(int((data_len / data_size))):
        output, hidden = decoder(inp[i], hidden)
        loss += criterion(output, target[i])

    loss.backward()
    decoder_optimizer.step()

    return loss.data[0] / (data_len / data_size)

In [98]:
n_epochs = 100
print_every = 10
plot_every = 10
hidden_size = 1024
n_layers = 3
lr = 0.001

decoder = RNN(n_characters, hidden_size, n_characters, n_layers)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

start = time.time()
all_losses = []
loss_avg = 0

for epoch in range(1, n_epochs + 1):
    loss = train(*random_training_set())       
    loss_avg += loss

    if epoch % print_every == 0:
        print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss))
        print(evaluate('Wh', 100), '\n')

    if epoch % plot_every == 0:
        all_losses.append(loss_avg / plot_every)
        loss_avg = 0

RuntimeError: Expected hidden[0] size (3, 1, 1024), got (1, 1024)