In [1]:
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem.QED import qed
from tdc.multi_pred import DTI

from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

import warnings
warnings.filterwarnings(action='ignore')

davis = DTI(name="Davis")
davis.convert_to_log(form='binding')
davis_split = davis.get_split()

Found local copy...
Loading...
Done!
To log space...


In [2]:
def canonicalize_smiles(df):
    for i, smiles in enumerate(df['Drug']):
        mol = Chem.MolFromSmiles(smiles)
        canonical_smiles = Chem.MolToSmiles(mol)
        
        df.loc[i, 'Drug'] = canonical_smiles
        
    return df

train_df = canonicalize_smiles(davis_split['train'])
valid_df = canonicalize_smiles(davis_split['valid'])
test_df = canonicalize_smiles(davis_split['test'])

In [3]:
from transformers import BertModel, BertTokenizer
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer

molecule_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="data/drug/tokenizer_model/vocab.json",
    pad_token="[PAD]",
    mask_token="[MASK]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    unk_token="[UNK]",
    model_max_length=128
)
molecule_bert = BertModel.from_pretrained("weights/molecule_bert_pretrained-masking_rate_30", local_files_only=True)

protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, model_max_length=2048)
protein_bert = BertModel.from_pretrained("Rostlab/prot_bert")

Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert_pretrained-masking_rate_30 and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initia

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import DataCollatorWithPadding


class DTIDataset(Dataset):
    def __init__(self, data, tokenizer=None, mode="Drug"):
        self.data = data
        self.mode = mode
        if self.mode == "Drug":
            self.max_seq_len = 128
        elif self.mode == "Target":
            self.max_seq_len = 2048
        
        self.tokenizer = tokenizer
    
        
    def encode(self, sequnce):
        sequnce = self.tokenizer(" ".join(sequnce), max_length=self.max_seq_len, truncation=True)
        
        return sequnce
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        if self.mode == "Drug":
            return_value = self.encode(self.data.loc[idx, "Drug"])
        elif self.mode == "Target":
            return_value = self.encode(self.data.loc[idx, "Target"])
        elif self.mode == "Y":
            return_value = torch.tensor(self.data.loc[idx, "Y"]).float()
        else:
            print("please choose between ['Drug' / 'Target' / 'Y']")
            raise()
        
        return return_value
    

def define_dataloaders(df, molecule_tokenizer, protein_tokenizer, batch_size=256, num_workers=16, prefetch_factor=10, mode="train"):
    drug_collator = DataCollatorWithPadding(molecule_tokenizer, padding=True, return_tensors="pt")
    target_collator = DataCollatorWithPadding(protein_tokenizer, padding=True, return_tensors="pt")
    
    drug_dataset = DTIDataset(df, molecule_tokenizer, mode="Drug")
    target_dataset = DTIDataset(df, protein_tokenizer, mode="Target")
    y_dataset = DTIDataset(df, None, mode="Y")
    
    if mode == "Train":
        drop_last = True
        sampler = RandomSampler(df, replacement=True, num_samples=5000)
    elif mode == "Valid":
        drop_last = False
        sampler = None
    else:
        print("please specify mode between ['Train' / 'Valid']")
        raise()
        
    
    drug_dataloader = DataLoader(drug_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=False, pin_memory=True, prefetch_factor=prefetch_factor, 
                                drop_last=drop_last, collate_fn=drug_collator, sampler=sampler)
    target_dataloader = DataLoader(target_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=False, pin_memory=True, prefetch_factor=prefetch_factor, 
                                drop_last=drop_last, collate_fn=drug_collator, sampler=sampler)
    y_dataloader = DataLoader(y_dataset, batch_size=batch_size, num_workers=num_workers, 
                              shuffle=False, pin_memory=True, prefetch_factor=prefetch_factor, 
                              drop_last=drop_last, sampler=sampler)
    
    return drug_dataloader, target_dataloader, y_dataloader


unified_train_dataloader = define_dataloaders(train_df, molecule_tokenizer, protein_tokenizer, batch_size=32, mode="Train")
unified_valid_dataloader = define_dataloaders(valid_df, molecule_tokenizer, protein_tokenizer, batch_size=32, mode="Valid")
unified_test_dataloader = define_dataloaders(test_df, molecule_tokenizer, protein_tokenizer, batch_size=32, mode="Valid")


