In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import copy 
import math 
from utils import get_smi_list, replace_atom, get_dic, encode_smi, pad_smi, clones, parallel_f, pad, normalize, get_atom_pos, MyDataset
from model import SourceAttention, PositionalEncoding, Encoder, TargetAttention

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

In [3]:
smi_list = get_smi_list('data/ADAGRASIB_SMILES.txt')

coor_list = parallel_f(get_atom_pos, smi_list)
longest_coor = len(max(coor_list, key = len))
coor_list = [pad(normalize(c), longest_coor) for c in coor_list]

smi_list = [replace_atom(smi) for smi in smi_list]
smi_dic = get_dic(smi_list)
smint_list = [encode_smi(smi, smi_dic) for smi in smi_list]
longest_smint = len(max(smint_list, key = len))
smint_list = [pad_smi(smint, longest_smint, smi_dic) for smint in smint_list]

[13:33:53] UFFTYPER: Unrecognized atom type: Ba (0)


In [4]:
BATCH_SIZE = 64
dataset = MyDataset(smint_list, coor_list)
train_set, val_set, test_set = random_split(dataset, [0.9, 0.05, 0.05])

train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = True)

In [9]:
class TargetAttention(nn.Module) :
    def __init__(self, dim_model, num_head, longest_coor) : 
        super(TargetAttention, self).__init__() 
        
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head  
    
        self.Q = nn.Linear(dim_model, dim_model)   
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)
    
    def forward(self, Q, K, V, mask = None) :
        B = Q.size(0) 

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)

        
        K_T = K.transpose(2,3).contiguous()

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2)
        print(f'attn_score: {attn_score.shape}')
        if mask is not None :
            attn_score = attn_score.masked_fill(mask == 0, -1e9)
        
        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)
        
        attn = self.out(attn)

        return attn, attn_distribution

In [22]:
class Decoder(nn.Module) :
    def __init__(self, dim_model, num_head, num_layer, dropout) : 
        super(Decoder, self).__init__()
        self.layers = clones(DecoderLayer(dim_model, num_head, dropout), num_layer)
        self.norm = nn.LayerNorm(dim_model)
        self.out = nn.Linear(dim_model, 3)
    def forward(self, x, target = None) :
        for layer in self.layers : 
            x = layer(x, target) 
        out = self.out(x)
        return out
    
class DecoderLayer(nn.Module) :
    def __init__(self, dim_model, num_head, dropout) :
        super (DecoderLayer, self).__init__()
        self.dim_model = dim_model
        self.norm1 = nn.LayerNorm(dim_model) 
        self.self_attn = TargetAttention(dim_model, num_head, longest_coor)
        self.drop1 = nn.Dropout(dropout) 

        self.norm2 = nn.LayerNorm(dim_model)
        self.cross_attn = SourceAttention(dim_model, num_head)
        self.drop2 = nn.Dropout(dropout)
        
        self.norm3 = nn.LayerNorm(dim_model)
        self.feed_foward = nn.Sequential(
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(dim_model, dim_model)
        )
        self.drop3 = nn.Dropout(dropout) 


    def forward(self, memory, target) : 
        x = torch.zeros(memory.size(0), 1, self.dim_model).to(device)
        for i in range(1, longest_coor + 1) :
            mask = subsequent_mask(i)
            mask = mask.unsqueeze(1).to(device)

            y = self.norm1(x) 

            attn, _ = self.self_attn(y, y, y, mask)
            y = y + self.drop1(attn) 

            y = self.norm2(y) 
            attn, _ = self.cross_attn(y, memory, memory) 
            y = y + self.drop2(attn) 

            y = self.norm3(y)
            y = y + self.drop3(self.feed_foward(y))
            x = torch.cat((x, y[:, -1, :].unsqueeze(1)), dim = 1)
        
        return y 
        
            

In [24]:
encoder = Encoder(128, 2, 1, 0.1, len(smi_dic)).to(device)
decoder = DecoderLayer(128, 2, 0.1).to(device)

for input, target in train_loader :
    input = input.to(device)
    memory = encoder(input) 
    print(f'memory: {memory.shape}')
    out = decoder(memory, None)
    print(f'out: {out.shape}')
    break

memory: torch.Size([64, 36, 128])
attn_score: torch.Size([64, 2, 1, 1])
attn_score: torch.Size([64, 2, 2, 2])
attn_score: torch.Size([64, 2, 3, 3])
attn_score: torch.Size([64, 2, 4, 4])
attn_score: torch.Size([64, 2, 5, 5])
attn_score: torch.Size([64, 2, 6, 6])
attn_score: torch.Size([64, 2, 7, 7])
attn_score: torch.Size([64, 2, 8, 8])
attn_score: torch.Size([64, 2, 9, 9])
attn_score: torch.Size([64, 2, 10, 10])
attn_score: torch.Size([64, 2, 11, 11])
attn_score: torch.Size([64, 2, 12, 12])
attn_score: torch.Size([64, 2, 13, 13])
attn_score: torch.Size([64, 2, 14, 14])
attn_score: torch.Size([64, 2, 15, 15])
attn_score: torch.Size([64, 2, 16, 16])
attn_score: torch.Size([64, 2, 17, 17])
attn_score: torch.Size([64, 2, 18, 18])
attn_score: torch.Size([64, 2, 19, 19])
attn_score: torch.Size([64, 2, 20, 20])
attn_score: torch.Size([64, 2, 21, 21])
attn_score: torch.Size([64, 2, 22, 22])
out: torch.Size([64, 22, 128])


In [13]:
class Decoder(nn.Module) :
    def __init__(self, layer, N) :
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)
    def forward(self, x, memory, src_mask, tgt_mask) : 
        for layer in self.layers :
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)
    

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [20]:
c = copy.deepcopy
src_vocab = len(smi_dic)

class Model(nn.Module) :
    def __init__(self,
                 dim_model,
                 dim_ff,
                 num_head,
                 dropout,
                 N,
                 encoder,
                 decoder,
                 src_embed,
                 tgt_embed) :
        super(Model, self).__init__()
        self.attn = MultiHeadedAttention(num_head, dim_model)
        self.ff = PositionwiseFeedForward(dim_model, dim_ff, dropout) 
        self.position = PositionalEncoding(dim_model, dropout)
        self.encoder = Encoder(
            EncoderLayer(dim_model, c(self.attn), c(self.ff), dropout), N)
        self.decoder = Decoder(
            DecoderLayer(dim_model, c(self.attn), c(self.attn), c(self.ff), dropout), N)
        
        self.src_embed = nn.Sequential(Embeddings(dim_model, src_vocab), c(self.position))
        # self.generator = Generator(dim_model) 
    def forward(self, x) :
        src_mask = (x != -2).unsqueeze(-2) 
        print(f'src_mask : {src_mask.shape}')
        x = self.src_embed(x)
        print(f'x : {x.shape}')

                