In [1]:
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem.QED import qed
from tdc.multi_pred import DTI

from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

import warnings
warnings.filterwarnings(action='ignore')

davis = DTI(name="Davis")
davis.convert_to_log(form='binding')
davis_split = davis.get_split()

Found local copy...
Loading...
Done!
To log space...


In [2]:
def canonicalize_smiles(df):
    for i, smiles in enumerate(df['Drug']):
        mol = Chem.MolFromSmiles(smiles)
        canonical_smiles = Chem.MolToSmiles(mol)
        
        df.loc[i, 'Drug'] = canonical_smiles
        
    return df

train_df = canonicalize_smiles(davis_split['train'])
valid_df = canonicalize_smiles(davis_split['valid'])
test_df = canonicalize_smiles(davis_split['test'])

In [3]:
from transformers import BertModel, BertTokenizer
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer

molecule_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="data/drug/tokenizer_model/vocab.json",
    pad_token="[PAD]",
    mask_token="[MASK]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    unk_token="[UNK]"
)
molecule_bert = BertModel.from_pretrained("weights/molecule_bert_pretrained", local_files_only=False)

protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protein_bert = BertModel.from_pretrained("Rostlab/prot_bert")

Some weights of the model checkpoint at weights/molecule_bert_pretrained were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert_pretrained and are newly initialized: ['bert.pooler.dense.weight', 'be

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import DataCollatorWithPadding

class DTIDataset(Dataset):
    def __init__(self, data, molecule_tokenizer, protein_tokenizer):
        self.data = data
        self.molecule_tokenizer = molecule_tokenizer
        self.protein_tokenizer = protein_tokenizer
        self.molecule_max_length = 128
        self.protein_max_length = 2048
        
    def encode(self, smiles, fasta):
        pad_len_smiles = self.molecule_max_length - len(smiles)
        pad_len_fasta = self.protein_max_length - len(fasta)
        
        smiles = smiles + '#' * pad_len_smiles
        smiles = " ".join(smiles).replace("#", "[PAD]")
        
        fasta = fasta + '#' * pad_len_fasta
        fasta = " ".join(fasta).replace("#", "[PAD]")
        
        molecule_seq = self.molecule_tokenizer.encode(smiles, max_length=self.molecule_max_length, truncation=True, return_tensors='pt')
        protein_seq = self.protein_tokenizer.encode(fasta, max_length=self.protein_max_length, truncation=True, return_tensors='pt')
        
        return molecule_seq, protein_seq
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        smiles = self.data.loc[idx, "Drug"]
        fasta = self.data.loc[idx, "Target"]
        y = torch.tensor(self.data.loc[idx, "Y"]).float()
        
        molecule_seq, protein_seq = self.encode(smiles, fasta)
        
        return molecule_seq, protein_seq, y


train_dataset = DTIDataset(train_df, molecule_tokenizer, protein_tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=16, pin_memory=True, prefetch_factor=10, drop_last=True)

valid_dataset = DTIDataset(valid_df, molecule_tokenizer, protein_tokenizer)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, num_workers=16, pin_memory=True, prefetch_factor=10)

test_dataset = DTIDataset(test_df, molecule_tokenizer, protein_tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=16, pin_memory=True, prefetch_factor=10)


In [5]:
class DTIPredictionHead(nn.Module):
    def __init__(self, molecule_encoder, protein_encoder):
        super().__init__()
        self.molecule_encoder = molecule_encoder
        self.protein_encoder = protein_encoder
        
        # model freezing
        for param in self.molecule_encoder.parameters():
            param.requires_grad = False
        
        for param in self.protein_encoder.parameters():
            param.requires_grad = False
        
        molecule_out = 128
        protein_out = 1024
        self.fc_1 = nn.Linear(molecule_out + protein_out, 512)        
        self.fc_2 = nn.Linear(512, 256)
        self.fc_out = nn.Linear(256, 1)
        
        
    def forward(self, smiles, fasta):
        smiles = smiles.squeeze(1)
        fasta = fasta.squeeze(1)
        
        smiles_vec = self.molecule_encoder(smiles).pooler_output
        fasta_vec = self.protein_encoder(fasta).pooler_output
        
        x = torch.cat((smiles_vec, fasta_vec), -1)
        x = F.dropout(F.gelu(self.fc_1(x)), 0.2)
        x = F.dropout(F.gelu(self.fc_2(x)), 0.2)
        x = self.fc_out(x)
        
        return x
        
        
