In [1]:
import rdkit
import pandas as pd
import numpy as np 
import time
import math
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader 

In [2]:
DIM_MODEL = 128
NUM_BLOCK = 1
NUM_HEAD = 4
DROPOUT = 0.5
FE = 1

BATCH_SIZE = 16
NUM_EPOCHS = 50
TEACHER_FORCING_RATE = 0.0
LEARNING_RATE = 0.001
PATIENCE_THRESHOLD = 4
ATTENTION_IMAGE_OUTPUT_PATH = 'image'

SMILES_PATH = '../data/ADAGRASIB_UNIQUE_SMILES.txt'
COORDINATE_PATH = '../data/ADAGRASIB_COOR.sdf'

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_all_atom(smi_path) :
    smi_dic = {}
    i = 0

    with open(smi_path, 'r') as file :
        for smi in file :
            smi = rdkit.Chem.MolFromSmiles(smi)
            for atom in smi.GetAtoms() :
                atom = atom.GetSymbol()
                if atom not in smi_dic : 
                    smi_dic[atom] = i 
                    i += 1
    return smi_dic

def replace_atom(input, mode = 'train') :
    if mode == 'train' :
        smi_list = [smi.replace('Cl', 'X')
                    .replace('Br', 'Y')
                    .replace('Na', 'Z')
                    .replace('Ba', 'T') for smi in input]
        return smi_list
        
    if mode == 'eval' :
        smi = input.replace('X', 'Cl').replace('Y', 'Br').replace('Z', 'Na').replace('T', 'Ba').replace('x','')
        return smi

def get_longest(input_list) :
    longest = 0
    for i in input_list :
        if len(i) > longest :
            longest = len(i)
    return longest

def get_coor(coor_path) :
    coor_list = []
    supplier = rdkit.Chem.SDMolSupplier(coor_path)
    for mol in supplier:
        coor = []
        if mol is not None:
            conformer = mol.GetConformer()
            for atom in mol.GetAtoms():
                atom_idx = atom.GetIdx()
                x, y, z = conformer.GetAtomPosition(atom_idx)
                coor_atom = list((x,y,z))
                coor.append(coor_atom)
        coor_list.append(coor)

    # Replace invalid idx
    for i, coor in enumerate(coor_list):
        
        if len(coor) == 0 :
            if i == 0 :
                coor_list = coor_list[1:]
            coor_list[i] = coor_list[i-1]
    return coor_list

def get_smi(smi_path) :
    smi_list = []
    i = 0
    with open(smi_path, 'r') as file :
        for smi in file :
            if rdkit.Chem.MolFromSmiles(smi) is None :
                if len(smi_list) == 0 :
                    continue 
                smi_list.append(smi_list[i-1])
                i += 1 
                continue 
            smi_list.append(smi)
            i += 1
    
    smi_list = [smi[:-1] for smi in smi_list]
    smi_list = [smi + 'E' for smi in smi_list]
    smi_list = replace_atom(smi_list)
    return smi_list

def get_dic(smi_list) :
    smi_dic = {'x': 0,
               'E': 1}
    i = len(smi_dic)

    for smi in smi_list : 
        for atom in smi :
            if atom not in smi_dic : 
                smi_dic[atom] = i
                i += 1 
    return smi_dic 

def count_atoms(smi):
    mol = rdkit.Chem.MolFromSmiles(smi)
    if mol is not None:
        num_atoms = mol.GetNumAtoms()
        return num_atoms
    else:
        print("Error: Unable to parse SMILES string.")
        return None
    
def smi2int(smi, smi_dic, longest_smi, mode = 'data') :
    if mode == 'eval' :
        # smi += 'E'
        smi = smi + 'E' if smi[-1] != 'E' else smi
        smi = list(smi)
        smint = [smi_dic[atom] for atom in smi]
        smint = smint + [0] * (longest_smi - len(smint))
        smint = torch.tensor(smint, dtype=torch.long, device = device)
        smint = smint.unsqueeze(0)
        return smint
    smi = list(smi)
    smint = [smi_dic[atom] for atom in smi]
    smint = smint + [0] * (longest_smi - len(smint))
    return smint 

def int2smi(smint, inv_smi_dic) :
    smint = smint.cpu().numpy()
    
    smi = [inv_smi_dic[i] for i in smint] 
    smi = ''.join(smi)
    smi = replace_atom(smi, mode = 'eval')
    return smi


