In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from transformers import BertModel, BertTokenizer

import numpy as np
import pandas as pd
from tqdm import tqdm
from tdc.multi_pred import DTI
from sklearn.metrics import average_precision_score

with open("data/drug/chem_qed_filtered_09.txt", "r") as f:
    data = f.readlines()

molecule_tokenizer = molecule_tokenizer = BertTokenizer.from_pretrained("data/drug/molecule_tokenizer", model_max_length=128)
protein_tokenizer = BertTokenizer.from_pretrained("data/target/protein_tokenizer", do_lower_case=False)

In [2]:
class DTIDataset(Dataset):
    def __init__(self, data, molecule_tokenizer, protein_tokenizer, target_seq):
        self.data = data
        
        self.molecule_max_len = 100
        self.protein_max_len = 512
        
        self.molecule_tokenizer = molecule_tokenizer
        self.protein_tokenizer = protein_tokenizer
        
        self.target_seq = self.protein_encode(target_seq)
        
        
    def molecule_encode(self, molecule_sequence):
        molecule_sequence = self.molecule_tokenizer(
            " ".join(molecule_sequence.rstrip()), 
            max_length=self.molecule_max_len, 
            truncation=True
        )
        
        return molecule_sequence
    
    
    def protein_encode(self, protein_sequence):
        protein_sequence = self.protein_tokenizer(
            " ".join(protein_sequence), 
            max_length=self.protein_max_len, 
            truncation=True
        )
        
        return protein_sequence
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        molecule_sequence = self.molecule_encode(self.data[idx])
                
        return molecule_sequence, self.target_seq, idx


def collate_batch(batch):
    molecule_seq, protein_seq, idx = [], [], []
    
    for (molecule_seq_, protein_seq_, idx_) in batch:
        molecule_seq.append(molecule_seq_)
        protein_seq.append(protein_seq_)
        idx.append(idx_)
        
    molecule_seq = molecule_tokenizer.pad(molecule_seq, return_tensors="pt")
    protein_seq = protein_tokenizer.pad(protein_seq, return_tensors="pt")
    
    return molecule_seq, protein_seq, idx


VEGFR2 = "MQSKVLLAVALWLCVETRAASVGLPSVSLDLPRLSIQKDILTIKANTTLQITCRGQRDLDWLWPNNQSGSEQRVEVTECSDGLFCKTLTIPKVIGNDTGAYKCFYRETDLASVIYVYVQDYRSPFIASVSDQHGVVYITENKNKTVVIPCLGSISNLNVSLCARYPEKRFVPDGNRISWDSKKGFTIPSYMISYAGMVFCEAKINDESYQSIMYIVVVVGYRIYDVVLSPSHGIELSVGEKLVLNCTARTELNVGIDFNWEYPSSKHQHKKLVNRDLKTQSGSEMKKFLSTLTIDGVTRSDQGLYTCAASSGLMTKKNSTFVRVHEKPFVAFGSGMESLVEATVGERVRIPAKYLGYPPPEIKWYKNGIPLESNHTIKAGHVLTIMEVSERDTGNYTVILTNPISKEKQSHVVSLVVYVPPQIGEKSLISPVDSYQYGTTQTLTCTVYAIPPPHHIHWYWQLEEECANEPSQAVSVTNPYPCEEWRSVEDFQGGNKIEVNKNQFALIEGKNKTVSTLVIQAANVSALYKCEAVNKVGRGERVISFHVTRGPEITLQPDMQPTEQESVSLWCTADRSTFENLTWYKLGPQPLPIHVGELPTPVCKNLDTLWKLNATMFSNSTNDILIMELKNASLQDQGDYVCLAQDRKTKKRHCVVRQLTVLERVAPTITGNLENQTTSIGESIEVSCTASGNPPPQIMWFKDNETLVEDSGIVLKDGNRNLTIRRVRKEDEGLYTCQACSVLGCAKVEAFFIIEGAQEKTNLEIIILVGTAVIAMFFWLLLVIILRTVKRANGGELKTGYLSIVMDPDELPLDEHCERLPYDASKWEFPRDRLKLGKPLGRGAFGQVIEADAFGIDKTATCRTVAVKMLKEGATHSEHRALMSELKILIHIGHHLNVVNLLGACTKPGGPLMVIVEFCKFGNLSTYLRSKRNEFVPYKTKGARFRQGKDYVGAIPVDLKRRLDSITSSQSSASSGFVEEKSLSDVEEEEAPEDLYKDFLTLEHLICYSFQVAKGMEFLASRKCIHRDLAARNILLSEKNVVKICDFGLARDIYKDPDYVRKGDARLPLKWMAPETIFDRVYTIQSDVWSFGVLLWEIFSLGASPYPGVKIDEEFCRRLKEGTRMRAPDYTTPEMYQTMLDCWHGEPSQRPTFSELVEHLGNLLQANAQQDGKDYIVLPISETLSMEEDSGLSLPTSPVSCMEEEEVCDPKFHYDNTAGISQYLQNSKRKSRPVSVKTFEDIPLEEPEVKVIPDDNQTDSGMVLASEELKTLEDRTKLSPSFGGMVPSKSRESVASEGSNQTSGYQSGYHSDDTDTTVYSSEEAELLKLIEIGVQTGSTAQILQPDSGTTLSSPPV"

