In [34]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import copy 
import math 
from utils import get_smi_list, replace_atom, get_dic, encode_smi, pad_smi, clones, parallel_f, pad, normalize, get_atom_pos, MyDataset, subsequent_mask
from model import Encoder, Decoder, device
from tqdm.auto import tqdm


In [37]:
def get_tgt_mask(target) :
    mask = (target != 0).unsqueeze(-2) 
    mask = mask & subsequent_mask(target.size(-1)) 
    return mask

In [102]:
def get_tgt_mask(target) :
    mask = torch.empty(0) 
    for i in range(1, target.size(1) + 1) :
        temp = torch.clone(target)
        temp[:, i:, :] = False
        temp = temp.unsqueeze(1)
        mask = torch.cat((mask, temp), dim = 1)
    return mask 

In [65]:
target = torch.randn(1, 1, 5) 
print(target)
target.repeat(1, 5, 1)


tensor([[[ 0.1734, -1.5940, -0.2606, -0.8086,  0.1455]]])


tensor([[[ 0.1734, -1.5940, -0.2606, -0.8086,  0.1455],
         [ 0.1734, -1.5940, -0.2606, -0.8086,  0.1455],
         [ 0.1734, -1.5940, -0.2606, -0.8086,  0.1455],
         [ 0.1734, -1.5940, -0.2606, -0.8086,  0.1455],
         [ 0.1734, -1.5940, -0.2606, -0.8086,  0.1455]]])

In [52]:
target = torch.randn(1, 5, 3)
print(f'target: {target}')
target = target[:, 1:, :]
print(f'target: {target}')



target: tensor([[[-1.1767,  2.7450, -0.8661],
         [-0.1688, -2.3080,  1.0370],
         [-0.9792, -1.4364,  0.3771],
         [ 0.4365,  0.6331, -0.4206],
         [-0.2603,  1.2520, -0.8751]]])
target: tensor([[[-0.1688, -2.3080,  1.0370],
         [-0.9792, -1.4364,  0.3771],
         [ 0.4365,  0.6331, -0.4206],
         [-0.2603,  1.2520, -0.8751]]])


tensor([[[[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]],

         [[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]],

         [[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]],

         [[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]]]])

In [16]:
smi_list = get_smi_list('data/ADAGRASIB_SMILES.txt')

coor_list = parallel_f(get_atom_pos, smi_list)
longest_coor = len(max(coor_list, key = len))
coor_list = [pad(normalize(c), longest_coor) for c in coor_list]

smi_list = [replace_atom(smi) for smi in smi_list]
smi_dic = get_dic(smi_list)
smint_list = [encode_smi(smi, smi_dic) for smi in smi_list]
longest_smint = len(max(smint_list, key = len))
smint_list = [pad_smi(smint, longest_smint, smi_dic) for smint in smint_list]

[19:29:17] UFFTYPER: Unrecognized atom type: Ba (0)


In [27]:
BATCH_SIZE = 1
dataset = MyDataset(smint_list, coor_list)
train_set, val_set, test_set = random_split(dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True)

In [None]:
class DecoderLayer(nn.Module) :
    def __init__(self, dim_model, num_head, dropout, longest_coor) :
        super (DecoderLayer, self).__init__()
        self.dim_model = dim_model
        self.longest_coor = longest_coor

        self.norm1 = nn.LayerNorm(dim_model) 
        self.self_attn = TargetAttention(dim_model, num_head, longest_coor)
        self.drop1 = nn.Dropout(dropout) 

        self.norm2 = nn.LayerNorm(dim_model)
        self.cross_attn = SourceAttention(dim_model, num_head)
        self.drop2 = nn.Dropout(dropout)
        
        self.norm3 = nn.LayerNorm(dim_model)
        self.feed_foward = nn.Sequential(
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU()
        )
        self.drop3 = nn.Dropout(dropout) 


    def forward(self, memory, target) : 
        target = target[]
        mask = get_tgt_mask(target).unsqueeze(1) 

        
        
        return y 

In [36]:
target = torch.randn(2, 10) 
print(f'target: {target}')
target_y = target[:, 1:] 
target_mask = (target != 2).unsqueeze(-2)
print(f'mask: {target_mask.shape} \n {target_mask}')

target_mask = target_mask & subsequent_mask(target.size(-1))
print(f'mask: {target_mask.shape} \n {target_mask}')

target: tensor([[-0.6757, -0.6612,  0.5823, -0.2785, -1.0716, -0.0079,  0.8986,  0.6482,
         -2.0828,  0.2412],
        [-0.4615,  0.5928,  1.1354, -0.5337,  0.2759, -0.5715,  1.7895,  0.1297,
          0.6090, -0.1425]])
mask: torch.Size([2, 1, 10]) 
 tensor([[[True, True, True, True, True, True, True, True, True, True]],

        [[True, True, True, True, True, True, True, True, True, True]]])
