In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import time
import random
from utils.data_preprocess import train_loader, test_loader, smi_list, smint_list, smi_dic, coor_list, np_coor_list, longest_coor, longest_smi, device
from utils.helper import visualize, timeSince

Data successfully extracted
----------------------------------------
Size: 4255
Longest SMILES: 36
Longest Coordinate: 22
----------------------------------------

Below is one example for each variable: 


smi_dic (SMILES Dictionary):
	 {'x': 0, 'E': 1, 'C': 2, 'n': 3, '1': 4, 'c': 5, '(': 6, ')': 7, 'B': 8, 'O': 9, 'F': 10, 'S': 11, '=': 12, 'o': 13, 'N': 14, '#': 15, '/': 16, 'l': 17, '\\': 18, '2': 19, '[': 20, 'H': 21, ']': 22, '+': 23, '-': 24, 's': 25, '.': 26, 'K': 27, '3': 28, 'r': 29, 'P': 30, 'a': 31}


smi_list (SMILES List):
	 Cn1ncc(c1)B1OC(C(O1)(C)C)(C)CE


smint_list (SMILES Integer List):
	 [2, 3, 4, 3, 5, 5, 6, 5, 4, 7, 8, 4, 9, 2, 6, 2, 6, 9, 4, 7, 6, 2, 7, 2, 7, 6, 2, 7, 2, 1, 0, 0, 0, 0, 0, 0]


coor_list (Coordinate List):
	 [[4.8285, -1.004, 0.2024], [3.5776, -0.2572, 0.0479], [3.4435, 1.1346, 0.1047], [2.1893, 1.445, -0.0747], [1.4645, 0.2475, -0.2554], [2.3676, -0.7919, -0.1777], [-0.0805, 0.1225, -0.5047], [-1.0404, 1.1849, -0.5963], [-2.2048, 0.6858, 0.0949],

In [53]:
class NNAtention(nn.Module):
    def __init__(self, dim_model):
        super(NNAtention, self).__init__()
        self.Wa = nn.Linear(dim_model, dim_model)
        self.Ua = nn.Linear(dim_model, dim_model)
        self.Va = nn.Linear(dim_model, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        print(f'scores: {scores.shape}')
        scores = scores.squeeze(2).unsqueeze(1)
        print(f'scores: {scores.shape}')

        weights = F.softmax(scores, dim=-1)
        print(f'weights: {weights.shape}')
        print(f'keys: {keys.shape}')
        context = torch.bmm(weights, keys)
        print(f'context: {context.shape}')


        return context, weights # context : attention, weights : distribution
    

class NN_Multihead_Attention(nn.Module) :
    def __init__(self, dim_model, num_head) :
        super(NN_Multihead_Attention, self).__init__()
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.W = nn.Linear(dim_model, dim_model)
        self.U = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model // num_head, 1)

    def forward(self, Q, K) :
        B, longest_smi, dim_model = Q.size()

        W = self.W(Q) 
        U = self.U(K)
        
        len_W, len_U = W.size(1), U.size(1)
        
        # print(f'W: {W.shape}')
        # print(f'U: {U.shape}')
        
        W = W.reshape(B, self.num_head, len_W, self.dim_head)
        U = U.reshape(B, self.num_head, len_U, self.dim_head)

        # print(f'Reshaped W: {W.shape}')
        # print(f'Reshaped U: {U.shape}')
        attn_score = self.V(W + U)
        
        # print(f'attn_score: {attn_score.shape}')
        attn_score = attn_score.squeeze(-1).unsqueeze(2)
        # print(f'attn_score: {attn_score.shape}')

        attn_distribution = torch.softmax(attn_score, dim = -1)
        # print(f'attn_distribution: {attn_distribution.shape}')
        
        # attn_distribution = attn_distribution.reshape(B, self.num_head, longest_smi)
        attn_distribution = attn_distribution.squeeze(2)
        # print(f'attn_distribution: {attn_distribution.shape}')
        attn = torch.bmm(attn_distribution, K)

        attn = torch.mean(attn, 1).unsqueeze(1)
        # print(f'attn: {attn.shape}')
        
        return attn, attn_distribution

In [57]:
class Attention(nn.Module) :
    def __init__(self, dim_model, num_head) :
        super(Attention, self).__init__()
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.Q = nn.Linear(dim_model, dim_model)
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)

    def forward(self, Q, K, V) :
        B = Q.size(0) # Shape Q, K, V: (B, longest_smi, dim_model)

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)
        
        K_T = K.transpose(2,3).contiguous()

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2)

        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)
        
        attn = self.out(attn)

        return attn, attn_distribution

