In [1]:
import os 
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


def load_dataset(path):
    def get_split_dataset(mode):
        dataset = {}
        for f in ["molecule", "protein", "y"]:
            f_path = os.path.join(path, mode + "_" + f + ".npy")
            print(f_path)
            data = np.load(f_path, allow_pickle=True)
            try:
                data = torch.tensor([d.squeeze(0).numpy() for d in data])
            except:
                data = torch.tensor(data)
            dataset[f] = data
            
        return dataset
            
    train_data = get_split_dataset("train")
    valid_data = get_split_dataset("valid")
    test_data = get_split_dataset("test")
    
    return train_data, valid_data, test_data
    
train_data, valid_data, test_data = load_dataset("data/interaction/kiba")

data/interaction/kiba/train_molecule.npy


  data = torch.tensor([d.squeeze(0).numpy() for d in data])


data/interaction/kiba/train_protein.npy
data/interaction/kiba/train_y.npy
data/interaction/kiba/valid_molecule.npy
data/interaction/kiba/valid_protein.npy
data/interaction/kiba/valid_y.npy
data/interaction/kiba/test_molecule.npy
data/interaction/kiba/test_protein.npy
data/interaction/kiba/test_y.npy


In [2]:
train_dataset = TensorDataset(train_data['molecule'], train_data['protein'], train_data['y'])
valid_dataset = TensorDataset(valid_data['molecule'], valid_data['protein'], valid_data['y'])
test_dataset = TensorDataset(test_data['molecule'], test_data['protein'], test_data['y'])
    
train_dataloader = DataLoader(train_dataset, batch_size=512, num_workers=16, 
                              shuffle=True, pin_memory=True, prefetch_factor=10, 
                              drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=512, num_workers=16, 
                              shuffle=False, pin_memory=True, prefetch_factor=10, 
                              drop_last=False)
test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=16, 
                             shuffle=False, pin_memory=True, prefetch_factor=10, 
                             drop_last=False)

In [3]:
class ConcatenateDTI(nn.Module):
    def __init__(self, molecule_dim=128, protein_dim=1024, inner_dim=512, projection=True):
        super().__init__()
        self.is_projection = projection

        if self.is_projection:
            self.mol_proj = nn.Linear(molecule_dim, inner_dim)        
            self.prot_proj = nn.Linear(protein_dim, inner_dim)            
            self.fc_1 = nn.Linear(inner_dim * 2, inner_dim)
        else:
            self.fc_1 = nn.Linear(molecule_dim + protein_dim, inner_dim)
        
        self.fc_2 = nn.Linear(inner_dim, int(inner_dim / 2))
        self.fc_out = nn.Linear(int(inner_dim / 2), 1)
   

    def forward(self, molecule, protein):
        if self.is_projection:
            molecule = self.mol_proj(molecule)
            protein = self.prot_proj(protein)
            
        x = torch.cat((molecule, protein), -1)
        x = F.dropout(F.gelu(self.fc_1(x)), 0.1)
        x = F.dropout(F.gelu(self.fc_2(x)), 0.1)
        x = self.fc_out(x)
        
        return x
        
        
concatenate_dti = ConcatenateDTI()
concatenate_dti

ConcatenateDTI(
  (mol_proj): Linear(in_features=128, out_features=512, bias=True)
  (prot_proj): Linear(in_features=1024, out_features=512, bias=True)
  (fc_1): Linear(in_features=1024, out_features=512, bias=True)
  (fc_2): Linear(in_features=512, out_features=256, bias=True)
  (fc_out): Linear(in_features=256, out_features=1, bias=True)
)

In [4]:
for batch in train_dataloader:
    y_hat = concatenate_dti(batch[0], batch[1])
    print(y_hat)
    break

tensor([[0.0136],
        [0.0205],
        [0.0256],
        [0.0158],
        [0.0172],
        [0.0151],
        [0.0186],
        [0.0225],
        [0.0138],
        [0.0127],
        [0.0152],
        [0.0235],
        [0.0188],
        [0.0077],
        [0.0251],
        [0.0230],
        [0.0149],
        [0.0187],
        [0.0127],
        [0.0133],
        [0.0125],
        [0.0238],
        [0.0221],
        [0.0179],
        [0.0170],
        [0.0096],
        [0.0205],
        [0.0196],
        [0.0221],
        [0.0146],
        [0.0198],
        [0.0214],
        [0.0227],
        [0.0151],
        [0.0235],
        [0.0135],
        [0.0177],
        [0.0206],
        [0.0177],
        [0.0152],
        [0.0139],
        [0.0239],
        [0.0178],
        [0.0151],
        [0.0162],
        [0.0163],
        [0.0112],
        [0.0142],
        [0.0130],
        [0.0222],
        [0.0216],
        [0.0251],
        [0.0133],
        [0.0142],
        [0.0149],
        [0