In [1]:
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem.QED import qed
from tdc.multi_pred import DTI

from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

import warnings
warnings.filterwarnings(action='ignore')

# davis = DTI(name="Davis")
# davis.convert_to_log(form='binding')
# davis_split = davis.get_split()

kiba = DTI(name="KIBA")
kiba_split = kiba.get_split()

train_df = kiba_split['train']
valid_df = kiba_split['valid']
test_df = kiba_split['test']

Found local copy...
Loading...
Done!


In [2]:
from transformers import BertModel, BertTokenizer
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer

molecule_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="data/drug/tokenizer_model/vocab.json",
    pad_token="[PAD]",
    mask_token="[MASK]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    unk_token="[UNK]",
    model_max_length=128
)
molecule_bert = BertModel.from_pretrained("weights/molecule_bert_pretrained-masking_rate_30", local_files_only=True)

protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, model_max_length=512)
protein_bert = BertModel.from_pretrained("Rostlab/prot_bert")

Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert_pretrained-masking_rate_30 and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initia

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import DataCollatorWithPadding


class DTIDataset(Dataset):
    def __init__(self, data, tokenizer=None, mode="Drug"):
        self.data = data
        self.mode = mode
        if self.mode == "Drug":
            self.max_seq_len = 128
        elif self.mode == "Target":
            self.max_seq_len = 512
        
        self.tokenizer = tokenizer
    
        
    def encode(self, sequnce):
        sequnce = self.tokenizer(" ".join(sequnce), max_length=self.max_seq_len, truncation=True)
        
        return sequnce
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        if self.mode == "Drug":
            return_value = self.encode(self.data.loc[idx, "Drug"])
        elif self.mode == "Target":
            return_value = self.encode(self.data.loc[idx, "Target"])
        elif self.mode == "Y":
            return_value = torch.tensor(self.data.loc[idx, "Y"]).float()
        else:
            print("please choose between ['Drug' / 'Target' / 'Y']")
            raise()
        
        return return_value
    

def define_dataloaders(df, molecule_tokenizer, protein_tokenizer, batch_size=256, num_workers=16, prefetch_factor=10, mode="train"):
    drug_collator = DataCollatorWithPadding(molecule_tokenizer, padding=True, return_tensors="pt")
    target_collator = DataCollatorWithPadding(protein_tokenizer, padding=True, return_tensors="pt")
    
    drug_dataset = DTIDataset(df, molecule_tokenizer, mode="Drug")
    target_dataset = DTIDataset(df, protein_tokenizer, mode="Target")
    y_dataset = DTIDataset(df, None, mode="Y")
    
    if mode == "Train":
        drop_last = True
        shuffle = True
    elif mode == "Valid":
        drop_last = False
        shuffle = False
    else:
        print("please specify mode between ['Train' / 'Valid']")
        raise()
        
    
    drug_dataloader = DataLoader(drug_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=shuffle, pin_memory=True, prefetch_factor=prefetch_factor, 
                                drop_last=drop_last, collate_fn=drug_collator)
    target_dataloader = DataLoader(target_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=shuffle, pin_memory=True, prefetch_factor=prefetch_factor, 
                                drop_last=drop_last, collate_fn=drug_collator)
    y_dataloader = DataLoader(y_dataset, batch_size=batch_size, num_workers=num_workers, 
                              shuffle=shuffle, pin_memory=True, prefetch_factor=prefetch_factor, 
                              drop_last=drop_last)
    
    return drug_dataloader, target_dataloader, y_dataloader


unified_train_dataloader = define_dataloaders(train_df, molecule_tokenizer, protein_tokenizer, batch_size=16, mode="Train")
unified_valid_dataloader = define_dataloaders(valid_df, molecule_tokenizer, protein_tokenizer, batch_size=16, mode="Valid")
unified_test_dataloader = define_dataloaders(test_df, molecule_tokenizer, protein_tokenizer, batch_size=16, mode="Valid")


