In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from transformers import BertModel, BertTokenizer

import pandas as pd
from tqdm import tqdm

molecule_tokenizer = molecule_tokenizer = BertTokenizer.from_pretrained("data/drug/molecule_tokenizer", model_max_length=128)
protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)


In [4]:
class CrossAttention(nn.Module):
    def __init__(self, input_dim=128, intermediate_dim=512, heads=8, dropout=0.1):
        super().__init__()
        project_out = input_dim

        self.heads = heads
        self.scale = (input_dim / heads) ** -0.5

        self.key = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.value = nn.Linear(input_dim, intermediate_dim, bias=False)
        self.query = nn.Linear(input_dim, intermediate_dim, bias=False)

        self.out = nn.Sequential(
            nn.Linear(intermediate_dim, project_out),
            nn.Dropout(dropout)
        )

        
    def forward(self, data):
        b, n, d, h = *data.shape, self.heads

        k = self.key(data)
        k = rearrange(k, 'b n (h d) -> b h n d', h=h)

        v = self.value(data)
        v = rearrange(v, 'b n (h d) -> b h n d', h=h)
        
        # get only cls token
        q = self.query(data[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h=h)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention = dots.softmax(dim=-1)

        output = einsum('b h i j, b h j d -> b h i d', attention, v)
        output = rearrange(output, 'b h n d -> b n (h d)')
        output = self.out(output)
        
        return output


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

    
class CrossAttentionLayer(nn.Module):
    def __init__(self, 
                 molecule_dim=128, molecule_intermediate_dim=256,
                 protein_dim=1024, protein_intermediate_dim=2048,
                 cross_attn_depth=1, cross_attn_heads=4, dropout=0.1):
        super().__init__()

        self.cross_attn_layers = nn.ModuleList([])
        
        for _ in range(cross_attn_depth):
            self.cross_attn_layers.append(nn.ModuleList([
                nn.Linear(molecule_dim, protein_dim),
                nn.Linear(protein_dim, molecule_dim),
                PreNorm(protein_dim, CrossAttention(
                    protein_dim, protein_intermediate_dim, cross_attn_heads, dropout
                )),
                nn.Linear(protein_dim, molecule_dim),
                nn.Linear(molecule_dim, protein_dim),
                PreNorm(molecule_dim, CrossAttention(
                    molecule_dim, molecule_intermediate_dim, cross_attn_heads, dropout
                ))
            ]))

            
    def forward(self, molecule, protein):
        for i, (f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l) in enumerate(self.cross_attn_layers):
            
            cls_molecule = molecule[:, 0]
            x_molecule = molecule[:, 1:]
            
            cls_protein = protein[:, 0]
            x_protein = protein[:, 1:]

            # Cross attention for protein sequence
            cal_q = f_ls(cls_protein.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_molecule), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_l(cal_qkv)
            cal_out = F.gelu(g_sl(cal_out))
            protein = torch.cat((cal_out, x_protein), dim=1)

            # Cross attention for molecule sequence
            cal_q = f_sl(cls_molecule.unsqueeze(1))
            cal_qkv = torch.cat((cal_q, x_protein), dim=1)
            # add activation function
            cal_out = cal_q + cross_attn_s(cal_qkv)
            cal_out = F.gelu(g_ls(cal_out))
            molecule = torch.cat((cal_out, x_molecule), dim=1)
            
        return molecule, protein
    
    
class AttentionalDTI(nn.Module):
    def __init__(self, 
                 molecule_encoder, protein_encoder, cross_attention_layer, 
                 molecule_input_dim=128, protein_input_dim=1024, hidden_dim=512, **kwargs):
        super().__init__()
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing without last layer
        for param in self.molecule_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False        
        for param in self.protein_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        self.cross_attention_layer = cross_attention_layer
        
        self.molecule_mlp = nn.Sequential(
            nn.LayerNorm(molecule_input_dim),
            nn.Linear(molecule_input_dim, hidden_dim)
        )
        
        self.protein_mlp = nn.Sequential(
            nn.LayerNorm(protein_input_dim),
            nn.Linear(protein_input_dim, hidden_dim)
        )
        
        self.fc_out = nn.Linear(hidden_dim, 1)
        
    
    def forward(self, molecule_seq, protein_seq):
        encoded_molecule = self.molecule_encoder(**molecule_seq)
        encoded_protein = self.protein_encoder(**protein_seq)
        
        molecule_out, protein_out = self.cross_attention_layer(encoded_molecule.last_hidden_state, encoded_protein.last_hidden_state)
        
        molecule_out = molecule_out[:, 0]
        protein_out = protein_out[:, 0]
        
        # cls token
        molecule_projected = self.molecule_mlp(molecule_out)
        protein_projected = self.protein_mlp(protein_out)
        
        out = self.fc_out(molecule_projected + protein_projected)
        
        return out

molecule_bert = BertModel.from_pretrained("weights/molecule_bert")
protein_bert = BertModel.from_pretrained("weights/protein_bert")
cross_attention_layer = CrossAttentionLayer()
attentional_dti = AttentionalDTI(molecule_bert, protein_bert, cross_attention_layer, cross_attn_depth=4)


Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at weights/protein_bert and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, attentional_dti):
        super().__init__()
        self.model = attentional_dti

        
    def forward(self, molecule_sequence, protein_sequence):
        return self.model(molecule_sequence, protein_sequence)
    
    
    def training_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        
        return y_hat