def normalize_coor(coor_list) :
    n_coor_list = []

    for mol_coor in coor_list :
        n_mol_coor = []

        x_origin, y_origin, z_origin = mol_coor[0]

        for atom_coor in mol_coor :
            n_atom_coor = [round(atom_coor[0] - x_origin, 2), 
                        round(atom_coor[1] - y_origin, 2), 
                        round(atom_coor[2] - z_origin, 2)]
            n_mol_coor.append(n_atom_coor)
        n_coor_list.append(n_mol_coor)
    return n_coor_list

def pad_coor(coor_list, longest_coor) :
    p_coor_list = []

    for i in coor_list :
        if len(i) < longest_coor :
            zeros = [[0,0,0]] * (longest_coor - len(i))
            zeros = torch.tensor(zeros)
            i = torch.tensor(i)
            i = torch.cat((i, zeros), dim = 0)
            p_coor_list.append(i)
        else :
            p_coor_list.append(i)
    return p_coor_list

def split_data(input, ratio = [0.9,0.05,0.05]) :
    assert sum(ratio) == 1, "Ratio does not add up to 1"  
    stop1 = int(len(input) * ratio[0])
    stop2 = int(len(input) * (ratio[0] + ratio[1])) 

    train = input[:stop1]
    val = input[stop1:stop2]
    test = input[stop2:]

    return train, val, test
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))


def plot_attn(matrix, smi, mode, output_path = "", output_name = "", output_type = "show") :

    if mode == "cross" :
        fig = plt.figure()
        ax = fig.add_subplot(111)
        cax = ax.matshow(matrix, cmap = "viridis")
        fig.colorbar(cax)
        ax.set_xticklabels([''] + list(smi))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))


    if mode == "self" :
        num_head = matrix.shape[0]
        
        fig, ax = plt.subplots(1, num_head, figsize=(num_head*10,30))
        for i, head in enumerate(matrix) :
            cax = ax[i].matshow(head, cmap='viridis')
            # fig.colorbar(cax)
            ax[i].set_xticklabels([''] + list(smi), fontsize="xx-large")
            ax[i].set_yticklabels([''] + list(smi), fontsize='xx-large')
            
            ax[i].xaxis.set_major_locator(ticker.MultipleLocator(1))
            ax[i].yaxis.set_major_locator(ticker.MultipleLocator(1))
    if output_type == "show" :
        plt.show()
    if output_type == "save" :
        plt.savefig(f"{output_path}/{output_name}")
    plt.close()

def evaluate(idx,
            encoder, decoder,
            inv_smi_dic, smi_list, np_coor_list) :

    target_smi = smi_list[idx]
    input_coor = np_coor_list[idx]

    print(f'Target SMILES: {target_smi[:-1]}')

    input_coor = torch.tensor(input_coor, device=device).unsqueeze(0)
    with torch.no_grad() : 
        e_all, h, c, self_attn = encoder(input_coor)
        prediction, cross_attn = decoder(e_all, h, c)
        _, idx = torch.topk(prediction, 1) 
        
    idx = idx.squeeze(-1).squeeze(0).cpu().numpy()

    pred_smi = ''.join([inv_smi_dic[i] for i in idx]).replace('x','')

    print(f'Predicted SMILES: {pred_smi}')
    return pred_smi, self_attn.squeeze(0), cross_attn.squeeze(0)

# def visualize(idx,
#               encoder, decoder,
#               smi_list, np_coor_list, inv_smi_dic) : 
#     target_smi = smi_list[idx]
#     input_coor = np_coor_list[idx]

#     target_smi = target_smi[:-1] if target_smi[-1] == 'E' else target_smi
    
#     _, self_attn, cross_attn = evaluate(input_coor,
#                                         encoder, decoder,
#                                         inv_smi_dic)

#     smi_len = len(target_smi)
#     coor_len = len(input_coor)




In [4]:
smi_list = get_smi(SMILES_PATH)
smi_dic = get_dic(smi_list)
inv_smi_dic = {value:key for key, value in smi_dic.items()}
longest_smi = get_longest(smi_list)
smint_list = [smi2int(smi, smi_dic, longest_smi) for smi in smi_list]


coor_list = get_coor(COORDINATE_PATH)
longest_coor = get_longest(coor_list)
np_coor_list = pad_coor(normalize_coor(coor_list), longest_coor)


train_smint, val_smint, test_smint = split_data(smint_list)
train_coor, val_coor, test_coor = split_data(np_coor_list)

In [5]:
class Coor2SmiDataset(Dataset) :
    def __init__(self, coor_list, smint_list) :
        self.smint_list = torch.tensor(smint_list, dtype = torch.long, device=device)
        self.coor_list = [torch.tensor(coor, device=device) for coor in coor_list]

    def __len__(self) :
        return len(self.smint_list)
    
    def __getitem__(self, idx) :
        return self.coor_list[idx], self.smint_list[idx]

