In [2]:
import os
import pickle
from tqdm import tqdm
from torchtext import data, datasets
from sklearn.model_selection import train_test_split

BATCH_SIZE = 10

# with open("data/chem_total.pickle", 'rb') as f:
#     data = pickle.load(f)

with open("data/molecule_small.pickle", 'rb') as f:
    data = pickle.load(f)

len(data)

10000

In [4]:
sample = data[0]
sample

'CC1=CC=C(C=C1)N2C(=CC(=C2C)C(=O)CN3CCN(CC3)CC(=O)NC4=C(C=CC=C4C)C)C'

In [5]:
import torch
import torchtext

SRC = torchtext.legacy.data.Field(tokenize=None,
                                init_token='<CLS>',
                                eos_token='<SEP>',
                                pad_token='<PAD>',
                                unk_token='<MASK>',
                                lower=False,
                                batch_first=False,
                                include_lengths=False)

SRC.build_vocab(data, min_freq=1)

In [6]:
SRC.preprocess(sample)

['CC1=CC=C(C=C1)N2C(=CC(=C2C)C(=O)CN3CCN(CC3)CC(=O)NC4=C(C=CC=C4C)C)C']

In [72]:
class MolecularLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15):
        super(MolecularLangaugeModelDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.seq_len       = seq_len
        self.masking_rate  = masking_rate
        
        self.cls_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.init_token]
        self.sep_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.eos_token]
        self.pad_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.pad_token]
        self.mask_token_id = self.tokenizer.vocab.stoi[self.tokenizer.unk_token]
        
    def __getitem__(self, idx):
        target = self.tokenizer.numericalize(self.data[idx]).squeeze()
        
        if len(target) < self.seq_len - 2:
            pad_length = self.seq_len - len(target) - 2
        else:
            target = target[:self.seq_len-2]
            pad_length = 0
               
        masked_sent, masking_label = self.masking(target)
        
        # MLM
        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            masked_sent,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        target = torch.cat([
            torch.tensor([self.cls_token_id]), 
            target,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        masking_label = torch.cat([
            torch.zeros(1), 
            masking_label,
            torch.zeros(1),
            torch.zeros(pad_length)
        ])
                
        segment_embedding = torch.zeros(target.size(0))
        
        return train, target, segment_embedding, masking_label
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x             = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x             = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label

In [74]:
dataset = MolecularLangaugeModelDataset(data, SRC, seq_len=128, masking_rate=0.15)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

for train, target, sengment_embedding, masking_label in data_loader:
    print(train)
    print(target)
#     print(sengment_embedding)
#     print(masking_label)
    break

tensor([[ 2,  4,  8, 10,  4,  6,  5,  8,  4,  5,  8, 10,  7,  4, 11,  5,  0,  0,
          6,  5,  8,  0,  5,  4, 11,  0,  8,  0,  4, 12,  5,  4,  8,  6,  8,  0,
          4, 12,  7,  4,  4, 13,  0,  4,  4,  0,  4,  4,  5,  0, 13,  3,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1]])
tensor([[ 2,  4,  8, 10,  4,  6,  5,  8,  4,  5,  8, 10,  7,  4, 11,  5,  8,  4,
          6,  5,  8,  4,  5,  4, 11,  7,  8,  4,  4, 12,  5,  4,  8,  6,  8,  5,
          4, 12,  7,  4,  4, 13,  5,  4,  4,  5,  4,  4,  5,  4, 13,  3,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         

  x             = torch.tensor(x).long().contiguous()
