In [1]:
import rdkit
import pandas as pd
import numpy as np 
import time
import math
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def get_all_atom(smi_path) :
    smi_dic = {}
    i = 0

    with open(smi_path, 'r') as file :
        for smi in file :
            smi = rdkit.Chem.MolFromSmiles(smi)
            for atom in smi.GetAtoms() :
                atom = atom.GetSymbol()
                if atom not in smi_dic : 
                    smi_dic[atom] = i 
                    i += 1
    return smi_dic

def replace_atom(smi_list) :
    smi_list = [smi.replace('Cl', 'X')
                   .replace('Br', 'Y')
                   .replace('Na', 'Z')
                   .replace('Ba', 'T') for smi in smi_list]
    return smi_list

def get_longest(input_list) :
    longest = 0
    for i in input_list :
        if len(i) > longest :
            longest = len(i)
    return longest

def get_coor(coor_path) :
    coor_list = []
    supplier = rdkit.Chem.SDMolSupplier(coor_path)
    for mol in supplier:
        coor = []
        if mol is not None:
            conformer = mol.GetConformer()
            for atom in mol.GetAtoms():
                atom_idx = atom.GetIdx()
                x, y, z = conformer.GetAtomPosition(atom_idx)
                coor_atom = list((x,y,z))
                coor.append(coor_atom)
        coor_list.append(coor)

    # Replace invalid idx
    for i, coor in enumerate(coor_list):
        
        if len(coor) == 0 :
            if i == 0 :
                coor_list = coor_list[1:]
            coor_list[i] = coor_list[i-1]
    return coor_list

def get_xyz(coor_list) :
    X, Y, Z = [], [], []
    for coor in coor_list :
        x_list, y_list, z_list = [], [], []
        for x, y, z in coor :
            x_list.append(x)
            y_list.append(y)
            z_list.append(z)
        X.append(x_list)
        Y.append(y_list)
        Z.append(z_list)
    return  X, Y, Z

def get_smi(smi_path) :
    smi_list = []
    i = 0
    with open(smi_path, 'r') as file :
        for smi in file :
            if rdkit.Chem.MolFromSmiles(smi) is None :
                if len(smi_list) == 0 :
                    continue 
                smi_list.append(smi_list[i-1])
                i += 1 
                continue 
            smi_list.append(smi)
            i += 1
    
    smi_list = [smi[:-1] for smi in smi_list]
    smi_list = [smi + 'E' for smi in smi_list]
    smi_list = replace_atom(smi_list)
    return smi_list

def get_dic(smi_list) :
    smi_dic = {'x': 0,
               'E': 1}
    i = len(smi_dic)

    for smi in smi_list : 
        for atom in smi :
            if atom not in smi_dic : 
                smi_dic[atom] = i
                i += 1 
    return smi_dic 

def smi2int(smi, smi_dic, longest_smi) :
    smi = list(smi)
    smint = [smi_dic[atom] for atom in smi]
    smint = smint + [0] * (longest_smi - len(smint))
    return smint 

def normalize_coor(coor_list) :
    n_coor_list = []

    for mol_coor in coor_list :
        n_mol_coor = []

        x_origin, y_origin, z_origin = mol_coor[0]

        for atom_coor in mol_coor :
            n_atom_coor = [round(atom_coor[0] - x_origin, 2), 
                        round(atom_coor[1] - y_origin, 2), 
                        round(atom_coor[2] - z_origin, 2)]
            n_mol_coor.append(n_atom_coor)
        n_coor_list.append(n_mol_coor)
    return n_coor_list

def pad_coor(coor_list, longest_coor) :
    p_coor_list = []

    for i in coor_list :
        if len(i) < longest_coor :
            zeros = [[0,0,0]] * (longest_coor - len(i))
            zeros = torch.tensor(zeros)
            i = torch.tensor(i)
            i = torch.cat((i, zeros), dim = 0)
            p_coor_list.append(i)
        else :
            p_coor_list.append(i)
    return p_coor_list

def train_test_split(input, ratio = 0.9) :
    ratio = int(len(input) * ratio) 

    train = input[:ratio]
    test = input[ratio:]
    return train, test

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))


