In [53]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn 
import pickle
import re
import model.base
from model.base import Transformer
import utils 
from utils import *
import rdkit
from rdkit.Chem import rdDistGeom
import numpy as np
from torch.nn.utils import clip_grad_norm_
import os 
import datetime 
from tqdm import tqdm 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rdkit.rdBase.DisableLog('rdApp.*') # Disable rdkit warnings


MAX_LEN = 30 
BATCH_SIZE = 64
EPOCHS = 100
LR = 0.0003
PATIENCE_THRESHOLD = 4
D_MODEL = 256
D_FF = 512 
N_LAYERS = 4
N_HEADS = 8
DROPOUT = 0.5


In [54]:
dataset = MyDataset('data/chembl24_canon_train.pickle', MAX_LEN)
train_set, val_set = random_split(dataset, [0.9, 0.1])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)
vocab, inv_vocab, max_coor_len, max_token_len = dataset.vocab, dataset.inv_vocab, dataset.max_coor_len, dataset.max_token_len

print(f'Number of data: {len(dataset)}')
print(f'Number of unique tokens: {len(vocab)}')
print(f'Maximum number of tokens: {max_token_len}')
print(f'Maximum number of coordinates: {max_coor_len}')

loss_fn = nn.L1Loss()
model = Transformer(D_MODEL, D_FF, N_HEADS, N_LAYERS, DROPOUT, vocab).to(device)
optim = torch.optim.Adam(model.parameters(), lr = LR, weight_decay=1e-6)

Error in embedding molecule and will be removed:  CCCCCCNC(=O)NC1=NC(=O)CN1C
Error in embedding molecule and will be removed:  Clc1ccc(C2CC3NCC32)cn1
Error in embedding molecule and will be removed:  N#CC12CCC1C1CCCCC12N1CCSCC1
Error in embedding molecule and will be removed:  CNC(=O)NC(=N)NCCN=[N+]=[N-]
Error in embedding molecule and will be removed:  COC(=O)NC1=NNC(=O)c2ccccc2N1
Error in embedding molecule and will be removed:  O=C1C(c2ccccc2)N2CCCCN1C2=O
Error in embedding molecule and will be removed:  Clc1ccc(OCC2NC3C=CC32)cn1
Error in embedding molecule and will be removed:  COC(=O)NC1=NCC(c2ccccc2)CN1C
Error in embedding molecule and will be removed:  Clc1ccc(C2CC3CNC32)cn1
Error in embedding molecule and will be removed:  CC(=O)NC(=N)NCCC(=O)O
Error in embedding molecule and will be removed:  Cc1cccnc1NC(=O)NC1=NC(=O)CN1C
Error in embedding molecule and will be removed:  c1cc2ccc1CCc1ccc(cc1)CC2
Number of data: 89959
Number of unique tokens: 64
Maximum number of tokens: 31
Max

In [55]:
best_loss = float('inf')
patience = 0
for epoch in range(EPOCHS) : 
    train_loss, val_loss = 0, 0
    model.train() 

    for src, x, y, z, tgt in tqdm(train_loader, desc=f'Epoch {epoch+1}') : 
        src, x, y, z, tgt = src.to(device), x.to(device), y.to(device), z.to(device), tgt.to(device)
        src_mask = (src != vocab['<PAD>']).unsqueeze(-2)
        x_mask = get_mask(x.squeeze(-1)[:, :-1])
        y_mask = get_mask(y.squeeze(-1)[:, :-1])
        z_mask = get_mask(z.squeeze(-1)[:, :-1])

        pred = model(src, x[:, :-1], y[:, :-1], z[:, :-1], src_mask, x_mask, y_mask, z_mask)
        loss = loss_fn(pred, tgt[:, 1:, :])
        train_loss += loss.item() 
        loss.backward(), optim.step(), optim.zero_grad(), clip_grad_norm_(model.parameters(), 5)

    model.eval()
    with torch.no_grad() : 
        for src, x, y, z, tgt in val_loader : 
            src, x, y, z, tgt = src.to(device), x.to(device), y.to(device), z.to(device), tgt.to(device)
            src_mask = (src != vocab['<PAD>']).unsqueeze(-2)
            x_mask = get_mask(x.squeeze(-1)[:, :-1])
            y_mask = get_mask(y.squeeze(-1)[:, :-1])
            z_mask = get_mask(z.squeeze(-1)[:, :-1])

            pred = model(src, x[:, :-1], y[:, :-1], z[:, :-1], src_mask, x_mask, y_mask, z_mask)
            loss = loss_fn(pred, tgt[:, 1:, :])
            val_loss += loss.item()



    if val_loss < best_loss :
        patience = 0
        best_loss = val_loss
    else : 
        patience += 1
        if patience > PATIENCE_THRESHOLD : 
            print(f'\n\nEarly Stopping at Epoch {epoch+1}')
            print(f'Best Loss: {best_loss / len(val_loader):.3f}')
            break
    
    print(f'\nEpoch {epoch+1} | Train Loss: {train_loss / len(train_loader):.3f} | Val Loss: {val_loss / len(val_loader):2f}\n')

Epoch 1:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊  | 1250/1266 [01:36<00:01, 12.94it/s]


KeyboardInterrupt: 