class LSTM(nn.Module) :
    def __init__(self, dim_model, longest_coor, num_head = 1, output_size = 3) :
        super(LSTM, self).__init__()

        self.longest_coor = longest_coor

        self.cross_attn = Attention(dim_model, num_head)
        self.cross_attn_nn = NN_Multihead_Attention(dim_model, num_head)

        self.lstm = nn.GRU(3 + dim_model, dim_model, batch_first=True)

        self.out = nn.Linear(dim_model, output_size)

        self.dropout = nn.Dropout(0.1)

    def forward(self, e_all, e_last, target = None) :
        B = e_all.size(0)

        d_input = torch.zeros(B, 1, 3).to(device)

        d_hidden = e_last

        d_outputs, cross_attn = [], []

        for i in range(self.longest_coor) :
            d_output, d_hidden, step_attn = self.forward_step(d_input, d_hidden, e_all)

            d_outputs.append(d_output), cross_attn.append(step_attn)

            if target is not None :
                d_input = target[:, i, :].unsqueeze(1)
            else :
                d_input = d_output

        d_outputs = torch.cat(d_outputs, dim = 1)

        cross_attn = torch.cat(cross_attn, dim = 2)
        
        return d_outputs, d_hidden, cross_attn


    def forward_step(self, d_input, d_hidden, e_all) :
        Q = d_hidden.permute(1,0,2)

        d_input = self.dropout(d_input)
        print(f'Q: {Q.shape}')
        print(f'e_all: {e_all.shape}')
        # attn, attn_distribution = self.cross_attn(Q, e_all, e_all)
        attn, attn_distribution = self.cross_attn_nn(Q, e_all)

        input_lstm = torch.cat((attn, d_input), dim = 2)

        output, d_hidden = self.lstm(input_lstm, d_hidden) # Recheck about 2nd param

        output = self.out(output)

        return output, d_hidden, attn_distribution

In [58]:
class EncoderBlock(nn.Module) :
    def __init__(self, dim_model, num_head, fe, dropout) :
        super(EncoderBlock, self).__init__()
        self.self_attn = Attention(dim_model,num_head)
        self.self_attn_nn = NN_Multihead_Attention(dim_model, num_head)

        self.norm1 = nn.LayerNorm(dim_model)
        self.norm2 = nn.LayerNorm(dim_model)
        self.lstm = nn.LSTM(input_size=dim_model, hidden_size=dim_model, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(dim_model, fe * dim_model),
            nn.ReLU(),
            nn.Linear(fe * dim_model, dim_model)
        )
    def forward(self, Q, K, V) :

        # attn, self_attn = self.self_attn(Q, Q, Q)
        attn, self_attn = self.self_attn_nn(Q, Q)

        # print(f'attn: {attn.shape}')
        # print(f'Q: {Q.shape}')
        
        # input_lstm = torch.cat((Q, attn), dim = -1)
        input_lstm = Q + attn

        all_state, (last_state, _) = self.lstm(input_lstm)


        return all_state, last_state, self_attn

In [59]:
class Encoder(nn.Module) :
    def __init__(self, dim_model, num_block, num_head,
                 len_dic, fe = 1, dropout = 0.1) :

        super(Encoder, self).__init__()

        self.dim_model = dim_model
        self.embed = nn.Embedding(len_dic, dim_model)
        self.dropout = nn.Dropout(dropout)

        self.encoder_blocks = nn.ModuleList(
            EncoderBlock(dim_model, num_head, fe, dropout) for _ in range(num_block)
        )

    def forward(self, x) :
        out = self.dropout(self.embed(x))

        for block in self.encoder_blocks : 
            out, last_state, self_attn = block(out, out, out) 
        return out, last_state, self_attn

In [60]:
class DecoderBlock(nn.Module) :
    def __init__(self, dim_model, num_head, longest_coor, fe, dropout) :
        super(DecoderBlock, self).__init__()

        self.lstm = LSTM(dim_model, longest_coor, num_head)

        self.norm1 = nn.LayerNorm(dim_model)
        self.norm2 = nn.LayerNorm(3)

        self.feed_forward = nn.Sequential(
            nn.Linear(3, fe * dim_model),
            nn.ReLU(),
            nn.Linear(fe * dim_model, 3)
        )

        self.dropout = nn.Dropout(dropout)


    def forward(self, e_all, e_last, target = None) :
        output, _, cross_attn = self.lstm(e_all, e_last, target)
        
        # x = self.dropout(output)

        # forward = self.feed_forward(x)

        # out = self.dropout(self.norm2(forward + x))

        return output, cross_attn

In [61]:
class Decoder(nn.Module) :
    def __init__(self, dim_model,num_block, num_head, longest_coor, fe = 1, dropout = 0.1) :
        super(Decoder, self).__init__()

        self.decoder_blocks = nn.ModuleList(
            [DecoderBlock(dim_model, num_head,longest_coor, fe, dropout) for _ in range(num_block)]
        )

        self.dropout = nn.Dropout(dropout)

        
    def forward(self, e_all, e_last, target = None) :
        for block in self.decoder_blocks :
            target, cross_attn = block(e_all, e_last, target)
        
        return target, cross_attn

In [62]:
import random
r = random.randint(1, len(smi_list))