In [3]:
class MultiEncoderDataset(Dataset) :
    def __init__(self, x, y, z, smint) :
        self.x = torch.tensor(x, device=device)
        self.y = torch.tensor(y, device=device)
        self.z = torch.tensor(z, device=device)
        self.smint = torch.tensor(smint, dtype=torch.long, device=device)
    
    def __len__(self) :
        return len(self.smint)

    def __getitem__(self, idx) :
        return self.x[idx], self.y[idx], self.z[idx], self.smint[idx]

In [4]:
smi_list = get_smi('../data/ADAGRASIB_UNIQUE_SMILES.txt')
coor_list = get_coor('../data/ADAGRASIB_COOR.sdf')
smi_dic = get_dic(smi_list)
longest_smi, longest_coor = get_longest(smi_list), get_longest(coor_list)
smint_list = [smi2int(smi, smi_dic, longest_smi) for smi in smi_list]
np_coor_list = pad_coor(normalize_coor(coor_list), longest_coor)
X, Y, Z = get_xyz(np_coor_list)
train_X, test_X = train_test_split(X)
train_Y, test_Y = train_test_split(Y)
train_Z, test_Z = train_test_split(Z)
train_smint, test_smint = train_test_split(smint_list)

In [5]:
B = 16
train_set = MultiEncoderDataset(train_X, train_Y, train_Z, train_smint)
test_set = MultiEncoderDataset(test_X, test_Y, test_Z, test_smint)

train_loader = DataLoader(train_set, batch_size=B, shuffle=True)
test_loader = DataLoader(test_set, batch_size=B, shuffle=True)

In [6]:
class NN_Attention(nn.Module): # Neural Network Attention 
    def __init__(self, dim_model):
        super(NN_Attention, self).__init__()
        self.Wa = nn.Linear(dim_model, dim_model)
        self.Ua = nn.Linear(dim_model, dim_model)
        self.Va = nn.Linear(dim_model, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)

        context = torch.bmm(weights, keys)

        return context, weights 
    

class DP_Attention(nn.Module) : # Dot Product Attention
    def __init__(self, dim_model, num_head) :
        super(DP_Attention, self).__init__()
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.Q = nn.Linear(dim_model, dim_model)
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)

    def forward(self, Q, K, V) :
        B = Q.size(0) # Shape Q, K, V: (B, longest_smi, dim_model)

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)
        
        K_T = K.transpose(2,3).contiguous()

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2)

        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)
        
        attn = self.out(attn)

        return attn, attn_distribution

In [7]:
class Encoder(nn.Module) :
    def __init__(self, dim_model, num_head, dropout) :
        super(Encoder, self).__init__()

        self.Self_Attention = DP_Attention(dim_model, num_head)
        self.GRU = nn.GRU(2 * dim_model, dim_model, batch_first=True)
        self.Up_Size = nn.Linear(1, dim_model)
        self.Dropout = nn.Dropout(dropout)
        self.LayerNorm = nn.LayerNorm(dim_model)

    def forward(self, x) :
        x = x.unsqueeze(-1)

        x = self.Dropout(self.Up_Size(x))

        x = self.LayerNorm(x) 

        attn, self_attn = self.Self_Attention(x, x, x) 

        input_gru = torch.cat((attn, x), dim = -1) 

        e_all, e_last = self.GRU(input_gru)
        
        return e_all, e_last, self_attn


