In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

import torchtext

import pickle

import glob
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from tqdm import tqdm

with open("./data/DTI/DTI_train.pickle", "rb") as f:
    train_data = pickle.load(f)

with open("./data/DTI/DTI_valid.pickle", "rb") as f:
    valid_data = pickle.load(f)
    
with open("./data/DTI/DTI_test.pickle", "rb") as f:
    test_data = pickle.load(f)
    
train_data

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,11314340.0,CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...,ABL1p,PFWKILNPLLERGTYYYFMGQQPGKVLGDQRRPSLPALHFIKGAGK...,4.999996
1,11314340.0,CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...,ABL2,MVLGTVLLPPNSYGRDQDTSLCCLCTEASESALPDLTDHFASCVED...,4.999996
2,11314340.0,CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...,ACVR1B,MAESAGASSFFPLVVLLLAGSGGSGPRGVQALLCACTSCLQANYTC...,4.999996
3,11314340.0,CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...,ACVRL1,MTLGSPRKGLLMLLMALVTQGDPVKPSRGPLVTCTCESPHCKGPTC...,4.999996
4,11314340.0,CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...,ADCK3,MAAILGDTIMVAKGLVKLTQAAVETHLQHLGIGGELIMAARALQST...,4.999996
...,...,...,...,...,...
55922,53358942.0,COC1=CC(C(=O)O)=CC=C1NC(=O)C1NC(CC(C)(C)C)C(C#...,,MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...,9.602060
55923,53476877.0,CC(C)(C)CC1NC(C(=O)NC2CCC(O)CC2)C(C2=CC=CC(Cl)...,,MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...,8.552842
55924,58573469.0,CC(C)C(CS(=O)(=O)C(C)C)N1C(=O)C(C)(CC(=O)O)CC(...,,MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...,9.838632
55925,113557.0,CCCCCCCOC1OC(CO)C(O)C(O)C1O,P08191,MKRVITLFAVLLMGWSVNAWSFACKTANGTAIPIGGGSANVYVNLA...,7.767004


In [2]:
with open("./data/molecule_net/MoleculeNet_tokenizer.pickle", "rb") as f:
    molecule_tokenizer = pickle.load(f)
    
with open("./data/DTI/protein_tokenizer.pickle", "rb") as f:
    protein_tokenizer = pickle.load(f)

In [4]:
molecule_vocab_dim     = len(molecule_tokenizer.vocab.itos)
molecule_seq_len       = 256
molecule_embedding_dim = 512

protein_vocab_dim     = len(protein_tokenizer.vocab.itos)
protein_seq_len       = 1024
protein_embedding_dim = 512

device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size    = 128

In [5]:
class DTIDataset(torch.utils.data.Dataset):
    def __init__(self, data, molecule_tokenizer, molecule_seq_len, protein_tokenizer, protein_seq_len):
        super(DTIDataset, self).__init__()

        self.data = data
        
        self.molecule_tokenizer = molecule_tokenizer
        self.molecule_vocab = molecule_tokenizer.vocab
        self.molecule_seq_len = molecule_seq_len
        
        self.protein_tokenizer = protein_tokenizer
        self.protein_vocab = protein_tokenizer.vocab
        self.protein_seq_len = protein_seq_len
        
        self.cls_token_id  = self.molecule_vocab.stoi[self.molecule_tokenizer.init_token]
        self.sep_token_id  = self.molecule_vocab.stoi[self.molecule_tokenizer.eos_token]
        self.pad_token_id  = self.molecule_vocab.stoi[self.molecule_tokenizer.pad_token]
        self.mask_token_id = self.molecule_vocab.stoi[self.molecule_tokenizer.unk_token]
        
    def __getitem__(self, idx):
        current_data = self.data.loc[idx]
        
        molecule_string = current_data['Drug']
        protein_string = current_data['Target']
        target = current_data['Y']

        molecule = self.molecule_tokenizer.numericalize(molecule_string).squeeze()
        protein = self.protein_tokenizer.numericalize(protein_string).squeeze()
        
        if len(molecule) < self.molecule_seq_len - 2:
            molecule_pad_length = self.molecule_seq_len - len(molecule) - 2
        else:
            molecule = molecule[:self.molecule_seq_len - 2]
            molecule_pad_length = 0
            
        if len(protein) < self.protein_seq_len - 2:
            protein_pad_length = self.protein_seq_len - len(protein) - 2
        else:
            protein = protein[:self.protein_seq_len - 2]
            protein_pad_length = 0
              
        train = [torch.cat([torch.tensor([self.cls_token_id]), molecule, torch.tensor([self.sep_token_id]), torch.tensor([self.pad_token_id] * molecule_pad_length)]),
                 torch.cat([torch.tensor([self.cls_token_id]), protein, torch.tensor([self.sep_token_id]), torch.tensor([self.pad_token_id] * protein_pad_length)])]             
        
        target = torch.tensor(target).long().contiguous()

        segment_embedding = torch.zeros(molecule.size(0) + 2)

        return train, target, segment_embedding

    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [6]:
def generate_dataloader(data, molecule_tokenizer, molecule_seq_len, protein_tokenizer, protein_seq_len, batch_size, collate_fn, shuffle=True, num_workers=6):
    dataset    = DTIDataset(data=data, molecule_tokenizer=molecule_tokenizer, molecule_seq_len=molecule_seq_len, protein_tokenizer=protein_tokenizer, protein_seq_len=protein_seq_len)
    data_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return data_loader

batch_size = 1

data_loader = generate_dataloader(train_data, molecule_tokenizer, molecule_seq_len, protein_tokenizer, protein_seq_len, batch_size, collate_fn, num_workers=1)
data_loader

<torch.utils.data.dataloader.DataLoader at 0x7ff8d4151fa0>

In [7]:
for train, target, sengment_embedding in data_loader:
    print(train)
    print(target)
#     print(sengment_embedding)
#     print(masking_label)
    break

[tensor([[ 2,  4,  9,  4, 10,  5,  4,  4,  5,  4,  6,  4,  9,  4, 11,  5,  4,  4,
          5,  4,  6,  4,  4, 12,  5,  4,  8,  5,  4,  6,  8,  7,  8,  5,  4, 12,
          8,  7,  4,  5,  4, 11,  9,  4,  7,  4,  5,  4, 10,  3,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1