In [6]:
B = 16


train_set = Coor2SmiDataset(train_coor, train_smint)
val_set = Coor2SmiDataset(val_coor, val_smint)
test_set = Coor2SmiDataset(test_coor, test_smint)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

  self.coor_list = [torch.tensor(coor, device=device) for coor in coor_list]


In [7]:
class NN_Attention(nn.Module): # Neural Network Attention 
    def __init__(self, dim_model):
        super(NN_Attention, self).__init__()
        self.Wa = nn.Linear(dim_model, dim_model)
        self.Ua = nn.Linear(dim_model, dim_model)
        self.Va = nn.Linear(dim_model, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)

        context = torch.bmm(weights, keys)

        return context, weights 
    

class DP_Attention(nn.Module) : # Dot Product Attention
    def __init__(self, dim_model, num_head) :
        super(DP_Attention, self).__init__()
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.Q = nn.Linear(dim_model, dim_model)
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)

    def forward(self, Q, K, V) :
        B = Q.size(0) # Shape Q, K, V: (B, longest_smi, dim_model)

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)
        
        K_T = K.transpose(2,3).contiguous()

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2)

        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)
        
        attn = self.out(attn)

        return attn, attn_distribution

In [8]:
class Encoder(nn.Module) :
    def __init__(self, dim_model, num_head, dropout) :
        super(Encoder, self).__init__()
        self.Self_Attention = DP_Attention(dim_model, num_head) 
        self.LSTM = nn.LSTM(2 * dim_model, dim_model, batch_first=True)
        self.Up_Size = nn.Linear(3, dim_model)
        self.Dropout = nn.Dropout(dropout)
    
    def forward(self, x) :
        x = self.Dropout(self.Up_Size(x))

        attn, self_attn = self.Self_Attention(x, x, x) 

        input_lstm = torch.cat((attn, x), dim = -1)

        e_all, (h, c) = self.LSTM(input_lstm)

        return e_all, h, c, self_attn

In [9]:
class Decoder(nn.Module) :
    def __init__(self, dim_model, num_head, output_size, longest_smi, dropout) :
        super(Decoder, self).__init__()
        self.longest_smi = longest_smi
        self.Embedding = nn.Embedding(longest_smi, dim_model)
        self.Cross_Attention = NN_Attention(dim_model) 
        self.Dropout = nn.Dropout(dropout)
        self.LSTM = nn.LSTM(2 * dim_model, dim_model, batch_first=True)
        self.Linear = nn.Linear(dim_model, output_size)
    def forward(self, e_all, e_h, e_c, target = None) :
        B = e_all.size(0)

        d_input = torch.zeros(B, 1, dtype=torch.long, device = device)

        d_h, d_c = e_h, e_c 

        outputs, cross_attn = [], [] 

        for i in range(self.longest_smi) : 
            output, d_h, d_c, step_attn = self.forward_step(d_input, d_h, d_c, e_all)

            outputs.append(output), cross_attn.append(step_attn)

            if target is not None :
                d_input = target[:, i].unsqueeze(1)
            else : 
                _, topi = output.topk(1)
                d_input = topi.squeeze(-1).detach()

        
        outputs = torch.cat(outputs, dim = 1)
        outputs = F.log_softmax(outputs, dim = -1) 

        cross_attn = torch.cat(cross_attn, dim = 1)

        return outputs, cross_attn

    def forward_step(self, d_input, d_h, d_c, e_all) :
        embedded = self.Dropout(self.Embedding(d_input))
        
        query = d_h.permute(1, 0, 2) + d_c.permute(1, 0, 2)

        attn, cross_attn = self.Cross_Attention(query, e_all)

        input_gru = torch.cat((embedded, attn), dim = 2)

        output, (d_h, d_c) = self.LSTM(input_gru, (d_h, d_c)) 

        output = self.Linear(output) 

        return output, d_h, d_c, cross_attn

