In [2]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import time
import random
from utils.data_preprocess import train_loader, test_loader, smi_list, smint_list, smi_dic, coor_list, np_coor_list, longest_coor, longest_smi, device
from utils.helper import visualize, timeSince

Data successfully extracted
----------------------------------------
Size: 4255
Longest SMILES: 36
Longest Coordinate: 22
----------------------------------------

Below is one example for each variable: 


smi_dic (SMILES Dictionary):
	 {'x': 0, 'E': 1, 'C': 2, 'n': 3, '1': 4, 'c': 5, '(': 6, ')': 7, 'B': 8, 'O': 9, 'F': 10, 'S': 11, '=': 12, 'o': 13, 'N': 14, '#': 15, '/': 16, 'l': 17, '\\': 18, '2': 19, '[': 20, 'H': 21, ']': 22, '+': 23, '-': 24, 's': 25, '.': 26, 'K': 27, '3': 28, 'r': 29, 'P': 30, 'a': 31}


smi_list (SMILES List):
	 Cn1ncc(c1)B1OC(C(O1)(C)C)(C)CE


smint_list (SMILES Integer List):
	 [2, 3, 4, 3, 5, 5, 6, 5, 4, 7, 8, 4, 9, 2, 6, 2, 6, 9, 4, 7, 6, 2, 7, 2, 7, 6, 2, 7, 2, 1, 0, 0, 0, 0, 0, 0]