mask: torch.Size([2, 10, 10]) 
 tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False

In [29]:
for input, target in train_loader :
    mask = target[:, 1:, :]
    print(f'mask: {mask.shape}')
    print(mask)
    break

target: torch.Size([1, 22, 3])
tensor([[[ 0.0000,  0.0000,  0.0000],
         [-1.0063, -0.1920, -0.9023],
         [-0.7568, -0.4628, -2.0892],
         [-2.4237, -0.0856, -0.5034],
         [-2.7379,  0.1860,  0.6837],
         [-3.4523, -0.2776, -1.4067],
         [ 2.5194,  0.4745,  0.7282],
         [ 1.7493,  0.4355, -0.4975],
         [ 0.2908,  0.1984, -0.2122],
         [-0.4396,  0.1725, -1.5517],
         [-1.8369, -0.0568, -1.2456],
         [-2.6100,  1.1284, -1.0383],
         [-4.0329,  1.0158, -1.5066],
         [-4.9130,  0.1504, -0.6993],
         [-4.2698, -1.0023, -0.0136],
         [-3.6355, -1.8918, -1.0662],
         [-2.2077, -1.4289, -1.2024],
         [-1.3145, -2.3092, -1.2802],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]])
mask: torch.Size([1, 21, 3])
tensor([[[-1.0063, -0.1920, -0.9023],
         [-0.7568, -0.4628, -2.0892],
         [-2.4237, -0.0856

In [18]:
DIM_MODEL = 128 
NUM_HEAD = 4
NUM_LAYER = 1
DROPOUT = 0.5

encoder = Encoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, len(smi_dic)).to(device)
decoder = Decoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, longest_coor).to(device)

loss_fn = nn.L1Loss() 
e_optim = torch.optim.Adam(encoder.parameters(), lr = 0.001)
d_optim = torch.optim.Adam(decoder.parameters(), lr = 0.001)

In [23]:
NUM_EPOCHS = 30 

for epoch in tqdm(range(1, NUM_EPOCHS +1), total=NUM_EPOCHS) : 
    train_loss = 0 
    val_loss = 0
    encoder.train(), decoder.train()
    for input, target in train_loader :
        input, target = input.to(device), target.to(device) 
        memory = encoder(input) 
        prediction = decoder(memory, None)

        loss = loss_fn(prediction, target) 
        train_loss += loss.item()
        loss.backward()
        e_optim.step(), d_optim.step()
        e_optim.zero_grad(), d_optim.zero_grad() 

    encoder.eval(), decoder.eval()
    with torch.no_grad() :
        for input, target in val_loader :
            input, target = input.to(device), target.to(device) 
            memory = encoder(input) 
            prediction = decoder(memory, None)

            loss = loss_fn(prediction, target) 
            val_loss += loss.item()
    print(f'Epoch {epoch} -- Train Loss: {train_loss / len(train_loader):.4f} -- Val Loss: {val_loss / len(val_loader):.4f}')

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1 -- Train Loss: 2.0436 -- Val Loss: 2.0880
Epoch 2 -- Train Loss: 2.0373 -- Val Loss: 2.0287
Epoch 3 -- Train Loss: 2.0409 -- Val Loss: 2.0511
Epoch 4 -- Train Loss: 2.0313 -- Val Loss: 1.9991
Epoch 5 -- Train Loss: 2.0358 -- Val Loss: 2.0412
Epoch 6 -- Train Loss: 2.0342 -- Val Loss: 2.0823
Epoch 7 -- Train Loss: 2.0382 -- Val Loss: 2.0728
Epoch 8 -- Train Loss: 2.0333 -- Val Loss: 2.0106
Epoch 9 -- Train Loss: 2.0365 -- Val Loss: 2.0457
Epoch 10 -- Train Loss: 2.0330 -- Val Loss: 2.0265
Epoch 11 -- Train Loss: 2.0350 -- Val Loss: 2.0276
Epoch 12 -- Train Loss: 2.0320 -- Val Loss: 2.0140
Epoch 13 -- Train Loss: 2.0287 -- Val Loss: 2.0551
Epoch 14 -- Train Loss: 2.0217 -- Val Loss: 2.0372
Epoch 15 -- Train Loss: 2.0318 -- Val Loss: 2.0029
Epoch 16 -- Train Loss: 2.0290 -- Val Loss: 2.0256
Epoch 17 -- Train Loss: 2.0340 -- Val Loss: 2.0390
Epoch 18 -- Train Loss: 2.0254 -- Val Loss: 2.0685
Epoch 19 -- Train Loss: 2.0273 -- Val Loss: 2.0392
Epoch 20 -- Train Loss: 2.0228 -- Val Lo

KeyboardInterrupt: 