In [10]:
def train_epoch(train_loader, val_loader, test_loader,
                encoder, decoder,
                encoder_optimizer, decoder_optimizer,
                criterion, tf):

    epoch_train_loss = 0
    epoch_test_loss = 0

    for input, target in train_loader:
        encoder_optimizer.zero_grad(), decoder_optimizer.zero_grad()
        
        e_all, h, c, self_attn = encoder(input)

        # Teacher Forcing
        if tf :
          prediction, cross_attn = decoder(e_all, h, c, target)
        else :
          prediction, cross_attn = decoder(e_all, h, c)


        loss = criterion(
           prediction.view(-1, prediction.size(-1)),
           target.view(-1)
        )

        loss.backward()

        encoder_optimizer.step(), decoder_optimizer.step()
        
        epoch_train_loss += loss.item()


    encoder.eval(), decoder.eval()

    with torch.no_grad() :
      for input, target in val_loader :
        e_all, h, c, self_attn = encoder(input)
        prediction, cross_attn = decoder(e_all, h, c)

        test_loss = criterion(
           prediction.view(-1, prediction.size(-1)),
           target.view(-1)
        )
        epoch_test_loss += test_loss.item()

    return epoch_train_loss / len(train_loader), epoch_test_loss / len(val_loader)

def train(train_loader, val_loader, test_loader,
          encoder, decoder, 
          patience_threshold, num_epoch=50, learning_rate=0.001, tf_rate = 0):
    start = time.time()
    
    best_val = float('inf')
    patience = 0

    train_plot, val_plot = [], []

    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.NLLLoss()

    tf = True

    for epoch in range(1, num_epoch + 1):
      if epoch > (tf_rate * num_epoch) :
        tf = False
      encoder.train()
      decoder.train()

      train_loss, val_loss = train_epoch(train_loader,val_loader, test_loader,
                                         encoder, decoder,
                                         encoder_optimizer, decoder_optimizer,
                                         criterion, tf)

      print('%s (%d %d%%) /// Train Loss: %.4f - Validation Loss: %.4f' % (timeSince(start, epoch / num_epoch),
                                      epoch, epoch / num_epoch * 100, train_loss, val_loss))

      if val_loss < best_val :
        best_val = val_loss
        patience = 0
      else :
        patience += 1 
      
      if patience > patience_threshold : 
        print("EARLY STOPPING !!!")
        plt.plot(x, train_plot, color = 'blue', label = 'Train Loss')
        plt.plot(x, val_plot, color = 'red', label = 'Validation Loss')
        plt.title("Final Plot Before Loss")
        plt.legend()
        plt.show()
        break

      train_plot.append(train_loss), val_plot.append(val_loss)
      x = np.linspace(0, num_epoch, epoch)

      if epoch == 1 : 
        continue
      if epoch % 5 == 0 :
        plt.plot(x, train_plot, color = 'blue', label = 'Train Loss')
        plt.plot(x, val_plot, color = 'red', label = 'Validation Loss')
        plt.title(f'Epoch {epoch}')
        plt.legend()
        plt.show()


In [11]:
encoder = Encoder(DIM_MODEL, NUM_HEAD, DROPOUT).to(device)
decoder = Decoder(DIM_MODEL, NUM_HEAD, len(smi_dic),longest_smi, DROPOUT).to(device)

In [12]:
train(train_loader, val_loader, test_loader,
      encoder, decoder,
      num_epoch=NUM_EPOCHS,
      learning_rate=LEARNING_RATE,
      patience_threshold=PATIENCE_THRESHOLD,
      tf_rate = TEACHER_FORCING_RATE)

TypeError: train() missing 1 required positional argument: 'patience_threshold'

In [19]:
def evaluate(idx,
            encoder, decoder,
            inv_smi_dic, smi_list, np_coor_list) :

    target_smi = smi_list[idx]
    input_coor = np_coor_list[idx]


    input_coor = torch.tensor(input_coor, device=device).unsqueeze(0)
    with torch.no_grad() : 
        e_all, h, c, self_attn = encoder(input_coor)
        prediction, cross_attn = decoder(e_all, h, c)
        _, idx = torch.topk(prediction, 1) 
        
    idx = idx.squeeze(-1).squeeze(0).cpu().numpy()

    pred_smi = ''.join([inv_smi_dic[i] for i in idx]).replace('x','')

    
    print(f'Target SMILES: {target_smi[:-1]}')
    print(f'Predicted SMILES: {pred_smi}')

    
    return pred_smi, self_attn.squeeze(0), cross_attn.squeeze(0)

In [18]:
r = random.randint(0, len(smi_list))
out = evaluate(r, encoder, decoder, inv_smi_dic, smi_list,np_coor_list)

Target SMILES: CC(C)(CCC=C)C1NC(=O)NC1=O
Predicted SMILES: \TNN---.-3#KPP13NNNNNK/S--#++KKS/S-33NKK#B1+KST


  input_coor = torch.tensor(input_coor, device=device).unsqueeze(0)