def train_epoch(train_loader,test_loader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion, tf):

    total_loss = 0
    total_test_loss = 0

    for input, target in train_loader:
        input, target = input.to(device), target.to(device)

        encoder_optimizer.zero_grad(), decoder_optimizer.zero_grad()
        
        e_all, e_last, self_attn = encoder(input)

        # Teacher Forcing
        if tf :
          prediction, cross_attn = decoder(e_all, e_last, target)
        else :
          prediction, cross_attn = decoder(e_all, e_last)


        loss = criterion(prediction, target)
        loss.backward()

        encoder_optimizer.step(), decoder_optimizer.step()
        
        total_loss += loss.item()


    encoder.eval(), decoder.eval()
    


    with torch.no_grad() :
      for input, target in test_loader :
        input, target = input.to(device), target.to(device)
        
        e_all, e_last, self_attn = encoder(input)
        prediction, cross_attn = decoder(e_all, e_last)

        test_loss = criterion(prediction, target)
        total_test_loss += test_loss.item()

    return total_loss / len(train_loader), total_test_loss / len(test_loader)


def train(train_loader, test_loader, encoder, decoder, n_epochs, learning_rate=0.001,
               print_every=1, visual_path= "", tf_rate = 1):
    start = time.time()

    train_loss_total = 0  
    test_loss_total = 0

    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.L1Loss()

    tf = True

    for epoch in range(1, n_epochs + 1):
      if epoch > (tf_rate * n_epochs) :
        tf = False
      encoder.train()
      decoder.train()

      train_loss, test_loss = train_epoch(train_loader, test_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, tf)
      train_loss_total += train_loss
      test_loss_total += test_loss

      for i in range(5) :
         visualize(encoder, decoder, smi_list[r], smi_dic, longest_smi, mode="cross", path=f"{visual_path}", name=f"R{i}-CROSS-E{epoch}")
         visualize(encoder, decoder, smi_list[r], smi_dic, longest_smi, mode="self", path=f"{visual_path}", name=f"R{i}-SELF-E{epoch}")

      if epoch % print_every == 0:
          train_loss_avg = train_loss_total / print_every
          test_loss_avg = test_loss_total / print_every
          train_loss_total = 0
          test_loss_total = 0
          print('%s (%d %d%%) /// Train loss: %.4f - Test loss: %.4f' % (timeSince(start, epoch / n_epochs),
                                      epoch, epoch / n_epochs * 100, train_loss_avg, test_loss_avg))


In [63]:
DIM_MODEL = 128
NUM_BLOCK = 4
NUM_HEAD = 4
DROPOUT = 0.5
FE = 2


encoder = Encoder(dim_model=DIM_MODEL,
                  num_block=NUM_BLOCK,
                  num_head=NUM_HEAD,
                  dropout=DROPOUT,
                  fe = FE,
                  len_dic=len(smi_dic)).to(device)

decoder = Decoder(dim_model=DIM_MODEL,
                  num_block=NUM_BLOCK,
                  num_head=NUM_HEAD,
                  dropout=DROPOUT,
                  fe=FE,
                  longest_coor=longest_coor,
                  ).to(device)

In [64]:
train(train_loader, test_loader, encoder, decoder,
      n_epochs=50,
      learning_rate=0.001,
      tf_rate = 0.0,
      visual_path="attention image",
      )

  return torch.tensor(self.x[idx], dtype = torch.long, device=device), torch.tensor(self.y[idx], device = device)


W: torch.Size([16, 36, 128])
U: torch.Size([16, 36, 128])
Reshaped W: torch.Size([16, 4, 36, 32])
Reshaped U: torch.Size([16, 4, 36, 32])
attn_score: torch.Size([16, 4, 36, 1])
attn_score: torch.Size([16, 4, 1, 36])
attn_distribution: torch.Size([16, 4, 1, 36])
attn_distribution: torch.Size([16, 4, 36])
attn: torch.Size([16, 1, 128])
W: torch.Size([16, 36, 128])
U: torch.Size([16, 36, 128])
Reshaped W: torch.Size([16, 4, 36, 32])
Reshaped U: torch.Size([16, 4, 36, 32])
attn_score: torch.Size([16, 4, 36, 1])
attn_score: torch.Size([16, 4, 1, 36])
attn_distribution: torch.Size([16, 4, 1, 36])
attn_distribution: torch.Size([16, 4, 36])
attn: torch.Size([16, 1, 128])
W: torch.Size([16, 36, 128])
U: torch.Size([16, 36, 128])
Reshaped W: torch.Size([16, 4, 36, 32])
Reshaped U: torch.Size([16, 4, 36, 32])
attn_score: torch.Size([16, 4, 36, 1])
attn_score: torch.Size([16, 4, 1, 36])
attn_distribution: torch.Size([16, 4, 1, 36])
attn_distribution: torch.Size([16, 4, 36])
attn: torch.Size([16, 1

KeyboardInterrupt: 