model = DTI_prediction(attentional_dti)

In [6]:
ckpt_fname = "attentional_dti-epoch=049-valid_loss=0.1788-valid_mae=0.2497.ckpt"

model = model.load_from_checkpoint("weights/Attentional_DTI_cross_attention_kiba/" + ckpt_fname, attentional_dti=attentional_dti)

In [33]:
def molecule_encode(molecule_sequence):
    molecule_sequence = molecule_tokenizer(
        " ".join(molecule_sequence), 
        max_length=100, 
        truncation=True, 
        return_tensors="pt"
    )

    return molecule_sequence


def protein_encode(protein_sequence):
    protein_sequence = protein_tokenizer(
        " ".join(protein_sequence), 
        max_length=512, 
        truncation=True, 
        return_tensors="pt"
    )

    return protein_sequence

# PDE5

In [9]:
pde5 = "MATALNHVSREEVEKYLEANHDVATDIFVTKATPDMIDQWLSKHANSLHKHGEGSPQDVSSWPDVSMKLTEKGVFQSIRKSFNISGTKSLRNLLSPRRRKSTLKRNKSALRQLDEKELFMELIRDIADELDLNTLCHKILMNVSILTNGDRCSLFLARGTKDRRFLVSKLFDVNENSTVEDSLHSEEEEIHIPFGQGIAGHVAQTKETVNIKNAYEDKRFNPEVDKITGYKTHSIMCMPICNHDGEVVGVAQVINKITGSHEFAAKDEEAQVELRRIVSHEFNPADEEVFKNYLTFCGIGIMNAQLFEMSVNEYKRNQMLLQLARGIFEEQTSLDNVVHKIMRQAVSLLKCQRCMVFILETTEESYLPAQLRMAEGKRHSIAYQSSFDAPLNDVKNISFLKGFELTDEDTEKLKTIPHEMLKNSINATIARHVADSGETTNIADFTVQKQFKEISDVDPEFRIRSVLCQPIYNSEQKIIGVAQMINKACKQTFTDQDEHLFEAFAIFCGLGIHNTQMFENAMRLMAKQQVALDVLSYHATAQPDEVSKLKKSCVPSARELKLYEFSFSDFDLTEDQTLQGTLRMFIECNLIEKYHIPYDVLCRWTLSVRKNYRPVIYHNWRHAFNVAQTMFSIVMTGKLRKLLTDLEIFALIVACLCHDLDHRGTNNTFQVKTSSPLSLLYGTSTMEHHHFDHCIMILNSEGNNIFEFMSPDDYREAIRMLESAILSTDLAIYFKKRADFFKLVEKGEHTWDNEEKKGLLRGMLMTACDVSAIAKPWLVQQKVAELVFSEFFQQGDLEREKLKEEPMAMMDRKKKDELPKMQVGFIDGICMPVYKMFAELWPDLKPLESGTQLNRDNWQALSEGKEPNDWGSSPPSLQTSKQMESTILQNDRTQLDTLDEKPSLECIQKQEGSRSTGGGEPKKRGSQMSQQCKEALAAKKNKSSLCSVI"

sildenafil = "CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"
tadalafil = "CN1CC(=O)N2C(C1=O)CC3=C(C2C4=CC5=C(C=C4)OCO5)NC6=CC=CC=C36"
vardenafil = "CCCC1=NC(=C2N1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)CC)OCC)C"


In [43]:
len(pde5)

949

In [11]:
pde5_ = protein_encode(pde5)

sildenafil_ = molecule_encode(sildenafil)
tadalafil_ = molecule_encode(tadalafil)
vardenafil_ = molecule_encode(vardenafil)

In [22]:
mol_ = ["sildenafil_", "tadalafil_", "vardenafil_"]

model.model.eval()
for i, mol in enumerate([sildenafil_, tadalafil_, vardenafil_]):
    res = model.model(mol, pde5_)
    print(mol_[i], res)

sildenafil_ tensor([[12.8052]], grad_fn=<AddmmBackward0>)
tadalafil_ tensor([[13.5117]], grad_fn=<AddmmBackward0>)
vardenafil_ tensor([[12.4900]], grad_fn=<AddmmBackward0>)


