In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import pickle
import re
import model.base
from model.base import Transformer
import utils 
from utils import MyDataset
import rdkit
from rdkit.Chem import rdDistGeom
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def subsequent_mask( size):
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )

    return subsequent_mask == 0


def get_mask(target, pad_value = -1) :
    mask = (target != torch.tensor(pad_value)).unsqueeze(-2)
    return mask & subsequent_mask(target.size(-1)).type_as(mask.data)

In [3]:
dataset = MyDataset('data/ADAGRASIB_SMILES.txt', 30)
vocab = dataset.vocab

[23:53:25] UFFTYPER: Unrecognized atom type: Ba (0)


In [4]:
train_loader = DataLoader(dataset, batch_size=64)

In [5]:
model = Transformer(256, 512, 8, 2, 0.5, vocab).to(device)
loss_fn = nn.L1Loss()


In [7]:
for src, x, y, z, tgt in train_loader : 
    src, x, y, z, tgt = src.to(device), x.to(device), y.to(device), z.to(device), tgt.to(device)
    src_mask = (src != vocab['<PAD>']).unsqueeze(-2) 
    x_mask = get_mask(x.squeeze(-1)[:, :-1])
    y_mask = get_mask(y.squeeze(-1)[:, :-1])
    z_mask = get_mask(z.squeeze(-1)[:, :-1])

    
    out = model(src, x[:, :-1], y[:, :-1], z[:, :-1], src_mask, x_mask, y_mask, z_mask)
    loss = loss_fn(out, tgt[:, 1:, :])
    print(loss)

tensor(1.9658, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.9990, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.9615, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.0108, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.9976, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.7509, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.8995, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.4141, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.1243, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.9428, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.1350, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.2438, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.3611, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.1680, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.1561, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(2.1116, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.7408, device='cuda:0', grad_fn=<MeanBackward0>)
