In [4]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import time
import sys

#Set device
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
    device = torch.device("cuda")
    cuda = True
else:
    device = torch.device("cpu")
    cuda = False
    
print("Device =",device)
gpus = [0]

Device = cuda


In [5]:
def time_elapsed(start_time):
    elapsed = time.time() - start_time
    hours = int(elapsed/3600)
    minutes = int(int(elapsed/60)%60)
    seconds = int(elapsed%60)
    
    return hours, minutes, seconds

In [6]:
#Load SMILES data as integer labels and as one-hot encoding
data = np.load("ohesmiles.npz")
data = data["arr_0"]

intdata = np.load("intsmiles.npz")
intdata = intdata["arr_0"]

data = torch.from_numpy(data).view(np.shape(data)[0], 1, np.shape(data)[1])
intdata = torch.from_numpy(intdata)

print("Dataset size: " + str(data.size()))
print("Integer dataset size: " + str(intdata.size()))

Dataset size: torch.Size([34131372, 1, 55])
Integer dataset size: torch.Size([34131372])


In [7]:
#Get input tensor
def inp(i, shuffle):
    
    #Input (does not include last character in SMILES)
    inp = data[int(shuffle[i] * seq_length) : int((shuffle[i] * seq_length) + seq_length - 1), :, :]

    return inp

In [8]:
#Get target tensor
def target(i, shuffle):
    
    #Target (does not include first character in SMILES)
    target = intdata[int((shuffle[i] * seq_length) + 1) : int((shuffle[i] * seq_length) + seq_length)]

    return target

In [9]:
#Define model
class Model(nn.Module):
    
    #Define model parameters
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(Model, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, dropout = dropout)
        self.linear = nn.Linear(hidden_size, input_size)
        
        self.cuda()  
        
    #Define initial hidden and cell states
    def init_states(self, num_layers, hidden_size):
        hidden = [Variable(torch.zeros(num_layers, 1, hidden_size)),
                  Variable(torch.zeros(num_layers, 1, hidden_size))]
        
        return hidden
    
    #Define forward propagation
    def forward(self, inp, hidden):
        output, hidden = self.lstm(inp, hidden)
        output = self.linear(output)
        
        return output, hidden
    

In [10]:
ipython_vars = ['In', 'Out', 'exit', 'quit', 'get_ipython', 'ipython_vars']
print(sorted([(x, sys.getsizeof(globals().get(x))) for x in dir() if not x.startswith('_') and x not in sys.modules and x not in ipython_vars], key=lambda x: x[1], reverse=True))

[('Model', 1016), ('Variable', 1016), ('inp', 136), ('target', 136), ('time_elapsed', 136), ('F', 80), ('nn', 80), ('np', 80), ('data', 72), ('gpus', 72), ('intdata', 72), ('USE_CUDA', 28), ('cuda', 28), ('device', 24)]


In [11]:
#Set start time
start_time = time.time()

#Define training
def train(epochs):
    #Iterate over desired number of epochs 
    for e in range(epochs):
        
        #Get random order of SMILES molecules (shuffle data)
        shuffle = np.arange(np.shape(data)[0] / seq_length)
        random.shuffle(shuffle)
        
        #Iterate over each molecule in dataset
        for i in range(int(np.shape(data)[0] / seq_length)):
            
            #Initialize hidden and cell states
            hidden = model.init_states(num_layers, hidden_size)
            
            #Run on GPU if available
            if cuda:
                hidden = (hidden[0].cuda(), hidden[1].cuda())
        
            #Set initial gradients
            model.zero_grad()
    
            #Set initial loss
            loss = 0 
            
            #Get input and target
            input_data = inp(i, shuffle).float()
            target_data = target(i, shuffle).long()
            
            #Run on GPU if available
            if cuda:
                input_data = input_data.cuda()
                target_data = target_data.cuda()
                
            #Run model, calculate loss
            output, hidden = model(input_data, hidden)
            loss += criterion(output.squeeze(), target_data.squeeze())
                
            #Backpropagate loss
            loss.backward()
            optimizer.step()
            
            hours, minutes, seconds = time_elapsed(start_time)
            print("Loss: {:0.8f}".format(loss.data.item() / seq_length) + " | Epoch: " + str(e) + " | Iteration: " + str(i) + " | Time elapsed: " + str(hours) + " hours " + str(minutes) + " minutes " + str(seconds) + " seconds ")
        

In [12]:
#Initialize model parameters
input_size = np.shape(data)[2]
hidden_size = 256
num_layers = 3
dropout = .2
learning_rate = 0.001
epochs = 10
seq_length = 76

