In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import copy 
import math 
from utils import get_smi_list, replace_atom, get_dic, encode_smi, pad_smi, clones, parallel_f, pad, normalize, get_atom_pos, MyDataset, subsequent_mask
from model import Encoder, Decoder, device
from tqdm.auto import tqdm


In [2]:
smi_list = get_smi_list('data/ADAGRASIB_SMILES.txt')

coor_list = parallel_f(get_atom_pos, smi_list)
longest_coor = len(max(coor_list, key = len))
coor_list = [pad(normalize(c), longest_coor) for c in coor_list]

smi_list = [replace_atom(smi) for smi in smi_list]
smi_dic = get_dic(smi_list)
smint_list = [encode_smi(smi, smi_dic) for smi in smi_list]
longest_smint = len(max(smint_list, key = len))
smint_list = [pad_smi(smint, longest_smint, smi_dic) for smint in smint_list]

[22:57:50] UFFTYPER: Unrecognized atom type: Ba (0)


In [3]:
BATCH_SIZE = 64
dataset = MyDataset(smint_list, coor_list)
train_set, val_set, test_set = random_split(dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True)

In [4]:
class DecoderLayer(nn.Module) :
    def __init__(self, dim_model, num_head, dropout, longest_coor) :
        super (DecoderLayer, self).__init__()
        self.dim_model = dim_model
        self.longest_coor = longest_coor

        self.seq1 = nn.Sequential(
            nn.Linear(3, dim_model),
            nn.LeakyReLU()
        )

        self.norm1 = nn.LayerNorm(dim_model) 
        self.self_attn = TargetAttention(dim_model, num_head, longest_coor)
        self.drop1 = nn.Dropout(dropout) 

        self.norm2 = nn.LayerNorm(dim_model)
        self.cross_attn = SourceAttention(dim_model, num_head)
        self.drop2 = nn.Dropout(dropout)
        
        self.norm3 = nn.LayerNorm(dim_model)
        self.feed_foward = nn.Sequential(
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU()
        )
        self.drop3 = nn.Dropout(dropout) 


    def forward(self, memory, target) : 
        target = target[:, :-1, :]
        mask = subsequent_mask(self.longest_coor - 1)
        mask = mask.unsqueeze(1)

        target = self.seq1(target) 
        
        target = self.norm1(target) 
        attn, _ = self.self_attn(target, target, target, mask) 
        target = target + self.drop1(attn) 

        target = self.norm2(target) 
        attn, _ = self.cross_attn(target, memory, memory)
        target = target + self.drop2(attn) 

        target = self.norm3(target) 
        target = target + self.drop3(self.feed_foward(target)) 
        
        return target

In [8]:
DIM_MODEL = 256 
NUM_HEAD = 4
NUM_LAYER = 1
DROPOUT = 0.5

encoder = Encoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, len(smi_dic)).to(device)
decoder = Decoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, longest_coor).to(device)

loss_fn = nn.L1Loss() 
e_optim = torch.optim.Adam(encoder.parameters(), lr = 0.001)
d_optim = torch.optim.Adam(decoder.parameters(), lr = 0.001)

In [9]:
NUM_EPOCHS = 30 

for epoch in tqdm(range(1, NUM_EPOCHS +1), total=NUM_EPOCHS) : 
    train_loss = 0 
    val_loss = 0
    encoder.train(), decoder.train()
    for input, target in train_loader :
        input, target = input.to(device), target.to(device) 
        memory = encoder(input) 
        prediction = decoder(memory, target)

        loss = loss_fn(prediction, target[:, 1:, :]) 
        train_loss += loss.item()
        loss.backward()
        e_optim.step(), d_optim.step()
        e_optim.zero_grad(), d_optim.zero_grad() 

    # encoder.eval(), decoder.eval()
    # with torch.no_grad() :
    #     for input, target in val_loader :
    #         input, target = input.to(device), target.to(device) 
    #         memory = encoder(input) 
    #         prediction = decoder(memory, None)

    #         loss = loss_fn(prediction, target) 
    #         val_loss += loss.item()
    print(f'Epoch {epoch} -- Train Loss: {train_loss / len(train_loader):.4f} -- Val Loss: {val_loss / len(val_loader):.4f}')

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1 -- Train Loss: 1.1044 -- Val Loss: 0.0000
Epoch 2 -- Train Loss: 0.8625 -- Val Loss: 0.0000
Epoch 3 -- Train Loss: 0.8240 -- Val Loss: 0.0000
Epoch 4 -- Train Loss: 0.8144 -- Val Loss: 0.0000
Epoch 5 -- Train Loss: 0.8013 -- Val Loss: 0.0000
Epoch 6 -- Train Loss: 0.7933 -- Val Loss: 0.0000
Epoch 7 -- Train Loss: 0.7854 -- Val Loss: 0.0000
Epoch 8 -- Train Loss: 0.7801 -- Val Loss: 0.0000
Epoch 9 -- Train Loss: 0.7775 -- Val Loss: 0.0000
Epoch 10 -- Train Loss: 0.7711 -- Val Loss: 0.0000
Epoch 11 -- Train Loss: 0.7677 -- Val Loss: 0.0000
Epoch 12 -- Train Loss: 0.7678 -- Val Loss: 0.0000
Epoch 13 -- Train Loss: 0.7634 -- Val Loss: 0.0000
Epoch 14 -- Train Loss: 0.7627 -- Val Loss: 0.0000
Epoch 15 -- Train Loss: 0.7586 -- Val Loss: 0.0000
Epoch 16 -- Train Loss: 0.7515 -- Val Loss: 0.0000
Epoch 17 -- Train Loss: 0.7560 -- Val Loss: 0.0000
Epoch 18 -- Train Loss: 0.7495 -- Val Loss: 0.0000
Epoch 19 -- Train Loss: 0.7488 -- Val Loss: 0.0000
Epoch 20 -- Train Loss: 0.7444 -- Val Lo