pred_dataset = DTIDataset(data, molecule_tokenizer, protein_tokenizer, VEGFR2)

pred_dataloader = DataLoader(pred_dataset, batch_size=512, num_workers=14, 
                             pin_memory=True, prefetch_factor=10, 
                             collate_fn=collate_batch)

In [3]:
class DualLanguageModelDTI(nn.Module):
    def __init__(self, 
                 molecule_encoder, protein_encoder, hidden_dim=512,
                 molecule_input_dim=128, protein_input_dim=1024):
        super().__init__()
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing without last layer
        for param in self.molecule_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False        
        for param in self.protein_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        self.molecule_align = nn.Sequential(
            nn.LayerNorm(molecule_input_dim),
            nn.Linear(molecule_input_dim, hidden_dim)
        )
        
        self.protein_align = nn.Sequential(
            nn.LayerNorm(protein_input_dim),
            nn.Linear(protein_input_dim, hidden_dim)
        )
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.fc_out = nn.Linear(hidden_dim, 1)
        
    
    def forward(self, molecule_seq, protein_seq):
        encoded_molecule = self.molecule_encoder(**molecule_seq)
        encoded_protein = self.protein_encoder(**protein_seq)
        
        cls_molecule = encoded_molecule.pooler_output
        cls_protein = encoded_protein.pooler_output
        
        cls_molecule = self.molecule_align(cls_molecule)
        cls_protein = self.protein_align(cls_protein)
        
        cls_concat = torch.cat([cls_molecule, cls_protein], dim=1)

        x = F.gelu(self.fc1(cls_concat))
        x = F.gelu(self.fc2(x))
        x = F.gelu(self.fc3(x))
        out = self.fc_out(x)
        
        return out

molecule_bert = BertModel.from_pretrained("weights/molecule_bert")
protein_bert = BertModel.from_pretrained("weights/protein_bert")
dlm_dti = DualLanguageModelDTI(molecule_bert, protein_bert)

Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at weights/protein_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, dlm_dti):
        super().__init__()
        self.model = dlm_dti

        
    def forward(self, molecule_sequence, protein_sequence):
        return self.model(molecule_sequence, protein_sequence)
    
    
    def training_step(self, batch, batch_idx):
        pass
    
    
    def validation_step(self, batch, batch_idx):
        pass
    
    
    def test_step(self, batch, batch_idx):
        pass
    
    
    def predict_step(self, batch, batch_idx):
        molecule_sequence, protein_sequence, idx = batch
        y_hat = self(molecule_sequence, protein_sequence).squeeze(-1)        
        
        return y_hat, idx

    
    def configure_optimizers(self):
        pass

    
dti_regressor = DTI_prediction(dlm_dti)
ckpt_fname = "dlm_dti-epoch=194-valid_loss=0.2036-valid_mae=0.2368.ckpt"
dti_regressor = dti_regressor.load_from_checkpoint("weights/dlm_dti_davis_512/" + ckpt_fname, dlm_dti=dlm_dti)

trainer = pl.Trainer(max_epochs=1, gpus=[0], enable_progress_bar=True)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
pred = trainer.predict(dti_regressor, pred_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]