In [12]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn 
import pickle
import re
import model.base
from model.base import Transformer
import utils 
from utils import *
import rdkit
from rdkit.Chem import rdDistGeom
import numpy as np
from torch.nn.utils import clip_grad_norm_
import os 
import datetime 
from tqdm import tqdm 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rdkit.rdBase.DisableLog('rdApp.*') # Disable rdkit warnings

In [7]:
dataset = MyDataset('data/MOSES_SMILES.txt', 30)
train_set, val_set = random_split(dataset, [0.9, 0.1])
train_loader = DataLoader(train_set, batch_size=128)
val_loader = DataLoader(val_set, batch_size=128)
vocab, inv_vocab, max_coor_len, max_token_len = dataset.vocab, dataset.inv_vocab, dataset.max_coor_len, dataset.max_token_len

print(f'Number of data: {len(dataset)}')
print(f'Number of unique tokens: {len(vocab)}')
print(f'Maximum number of tokens: {max_token_len}')
print(f'Maximum number of coordinates: {max_coor_len}')

Number of data: 88192
Number of unique tokens: 24
Maximum number of tokens: 30
Maximum number of coordinates: 23


In [10]:
model = Transformer(256, 512, 8, 2, 0.5, vocab).to(device)
loss_fn = nn.L1Loss()
optim = torch.optim.Adam(model.parameters(), lr = 0.0003)

In [11]:
for i in range(20) : 
    train_loss, val_loss = 0, 0
    model.train()
    for src, x, y, z, tgt in tqdm(train_loader, desc=f'Epoch {i+1}') :
    # for src, x, y, z, tgt in train_loader : 
        src, x, y, z, tgt = src.to(device), x.to(device), y.to(device), z.to(device), tgt.to(device)
        src_mask = (src != vocab['<PAD>']).unsqueeze(-2) 
        x_mask = get_mask(x.squeeze(-1)[:, :-1])
        y_mask = get_mask(y.squeeze(-1)[:, :-1])
        z_mask = get_mask(z.squeeze(-1)[:, :-1])

        
        out = model(src, x[:, :-1], y[:, :-1], z[:, :-1], src_mask, x_mask, y_mask, z_mask)
        loss = loss_fn(out, tgt[:, 1:, :])
        train_loss += loss.item() 
        loss.backward(), optim.step(), optim.zero_grad(), clip_grad_norm_(model.parameters(), 5)

    model.eval()
    for src, x, y, z, tgt in val_loader : 
        src, x, y, z, tgt = src.to(device), x.to(device), y.to(device), z.to(device), tgt.to(device)
        src_mask = (src != vocab['<PAD>']).unsqueeze(-2) 
        x_mask = get_mask(x.squeeze(-1)[:, :-1])
        y_mask = get_mask(y.squeeze(-1)[:, :-1])
        z_mask = get_mask(z.squeeze(-1)[:, :-1])

        
        out = model(src, x[:, :-1], y[:, :-1], z[:, :-1], src_mask, x_mask, y_mask, z_mask)
        loss = loss_fn(out, tgt[:, 1:, :])
        val_loss += loss.item() 
    print(f'Epoch {i+1} - Train Loss: {train_loss / len(train_loader):.3f} - Val Loss: {val_loss / len(val_loader):.3f}')

Epoch 1 - Train Loss: 0.800 - Val Loss: 0.678
Epoch 2 - Train Loss: 0.685 - Val Loss: 0.656
Epoch 3 - Train Loss: 0.662 - Val Loss: 0.638
Epoch 4 - Train Loss: 0.645 - Val Loss: 0.649
Epoch 5 - Train Loss: 0.633 - Val Loss: 0.634
Epoch 6 - Train Loss: 0.623 - Val Loss: 0.630
Epoch 7 - Train Loss: 0.617 - Val Loss: 0.631


KeyboardInterrupt: 