coor_list (Coordinate List):
	 [[4.8285, -1.004, 0.2024], [3.5776, -0.2572, 0.0479], [3.4435, 1.1346, 0.1047], [2.1893, 1.445, -0.0747], [1.4645, 0.2475, -0.2554], [2.3676, -0.7919, -0.1777], [-0.0805, 0.1225, -0.5047], [-1.0404, 1.1849, -0.5963], [-2.2048, 0.6858, 0.0949],

In [3]:
class Attention(nn.Module):
    def __init__(self, dim_model):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(dim_model, dim_model)
        self.Ua = nn.Linear(dim_model, dim_model)
        self.Va = nn.Linear(dim_model, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights # context : attention, weights : distribution

In [12]:
class Encoder(nn.Module) :
  def __init__(self, input_size, dim_model, dropout_p = 0.1) :
    super(Encoder, self).__init__()

    self.self_attn = Attention(dim_model)

    self.embedding = nn.Embedding(input_size, dim_model) # input_size : num words in input language dictionary - dim_model : dimension that we want to map to

    self.gru = nn.GRU(dim_model, dim_model, batch_first = True)

    self.dropout = nn.Dropout(dropout_p)

    self.out = nn.Linear(dim_model, dim_model)

  def forward(self, input) :
    embedded = self.dropout(self.embedding(input))

    attn, self_attn = self.self_attn(embedded, embedded)
    # print(f'self_attn: {self_attn.shape}')
    all_states, last_state = self.gru(embedded + attn)

    return all_states, last_state, self_attn

In [13]:
class Decoder(nn.Module):
  def __init__(self, dim_model, longest_coordinate, output_size = 3) :
    super(Decoder, self).__init__()

    self.attention = Attention(dim_model)
    self.gru = nn.GRU(3 + dim_model, dim_model, batch_first = True)
    self.out = nn.Linear(dim_model, output_size)
    self.longest_coor = longest_coordinate
    self.dropout = nn.Dropout(p=0.1)

  def forward(self, e_all, e_last, target_tensor = None) :
    batch_size = e_all.size(0)

    d_input = torch.zeros(batch_size, 1, 3).to(device)

    d_hidden = e_last

    d_outputs = []
    attentions = []

    for i in range(self.longest_coor) :
      d_output, d_hidden, a_weights = self.forward_step(d_input, d_hidden, e_all)

      d_outputs.append(d_output)
      attentions.append(a_weights)

      if target_tensor is not None :
        d_input = target_tensor[:, i, :].unsqueeze(1)
      else :
        d_input = d_output


    d_outputs = torch.cat(d_outputs, dim = 1)
    attentions = torch.cat(attentions, dim = 1)

    return d_outputs, attentions

  def forward_step(self, d_input, d_hidden, e_all) : # d_hidden : last encoder state
    query = d_hidden.permute(1,0,2)

    d_input = self.dropout(d_input) 

    attn, cross_attn = self.attention(query, e_all)
    
    input_gru = torch.cat((d_input, attn), dim = 2)

    output, d_hidden = self.gru(input_gru, d_hidden)

    output = self.out(output) 
    return output, d_hidden, cross_attn

In [14]:
import random
r = random.randint(1, len(smi_list))

def train_epoch(train_loader,test_loader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion, tf):

    total_loss = 0
    total_test_loss = 0

    for input, target in train_loader:
        input, target = input.to(device), target.to(device)

        encoder_optimizer.zero_grad(), decoder_optimizer.zero_grad()
        
        e_all, e_last, self_attn = encoder(input)

        # Teacher Forcing
        if tf :
          prediction, cross_attn = decoder(e_all, e_last, target)
        else :
          prediction, cross_attn = decoder(e_all, e_last)
        print(f'self_attn: {self_attn.shape}')
        print(f'cross_attn: {cross_attn.shape}')

        loss = criterion(prediction, target)
        loss.backward()

        encoder_optimizer.step(), decoder_optimizer.step()
        
        total_loss += loss.item()


    encoder.eval(), decoder.eval()
    


    with torch.no_grad() :
      for input, target in test_loader :
        input, target = input.to(device), target.to(device)
        
        e_all, e_last, self_attn = encoder(input)
        prediction, cross_attn = decoder(e_all, e_last)

        test_loss = criterion(prediction, target)
        total_test_loss += test_loss.item()

    return total_loss / len(train_loader), total_test_loss / len(test_loader)


def train(train_loader, test_loader, encoder, decoder, n_epochs, learning_rate=0.001,
               print_every=1, visual_path= "", tf_rate = 1):
    start = time.time()

    train_loss_total = 0  
    test_loss_total = 0

    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.L1Loss()

    tf = True

    for epoch in range(1, n_epochs + 1):
      if epoch > (tf_rate * n_epochs) :
        tf = False
      encoder.train()
      decoder.train()

      train_loss, test_loss = train_epoch(train_loader, test_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, tf)
      train_loss_total += train_loss
      test_loss_total += test_loss

      # for i in range(5) :
      #    visualize(encoder, decoder, smi_list[r], smi_dic, longest_smi, mode="cross", path=f"{visual_path}", name=f"R{i}-CROSS-E{epoch}")
      #    visualize(encoder, decoder, smi_list[r], smi_dic, longest_smi, mode="self", path=f"{visual_path}", name=f"R{i}-SELF-E{epoch}")

      if epoch % print_every == 0:
          train_loss_avg = train_loss_total / print_every
          test_loss_avg = test_loss_total / print_every
          train_loss_total = 0
          test_loss_total = 0
          print('%s (%d %d%%) /// Train loss: %.4f - Test loss: %.4f' % (timeSince(start, epoch / n_epochs),
                                      epoch, epoch / n_epochs * 100, train_loss_avg, test_loss_avg))


In [15]:
encoder = Encoder(input_size=len(smi_dic),
                  dim_model=256,
                  dropout_p=0.1)

decoder = Decoder(dim_model=256,
                  longest_coordinate=longest_coor)

In [16]:
train(train_loader, test_loader, encoder, decoder,
      n_epochs=100,
      learning_rate=0.001,
      tf_rate = 0.0,
      visual_path="attention image",
      )

self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 22, 36])
self_attn: torch.Size([16, 1, 36])
cross_attn: torch.Size([16, 2

KeyboardInterrupt: 