In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
from sklearn.model_selection import train_test_split

import os
import pickle

from model import SelfiesTransformer
from custom_selfies_dataset import CustomSELFIESDataset

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debugging

In [4]:
PICKLED_DATASET_DIR = "./dataset_pickles"

## BBBP
binary classification

In [5]:
DATASET = 'bbbp'

In [6]:
data_path = os.path.join(PICKLED_DATASET_DIR, DATASET+"_data.pickle")
label_path= os.path.join(PICKLED_DATASET_DIR, DATASET+"_label.pickle")
sym2idx_path = os.path.join(PICKLED_DATASET_DIR, "symbol2idx_"+DATASET+".pickle")

In [7]:
with open(data_path, "rb") as f:
    X = pickle.load(f)
with open(label_path, "rb") as f:
    y = pickle.load(f)
with open(sym2idx_path, "rb") as f:
    symbol2idx = pickle.load(f)

In [8]:
print(len(X), len(y))

1996 1996


In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)

In [10]:
train_dataset = CustomSELFIESDataset(X_train, y_train)
val_dataset = CustomSELFIESDataset(X_val, y_val)
test_dataset = CustomSELFIESDataset(X_test, y_test)

train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
len(X_train[0])

100

In [13]:
config = {
    "vocab_dict": symbol2idx,
    "max_length": len(X_train[0]),
    "dim": 32,
    "n_classes": 1, # binary classification
    "heads": 2,
    "mlp_dim": 16,
    "depth": 2,
    "dim_head": 32,
    "dropout": 0.1,
    "emb_dropout": 0.1
}

In [14]:
model = SelfiesTransformer(**config)

In [15]:
criterion = nn.BCEWithLogitsLoss().to(device)

In [16]:
lr = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [17]:
softmax = nn.Softmax(dim=-1)

In [18]:
model.to(device)
for epoch in range(30):
    train_loss = []
    model.train()
    train_correct = 0
    # temp = 0
    for i, data in enumerate(train_dataloader, 0):
        inputs, labels = data
        # temp += len(inputs)
        inputs = inputs.to(device)
        labels = labels.unsqueeze(-1).to(device)
        #print(inputs.get_device())
        #print(labels.get_device())
        optimizer.zero_grad()
        outputs = model(inputs)
        
        labels_pred = softmax(outputs).argmax(1)
        
        # print(labels)
        # print(outputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_correct += (labels == labels_pred.unsqueeze(-1)).float().sum() # ???
        train_loss += [loss.item()]
    
    with torch.no_grad():
        model.eval()
        val_loss = []
        val_correct = 0
        for _, val_data in enumerate(val_dataloader, 0):
            v_inputs, v_labels = val_data
            v_inputs = v_inputs.to(device)
            v_labels = v_labels.unsqueeze(-1).to(device)
            
            v_outputs = model(v_inputs)
            v_labels_pred = softmax(v_outputs).argmax(1)
            
            v_loss = criterion(v_outputs, v_labels)
            val_loss += [v_loss.item()]
            val_correct += (v_labels == v_labels_pred.unsqueeze(-1)).float().sum()
    
    accuracy_train = train_correct / len(X_train)
    accuracy_val = val_correct / len(X_val)
    print("epoch: %04d | train loss: %.5f | train accuracy: %.4f | valid loss: %.5f | valid accuracy: %.4f" %
         (epoch + 1, np.mean(train_loss), accuracy_train, np.mean(val_loss), accuracy_val))
print("Finished Training")

epoch: 0001 | train loss: 0.61484 | train accuracy: 0.2222 | valid loss: 0.59523 | valid accuracy: 0.2180
epoch: 0002 | train loss: 0.60101 | train accuracy: 0.2222 | valid loss: 0.58254 | valid accuracy: 0.2180
epoch: 0003 | train loss: 0.58884 | train accuracy: 0.2222 | valid loss: 0.57228 | valid accuracy: 0.2180
epoch: 0004 | train loss: 0.57640 | train accuracy: 0.2222 | valid loss: 0.56382 | valid accuracy: 0.2180
epoch: 0005 | train loss: 0.57784 | train accuracy: 0.2222 | valid loss: 0.55709 | valid accuracy: 0.2180
epoch: 0006 | train loss: 0.56692 | train accuracy: 0.2222 | valid loss: 0.55075 | valid accuracy: 0.2180
epoch: 0007 | train loss: 0.56085 | train accuracy: 0.2222 | valid loss: 0.54621 | valid accuracy: 0.2180
epoch: 0008 | train loss: 0.55795 | train accuracy: 0.2222 | valid loss: 0.54306 | valid accuracy: 0.2180
epoch: 0009 | train loss: 0.55449 | train accuracy: 0.2222 | valid loss: 0.53975 | valid accuracy: 0.2180
epoch: 0010 | train loss: 0.54923 | train accu