dti_prediction_head = DTIPredictionHead(molecule_bert, protein_bert)

In [6]:
# for batch in valid_dataloader:
#     smiles, fasta, y = batch
    
#     print(dti_prediction_head(smiles, fasta))
#     break

In [7]:
# sample_smiles = train_df.loc[0, "Drug"]
# sample_fasta = " ".join(train_df.loc[0, "Target"])

# encoded_smiles = molecule_tokenizer.encode(sample_smiles, return_tensors='pt')
# encoded_fasta = protein_tokenizer.encode(sample_fasta, return_tensors='pt')

# dti_prediction_head(encoded_smiles, encoded_fasta)

In [8]:
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


class DTI_prediction(pl.LightningModule):
    def __init__(self, dti_prediction_head):
        super().__init__()
        self.model = dti_prediction_head
        
        self.train_accuracy = torchmetrics.MeanAbsoluteError()
        self.valid_accuracy = torchmetrics.MeanAbsoluteError()
        self.test_accuracy = torchmetrics.MeanAbsoluteError()
        
        
    def forward(self, smiles, fasta):
        return self.model(smiles, fasta)

       
    def training_step(self, batch, batch_idx):
        smiles, fasta, y = batch
        y_hat = self(smiles, fasta).squeeze(-1)
        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", self.train_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        smiles, fasta, y = batch
        y_hat = self(smiles, fasta).squeeze(-1)        
        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_accuracy", self.valid_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        smiles, fasta, y = batch
        y_hat = self(smiles, fasta).squeeze(-1)

        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_accuracy", self.valid_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        smiles, fasta, y = batch
        y_hat = self(smiles, fasta).squeeze(-1)
        
        return y_hat
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=10, dirpath='weights/DTI_prediction_CLS_token_concatenate', filename='molecule_bert-{epoch:02d}-{valid_loss:.4f}'),
]

model = DTI_prediction(dti_prediction_head)
trainer = pl.Trainer(max_epochs=100, gpus=1, enable_progress_bar=True, callbacks=callbacks, precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [9]:
# trainer.fit(model, train_dataloader, valid_dataloader)

In [13]:


checkpoint_file = "molecule_bert-epoch=08-valid_loss=0.7020.ckpt"
model.load_from_checkpoint(dti_prediction_head=dti_prediction_head, checkpoint_path="weights/DTI_prediction_CLS_token_concatenate/" + checkpoint_file)

pred = trainer.predict(model, test_dataloader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

AttributeError: 'DataLoader' object has no attribute 'data'

In [29]:
res = []

for p in pred:
    res += p.tolist()

res

[5.53125,
 5.50390625,
 5.85546875,
 4.58984375,
 5.5546875,
 5.46484375,
 5.40234375,
 5.484375,
 4.93359375,
 5.16796875,
 5.75,
 4.98828125,
 5.30078125,
 5.296875,
 5.5546875,
 5.38671875,
 5.4453125,
 4.8984375,
 4.84375,
 5.234375,
 5.77734375,
 5.6015625,
 5.22265625,
 5.5625,
 5.90625,
 5.10546875,
 5.23046875,
 4.91015625,
 5.6171875,
 5.359375,
 5.5546875,
 5.23046875,
 5.3515625,
 5.7109375,
 5.4921875,
 5.36328125,
 5.4453125,
 5.390625,
 5.6953125,
 5.671875,
 5.1875,
 5.6796875,
 5.8828125,
 5.28515625,
 5.3359375,
 5.48046875,
 5.234375,
 5.03125,
 5.80859375,
 5.671875,
 5.35546875,
 5.90234375,
 5.1015625,
 5.59765625,
 5.55859375,
 5.74609375,
 5.828125,
 5.5546875,
 4.88671875,
 5.00390625,
 5.23828125,
 5.20703125,
 5.796875,
 5.6484375,
 5.52734375,
 5.12890625,
 5.140625,
 5.59765625,
 5.53125,
 5.97265625,
 5.0703125,
 5.5234375,
 5.43359375,
 4.97265625,
 5.0546875,
 5.38671875,
 5.79296875,
 5.56640625,
 5.6015625,
 5.234375,
 5.734375,
 5.15234375,
 5.48046875

In [31]:
from sklearn import metrics as metrics

true = test_dataloader.dataset.data.Y.values

mae = metrics.mean_absolute_error(true, res).round(4)
mse = metrics.mean_squared_error(true, res).round(4)
print(f"MAE: {mae}\tMSE: {mse}")

MAE: 0.5773	MSE: 0.6834
