In [1]:
PROJECT_NAME = "MTDTI_MSE"
LEARNING_RATE = 5e-5
PROT_MAX_LEN = 1024

import pickle
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

import torch
from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader, Dataset
from torchmetrics.functional import average_precision
from torchmetrics.functional.classification import binary_auroc

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name=f'{PROJECT_NAME}_lr-{LEARNING_RATE}_prot_{PROT_MAX_LEN}',
                           project='DLM_DTI')

from transformers import BertTokenizer, AutoModel

train_data = pd.read_csv("data/mol_trans/train_dataset.csv")
valid_data = pd.read_csv("data/mol_trans/valid_dataset.csv")
test_data = pd.read_csv("data/mol_trans/test_dataset.csv")
    
mol_tokenizer = BertTokenizer.from_pretrained("jonghyunlee/DrugLikeMoleculeBERT")
mol_encoder = AutoModel.from_pretrained("jonghyunlee/DrugLikeMoleculeBERT")

prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjonghyunlee1993[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from typing import Callable

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices: a list of indices
        num_samples: number of samples to draw
        callback_get_label: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(
        self,
        dataset,
        labels=None,
        indices=None,
        num_samples=None,
        callback_get_label=None,
    ):
        # if indices is not provided, all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) if indices is None else indices

        # define custom callback
        self.callback_get_label = dataset.data.Label

        # if num_samples is not provided, draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) if num_samples is None else num_samples

        # distribution of classes in the dataset
        df = pd.DataFrame()
        df["Label"] = self._get_labels(dataset) if labels is None else labels
        df.index = self.indices
        df = df.sort_index()

        label_to_count = df["Label"].value_counts()

        weights = 1.0 / label_to_count[df["Label"]]

        self.weights = torch.DoubleTensor(weights.to_list())

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples
    
class DTIDataset(Dataset):
    def __init__(self, data, mol_tokenizer, prot_tokenizer):
        self.data = data
        self.mol_tokenizer = mol_tokenizer
        self.prot_tokenizer = prot_tokenizer
        
    def get_mol_feature(self, smiles):
        return self.mol_tokenizer(" ".join(smiles), max_length=128, truncation=True)
    
    def get_prot_feature(self, fasta):
        return self.prot_tokenizer(" ".join(fasta), max_length=PROT_MAX_LEN, truncation=True)
    
    def __len__(self):    
        return len(self.data)
    
    def __getitem__(self, index):
        smiles = self.data.loc[index, "SMILES"]
        mol_feature = self.get_mol_feature(smiles)
        
        fasta = self.data.loc[index, "Target Sequence"]
        prot_feature = self.get_prot_feature(fasta)
        
        y = self.data.loc[index, "Label"]
        source = self.data.loc[index, "Source"]
                
        return mol_feature, prot_feature, y, source
    
def collate_batch(batch):
    mol_features, prot_features, y, source = [], [], [], []
    
    for (mol_seq, prot_seq, y_, source_) in batch:
        mol_features.append(mol_seq)
        prot_features.append(prot_seq)
        y.append(y_)
        source.append(source_)
        
    mol_features = mol_tokenizer.pad(mol_features, return_tensors="pt")
    prot_features = prot_tokenizer.pad(prot_features, return_tensors="pt")
    y = torch.tensor(y).float()
    source = torch.tensor(source)
    
    return mol_features, prot_features['input_ids'], y, source


train_dataset = DTIDataset(train_data, mol_tokenizer, prot_tokenizer)
valid_dataset = DTIDataset(valid_data, mol_tokenizer, prot_tokenizer)
test_dataset = DTIDataset(test_data, mol_tokenizer, prot_tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=16, 
                              pin_memory=True, prefetch_factor=10, drop_last=True, 
                              sampler=ImbalancedDatasetSampler(train_dataset, labels=train_dataset.data.Label),
                              collate_fn=collate_batch)

valid_dataloader = DataLoader(valid_dataset, batch_size=512, num_workers=16, 
                              pin_memory=True, prefetch_factor=10,
                              collate_fn=collate_batch)

test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=16, 
                             pin_memory=True, prefetch_factor=10,
                             collate_fn=collate_batch)