# HMG-CoA

In [41]:
hmg_coa = "MLSRLFRMHGLFVASHPWEVIVGTVTLTICMMSMNMFTGNNKICGWNYECPKFEEDVLSSDIIILTITRCIAILYIYFQFQNLRQLGSKYILGIAGLFTIFSSFVFSTVVIHFLDKELTGLNEALPFFLLLIDLSRASTLAKFALSSNSQDEVRENIARGMAILGPTFTLDALVECLVIGVGTMSGVRQLEIMCCFGCMSVLANYFVFMTFFPACVSLVLELSRESREGRPIWQLSHFARVLEEEENKPNPVTQRVKMIMSLGLVLVHAHSRWIADPSPQNSTADTSKVSLGLDENVSKRIEPSVSLWQFYLSKMISMDIEQVITLSLALLLAVKYIFFEQTETESTLSLKNPITSPVVTQKKVPDNCCRREPMLVRNNQKCDSVEEETGINRERKVEVIKPLVAETDTPNRATFVVGNSSLLDTSSVLVTQEPEIELPREPRPNEECLQILGNAEKGAKFLSDAEIIQLVNAKHIPAYKLETLMETHERGVSIRRQLLSKKLSEPSSLQYLPYRDYNYSLVMGACCENVIGYMPIPVGVAGPLCLDEKEFQVPMATTEGCLVASTNRGCRAIGLGGGASSRVLADGMTRGPVVRLPRACDSAEVKAWLETSEGFAVIKEAFDSTSRFARLQKLHTSIAGRNLYIRFQSRSGDAMGMNMISKGTEKALSKLHEYFPEMQILAVSGNYCTDKKPAAINWIEGRGKSVVCEAVIPAKVVREVLKTTTEAMIEVNINKNLVGSAMAGSIGGYNAHAANIVTAIYIACGQDAAQNVGSSNCITLMEASGPTNEDLYISCTMPSIEIGTVGGGTNLLPQQACLQMLGVQGACKDNPGENARQLARIVCGTVMAGELSLMAALAAGHLVKSHMIHNRSKINLQDLQGACTKKTA"

atorvastatin = "CC(C)C1=C(C(=C(N1CCC(CC(CC(=O)O)O)O)C2=CC=C(C=C2)F)C3=CC=CC=C3)C(=O)NC4=CC=CC=C4"
lovastatin = "CCC(C)C(=O)OC1CC(C=C2C1C(C(C=C2)C)CCC3CC(CC(=O)O3)O)C"
pravastatin = "CCC(C)C(=O)OC1CC(C=C2C1C(C(C=C2)C)CCC(CC(CC(=O)O)O)O)O"
rosuvastatin ="CC(C)C1=NC(=NC(=C1C=CC(CC(CC(=O)[O-])O)O)C2=CC=C(C=C2)F)N(C)S(=O)(=O)C.CC(C)C1=NC(=NC(=C1C=CC(CC(CC(=O)[O-])O)O)C2=CC=C(C=C2)F)N(C)S(=O)(=O)C.[Ca+2]"
simvastatin = "CCC(C)(C)C(=O)OC1CC(C=C2C1C(C(C=C2)C)CCC3CC(CC(=O)O3)O)C"
fluvastatin = "CC(C)N1C2=CC=CC=C2C(=C1C=CC(CC(CC(=O)O)O)O)C3=CC=C(C=C3)F"
pitavastatin = "C1CC1C2=NC3=CC=CC=C3C(=C2C=CC(CC(CC(=O)O)O)O)C4=CC=C(C=C4)F"

In [42]:
len(hmg_coa)

888

In [24]:
atorvastatin_ = molecule_encode(atorvastatin)
lovastatin_ = molecule_encode(lovastatin)
pravastatin_ = molecule_encode(pravastatin)
rosuvastatin_ = molecule_encode(rosuvastatin)
simvastatin_ = molecule_encode(simvastatin)
fluvastatin_ = molecule_encode(fluvastatin)
pitavastatin_ = molecule_encode(pitavastatin)

hmg_coa_ = protein_encode(hmg_coa)

In [26]:
mol_ = ["atorvastatin_", "lovastatin_", "pravastatin_", "rosuvastatin_", "simvastatin_", "fluvastatin_", "pitavastatin_"]

model.model.eval()
for i, mol in enumerate([atorvastatin_, lovastatin_, pravastatin_, rosuvastatin_, simvastatin_, fluvastatin_, pitavastatin_]):
    res = model.model(mol, hmg_coa_)
    print(mol_[i], res)

