In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from transformers import BertModel, BertTokenizer

import pandas as pd
from tqdm import tqdm

molecule_tokenizer = molecule_tokenizer = BertTokenizer.from_pretrained("data/drug/molecule_tokenizer", model_max_length=128)
protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)

with open("data/drug/molecule_qed_filtered.txt", 'r') as f:
    data = f.readlines()

In [11]:
def protein_encode(protein_sequence):
    protein_sequence = protein_tokenizer(
        " ".join(protein_sequence), 
        max_length=512, 
        truncation=True
    )

    return protein_sequence

mlkl = "GSPGENLKHIITLGQVIHKRCEEMKYCKKQCRRLGHRVLGLIKPLEMLQDQGKRSVPSEKLTTAMNRFKAALEEANGEIEKFSNRSNICRFLTASQDKILFKDVNRKLSDVWKELSLLLQVEQRMPVSPISQGASWAQEDQQDADEDRRAFQMLRRD"
mlkl_ = protein_encode(mlkl)


class DTIDataset(Dataset):
    def __init__(self, data, molecule_tokenizer, mlkl_):
        self.data = data
        
        self.molecule_max_len = 100
        
        self.molecule_tokenizer = molecule_tokenizer
        self.mlkl_ = mlkl_
        
    def molecule_encode(self, molecule_sequence):
        molecule_sequence = self.molecule_tokenizer(
            " ".join(molecule_sequence), 
            max_length=self.molecule_max_len, 
            truncation=True
        )
        
        return molecule_sequence
    
    
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        molecule_sequence = self.molecule_encode(self.data[idx])
        protein_sequence = self.mlkl_
                
        return molecule_sequence, protein_sequence

    
def collate_batch(batch):
    molecule_seq, protein_seq = [], []
    
    for (molecule_seq_, protein_seq_) in batch:
        molecule_seq.append(molecule_seq_)
        protein_seq.append(protein_seq_)
        
    molecule_seq = molecule_tokenizer.pad(molecule_seq, return_tensors="pt")
    protein_seq = protein_tokenizer.pad(protein_seq, return_tensors="pt")
    
    return molecule_seq, protein_seq


predict_dataset = DTIDataset(data, molecule_tokenizer, mlkl_)
predict_dataloader = DataLoader(predict_dataset, batch_size=3072, num_workers=16, 
                                pin_memory=True, prefetch_factor=10, 
                                collate_fn=collate_batch, shuffle=False)


