In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import time
from train import  train_loader, test_loader, smi_dic, longest_coor,longest_smi, device, timeSince

Train data extracted
Size: 4255
Longest SMILES: 36
Longest Coordinate: 22
----------------------------------------
Sample x: [2, 3, 4, 3, 5, 5, 6, 5, 4, 7, 8, 4, 9, 2, 6, 2, 6, 9, 4, 7, 6, 2, 7, 2, 7, 6, 2, 7, 2, 1, 0, 0, 0, 0, 0, 0]
Sample y: [[4.8285, -1.004, 0.2024], [3.5776, -0.2572, 0.0479], [3.4435, 1.1346, 0.1047], [2.1893, 1.445, -0.0747], [1.4645, 0.2475, -0.2554], [2.3676, -0.7919, -0.1777], [-0.0805, 0.1225, -0.5047], [-1.0404, 1.1849, -0.5963], [-2.2048, 0.6858, 0.0949], [-2.0701, -0.8493, 0.0305], [-0.8409, -1.0789, -0.697], [-1.9777, -1.4359, 1.4405], [-3.2539, -1.4571, -0.7246], [-2.2055, 1.1603, 1.5494], [-3.4816, 1.1392, -0.6157]]


In [25]:
class Attention(nn.Module) :
    def __init__(self, dim_model, num_head) :
        super(Attention, self).__init__()
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.Q = nn.Linear(dim_model, dim_model)
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)

    def forward(self, Q, K, V, mask) :
        B = Q.size(0) # Shape Q, K, V: (B, longest_smi, dim_model)

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)

        K_T = K.transpose(2,3)

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2) 

        # mask = mask.to(device)
        if mask is not None :
            attn_score = attn_score.masked_fill(mask == 0, float("-1e20"))
            
        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V 

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)

        return attn, attn_distribution

In [30]:

class EncoderBlock(nn.Module) :
  def __init__(self, dim_model, num_head, fe, dropout) :
    super(EncoderBlock, self).__init__()
    self.self_attn = Attention(dim_model, num_head)
    self.norm1 = nn.LayerNorm(dim_model)
    self.norm2 = nn.LayerNorm(dim_model)
    self.dropout = nn.Dropout(dropout)
    self.ff = nn.Sequential(
      nn.Linear(dim_model, fe * dim_model),
      nn.ReLU(),
      nn.Linear(fe * dim_model, dim_model)
    )

  def forward(self, Q, K, V, mask) :
    context, self_attn = self.self_attn(Q, K, V, mask)

    x = self.dropout(self.norm1(context + Q))

    forward = self.ff(x)

    out = self.dropout(self.norm2(forward + x))

    return out, self_attn

In [31]:
class Encoder(nn.Module) :
    def __init__(self, dim_model, num_block, num_head, len_dic, longest_smi, fe, dropout) :
        super(Encoder, self).__init__()

        self.word_embedding = nn.Embedding(len_dic, dim_model)
        self.pos_embedding = nn.Embedding(longest_smi, dim_model)
        self.dropout = nn.Dropout(dropout)
        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock(dim_model, num_head, fe, dropout) for _ in range(num_block)]
        )

    def forward(self, x, mask) :
        B, longest_smi = x.size()

        pos = torch.arange(0, longest_smi).expand(B, longest_smi).to(device)

        word_embed = self.word_embedding(x)
        pos_embed = self.pos_embedding(pos)

        out = self.dropout(word_embed + pos_embed)

        for block in self.encoder_blocks :
            out, self_attn = block(out, out, out, mask)

        return out, self_attn

In [32]:
encoder = Encoder(512, 2, 2, 34, 36, 2, 0.1)

In [33]:
test_input = torch.randint(1, len(smi_dic), (16, 36))
mask = None

out = encoder(test_input, mask)

In [5]:
class DecoderBlock(nn.Module) :
    def __init__(self, dim_model, num_head, fe, dropout) :
        super(DecoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim_model)
        self.norm2 = nn.LayerNorm(dim_model)
        self.norm3 = nn.LayerNorm(dim_model)
        self.cross_attn = Attention(dim_model, num_head)
        self.dropout = nn.Dropout(dropout)
        self.ff = nn.Sequential(
            nn.Linear(dim_model, fe * dim_model),
            nn.ReLU(),
            nn.Linear(fe * dim_model, dim_model)
        )

    def forward(self, target, K, V, input_mask, target_mask) :
        Q = self.dropout(self.norm1(target))

        context, cross_attn = self.cross_attn(Q, K, V, target_mask)

        x = self.dropout(self.norm2(context + Q))

        forward = self.ff(x)

        out = self.dropout(self.norm3(forward + x))

        return out, cross_attn

In [None]:
class Decoder(nn.Module) :
    def __init__(self, dim_model, num_block, num_head, longest_coor, fe, dropout) :
        super(Decoder, self).__init__()

        