In [13]:
#Call model, set optimizer and loss function
model = Model(input_size, hidden_size, num_layers, dropout)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

#Run on GPU if available
if cuda:
    model.cuda()
    criterion.cuda()  

In [14]:
#Total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters in network: " + str(total_params))

Total number of parameters in network: 1387319


In [15]:
#Train
train(epochs)

Loss: 0.05259299 | Epoch: 0 | Iteration: 0 | Time elapsed: 0 hours 1 minutes 39 seconds 
Loss: 0.05197072 | Epoch: 0 | Iteration: 1 | Time elapsed: 0 hours 1 minutes 39 seconds 
Loss: 0.05177434 | Epoch: 0 | Iteration: 2 | Time elapsed: 0 hours 1 minutes 39 seconds 
Loss: 0.05026408 | Epoch: 0 | Iteration: 3 | Time elapsed: 0 hours 1 minutes 39 seconds 
Loss: 0.05029769 | Epoch: 0 | Iteration: 4 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.04959810 | Epoch: 0 | Iteration: 5 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.03839620 | Epoch: 0 | Iteration: 6 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.03233746 | Epoch: 0 | Iteration: 7 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.03872953 | Epoch: 0 | Iteration: 8 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.03565246 | Epoch: 0 | Iteration: 9 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.02443756 | Epoch: 0 | Iteration: 10 | Time elapsed: 0 hours 1 minutes 40 seconds 
Loss: 0.01795429 | E

Loss: 0.03744731 | Epoch: 0 | Iteration: 93 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.03976618 | Epoch: 0 | Iteration: 94 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.03449283 | Epoch: 0 | Iteration: 95 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.02458988 | Epoch: 0 | Iteration: 96 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.01826405 | Epoch: 0 | Iteration: 97 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.02370512 | Epoch: 0 | Iteration: 98 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.01269960 | Epoch: 0 | Iteration: 99 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.02200465 | Epoch: 0 | Iteration: 100 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.02016206 | Epoch: 0 | Iteration: 101 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.01920528 | Epoch: 0 | Iteration: 102 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 0.01367604 | Epoch: 0 | Iteration: 103 | Time elapsed: 0 hours 1 minutes 46 seconds 
Loss: 

Loss: 0.01447837 | Epoch: 0 | Iteration: 186 | Time elapsed: 0 hours 1 minutes 50 seconds 
Loss: 0.01850335 | Epoch: 0 | Iteration: 187 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.01599591 | Epoch: 0 | Iteration: 188 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.03261890 | Epoch: 0 | Iteration: 189 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.01970075 | Epoch: 0 | Iteration: 190 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.02637270 | Epoch: 0 | Iteration: 191 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.02279642 | Epoch: 0 | Iteration: 192 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.02195648 | Epoch: 0 | Iteration: 193 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.01823464 | Epoch: 0 | Iteration: 194 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.01868761 | Epoch: 0 | Iteration: 195 | Time elapsed: 0 hours 1 minutes 51 seconds 
Loss: 0.01034693 | Epoch: 0 | Iteration: 196 | Time elapsed: 0 hours 1 minutes 51 seconds 

Loss: 0.02571728 | Epoch: 0 | Iteration: 278 | Time elapsed: 0 hours 1 minutes 55 seconds 
Loss: 0.02090328 | Epoch: 0 | Iteration: 279 | Time elapsed: 0 hours 1 minutes 55 seconds 
Loss: 0.02099344 | Epoch: 0 | Iteration: 280 | Time elapsed: 0 hours 1 minutes 55 seconds 
Loss: 0.01976320 | Epoch: 0 | Iteration: 281 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.01993228 | Epoch: 0 | Iteration: 282 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.01928673 | Epoch: 0 | Iteration: 283 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.02041777 | Epoch: 0 | Iteration: 284 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.01545162 | Epoch: 0 | Iteration: 285 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.03080968 | Epoch: 0 | Iteration: 286 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.01918701 | Epoch: 0 | Iteration: 287 | Time elapsed: 0 hours 1 minutes 56 seconds 
Loss: 0.01692353 | Epoch: 0 | Iteration: 288 | Time elapsed: 0 hours 1 minutes 56 seconds 