In [3]:
import torch.nn as nn
import torch.nn.functional as F


class Embedding(nn.Module):
    def __init__(self, vocab_dim, embedding_dim, dropout_rate):
        super().__init__()
        self.vocab_dim = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate = dropout_rate
        
        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def forward(self, x):
        embedding = self.embedding(x.long())
        embedding = self.dropout(embedding)
        
        return embedding
    

class Encoder(nn.Module):
    def __init__(self, vocab_dim, embedding_dim, kernel_size, dropout_rate=0.1):
        super(Encoder, self).__init__()
        self.embedding = Embedding(vocab_dim, embedding_dim, dropout_rate)
        
        self.conv1 = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=kernel_size[0], padding=1)       
        self.conv2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=kernel_size[1], padding=1)        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=kernel_size[2], padding=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.embedding(x)
        x = x.moveaxis(1, 2)

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.adaptive_max_pool1d(x, output_size=1)

        x = x.view(batch_size, -1)

        return x


class DTI(nn.Module):
    def __init__(self, mol_encoder, prot_vocab_dim, embedding_dim, 
                 hidden_dim=512, dropout_rate=0.1):
        super(DTI, self).__init__()
        prot_encoder_kernel_size = [4, 8, 12]
        self.dropout_rate = dropout_rate
        
        self.mol_encoder = mol_encoder
        self.prot_encoder = Encoder(prot_vocab_dim, embedding_dim, prot_encoder_kernel_size, dropout_rate=dropout_rate)

        self.mol_align = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, hidden_dim, bias=False)
        )
        
        self.prot_align = nn.Sequential(
            nn.LayerNorm(512),
            nn.Linear(512, hidden_dim, bias=False)
        )       
        
        self.fc1 = nn.Linear(2*hidden_dim, 2*hidden_dim)
        self.fc2 = nn.Linear(2*hidden_dim, 2*hidden_dim)
        self.fc3 = nn.Linear(2*hidden_dim, hidden_dim)

        self.out = nn.Linear(hidden_dim, 1)


    def forward(self, SMILES, target):
        mol_feature = self.mol_encoder(**SMILES).pooler_output
        mol_feature = self.mol_align(mol_feature)
        
        prot_feature = self.prot_encoder(target)
        prot_feature = self.prot_align(prot_feature)
        
        x = torch.cat((mol_feature, prot_feature), axis=1)

        x = F.dropout(F.relu(self.fc1(x)), self.dropout_rate)
        x = F.dropout(F.relu(self.fc2(x)), self.dropout_rate)
        x = F.dropout(F.relu(self.fc3(x)), self.dropout_rate)
        
        out = self.out(x)
        
        return F.tanh(out)
    
    
model = DTI(mol_encoder, prot_tokenizer.vocab_size, 128)

In [4]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    
    def step(self, batch):
        mol_feature, prot_feature, y, source = batch
        pred = self.model(mol_feature, prot_feature).squeeze(-1)
        