In [8]:
class Decoder(nn.Module) : 
    def __init__(self, dim_model, output_size, longest_smi, dropout) :
        super(Decoder, self).__init__()
        self.longest_smi = longest_smi
        self.Embedding = nn.Embedding(longest_smi, dim_model)

        self.xCross_Attention = NN_Attention(dim_model) 
        self.yCross_Attention = NN_Attention(dim_model) 
        self.zCross_Attention = NN_Attention(dim_model) 

        self.Dropout = nn.Dropout(dropout)
        self.GRU = nn.GRU(dim_model,dim_model, batch_first=True)
        self.Linear = nn.Linear(dim_model, output_size)

        self.LayerNorm1 = nn.LayerNorm(dim_model)
        self.LayerNorm2 = nn.LayerNorm(dim_model)


    def forward(self, xe_all, ye_all, ze_all, xe_last, ye_last, ze_last, target = None) :
        B = xe_all.size(0)

        d_input = torch.zeros(B, 1, dtype = torch.long, device = device)

        xdh = xe_last
        ydh = ye_last
        zdh = ze_last
        dh = None 

        outputs, cross_attn = [], [] 
        for i in range(self.longest_smi) :
            output, dh, step_attn = self.forward_step(d_input, xdh,ydh,zdh, xe_all, ye_all, ze_all, dh)

            outputs.append(output), cross_attn.append(step_attn)

            if target is not None :
                d_input = target[:, i].unsqueeze(1)
            else :
                _, topi = output.topk(1)
                d_input = topi.squeeze(-1).detach()

        outputs = torch.cat(outputs, dim = 1)
        outputs = F.log_softmax(outputs, dim = -1)

        # cross_attn = torch.cat(cross_attn, dim = 1)

        return outputs, cross_attn
    
    def forward_step(self, d_input, xdh, ydh, zdh, xe_all, ye_all, ze_all, dh) :
        embedded = self.Dropout(self.Embedding(d_input))

        x_Q = xdh.permute(1, 0, 2)
        y_Q = ydh.permute(1, 0, 2)
        z_Q = zdh.permute(1, 0, 2)

        x_attn, x_cross_attn = self.xCross_Attention(x_Q, xe_all)
        y_attn, y_cross_attn = self.yCross_Attention(y_Q, ye_all)
        z_attn, z_cross_attn = self.zCross_Attention(z_Q, ze_all)

        input_gru = embedded + x_attn + y_attn + z_attn 

        # input_gru = self.LayerNorm(input_gru)

        if dh is None :
            output, dh = self.GRU(input_gru, xdh + ydh + zdh)
        else : 
            output, dh = self.GRU(input_gru, dh) 

        output = self.Linear(output) 

        return output, dh, (x_cross_attn, y_cross_attn, z_cross_attn)


In [9]:
def train_epoch(train_loader,test_loader,
                xencoder, yencoder, zencoder, decoder,
                xencoder_optimizer, yencoder_optimizer, zencoder_optimizer, decoder_optimizer,
                criterion, tf):

    total_loss = 0
    total_test_loss = 0

    for x, y, z, target in train_loader:
        xencoder_optimizer.zero_grad(), yencoder_optimizer.zero_grad(), zencoder_optimizer.zero_grad(), decoder_optimizer.zero_grad()
        
        xe_all, xe_last, _ = xencoder(x)
        ye_all, ye_last, _ = yencoder(y)
        ze_all, ze_last, _ = zencoder(z)

        # Teacher Forcing
        if tf :
          prediction, _ = decoder(xe_all, ye_all, ze_all, xe_last, ye_last, ze_last, target)
        else :
          prediction, _ = decoder(xe_all, ye_all, ze_all, xe_last, ye_last, ze_last)

        print(f'prediction: {prediction.shape}')
        print(f'target: {target.shape}')
        loss = criterion(
           prediction.view(-1, prediction.size(-1)),
           target.view(-1)
        )

        loss.backward()

        xencoder_optimizer.step(), yencoder_optimizer.step(), zencoder_optimizer.step(), decoder_optimizer.step()
        
        total_loss += loss.item()


    xencoder.eval(), yencoder.eval(), zencoder.eval(), decoder.eval()

    with torch.no_grad() :
      for x, y , z, target in test_loader :
        xe_all, xe_last, _ = xencoder(x)
        ye_all, ye_last, _ = yencoder(y)
        ze_all, ze_last, _ = zencoder(z)
        prediction, cross_attn = decoder(xe_all, ye_all, ze_all, xe_last, ye_last, ze_last)

        test_loss = criterion(
           prediction.view(-1, prediction.size(-1)),
           target.view(-1)
        )
        total_test_loss += test_loss.item()

    return total_loss / len(train_loader), total_test_loss / len(test_loader)