atorvastatin_ tensor([[12.6351]], grad_fn=<AddmmBackward0>)
lovastatin_ tensor([[10.8080]], grad_fn=<AddmmBackward0>)
pravastatin_ tensor([[10.6947]], grad_fn=<AddmmBackward0>)
rosuvastatin_ tensor([[12.2706]], grad_fn=<AddmmBackward0>)
simvastatin_ tensor([[10.7043]], grad_fn=<AddmmBackward0>)
fluvastatin_ tensor([[11.6072]], grad_fn=<AddmmBackward0>)
pitavastatin_ tensor([[11.5024]], grad_fn=<AddmmBackward0>)


# MLKL

In [29]:
mlkl = "GSPGENLKHIITLGQVIHKRCEEMKYCKKQCRRLGHRVLGLIKPLEMLQDQGKRSVPSEKLTTAMNRFKAALEEANGEIEKFSNRSNICRFLTASQDKILFKDVNRKLSDVWKELSLLLQVEQRMPVSPISQGASWAQEDQQDADEDRRAFQMLRRD"

necrosulfonamide = "COC1=NC=CN=C1NS(=O)(=O)C2=CC=C(C=C2)NC(=O)C=CC3=CC=C(S3)[N+](=O)[O-]"
TC13172 = "CN1C2=C(N=C1S(=O)(=O)C)N(C(=O)N(C2=O)C)CC#CC3=CC(=CC=C3)O"
GW806742X = "CN(C1=CC=C(C=C1)NC(=O)NC2=CC=C(C=C2)OC(F)(F)F)C3=NC(=NC=C3)NC4=CC(=CC=C4)S(=O)(=O)N"

In [39]:
len(mlkl)

157

In [30]:
mlkl_ = protein_encode(mlkl)

necrosulfonamide_ = molecule_encode(necrosulfonamide)
TC13172_ = molecule_encode(TC13172)
GW806742X_ = molecule_encode(GW806742X)

In [31]:
mol_ = ["necrosulfonamide_", "TC13172_", "GW806742X_"]

model.model.eval()
for i, mol in enumerate([necrosulfonamide_, TC13172_, GW806742X_]):
    res = model.model(mol, mlkl_)
    print(mol_[i], res)

necrosulfonamide_ tensor([[13.0803]], grad_fn=<AddmmBackward0>)
TC13172_ tensor([[12.8356]], grad_fn=<AddmmBackward0>)
GW806742X_ tensor([[12.6528]], grad_fn=<AddmmBackward0>)


# Cyclooxygenase2

In [34]:
cyclooxygenase2 = "MLARALLLCAVLALSHTANPCCSHPCQNRGVCMSVGFDQYKCDCTRTGFYGENCSTPEFLTRIKLFLKPTPNTVHYILTHFKGFWNVVNNIPFLRNAIMSYVLTSRSHLIDSPPTYNADYGYKSWEAFSNLSYYTRALPPVPDDCPTPLGVKGKKQLPDSNEIVGKLLLRRKFIPDPQGSNMMFAFFAQHFTHQFFKTDHKRGPAFTNGLGHGVDLNHIYGETLARQRKLRLFKDGKMKYQIIDGEMYPPTVKDTQAEMIYPPQVPEHLRFAVGQEVFGLVPGLMMYATIWLREHNRVCDVLKQEHPEWGDEQLFQTSRLILIGKQENDLYKTLFPREN"

celecoxib = "CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F"
valdecoxib = "CC1=C(C(=NO1)C2=CC=CC=C2)C3=CC=C(C=C3)S(=O)(=O)N"
rofecoxib = "CS(=O)(=O)C1=CC=C(C=C1)C2=C(C(=O)OC2)C3=CC=CC=C3"
etoricoxib = "CC1=NC=C(C=C1)C2=C(C=C(C=N2)Cl)C3=CC=C(C=C3)S(=O)(=O)C"

In [38]:
len(cyclooxygenase2)

339

In [35]:
cyclooxygenase2_ = protein_encode(cyclooxygenase2)

celecoxib_ = molecule_encode(celecoxib)
valdecoxib_ = molecule_encode(valdecoxib)
rofecoxib_ = molecule_encode(rofecoxib)
etoricoxib_ = molecule_encode(etoricoxib)

In [37]:
mol_ = ["celecoxib_", "valdecoxib_", "rofecoxib_", "etoricoxib_"]

model.model.eval()
for i, mol in enumerate([celecoxib_, valdecoxib_, rofecoxib_, etoricoxib_]):
    res = model.model(mol, cyclooxygenase2_)
    print(mol_[i], res)

celecoxib_ tensor([[12.6371]], grad_fn=<AddmmBackward0>)
valdecoxib_ tensor([[12.3927]], grad_fn=<AddmmBackward0>)
rofecoxib_ tensor([[11.3644]], grad_fn=<AddmmBackward0>)
etoricoxib_ tensor([[12.1555]], grad_fn=<AddmmBackward0>)