In [5]:
class DTIPredictionHead(nn.Module):
    def __init__(self, molecule_encoder, protein_encoder):
        super().__init__()
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing
        for param in self.molecule_encoder.parameters():
            param.requires_grad = False
        
        for param in self.protein_encoder.parameters():
            param.requires_grad = False
        
        molecule_out = 128
        protein_out = 1024
        self.fc_1 = nn.Linear(molecule_out + protein_out, 1024)        
        self.fc_2 = nn.Linear(1024, 512)
        self.fc_out = nn.Linear(512, 1)
        
        
    def forward(self, smiles, fasta):
        smiles_vec = self.molecule_encoder(smiles).pooler_output
        fasta_vec = self.protein_encoder(fasta).pooler_output
        
        x = torch.cat((smiles_vec, fasta_vec), -1)
        x = F.dropout(F.gelu(self.fc_1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc_2(x)), 0.1)
        x = self.fc_out(x)
        
        return x
        
        
dti_prediction_head = DTIPredictionHead(molecule_bert, protein_bert)

In [6]:
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.trainer.supporters import CombinedLoader


class DTI_prediction(pl.LightningModule):
    def __init__(self, dti_prediction_head, 
                 unified_train_dataloader, 
                 unified_valid_dataloader, 
                 unified_test_dataloader):
        super().__init__()
        self.model = dti_prediction_head
        
        self.unified_train_dataloader = unified_train_dataloader
        self.unified_valid_dataloader = unified_valid_dataloader
        self.unified_test_dataloader = unified_test_dataloader
        
        self.train_accuracy = torchmetrics.MeanAbsoluteError()
        self.valid_accuracy = torchmetrics.MeanAbsoluteError()
        self.test_accuracy = torchmetrics.MeanAbsoluteError()
        
        
    def forward(self, smiles, fasta):
        return self.model(smiles, fasta)

    
    def train_dataloader(self):
        drug_dataloader, target_dataloader, y_dataloader = self.unified_train_dataloader

        return CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader})
    
    
    def training_step(self, batch, batch_idx):
        smiles = batch["drug"]
        fasta = batch["target"]
        y = batch["y"]
        
        y_hat = self(smiles['input_ids'], fasta['input_ids']).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", self.train_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    
    def val_dataloader(self):
        drug_dataloader, target_dataloader, y_dataloader = self.unified_valid_dataloader

        return CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader})
    
    
    def validation_step(self, batch, batch_idx):
        smiles = batch["drug"]
        fasta = batch["target"]
        y = batch["y"]
        
        y_hat = self(smiles['input_ids'], fasta['input_ids']).squeeze(-1)
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_accuracy", self.valid_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_dataloader(self):
        drug_dataloader, target_dataloader, y_dataloader = self.unified_test_dataloader

        return CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader})
    
    
    def test_step(self, batch, batch_idx):
        smiles = batch["drug"]
        fasta = batch["target"]
        y = batch["y"]
        
        y_hat = self(smiles['input_ids'], fasta['input_ids']).squeeze(-1)
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_accuracy", self.valid_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        smiles = batch["drug"]
        fasta = batch["target"]
        y = batch["y"]
        
        y_hat = self(smiles['input_ids'], fasta['input_ids']).squeeze(-1)
        
        return y_hat
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=20, dirpath='weights/DTI_prediction_CLS_token_concatenate', filename='attentional_dti-{epoch:03d}-{valid_loss:.4f}-{valid_accuracy:.4f}'),
]

model = DTI_prediction(
    dti_prediction_head, 
    unified_train_dataloader,
    unified_valid_dataloader,
    unified_test_dataloader
)

trainer = pl.Trainer(max_epochs=1000, gpus=1, enable_progress_bar=True, callbacks=callbacks, precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [7]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params
-----------------------------------------------------
0 | model          | DTIPredictionHead | 423 M 
1 | train_accuracy | MeanAbsoluteError | 0     
2 | valid_accuracy | MeanAbsoluteError | 0     
3 | test_accuracy  | MeanAbsoluteError | 0     
-----------------------------------------------------
1.7 M     Trainable params
421 M     Non-trainable params
423 M     Total params
846.531   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
# checkpoint_file = ""
# model.load_from_checkpoint(dti_prediction_head=dti_prediction_head, checkpoint_path="weights/DTI_prediction_CLS_token_concatenate/" + checkpoint_file)

# trainer.test(model, test_dataloader)

In [None]:
# drug_dataloader, target_dataloader, y_dataloader = self.unified_test_dataloader
# trainer.predict(model, CombinedLoader({"drug": drug_dataloader, "target": target_dataloader, "y": y_dataloader}))