def train(train_loader, test_loader, xencoder, yencoder, zencoder, decoder, n_epochs, learning_rate=0.001,
               print_every=1, visual_path= "", tf_rate = 1):
    start = time.time()

    train_loss_total = 0  
    test_loss_total = 0

    xencoder_optimizer = torch.optim.Adam(xencoder.parameters(), lr=learning_rate)
    yencoder_optimizer = torch.optim.Adam(yencoder.parameters(), lr=learning_rate)
    zencoder_optimizer = torch.optim.Adam(zencoder.parameters(), lr=learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.NLLLoss()

    tf = True

    for epoch in range(1, n_epochs + 1):
      if epoch > (tf_rate * n_epochs) :
        tf = False
      xencoder.train()
      yencoder.train()
      zencoder.train()
      decoder.train()

      train_loss, test_loss = train_epoch(train_loader, test_loader,
                                          xencoder, yencoder, zencoder, decoder,
                                          xencoder_optimizer, yencoder_optimizer,zencoder_optimizer, decoder_optimizer,
                                          criterion, tf)
      train_loss_total += train_loss
      test_loss_total += test_loss

      if epoch % print_every == 0:
          train_loss_avg = train_loss_total / print_every
          test_loss_avg = test_loss_total / print_every
          train_loss_total = 0
          test_loss_total = 0
          print('%s (%d %d%%) /// Train loss: %.4f - Test loss: %.4f' % (timeSince(start, epoch / n_epochs),
                                      epoch, epoch / n_epochs * 100, train_loss_avg, test_loss_avg))


In [10]:
DIM_MODEL = 256
NUM_HEAD = 4 
DROPOUT = 0.5 
NUM_EPOCHS = 100
LEARNING_RATE = 0.001
TF_RATE = 0.0

xencoder = Encoder(dim_model=DIM_MODEL, num_head=NUM_HEAD, dropout=DROPOUT)
yencoder = Encoder(dim_model=DIM_MODEL, num_head=NUM_HEAD, dropout=DROPOUT)
zencoder = Encoder(dim_model=DIM_MODEL, num_head=NUM_HEAD, dropout=DROPOUT)

decoder = Decoder(dim_model=DIM_MODEL, output_size=len(smi_dic), longest_smi=longest_smi, dropout=DROPOUT)

In [34]:
for x, y, z, smint in train_loader :
    print(smint)
    break

tensor([[18,  2,  6,  7,  2,  2,  7,  2,  8,  2,  7,  2,  6,  9,  3, 13,  5,  8,
          7, 12,  9,  8,  7, 12,  9,  4,  2,  2,  7,  2,  1,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  2,  4,  8,  2,  2,  9,  2,  2,  4,  2,  8,  7, 12,  9,  2,  8,  7,
         12,  9,  4,  2,  6,  7,  2,  8, 18,  9,  2,  7,  2,  2,  7,  2,  6,  1,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  2, 12,  2,  6,  7,  2,  2,  7,  2,  8,  2,  7,  2,  6,  9,  2, 11,
          7,  4,  4,  7,  2, 24,  2,  7,  2,  2,  8,  7,  4,  3,  4,  5, 11, 24,
          9, 13,  1,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  2,  6,  7,  2,  2,  7,  2,  8,  2,  7,  2,  6,  9,  3,  4,  5, 11,
          4,  7,  2,  2, 24,  7,  2, 11,  4,  7,  2,  8, 13,  9,  4,  2, 24,  7,
         12,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  2,  8, 12,  9,  2,  4,  2,  8,  7, 12,  9,  2,  8,  7, 12,  9,  4,
          2,  2,  6,  7,  2,  2,  7,  2,  8, 14,  9,  2

In [11]:
train(train_loader, test_loader,
      xencoder, yencoder, zencoder, decoder,
      n_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, tf_rate=TF_RATE)

prediction: torch.Size([16, 47, 30])
target: torch.Size([16, 47])
prediction: torch.Size([16, 47, 30])
target: torch.Size([16, 47])
prediction: torch.Size([16, 47, 30])
target: torch.Size([16, 47])
prediction: torch.Size([16, 47, 30])
target: torch.Size([16, 47])
prediction: torch.Size([16, 47, 30])
target: torch.Size([16, 47])
prediction: torch.Size([16, 47, 30])
target: torch.Size([16, 47])


KeyboardInterrupt: 