## Lipophilicity
regression

In [19]:
DATASET = 'lipo'

In [20]:
data_path = os.path.join(PICKLED_DATASET_DIR, DATASET+"_data.pickle")
label_path= os.path.join(PICKLED_DATASET_DIR, DATASET+"_label.pickle")
sym2idx_path = os.path.join(PICKLED_DATASET_DIR, "symbol2idx_"+DATASET+".pickle")

In [21]:
with open(data_path, "rb") as f:
    X = pickle.load(f)
with open(label_path, "rb") as f:
    y = pickle.load(f)
with open(sym2idx_path, "rb") as f:
    symbol2idx = pickle.load(f)

In [22]:
print(len(X), len(y))

4194 4194


In [23]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)

In [24]:
print(len(X_train), len(X_val), len(X_test))

2516 839 839


In [25]:
train_dataset = CustomSELFIESDataset(X_train, y_train)
val_dataset = CustomSELFIESDataset(X_val, y_val)
test_dataset = CustomSELFIESDataset(X_test, y_test)

train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
with open("symbol2idx_lipo.pickle", "rb") as f:
    symbol2idx_lipo = pickle.load(f)

In [28]:
config = {
    "vocab_dict": symbol2idx,
    "max_length": len(X_train[0]),
    "dim": 32,
    "n_classes": 1, # regression
    "heads": 2,
    "mlp_dim": 16,
    "depth": 2,
    "dim_head": 32,
    "dropout": 0.1,
    "emb_dropout": 0.1
}

In [29]:
model = SelfiesTransformer(**config)

In [30]:
criterion = nn.MSELoss()
lr = 0.001 # 0.0001
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [31]:
model.to(device)
for epoch in range(30):
    train_loss = []
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.unsqueeze(-1).to(device)
        #print(inputs.get_device())
        #print(labels.get_device())
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # labels_pred = softmax(outputs).argmax(1)
        
        # print(labels)
        # print(outputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # train_correct += (labels == labels_pred.unsqueeze(-1)).float().sum() # ???
        train_loss += [loss.item()]
    
    with torch.no_grad():
        model.eval()
        val_loss = []
        for _, val_data in enumerate(val_dataloader, 0):
            v_inputs, v_labels = val_data
            v_inputs = v_inputs.to(device)
            v_labels = v_labels.unsqueeze(-1).to(device)
            v_outputs = model(v_inputs)
            v_loss = criterion(v_outputs, v_labels)
            val_loss += [v_loss.item()]
    # accuracy_train = train_correct / len(X_train)
    print("epoch: %04d | train loss: %.5f | valid loss: %.5f" %
         (epoch + 1, np.mean(train_loss), np.mean(val_loss)))
print("Finished Training")

epoch: 0001 | train loss: 1.62360 | valid loss: 1.38040
epoch: 0002 | train loss: 1.34807 | valid loss: 1.23110
epoch: 0003 | train loss: 1.24456 | valid loss: 1.41055
epoch: 0004 | train loss: 1.20306 | valid loss: 1.13739
epoch: 0005 | train loss: 1.18320 | valid loss: 1.15619
epoch: 0006 | train loss: 1.17312 | valid loss: 1.14489
epoch: 0007 | train loss: 1.16015 | valid loss: 1.15669
epoch: 0008 | train loss: 1.13301 | valid loss: 1.13610
epoch: 0009 | train loss: 1.10797 | valid loss: 1.13433
epoch: 0010 | train loss: 1.11741 | valid loss: 1.11912
epoch: 0011 | train loss: 1.09624 | valid loss: 1.12030
epoch: 0012 | train loss: 1.09728 | valid loss: 1.12407
epoch: 0013 | train loss: 1.06992 | valid loss: 1.16924
epoch: 0014 | train loss: 1.06500 | valid loss: 1.12765
epoch: 0015 | train loss: 1.05079 | valid loss: 1.09100
epoch: 0016 | train loss: 1.03690 | valid loss: 1.13246
epoch: 0017 | train loss: 1.07511 | valid loss: 1.06088
epoch: 0018 | train loss: 1.03054 | valid loss: 