In [1]:
import pickle
import pandas as pd
import numpy as np

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader, Dataset
from torchmetrics.functional import average_precision
from torchmetrics.functional.classification import binary_auroc

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

def scale_y_value(train_data):
    pkd_scaler = MinMaxScaler()
    train_data.loc[train_data.Source_ID.isin([0, 1]), "Y"] = pkd_scaler.fit_transform(train_data.loc[train_data.Source_ID.isin([0, 1]), "Y"].values.reshape(-1, 1))

    aug_scaler = MinMaxScaler()
    train_data.loc[train_data.Source_ID.isin([2]), "Y"] = aug_scaler.fit_transform(train_data.loc[train_data.Source_ID.isin([2]), "Y"].values.reshape(-1, 1))

    train_data.Y = train_data.Y.fillna(0)
    
    return train_data, pkd_scaler, aug_scaler

def apply_scaler(test_data, pkd_scaler, aug_scaler):
    test_data.loc[test_data.Source_ID.isin([0, 1]), "Y"] = pkd_scaler.transform(test_data.loc[test_data.Source_ID.isin([0, 1]), "Y"].values.reshape(-1, 1))
    test_data.loc[test_data.Source_ID.isin([2]), "Y"] = aug_scaler.transform(test_data.loc[test_data.Source_ID.isin([2]), "Y"].values.reshape(-1, 1))
    
    return test_data

with open("data/fold_number_0_train.pkl", "rb") as f:
    train_data = pickle.load(f)

train_data, valid_data = train_test_split(train_data, stratify=train_data["Y_label"], test_size=0.1, random_state=42)
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)

train_data, pkd_scaler, aug_scaler = scale_y_value(train_data)
valid_data = apply_scaler(valid_data, pkd_scaler, aug_scaler)

with open("data/fold_number_0_test.pkl", "rb") as f:
    test_data = pickle.load(f)

test_data = test_data.reset_index(drop=True)
    
with open("data/mols_cls.pkl", "rb") as f:
    mols_embedding = pickle.load(f)
    
with open("data/prots_cls.pkl", "rb") as f:
    prots_embedding = pickle.load(f)

In [2]:
from typing import Callable

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices: a list of indices
        num_samples: number of samples to draw
        callback_get_label: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(
        self,
        dataset,
        labels=None,
        indices=None,
        num_samples=None,
        callback_get_label=None,
    ):
        # if indices is not provided, all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) if indices is None else indices

        # define custom callback
        self.callback_get_label = dataset.data.Y_label

        # if num_samples is not provided, draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) if num_samples is None else num_samples

        # distribution of classes in the dataset
        df = pd.DataFrame()
        df["label"] = self._get_labels(dataset) if labels is None else labels
        df.index = self.indices
        df = df.sort_index()

        label_to_count = df["label"].value_counts()

        weights = 1.0 / label_to_count[df["label"]]

        self.weights = torch.DoubleTensor(weights.to_list())

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples
    

class DTIDataset(Dataset):
    def __init__(self, data, mols_embedding, prots_embedding):
        self.data = data
        self.mols_embedding = mols_embedding
        self.prots_embedding = prots_embedding
        
    def get_mol_feature(self, mol_id):
        return self.mols_embedding[mol_id]
    
    def get_prot_feature(self, prot_id):
        return self.prots_embedding[prot_id]
    
    def __len__(self):    
        return len(self.data)
    
    def __getitem__(self, index):
        mol_id = self.data.loc[index, "Drug_ID"]
        mol_feature = self.get_mol_feature(mol_id).squeeze(0)
        
        prot_id = self.data.loc[index, "Target_ID"]
        prot_feature = self.get_prot_feature(prot_id).squeeze(0)
        
        source_id = self.data.loc[index, "Source_ID"]
        y_cls = torch.tensor(self.data.loc[index, "Y_label"]).float()
        
        if source_id in [0, 1]: # davis, bidning_db
            y_pkd = torch.tensor(self.data.loc[index, "Y"]).float()
            y_aug = torch.tensor(1e-10).float()
        elif source_id == 2: # kiba
            y_pkd = torch.tensor(1e-10).float()
            y_aug = torch.tensor(self.data.loc[index, "Y"]).float()
        else:
            y_pkd = torch.tensor(1e-10).float()
            y_aug = torch.tensor(1e-10).float()
        
        return mol_feature, prot_feature, y_cls, y_pkd, y_aug

train_dataset = DTIDataset(train_data, mols_embedding, prots_embedding)
valid_dataset = DTIDataset(valid_data, mols_embedding, prots_embedding)

train_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=16, 
                              pin_memory=True, prefetch_factor=10, drop_last=True, 
                              sampler=ImbalancedDatasetSampler(train_dataset, labels=train_dataset.data.Y_label))

valid_dataloader = DataLoader(valid_dataset, batch_size=128, num_workers=16, 
                              pin_memory=True, prefetch_factor=10)

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class DTI(nn.Module):
    def __init__(self, hidden_dim=512, mol_dim=128, prot_dim=1024):
        super().__init__()
        
        self.molecule_align = nn.Sequential(
            nn.LayerNorm(mol_dim),
            nn.Linear(mol_dim, hidden_dim, bias=False)
        )
        
        self.protein_align = nn.Sequential(
            nn.LayerNorm(prot_dim),
            nn.Linear(prot_dim, hidden_dim, bias=False)
        )
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 4)
        self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        
        self.cls_out = nn.Linear(hidden_dim, 1)
        self.pkd_out = nn.Linear(hidden_dim, 1)
        self.aug_out = nn.Linear(hidden_dim, 1)
    
    
    def forward(self, mol_feature, prot_feature):
        mol_feature = self.molecule_align(mol_feature)
        prot_feature = self.protein_align(prot_feature)

        x = torch.cat([mol_feature, prot_feature], dim=1)

        x = F.dropout(F.gelu(self.fc1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc2(x)), 0.1)
        x = F.dropout(F.gelu(self.fc3(x)), 0.1)
        
        cls_out = self.cls_out(x).squeeze(-1)
        pkd_out = self.pkd_out(x).squeeze(-1)
        aug_out = self.aug_out(x).squeeze(-1)
        
        return cls_out, pkd_out, aug_out
    