#         loss = F.binary_cross_entropy_with_logits(pred, y)
        loss = F.smooth_l1_loss(pred, y)
    
        auroc = binary_auroc(pred, y)
        auprc = average_precision(pred, y)
        
        return pred, source, loss, auroc, auprc, 
    
    
    def training_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc = self.step(batch)
        
        self.log('train_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc = self.step(batch)
        
        self.log('valid_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc = self.step(batch)
        
        self.log('test_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        pred, source, _, _, _ = self.step(batch)
        
        return pred, batch[2], source

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_auroc', mode="max",
                    save_top_k=5, dirpath=f'weights/{PROJECT_NAME}', filename='DTI-{epoch:03d}-{valid_loss:.4f}-{valid_auroc:.4f}-{valid_auprc:.4f}'),
]

predictor = DTI_prediction(model)
trainer = pl.Trainer(max_epochs=100, gpus=[0], enable_progress_bar=True, 
                     callbacks=callbacks, logger=wandb_logger, precision=16)

  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
# trainer.fit(predictor, train_dataloader, valid_dataloader)

In [6]:
predictor = predictor.load_from_checkpoint(
    "weights/MTDTI_MSE/DTI-epoch=099-valid_loss=0.0724-valid_auroc=0.9168-valid_auprc=0.6605.ckpt",
    model=model
)

pred_out = trainer.predict(predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

In [7]:
results = np.array([]).reshape(0, 3)

for pred, label, source in pred_out:
    line_ = np.array([pred.detach().numpy(), 
         label.detach().numpy(), 
         source.detach().numpy()]).T
    results = np.vstack([results, line_])

results = pd.DataFrame(results, columns=["pred", "y", "source"])
results.y = results.y.astype(int)
results.source = results.source.astype(int)
results.head()

Unnamed: 0,pred,y,source
0,0.012749,0,0
1,0.181885,0,0
2,0.906738,0,0
3,0.662109,0,0
4,1.0,1,0


In [8]:
def get_evaluation_metrics(df, source=0):
    from sklearn.metrics import average_precision_score, roc_auc_score
    
    def get_cindex(Y, P):
        summ = 0
        pair = 0

        for i in range(1, len(Y)):
            for j in range(0, i):
                if i is not j:
                    if(Y[i] > Y[j]):
                        pair +=1
                        summ +=  1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])

        if pair is not 0:
            return summ/pair
        else:
            return 0

    def r_squared_error(y_obs,y_pred):
        y_obs = np.array(y_obs)
        y_pred = np.array(y_pred)
        y_obs_mean = [np.mean(y_obs) for y in y_obs]
        y_pred_mean = [np.mean(y_pred) for y in y_pred]

        mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
        mult = mult * mult

        y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
        y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )

        return mult / float(y_obs_sq * y_pred_sq)

    def get_k(y_obs,y_pred):
        y_obs = np.array(y_obs)
        y_pred = np.array(y_pred)

        return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))

    def squared_error_zero(y_obs,y_pred):
        k = get_k(y_obs,y_pred)

        y_obs = np.array(y_obs)
        y_pred = np.array(y_pred)
        y_obs_mean = [np.mean(y_obs) for y in y_obs]
        upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
        down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))

        return 1 - (upp / float(down))


    def get_rm2(ys_orig, ys_line):
        r2 = r_squared_error(ys_orig, ys_line)
        r02 = squared_error_zero(ys_orig, ys_line)

        return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))
    
    
    source_df = df[df.source == source].reset_index(drop=True)
    auroc = roc_auc_score(source_df.y, source_df.pred)
    auprc = average_precision_score(source_df.y, source_df.pred)
    
    cindex = get_cindex(source_df.y, source_df.pred)
    rm2 = get_rm2(source_df.y, source_df.pred)

    
    if source == 0:
        dataset = "Davis"
    elif source == 1:
        dataset = "BindingDB"
    elif source == 2:
        dataset = "BIOSNAP"
        
    print(f"Dataset: {dataset}")
    print("AUROC\tAUPRC\tCindex\trm2")
    print(f"{auroc.round(4)}\t{auprc.round(4)}\t{cindex.round(4)}\t{rm2.round(4)}")
    print()
    
get_evaluation_metrics(results, source=0)
get_evaluation_metrics(results, source=1)
get_evaluation_metrics(results, source=2)

Dataset: Davis
AUROC	AUPRC	Cindex	rm2
0.9276	0.3886	0.9212	0.1898

Dataset: BindingDB
AUROC	AUPRC	Cindex	rm2
0.8831	0.8807	0.8869	0.3718

Dataset: BIOSNAP
AUROC	AUPRC	Cindex	rm2
0.9134	0.5933	0.913	0.3417

