In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from transformers import BertModel, BertTokenizer

import numpy as np
import pandas as pd
from tqdm import tqdm
from tdc.multi_pred import DTI
from sklearn.metrics import average_precision_score

molecule_tokenizer = molecule_tokenizer = BertTokenizer.from_pretrained("data/drug/molecule_tokenizer", model_max_length=128)
protein_tokenizer = BertTokenizer.from_pretrained("data/target/protein_tokenizer", do_lower_case=False)

In [2]:
class DualLanguageModelDTI(nn.Module):
    def __init__(self, 
                 molecule_encoder, protein_encoder, hidden_dim=512,
                 molecule_input_dim=128, protein_input_dim=1024):
        super().__init__()
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing without last layer
        for param in self.molecule_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False        
        for param in self.protein_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        self.molecule_align = nn.Sequential(
            nn.LayerNorm(molecule_input_dim),
            nn.Linear(molecule_input_dim, hidden_dim)
        )
        
        self.protein_align = nn.Sequential(
            nn.LayerNorm(protein_input_dim),
            nn.Linear(protein_input_dim, hidden_dim)
        )
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.fc_out = nn.Linear(hidden_dim, 1)
        
    
    def forward(self, molecule_seq, protein_seq):
        encoded_molecule = self.molecule_encoder(**molecule_seq)
        encoded_protein = self.protein_encoder(**protein_seq)
        
        cls_molecule = encoded_molecule.pooler_output
        cls_protein = encoded_protein.pooler_output
        
        cls_molecule = self.molecule_align(cls_molecule)
        cls_protein = self.protein_align(cls_protein)
        
        cls_concat = torch.cat([cls_molecule, cls_protein], dim=1)

        x = F.gelu(self.fc1(cls_concat))
        x = F.gelu(self.fc2(x))
        x = F.gelu(self.fc3(x))
        out = self.fc_out(x)
        
        return out

molecule_bert = BertModel.from_pretrained("weights/molecule_bert")
protein_bert = BertModel.from_pretrained("weights/protein_bert")
dlm_dti = DualLanguageModelDTI(molecule_bert, protein_bert)

Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at weights/protein_bert and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, dlm_dti):
        super().__init__()
        self.model = dlm_dti

        
    def forward(self, molecule_sequence, protein_sequence):
        return self.model(molecule_sequence, protein_sequence)
    
    
    def training_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, y = batch
        
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        
        return y_hat

    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    
dti_regressor = DTI_prediction(dlm_dti)
ckpt_fname = "dlm_dti-epoch=194-valid_loss=0.2036-valid_mae=0.2368.ckpt"
dti_regressor = dti_regressor.load_from_checkpoint("weights/dlm_dti_davis_512/" + ckpt_fname, dlm_dti=dlm_dti)

trainer = pl.Trainer(max_epochs=1, accelerator="cpu", enable_progress_bar=True)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [53]:
from rdkit import Chem

def encode_sequences(molecule, protein):
    canonical_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(molecule))
    
    molecule = molecule_tokenizer(" ".join(canonical_smiles), 
                                  max_length=100, truncation=True,
                                  return_tensors="pt")

    protein = protein_tokenizer(" ".join(protein), 
                                max_length=512, truncation=True,
                                return_tensors="pt")
    
    return molecule, protein


def predict(molecule, protein, dti_regressor):
    molecule, protein = encode_sequences(molecule, protein)
    
    dti_regressor.model.eval()
    pred = dti_regressor.model(molecule, protein)  
    print(pred)

In [37]:
davis = DTI(name="davis")
davis.convert_to_log(form = 'binding')
davis_split = davis.get_split()

train_df = davis_split['train']
valid_df = davis_split['valid']
test_df = davis_split['test']
# test_df[test_df.Y >= 7].loc[:, ["Target_ID", "Target"]].drop_duplicates().head(30).values

Found local copy...
Loading...
Done!
To log space...


