In [3]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from train import train, train_epoch, train_loader, test_loader, smi_dic, longest_coor,longest_smi, device

In [2]:
class SelfAttention(nn.Module) :
    def __init__(self, dim_model, num_head) :
        super(SelfAttention, self).__init__()
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.Q = nn.Linear(dim_model, dim_model)
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)

    def forward(self, Q, K, V) :
        B = Q.size(0) # Shape Q, K, V: (B, longest_smi, dim_model)

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)

        K_T = K.transpose(2,3)

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2) 

        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V 

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)

        return attn, attn_distribution

In [None]:
class EncoderBlock(nn.Module) :
    def __init__(self, dim_model, num_head, fe, dropout) :
        super(EncoderBlock, self).__init__()
        self.self_attn = SelfAttention(dim_model,num_head)
        self.norm1 = nn.LayerNorm(dim_model) 
        self.norm2 = nn.LayerNorm(dim_model)
        self.lstm = nn.LSTM(input_size=longest_smi, hidden_size=dim_model)

        self.feed_forward = nn.Sequential(
            nn.Linear(dim_model, fe * dim_model),
            nn.ReLU(),
            nn.Linear(fe * dim_model, dim_model)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V) :
        attn, attn_distribution = self.self_attn(Q, K, V) 

        all_state, _ = self.lstm(Q)

        x = self.dropout(self.norm1(attn + all_state))

        forward = self.feed_forward(x)

        out = self.dropout(self.norm2(forward + x))

        return out, attn_distribution




In [None]:
class Encoder(nn.Module) :
    def __init__(self, dim_model, num_block, num_head,
                 len_dic, fe = 1, dropout = 0.1) :

        super(Encoder, self).__init__()

        self.dim_model = dim_model
        self.embed = nn.Embedding(len_dic, dim_model)
        self.dropout = nn.Dropout(dropout)

        self.encoder_blocks = nn.ModuleList(
            EncoderBlock(dim_model, num_head, fe, dropout) for _ in range num_block
        )

    def forward(self, x) :
        out = self.dropout(self.embed(x))

        for block in self.encoder_blocks : 
            out, self_attn = block(out, out, out) 
        return out, self_attn