In [3]:
class CrossAttention(nn.Module):
    def __init__(self, input_dim=128, intermediate_dim=512, heads=8, dropout=0.1):
        super().__init__()
        project_out = input_dim

        self.heads = heads
        self.scale = (input_dim / heads) ** -0.5

        self.key = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.value = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.query = nn.Linear(input_dim, intermediate_dim, bias=False)

        self.out = nn.Sequential(
            nn.Linear(intermediate_dim, project_out),
            nn.Dropout(dropout)
        )

        
    def forward(self, data):
        b, n, d, h = *data.shape, self.heads

        k = self.key(data)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)

        v = self.value(data)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)
        
        # get only cls token
        q = self.query(data[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention = dots.softmax(dim=-1)

        output = einsum('b h i j, b h j d -> b h i d', attention, v)
        output = rearrange(output, 'b h n d -> b n (h d)')
        output = self.out(output)
        
        return output


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

    
class CrossAttentionLayer(nn.Module):
    def __init__(self, 
                 molecule_dim=128, molecule_intermediate_dim=256,
                 protein_dim=1024, protein_intermediate_dim=2048,
                 cross_attn_depth=1, cross_attn_heads=4, dropout=0.1):
        super().__init__()

        self.cross_attn_layers = nn.ModuleList([])
        
        for _ in range(cross_attn_depth):
            self.cross_attn_layers.append(nn.ModuleList([
                nn.Linear(molecule_dim, protein_dim),
                nn.Linear(protein_dim, molecule_dim),
                PreNorm(protein_dim, CrossAttention(
                    protein_dim, protein_intermediate_dim, cross_attn_heads, dropout
                )),
                nn.Linear(protein_dim, molecule_dim),
                nn.Linear(molecule_dim, protein_dim),
                PreNorm(molecule_dim, CrossAttention(
                    molecule_dim, molecule_intermediate_dim, cross_attn_heads, dropout
                ))
            ]))

            
    def forward(self, molecule, protein):
        for i, (f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l) in enumerate(self.cross_attn_layers):
            
            cls_molecule = molecule[:, 0]
            x_molecule = molecule[:, 1:]
            
            cls_protein = protein[:, 0]
            x_protein = protein[:, 1:]

            # Cross attention for protein sequence
            cal_q = f_ls(cls_protein.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_molecule), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_l(cal_qkv)
            cal_out = F.gelu(g_sl(cal_out))
            protein = torch.cat((cal_out, x_protein), dim=1)

            # Cross attention for molecule sequence
            cal_q = f_sl(cls_molecule.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_protein), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_s(cal_qkv)
            cal_out = F.gelu(g_ls(cal_out))
            molecule = torch.cat((cal_out, x_molecule), dim=1)
            
        return molecule, protein
    
    
class AttentionalDTI(nn.Module):
    def __init__(self, 
                 molecule_encoder, protein_encoder, cross_attention_layer, 
                 molecule_input_dim=128, protein_input_dim=1024, hidden_dim=512, **kwargs):
        super().__init__()
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing without last layer
        for param in self.molecule_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False        
        for param in self.protein_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        self.cross_attention_layer = cross_attention_layer
        
        self.molecule_mlp = nn.Sequential(
            nn.LayerNorm(molecule_input_dim),
            nn.Linear(molecule_input_dim, hidden_dim)
        )
        
        self.protein_mlp = nn.Sequential(
            nn.LayerNorm(protein_input_dim),
            nn.Linear(protein_input_dim, hidden_dim)
        )
        
        self.fc_out = nn.Linear(hidden_dim, 1)
        
    
    def forward(self, molecule_seq, protein_seq):
        encoded_molecule = self.molecule_encoder(**molecule_seq)
        encoded_protein = self.protein_encoder(**protein_seq)
        
        molecule_out, protein_out = self.cross_attention_layer(encoded_molecule.last_hidden_state, encoded_protein.last_hidden_state)
        
        molecule_out = molecule_out[:, 0]
        protein_out = protein_out[:, 0]
        
        # cls token
        molecule_projected = self.molecule_mlp(molecule_out)
        protein_projected = self.protein_mlp(protein_out)
        
        out = self.fc_out(molecule_projected + protein_projected)
        
        return out

molecule_bert = BertModel.from_pretrained("weights/molecule_bert")
protein_bert = BertModel.from_pretrained("weights/protein_bert")
cross_attention_layer = CrossAttentionLayer()
attentional_dti = AttentionalDTI(molecule_bert, protein_bert, cross_attention_layer, cross_attn_depth=4)


Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at weights/protein_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, attentional_dti):
        super().__init__()
        self.model = attentional_dti

        
    def forward(self, molecule_sequence, protein_sequence):
        return self.model(molecule_sequence, protein_sequence)
    
    
    def predict_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        
        return y_hat

    
model = DTI_prediction(attentional_dti)
trainer = pl.Trainer(max_epochs=50, gpus=1, enable_progress_bar=True)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [5]:
ckpt_fname = "attentional_dti-epoch=049-valid_loss=0.1788-valid_mae=0.2497.ckpt"

model = model.load_from_checkpoint("weights/Attentional_DTI_cross_attention_kiba/" + ckpt_fname, attentional_dti=attentional_dti)

In [None]:
pred = trainer.predict(model, predict_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [26]:
import numpy as np
import pickle 

with open("data/drug/molecule_qed_filtered.txt", 'r') as f:
    mols = f.readlines()
mols = np.array(mols)
    
with open("data/interaction/prediction.pkl", 'rb') as f:
    data = pickle.load(f)

In [2]:
from tqdm import tqdm

results = []

for d in tqdm(data):
    for p in d:
        results.append(p)
        
results = np.array(results)

100%|████████████████████████████████████| 11885/11885 [01:18<00:00, 151.24it/s]


In [42]:
idx = results.argsort()[::-1]
sample_idx = idx[:20]

In [43]:
for a in results[sample_idx]:
    print(a.round(4))

for b in mols[sample_idx]:
    print(b.replace("\n", ""))

15.0683
15.0277
15.005
14.9328
14.9293
14.9052
14.8696
14.8621
14.8332
14.8302
14.8173
14.7982
14.7962
14.7945
14.7942
14.7875
14.7864
14.7821
14.7805
14.7757
C[N+]1(C)CCC23c4c5ccc(C(N)=O)c4OC2C(=O)CCC3(O)C1C5
CCC[N+]1(C)CCC23c4c5ccc(C(N)=O)c4OC2C(=O)CCC3(O)C1C5
C[N+]1(C)CCC23c4c5ccc(C(N)=O)c4OC2C(O)CCC3C1C5
C#[N+]C(=O)c1nnc2[nH]ccc2c1NC12CC3CC4(O)CC(C1)C32C4
CC1CC(C#N)N(C(=O)C(N)C2=C3C4CC(C2)CC3(O)C4)C1
CC1CN(c2n[nH]c(C3C4C5CCC(C5)C34)n2)CCN1
COC1=C(C)C(=O)OC1=C1OC23OC4CC(C2C1C)N1CCC3C41CNC(C)C
CNCc1c(S(=O)(=O)NC2C3C4CCC(C4)C23)n[nH]c1C
NC1CCN(c2n[nH]c(C3C4C5CCC(C5)C34)n2)CC1
CCn1c(C2C3C4CCC(C4)C23)nnc1S(N)(=O)=O
CC1=C(C)c2c(C)c(C)c3c4c(c(C)c(C)c1c24)C([O-])=C3[O-]
CC.CC1=CC(C)(C)Nc2ccc3c4c(oc(=O)c3c21)=CCC=4.[HH]
CC1CN(c2n[nH]c(C3C4C5CCC(C5)C34)n2)CC(C)N1
COC1=C(C)C(=O)OC1=C1OC23OC4CC(C2C1C)N1CCC3C41CNC1CC1
C1CNCCN(c2n[nH]c(C3C4C5CCC(C5)C34)n2)C1
COC1C(C(C)=O)CC2C3[NH+](C)CC34CC23c2c4ccc(O)c2OC13
C[N+]1(CC2CC2)CC23c4c5ccc(O)c4OC2C(=O)CCC3(O)C1C5
NCC1CCN(c2n[nH]c(C3C4C5CCC(C5)C34)n2)C

In [41]:
from rdkit import Chem

m = Chem.MolFromSmiles('C[N+]1(C)CCC23c4c5ccc(C(N)=O)c4OC2C(O)CCC3C1C5')
Chem.MolToSmiles(m, canonical=True, kekuleSmiles=True)

'C[N+]1(C)CCC23C4=C5C=CC(C(N)=O)=C4OC2C(O)CCC3C1C5'