In [57]:
# VEGFR2, https://go.drugbank.com/bio_entities/BE0000369
VEGFR2 = "MQSKVLLAVALWLCVETRAASVGLPSVSLDLPRLSIQKDILTIKANTTLQITCRGQRDLDWLWPNNQSGSEQRVEVTECSDGLFCKTLTIPKVIGNDTGAYKCFYRETDLASVIYVYVQDYRSPFIASVSDQHGVVYITENKNKTVVIPCLGSISNLNVSLCARYPEKRFVPDGNRISWDSKKGFTIPSYMISYAGMVFCEAKINDESYQSIMYIVVVVGYRIYDVVLSPSHGIELSVGEKLVLNCTARTELNVGIDFNWEYPSSKHQHKKLVNRDLKTQSGSEMKKFLSTLTIDGVTRSDQGLYTCAASSGLMTKKNSTFVRVHEKPFVAFGSGMESLVEATVGERVRIPAKYLGYPPPEIKWYKNGIPLESNHTIKAGHVLTIMEVSERDTGNYTVILTNPISKEKQSHVVSLVVYVPPQIGEKSLISPVDSYQYGTTQTLTCTVYAIPPPHHIHWYWQLEEECANEPSQAVSVTNPYPCEEWRSVEDFQGGNKIEVNKNQFALIEGKNKTVSTLVIQAANVSALYKCEAVNKVGRGERVISFHVTRGPEITLQPDMQPTEQESVSLWCTADRSTFENLTWYKLGPQPLPIHVGELPTPVCKNLDTLWKLNATMFSNSTNDILIMELKNASLQDQGDYVCLAQDRKTKKRHCVVRQLTVLERVAPTITGNLENQTTSIGESIEVSCTASGNPPPQIMWFKDNETLVEDSGIVLKDGNRNLTIRRVRKEDEGLYTCQACSVLGCAKVEAFFIIEGAQEKTNLEIIILVGTAVIAMFFWLLLVIILRTVKRANGGELKTGYLSIVMDPDELPLDEHCERLPYDASKWEFPRDRLKLGKPLGRGAFGQVIEADAFGIDKTATCRTVAVKMLKEGATHSEHRALMSELKILIHIGHHLNVVNLLGACTKPGGPLMVIVEFCKFGNLSTYLRSKRNEFVPYKTKGARFRQGKDYVGAIPVDLKRRLDSITSSQSSASSGFVEEKSLSDVEEEEAPEDLYKDFLTLEHLICYSFQVAKGMEFLASRKCIHRDLAARNILLSEKNVVKICDFGLARDIYKDPDYVRKGDARLPLKWMAPETIFDRVYTIQSDVWSFGVLLWEIFSLGASPYPGVKIDEEFCRRLKEGTRMRAPDYTTPEMYQTMLDCWHGEPSQRPTFSELVEHLGNLLQANAQQDGKDYIVLPISETLSMEEDSGLSLPTSPVSCMEEEEVCDPKFHYDNTAGISQYLQNSKRKSRPVSVKTFEDIPLEEPEVKVIPDDNQTDSGMVLASEELKTLEDRTKLSPSFGGMVPSKSRESVASEGSNQTSGYQSGYHSDDTDTTVYSSEEAELLKLIEIGVQTGSTAQILQPDSGTTLSSPPV"
sunitinib = "CCN(CC)CCNC(=O)C1=C(C)NC(\C=C2/C(=O)NC3=C2C=C(F)C=C3)=C1C"
midostaurin = "CO[C@@H]1[C@@H](C[C@H]2O[C@]1(C)N1C3=C(C=CC=C3)C3=C1C1=C(C4=C(C=CC=C4)N21)C1=C3CNC1=O)N(C)C(=O)C1=CC=CC=C1"
axitinib = "CNC(=O)C1=C(SC2=CC=C3C(NN=C3\C=C\C3=CC=CC=N3)=C2)C=CC=C1"
cabozantinib = "COC1=CC2=C(C=C1OC)C(OC1=CC=C(NC(=O)C3(CC3)C(=O)NC3=CC=C(F)C=C3)C=C1)=CC=N2"
regorafenib = "CNC(=O)C1=CC(OC2=CC(F)=C(NC(=O)NC3=CC=C(Cl)C(=C3)C(F)(F)F)C=C2)=CC=N1"
lenvatinib = "COC1=C(C=C2C(OC3=CC=C(NC(=O)NC4CC4)C(Cl)=C3)=CC=NC2=C1)C(N)=O"
nintendanib = "COC(=O)C1=CC=C2C(NC(=O)\C2=C(/NC2=CC=C(C=C2)N(C)C(=O)CN2CCN(C)CC2)C2=CC=CC=C2)=C1"
ripretinib = "CCN1C(=O)C(=CC2=C1C=C(NC)N=C2)C1=C(Br)C=C(F)C(NC(=O)NC2=CC=CC=C2)=C1"
tivozanib = "COC1=C(OC)C=C2C(OC3=CC(Cl)=C(NC(=O)NC4=NOC(C)=C4)C=C3)=CC=NC2=C1"

predict(sunitinib, VEGFR2, dti_regressor)
predict(midostaurin, VEGFR2, dti_regressor)
predict(axitinib, VEGFR2, dti_regressor)
predict(cabozantinib, VEGFR2, dti_regressor)
predict(regorafenib, VEGFR2, dti_regressor)
predict(lenvatinib, VEGFR2, dti_regressor)
predict(nintendanib, VEGFR2, dti_regressor)
predict(ripretinib, VEGFR2, dti_regressor)
predict(tivozanib, VEGFR2, dti_regressor)


tensor([[8.6080]], grad_fn=<AddmmBackward0>)
tensor([[5.1738]], grad_fn=<AddmmBackward0>)
tensor([[7.1522]], grad_fn=<AddmmBackward0>)
tensor([[5.0940]], grad_fn=<AddmmBackward0>)
tensor([[7.1733]], grad_fn=<AddmmBackward0>)
tensor([[5.0296]], grad_fn=<AddmmBackward0>)
tensor([[8.4587]], grad_fn=<AddmmBackward0>)
tensor([[5.0094]], grad_fn=<AddmmBackward0>)
tensor([[5.0660]], grad_fn=<AddmmBackward0>)
