In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import copy 
import math 
from utils import get_smi_list, replace_atom, get_dic, encode_smi, pad_smi, clones, parallel_f, pad, normalize, get_atom_pos, MyDataset, subsequent_mask
from model import Encoder, Decoder, device
from tqdm.auto import tqdm


In [2]:
smi_list = get_smi_list('data/ADAGRASIB_SMILES.txt')

coor_list = parallel_f(get_atom_pos, smi_list)
longest_coor = len(max(coor_list, key = len))
coor_list = [pad(normalize(c), longest_coor) for c in coor_list]

smi_list = [replace_atom(smi) for smi in smi_list]
smi_dic = get_dic(smi_list)
smint_list = [encode_smi(smi, smi_dic) for smi in smi_list]
longest_smint = len(max(smint_list, key = len))
smint_list = [pad_smi(smint, longest_smint, smi_dic) for smint in smint_list]

[17:58:03] UFFTYPER: Unrecognized atom type: Ba (0)


In [7]:
BATCH_SIZE = 1
dataset = MyDataset(smint_list, coor_list)
train_set, val_set, test_set = random_split(dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True)

In [8]:
for input, target in train_loader :
    print(f'target: {target.shape} {target}')
    break

target: torch.Size([1, 22, 3]) tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.9806, -0.4204,  0.4636],
         [ 2.1963, -0.9230,  1.0381],
         [ 3.4562, -0.3958,  0.7856],
         [ 4.2967, -1.1415,  1.5060],
         [ 3.6827, -2.0901,  2.1890],
         [ 2.3903, -1.9700,  1.9128],
         [ 1.3577, -2.7630,  2.4119],
         [ 1.4367, -4.0102,  2.6923],
         [ 0.4177, -4.8152,  3.1904],
         [ 0.6199, -5.5436,  4.4101],
         [-0.8342, -4.9258,  2.4995],
         [ 4.3631, -3.0116,  3.0299],
         [ 5.3108, -3.8997,  2.5681],
         [ 5.9518, -4.7824,  3.4141],
         [ 5.6229, -4.7625,  4.7586],
         [ 4.7032, -3.8921,  5.1723],
         [ 4.0970, -3.0520,  4.3501],
         [ 6.3576, -5.8246,  5.9342],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]])


In [4]:
DIM_MODEL = 256 
NUM_HEAD = 8
NUM_LAYER = 2
DROPOUT = 0.5

encoder = Encoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, len(smi_dic)).to(device)
decoder = Decoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, longest_coor).to(device)

loss_fn = nn.L1Loss() 
e_optim = torch.optim.Adam(encoder.parameters(), lr = 0.001)
d_optim = torch.optim.Adam(decoder.parameters(), lr = 0.001)

In [5]:
NUM_EPOCHS = 1000

for epoch in tqdm(range(1, NUM_EPOCHS +1), total=NUM_EPOCHS) : 
    train_loss = 0 
    val_loss = 0
    encoder.train(), decoder.train()
    for input, target in train_loader :
        input, target = input.to(device), target.to(device) 
        memory = encoder(input) 
        prediction, cross_attn = decoder(memory, target[:, :-1, :])

        loss = loss_fn(prediction, target[:, 1:, :]) 
        train_loss += loss.item()
        loss.backward()
        e_optim.step(), d_optim.step()
        e_optim.zero_grad(), d_optim.zero_grad() 

    encoder.eval(), decoder.eval()
    with torch.no_grad() :
        for input, target in val_loader :
            input, target = input.to(device), target.to(device) 
            memory = encoder(input) 
            prediction, _ = decoder(memory)

            loss = loss_fn(prediction, target[:, 1:, :]) 
            val_loss += loss.item()
    print(f'Epoch {epoch} -- Train Loss: {train_loss / len(train_loader):.4f} -- Val Loss: {val_loss / len(val_loader):.4f}')

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1 -- Train Loss: 2.0469 -- Val Loss: 2.5649
Epoch 2 -- Train Loss: 1.9935 -- Val Loss: 2.5419
Epoch 3 -- Train Loss: 1.9762 -- Val Loss: 2.5204
Epoch 4 -- Train Loss: 1.9601 -- Val Loss: 2.4994
Epoch 5 -- Train Loss: 1.9455 -- Val Loss: 2.4745
Epoch 6 -- Train Loss: 1.9310 -- Val Loss: 2.4629
Epoch 7 -- Train Loss: 1.9169 -- Val Loss: 2.4419
Epoch 8 -- Train Loss: 1.9027 -- Val Loss: 2.4288
Epoch 9 -- Train Loss: 1.8891 -- Val Loss: 2.4180
Epoch 10 -- Train Loss: 1.8760 -- Val Loss: 2.3994
Epoch 11 -- Train Loss: 1.8626 -- Val Loss: 2.3850
Epoch 12 -- Train Loss: 1.8501 -- Val Loss: 2.3712
Epoch 13 -- Train Loss: 1.8405 -- Val Loss: 2.3738
Epoch 14 -- Train Loss: 1.8267 -- Val Loss: 2.3576
Epoch 15 -- Train Loss: 1.8146 -- Val Loss: 2.3517
Epoch 16 -- Train Loss: 1.8033 -- Val Loss: 2.3221
Epoch 17 -- Train Loss: 1.7933 -- Val Loss: 2.3321
Epoch 18 -- Train Loss: 1.7819 -- Val Loss: 2.3313
Epoch 19 -- Train Loss: 1.7718 -- Val Loss: 2.3228
Epoch 20 -- Train Loss: 1.7618 -- Val Lo

KeyboardInterrupt: 

In [None]:
a = torch.empty(1, 1)
b = torch.randn(1, 1)
c = torch.cat((a, b), dim = 1)
c.shape

torch.Size([1, 2])