model = DTI(hidden_dim=1024, mol_dim=128, prot_dim=1024)

In [4]:
class DTI_prediction(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.lambda_cls = 0.8
        self.lambda_pkd = 0.1
        self.lambda_aug = 0.1
    
    
    def step(self, batch):
        mol_feature, prot_feature, y_cls, y_pkd, y_aug = batch
        pred_cls, pred_pkd, pred_aug = self.model(mol_feature, prot_feature)
        
        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, y_cls)
        
        valid_pkd = y_pkd.gt(1e-2)
        loss_pkd = F.smooth_l1_loss(torch.masked_select(pred_pkd, valid_pkd), torch.masked_select(y_pkd, valid_pkd))

        valid_aug = y_aug.gt(1e-2)
        loss_aug = F.smooth_l1_loss(torch.masked_select(pred_aug, valid_aug), torch.masked_select(y_aug, valid_aug))
        
        total_loss = self.lambda_cls * loss_cls + self.lambda_pkd * loss_pkd + self.lambda_aug * loss_aug
    
        auroc = binary_auroc(pred_cls, y_cls)
        auprc = average_precision(pred_cls, y_cls)
        
        return pred_cls, pred_pkd, pred_aug, total_loss, auroc, auprc
    
    
    def training_step(self, batch, batch_idx):
        _, _, _, total_loss, auroc, auprc = self.step(batch)
        
        self.log('train_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return total_loss
    
    
    def validation_step(self, batch, batch_idx):
        _, _, _, total_loss, auroc, auprc = self.step(batch)
        
        self.log('valid_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        _, _, _, total_loss, auroc, auprc = self.step(batch)
        
        self.log('test_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        pred_cls, pred_pkd, pred_aug, _, _ = self.step(batch)
        
        return pred_cls, pred_pkd, pred_aug

    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=5e-5)
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=1, dirpath='weights/cls_concat', filename='DTI-{epoch:03d}-{valid_loss:.4f}-{valid_auroc:.4f}-{valid_auprc:.4f}'),
]

predictor = DTI_prediction(model)
trainer = pl.Trainer(max_epochs=300, gpus=[0], enable_progress_bar=True, callbacks=callbacks)

  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
trainer.fit(predictor, train_dataloader, valid_dataloader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | DTI  | 20.1 M
-------------------------------
20.1 M    Trainable params
0         Non-trainable params
20.1 M    Total params
80.266    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [8]:
test_data_davis = test_data[test_data.Source_ID==0].reset_index(drop=True)
test_dataset_davis = DTIDataset(test_data_davis, mols_embedding, prots_embedding)
test_dataloader_davis = DataLoader(test_dataset_davis, batch_size=512, num_workers=16, 
                             pin_memory=True, prefetch_factor=10)

trainer.test(predictor, test_dataloader_davis)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc                   nan
       test_auroc           0.8010944724082947
        test_loss                   nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_auroc': 0.8010944724082947, 'test_auprc': nan, 'test_loss': nan}]

In [9]:
test_data_binding = test_data[test_data.Source_ID==1].reset_index(drop=True)
test_dataset_binding = DTIDataset(test_data_binding, mols_embedding, prots_embedding)
test_dataloader_binding = DataLoader(test_dataset_binding, batch_size=512, num_workers=16, 
                             pin_memory=True, prefetch_factor=10)

trainer.test(predictor, test_dataloader_binding)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc           0.6288596391677856
       test_auroc           0.8236401081085205
        test_loss                   nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_auroc': 0.8236401081085205,
  'test_auprc': 0.6288596391677856,
  'test_loss': nan}]

In [10]:
test_data_kiba = test_data[test_data.Source_ID==2].reset_index(drop=True)
test_dataset_kiba = DTIDataset(test_data_kiba, mols_embedding, prots_embedding)
test_dataloader_kiba = DataLoader(test_dataset_kiba, batch_size=512, num_workers=16, 
                             pin_memory=True, prefetch_factor=10)

trainer.test(predictor, test_dataloader_kiba)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc           0.5820268392562866
       test_auroc           0.8010764718055725
        test_loss                   nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_auroc': 0.8010764718055725,
  'test_auprc': 0.5820268392562866,
  'test_loss': nan}]

In [11]:
test_data_biosnap = test_data[test_data.Source_ID==2].reset_index(drop=True)
test_dataset_biosnap = DTIDataset(test_data_biosnap, mols_embedding, prots_embedding)
test_dataloader_biosnap = DataLoader(test_dataset_biosnap, batch_size=512, num_workers=16, 
                             pin_memory=True, prefetch_factor=10)

trainer.test(predictor, test_dataloader_biosnap)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc           0.5816890597343445
       test_auroc           0.8016542196273804
        test_loss                   nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_auroc': 0.8016542196273804,
  'test_auprc': 0.5816890597343445,
  'test_loss': nan}]