In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import copy 
import math 
from utils import get_smi_list, replace_atom, get_dic, encode_smi, pad_smi, clones, parallel_f, pad, normalize, get_atom_pos, MyDataset
from model import Encoder, Decoder, device


In [2]:
smi_list = get_smi_list('data/ADAGRASIB_SMILES.txt')

coor_list = parallel_f(get_atom_pos, smi_list)
longest_coor = len(max(coor_list, key = len))
coor_list = [pad(normalize(c), longest_coor) for c in coor_list]

smi_list = [replace_atom(smi) for smi in smi_list]
smi_dic = get_dic(smi_list)
smint_list = [encode_smi(smi, smi_dic) for smi in smi_list]
longest_smint = len(max(smint_list, key = len))
smint_list = [pad_smi(smint, longest_smint, smi_dic) for smint in smint_list]

[14:03:06] UFFTYPER: Unrecognized atom type: Ba (0)


In [3]:
BATCH_SIZE = 64
dataset = MyDataset(smint_list, coor_list)
train_set, val_set, test_set = random_split(dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True)

In [4]:
DIM_MODEL = 128 
NUM_HEAD = 4
NUM_LAYER = 1
DROPOUT = 0.1

encoder = Encoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, len(smi_dic)).to(device)
decoder = Decoder(DIM_MODEL, NUM_HEAD, NUM_LAYER, DROPOUT, longest_coor).to(device)

for input, target in train_loader :
    input = input.to(device)
    memory = encoder(input) 
    print(f'memory: {memory.shape}')
    out = decoder(memory, None)
    print(f'out: {out.shape}')
    break

memory: torch.Size([64, 36, 128])
out: torch.Size([64, 22, 3])