Loss: 0.01810839 | Epoch: 0 | Iteration: 370 | Time elapsed: 0 hours 2 minutes 0 seconds 
Loss: 0.02641704 | Epoch: 0 | Iteration: 371 | Time elapsed: 0 hours 2 minutes 0 seconds 
Loss: 0.01790132 | Epoch: 0 | Iteration: 372 | Time elapsed: 0 hours 2 minutes 0 seconds 
Loss: 0.02654272 | Epoch: 0 | Iteration: 373 | Time elapsed: 0 hours 2 minutes 0 seconds 
Loss: 0.02954559 | Epoch: 0 | Iteration: 374 | Time elapsed: 0 hours 2 minutes 0 seconds 
Loss: 0.01086215 | Epoch: 0 | Iteration: 375 | Time elapsed: 0 hours 2 minutes 0 seconds 
Loss: 0.01924834 | Epoch: 0 | Iteration: 376 | Time elapsed: 0 hours 2 minutes 1 seconds 
Loss: 0.02031528 | Epoch: 0 | Iteration: 377 | Time elapsed: 0 hours 2 minutes 1 seconds 
Loss: 0.02597263 | Epoch: 0 | Iteration: 378 | Time elapsed: 0 hours 2 minutes 1 seconds 
Loss: 0.01835566 | Epoch: 0 | Iteration: 379 | Time elapsed: 0 hours 2 minutes 1 seconds 
Loss: 0.02157862 | Epoch: 0 | Iteration: 380 | Time elapsed: 0 hours 2 minutes 1 seconds 
Loss: 0.01

Loss: 0.02933039 | Epoch: 0 | Iteration: 462 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01958766 | Epoch: 0 | Iteration: 463 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01128637 | Epoch: 0 | Iteration: 464 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01879349 | Epoch: 0 | Iteration: 465 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.03057886 | Epoch: 0 | Iteration: 466 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.02209625 | Epoch: 0 | Iteration: 467 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01737340 | Epoch: 0 | Iteration: 468 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01476766 | Epoch: 0 | Iteration: 469 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01303931 | Epoch: 0 | Iteration: 470 | Time elapsed: 0 hours 2 minutes 5 seconds 
Loss: 0.01494474 | Epoch: 0 | Iteration: 471 | Time elapsed: 0 hours 2 minutes 6 seconds 
Loss: 0.02154333 | Epoch: 0 | Iteration: 472 | Time elapsed: 0 hours 2 minutes 6 seconds 
Loss: 0.01

Loss: 0.02572964 | Epoch: 0 | Iteration: 555 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.02675602 | Epoch: 0 | Iteration: 556 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.02592229 | Epoch: 0 | Iteration: 557 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.01210214 | Epoch: 0 | Iteration: 558 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.02180693 | Epoch: 0 | Iteration: 559 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.02675682 | Epoch: 0 | Iteration: 560 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.01915044 | Epoch: 0 | Iteration: 561 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.02147688 | Epoch: 0 | Iteration: 562 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.01213869 | Epoch: 0 | Iteration: 563 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.01131176 | Epoch: 0 | Iteration: 564 | Time elapsed: 0 hours 2 minutes 11 seconds 
Loss: 0.01389537 | Epoch: 0 | Iteration: 565 | Time elapsed: 0 hours 2 minutes 11 seconds 

Loss: 0.00819935 | Epoch: 0 | Iteration: 648 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01607983 | Epoch: 0 | Iteration: 649 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01608893 | Epoch: 0 | Iteration: 650 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01764079 | Epoch: 0 | Iteration: 651 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.02155855 | Epoch: 0 | Iteration: 652 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01258638 | Epoch: 0 | Iteration: 653 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.02225593 | Epoch: 0 | Iteration: 654 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01340318 | Epoch: 0 | Iteration: 655 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01337396 | Epoch: 0 | Iteration: 656 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01523514 | Epoch: 0 | Iteration: 657 | Time elapsed: 0 hours 2 minutes 17 seconds 
Loss: 0.01036998 | Epoch: 0 | Iteration: 658 | Time elapsed: 0 hours 2 minutes 17 seconds 

Loss: 0.01130079 | Epoch: 0 | Iteration: 739 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01289627 | Epoch: 0 | Iteration: 740 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01392669 | Epoch: 0 | Iteration: 741 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01402325 | Epoch: 0 | Iteration: 742 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01832409 | Epoch: 0 | Iteration: 743 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01321732 | Epoch: 0 | Iteration: 744 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01559056 | Epoch: 0 | Iteration: 745 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.02296677 | Epoch: 0 | Iteration: 746 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01627017 | Epoch: 0 | Iteration: 747 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01522683 | Epoch: 0 | Iteration: 748 | Time elapsed: 0 hours 2 minutes 22 seconds 
Loss: 0.01714658 | Epoch: 0 | Iteration: 749 | Time elapsed: 0 hours 2 minutes 22 seconds 

KeyboardInterrupt: 