In [4]:
import os
import pickle
import pandas as pd
from tqdm import tqdm
from torchtext.legacy import data, datasets
from sklearn.model_selection import train_test_split

BATCH_SIZE = 10

with open("data/DTI_train.pickle", "rb") as f:
    df = pickle.load(f)
    
print(df)

           Drug_ID                                               Drug  \
0         11314340  CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...   
1         11314340  CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...   
2         11314340  CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...   
3         11314340  CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...   
4         11314340  CC1=C2C=C(C3=CN=CC(OCC(N)CC4=CC=CC=C4)=C3)C=CC...   
...            ...                                                ...   
138547  53358942.0  COC1=CC(C(=O)O)=CC=C1NC(=O)[C@@H]1N[C@@H](CC(C...   
138548  53476877.0  CC(C)(C)C[C@@H]1N[C@@H](C(=O)N[C@H]2CC[C@H](O)...   
138549  58573469.0  CC(C)[C@@H](CS(=O)(=O)C(C)C)N1C(=O)[C@@](C)(CC...   
138550    113557.0                        CCCCCCCOC1OC(CO)C(O)C(O)C1O   
138551    113557.0                        CCCCCCCOC1OC(CO)C(O)C(O)C1O   

       Target_ID                                             Target         Y  
0          ABL1p  PFWKILNPLLERGTYYYFMGQQPGK

In [7]:
SRC = data.Field(tokenize=None,
                 init_token='<CLS>',
                 eos_token='<SEP>',
                 pad_token='<PAD>',
                 unk_token='<MASK>',
                 lower=False,
                 batch_first=False,
                 include_lengths=False)

SRC.build_vocab(df.Target.values, min_freq=1)

In [11]:
with open("./data/DTI/protein_vocab.pickle", "wb") as f:
    pickle.dump(SRC, f)

In [8]:
SRC.vocab.itos

['<MASK>',
 '<PAD>',
 '<CLS>',
 '<SEP>',
 'L',
 'S',
 'E',
 'G',
 'A',
 'V',
 'K',
 'P',
 'R',
 'D',
 'T',
 'I',
 'Q',
 'N',
 'F',
 'Y',
 'H',
 'M',
 'C',
 'W',
 'X']

In [9]:
class ProteinDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15):
        super(ProteinDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.seq_len       = seq_len
        self.masking_rate  = masking_rate
        
        self.cls_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.init_token]
        self.sep_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.eos_token]
        self.pad_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.pad_token]
        self.mask_token_id = self.tokenizer.vocab.stoi[self.tokenizer.unk_token]
        
    def __getitem__(self, idx):
        target = self.tokenizer.numericalize(self.data[idx]).squeeze()
        
        if len(target) < self.seq_len - 2:
            pad_length = self.seq_len - len(target) - 2
        else:
            target = target[:self.seq_len-2]
            pad_length = 0
               
        masked_sent, masking_label = self.masking(target)
        
        # MLM
        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            masked_sent,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        target = torch.cat([
            torch.tensor([self.cls_token_id]), 
            target,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        masking_label = torch.cat([
            torch.zeros(1), 
            masking_label,
            torch.zeros(1),
            torch.zeros(pad_length)
        ])
                
        segment_embedding = torch.zeros(target.size(0))
        
        return train, target, segment_embedding, masking_label
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
    def masking(self, x):
        x             = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x             = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label

In [10]:
dataset = ProteinDataset(df.Target.values, SRC, seq_len=256, masking_rate=0.15)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)

for train, target, sengment_embedding, masking_label in data_loader:
    print(train)
    print(target)
#     print(sengment_embedding)
#     print(masking_label)
    break

tensor([[ 2, 11, 18,  ..., 11,  8,  3],
        [ 2, 21,  7,  ...,  0,  0,  3],
        [ 2,  0,  5,  ...,  7,  0,  3],
        ...,
        [ 2,  0,  8,  ..., 14, 18,  3],
        [ 2, 21, 12,  ..., 10, 18,  3],
        [ 2, 21,  5,  ...,  7, 16,  3]])
tensor([[ 2, 11, 18,  ..., 11,  8,  3],
        [ 2, 21,  7,  ..., 19, 13,  3],
        [ 2, 21,  5,  ...,  7,  9,  3],
        ...,
        [ 2, 21,  8,  ..., 14, 18,  3],
        [ 2, 21, 12,  ..., 10, 18,  3],
        [ 2, 21,  5,  ...,  7, 16,  3]])


  x             = torch.tensor(x).long().contiguous()