In [4]:
class DTIPredictionHead(nn.Module):
    def __init__(self, molecule_encoder, protein_encoder, 
                 molecule_dim=128, protein_dim=1024, inner_dim=256, projection=True):
        super().__init__()
        self.is_projection = projection
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing without last layer
        for param in self.molecule_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        for param in self.protein_encoder.encoder.layer[0:-1].parameters():
            param.requires_grad = False
        
        if self.is_projection:
            self.mol_proj = nn.Linear(molecule_dim, inner_dim, bias=False)        
            self.prot_proj = nn.Linear(protein_dim, inner_dim, bias=False)            
            self.fc_1 = nn.Linear(inner_dim * 2, inner_dim)
        else:
            self.fc_1 = nn.Linear(molecule_dim + protein_dim, inner_dim)
        
        self.fc_2 = nn.Linear(inner_dim, int(inner_dim / 2))
        self.fc_out = nn.Linear(int(inner_dim / 2), 1)
        
        
    def forward(self, molecule, protein):
        molecule = self.molecule_encoder(molecule).pooler_output
        protein = self.protein_encoder(protein).pooler_output
        
        if self.is_projection:
            molecule = self.mol_proj(molecule)
            protein = self.prot_proj(protein)
            
        x = torch.cat((molecule, protein), -1)
        x = F.dropout(F.gelu(self.fc_1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc_2(x)), 0.1)
        x = self.fc_out(x)
        
        return x


dti_prediction_head = DTIPredictionHead(molecule_bert, protein_bert, projection=True)
dti_prediction_head

DTIPredictionHead(
  (molecule_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(69, 128, padding_idx=0)
      (position_embeddings): Embedding(128, 128)
      (token_type_embeddings): Embedding(1, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwis

In [5]:
import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchmetrics.functional import mean_squared_error, mean_absolute_error


class DTI_prediction(pl.LightningModule):
    def __init__(self, dti_prediction_head, 
                 unified_train_dataloader, 
                 unified_valid_dataloader, 
                 unified_test_dataloader):
        super().__init__()
        self.model = dti_prediction_head
        
        self.unified_train_dataloader = unified_train_dataloader
        self.unified_valid_dataloader = unified_valid_dataloader
        self.unified_test_dataloader = unified_test_dataloader
        
        
    def forward(self, smiles, fasta):
        return self.model(smiles, fasta)

    
    def train_dataloader(self):
        drug_dataloader, target_dataloader, y_dataloader = self.unified_train_dataloader

        return CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader})
    
    
    def training_step(self, batch, batch_idx):
        molecule = batch["drug"]
        protein = batch["target"]
        y = batch["y"]
        
        y_hat = self(molecule['input_ids'], protein['input_ids']).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    
    def val_dataloader(self):
        drug_dataloader, target_dataloader, y_dataloader = self.unified_valid_dataloader

        return CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader})
    
    
    def validation_step(self, batch, batch_idx):
        molecule = batch["drug"]
        protein = batch["target"]
        y = batch["y"]
        
        y_hat = self(molecule['input_ids'], protein['input_ids']).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_dataloader(self):
        drug_dataloader, target_dataloader, y_dataloader = self.unified_test_dataloader

        return CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader})
    
    
    def test_step(self, batch, batch_idx):
        molecule = batch["drug"]
        protein = batch["target"]
        y = batch["y"]
        
        y_hat = self(molecule['input_ids'], protein['input_ids']).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        molecule = batch["drug"]
        protein = batch["target"]
        y = batch["y"]
        
        y_hat = self(molecule['input_ids'], protein['input_ids']).squeeze(-1)        
        
        return y_hat
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=200)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=30, dirpath='weights/DTI_prediction_CLS_concatenate_with_projection', filename='attentional_dti-{epoch:03d}-{valid_loss:.4f}-{valid_mae:.4f}'),
]

model = DTI_prediction(
    dti_prediction_head, 
    unified_train_dataloader,
    unified_valid_dataloader,
    unified_test_dataloader
)

trainer = pl.Trainer(max_epochs=1000, gpus=1, enable_progress_bar=True, callbacks=callbacks, precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | model | DTIPredictionHead | 422 M 
--------------------------------------------
55.3 M    Trainable params
366 M     Non-trainable params
422 M     Total params
844.037   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [None]:
# checkpoint_file = ""
# model.load_from_checkpoint(dti_prediction_head=dti_prediction_head, checkpoint_path="weights/DTI_prediction_CLS_token_concatenate/" + checkpoint_file)

# trainer.test(model, test_dataloader)

In [None]:
# drug_dataloader, target_dataloader, y_dataloader = self.unified_test_dataloader
# trainer.predict(model, CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader}))