In [1]:
import os 
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


def load_dataset(path):
    def get_split_dataset(mode):
        dataset = {}
        for f in ["molecule", "protein", "y"]:
            f_path = os.path.join(path, mode + "_" + f + ".npy")
            print(f_path)
            data = np.load(f_path, allow_pickle=True)
            try:
                data = torch.tensor([d.squeeze(0).numpy() for d in data])
            except:
                data = torch.tensor(data)
            dataset[f] = data.float()
            
        return dataset
            
    train_data = get_split_dataset("train")
    valid_data = get_split_dataset("valid")
    test_data = get_split_dataset("test")
    
    return train_data, valid_data, test_data
    
train_data, valid_data, test_data = load_dataset("data/interaction/kiba")

data/interaction/kiba/train_molecule.npy


  data = torch.tensor([d.squeeze(0).numpy() for d in data])


data/interaction/kiba/train_protein.npy
data/interaction/kiba/train_y.npy
data/interaction/kiba/valid_molecule.npy
data/interaction/kiba/valid_protein.npy
data/interaction/kiba/valid_y.npy
data/interaction/kiba/test_molecule.npy
data/interaction/kiba/test_protein.npy
data/interaction/kiba/test_y.npy


In [6]:
train_dataset = TensorDataset(train_data['molecule'], train_data['protein'], train_data['y'])
valid_dataset = TensorDataset(valid_data['molecule'], valid_data['protein'], valid_data['y'])
test_dataset = TensorDataset(test_data['molecule'], test_data['protein'], test_data['y'])

batch_size = 64
num_workers = 16
prefetch_factor = 10

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 
                              shuffle=True, pin_memory=True, prefetch_factor=prefetch_factor, 
                              drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers, 
                              shuffle=False, pin_memory=True, prefetch_factor=prefetch_factor, 
                              drop_last=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, 
                             shuffle=False, pin_memory=True, prefetch_factor=prefetch_factor, 
                             drop_last=False)

In [17]:
class ConcatenateDTI(nn.Module):
    def __init__(self, molecule_dim=128, protein_dim=1024, inner_dim=1024, projection=True):
        super().__init__()
        self.is_projection = projection

        if self.is_projection:
            self.mol_proj = nn.Linear(molecule_dim, inner_dim)        
            self.prot_proj = nn.Linear(protein_dim, inner_dim)            
            self.fc_1 = nn.Linear(inner_dim * 2, inner_dim)
        else:
            self.fc_1 = nn.Linear(molecule_dim + protein_dim, inner_dim)
        
        self.fc_2 = nn.Linear(inner_dim, int(inner_dim / 2))
        self.fc_3 = nn.Linear(int(inner_dim / 2), int(inner_dim / 4))
        self.fc_out = nn.Linear(int(inner_dim / 4), 1)
   

    def forward(self, molecule, protein):
        if self.is_projection:
            molecule = self.mol_proj(molecule)
            protein = self.prot_proj(protein)
            
        x = torch.cat((molecule, protein), -1)
        x = F.dropout(F.gelu(self.fc_1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc_2(x)), 0.1)
        x = F.dropout(F.gelu(self.fc_3(x)), 0.1)
        x = self.fc_out(x)
        
        return x
        
        
concatenate_dti = ConcatenateDTI(projection=True)
concatenate_dti

ConcatenateDTI(
  (mol_proj): Linear(in_features=128, out_features=1024, bias=True)
  (prot_proj): Linear(in_features=1024, out_features=1024, bias=True)
  (fc_1): Linear(in_features=2048, out_features=1024, bias=True)
  (fc_2): Linear(in_features=1024, out_features=512, bias=True)
  (fc_3): Linear(in_features=512, out_features=256, bias=True)
  (fc_out): Linear(in_features=256, out_features=1, bias=True)
)

In [19]:
from torchmetrics.functional import mean_squared_error, mean_absolute_error

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.trainer.supporters import CombinedLoader


class DTI_prediction(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

        
    def forward(self, molecule, protein):
        return self.model(molecule, protein)
   
    
    def training_step(self, batch, batch_idx):
        molecule = batch[0]
        protein = batch[1]
        y = batch[2]
        
        y_hat = self(molecule, protein).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_rmse", mean_squared_error(y_hat, y, squared=False), on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

        
    def validation_step(self, batch, batch_idx):
        molecule = batch[0]
        protein = batch[1]
        y = batch[2]
        
        y_hat = self(molecule, protein).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_rmse", mean_squared_error(y_hat, y, squared=False), on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
    
    def test_step(self, batch, batch_idx):
        molecule = batch[0]
        protein = batch[1]
        y = batch[2]
        
        y_hat = self(molecule, protein).squeeze(-1)        
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_rmse", mean_squared_error(y_hat, y, squared=False), on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_mae", mean_absolute_error(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        molecule = batch[0]
        protein = batch[1]
        y = batch[2]
        
        y_hat = self(molecule, protein).squeeze(-1)        
        
        return y_hat
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=20, dirpath='weights/DTI_prediction_CLS_token_concatenate_with_projection', filename='attentional_dti-{epoch:03d}-{valid_loss:.4f}-{valid_rmse:.4f}-{valid_mae:.4f}'),
]

model = DTI_prediction(
    model=concatenate_dti
)

# remove precision 16, because prot bert trained using full precision
trainer = pl.Trainer(max_epochs=500, gpus=1, enable_progress_bar=True, callbacks=callbacks)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type           | Params
-----------------------------------------
0 | model | ConcatenateDTI | 3.9 M 
-----------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.745    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4504fbff70>
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
Exception ignored in:     if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7f4504fbff70>
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    
assert self._parent_pid == os.getpid(), 'can only test a child process': Traceback (most recent call last):

  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
AssertionErrorcan only test a child process    
self._shutdown_workers()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/p

  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'assert self._parent_pid == os.getpid(), 'can only test a child process'

AssertionErrorAssertionError: : can only test a child processcan only test a child process

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4504fbff70>
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can onl

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [6]:
# checkpoint_file = ""
# model.load_from_checkpoint(model=concatenate_dti, checkpoint_path="weights/DTI_prediction_CLS_token_concatenate/" + checkpoint_file)

# trainer.test(model, test_dataloader)