In [15]:
import os
import pickle
from tqdm import tqdm
from torchtext import data, datasets
from sklearn.model_selection import train_test_split

BATCH_SIZE = 1

with open("../data/molecule_net/molecule_small.pickle", 'rb') as f:
    data = pickle.load(f)

len(data)

10000

In [16]:
sample = data[0]
sample

'CC1=CC=C(C=C1)N2C(=CC(=C2C)C(=O)CN3CCN(CC3)CC(=O)NC4=C(C=CC=C4C)C)C'

In [17]:
import torch
import torchtext

SRC = torchtext.legacy.data.Field(tokenize=None,
                                init_token='<CLS>',
                                eos_token='<SEP>',
                                pad_token='<PAD>',
                                unk_token='<MASK>',
                                lower=False,
                                batch_first=False,
                                include_lengths=False)

SRC.build_vocab(data, min_freq=1)

In [18]:
SRC.preprocess(sample)

['CC1=CC=C(C=C1)N2C(=CC(=C2C)C(=O)CN3CCN(CC3)CC(=O)NC4=C(C=CC=C4C)C)C']

In [19]:
SRC.vocab.itos

['<MASK>',
 '<PAD>',
 '<CLS>',
 '<SEP>',
 'C',
 '=',
 '(',
 ')',
 'N',
 'O',
 '1',
 '2',
 '3',
 '4',
 'S',
 'F',
 'l',
 '5',
 '[',
 ']',
 '-',
 '+',
 '#',
 'B',
 'r',
 '6',
 'I',
 '7',
 'H',
 '8']

In [48]:
class MolecularLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15):
        super(MolecularLangaugeModelDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.seq_len       = seq_len
        self.masking_rate  = masking_rate
        
        self.cls_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.init_token]
        self.sep_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.eos_token]
        self.pad_token_id  = self.tokenizer.vocab.stoi[self.tokenizer.pad_token]
        self.mask_token_id = self.tokenizer.vocab.stoi[self.tokenizer.unk_token]
        
    def __getitem__(self, idx):
        target = self.tokenizer.numericalize(self.data[idx]).squeeze()
        
        if len(target) < self.seq_len - 2:
            pad_length = self.seq_len - len(target) - 2
        else:
            target = target[:self.seq_len-2]
            pad_length = 0
               
        masked_sent, masking_label = self.masking(target)
        
        # MLM

        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            masked_sent,
            torch.tensor([self.sep_token_id])
        ]).long().contiguous()

        target = torch.cat([
            torch.tensor([self.cls_token_id]), 
            target,
            torch.tensor([self.sep_token_id])
        ]).long().contiguous()

        masking_label = torch.cat([
            torch.zeros(1), 
            masking_label,
            torch.zeros(1)
        ])

        segment_embedding = torch.zeros(target.size(0))
        
        return train, target, segment_embedding, masking_label
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x             = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x             = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label
    
    
def collate_fn(batch):
    data = [item[0] for item in batch]
#     data = pack_sequence(data, enforce_sorted=False)
    
#     target = [item[1] for item in batch]
#     segment_embedding = [item[2] for item in batch]
#     masking_label = [item[3] for item in batch]  
    
    return [data, target, segment_embedding, masking_label]

In [53]:
from torch.nn.utils.rnn import pack_sequence

dataset = MolecularLangaugeModelDataset(data, SRC, seq_len=128, masking_rate=0.15)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)

for train, target, segment_embedding, masking_label in data_loader:
    print(train)
    print(target)
    print(sengment_embedding)
    print(masking_label)
    break

PackedSequence(data=tensor([ 2,  2,  4,  4,  9,  4,  4,  8, 10,  6,  5,  4,  4,  4,  4,  7,  0, 14,
         4,  6,  0,  0,  4,  9,  0,  7,  4,  0, 10,  5,  7,  0,  4,  0,  8,  0,
        11,  4,  4,  4,  4,  9,  8,  0,  0, 10,  4,  5,  4,  4, 11,  4,  7,  5,
         0,  4,  4,  6,  6,  4,  5,  5,  9,  4,  7, 10,  8,  7,  4,  4,  4, 16,
         0,  3,  5,  4,  4,  6,  5,  0,  6,  4,  6,  5,  4, 12,  0,  9,  4,  7,
         9,  4,  0,  9,  0,  3]), batch_sizes=tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), sorted_indices=tensor([0, 1]), unsorted_indices=tensor([0, 1]))
[tensor([ 2,  4,  9,  4, 10,  5,  4,  4,  5,  4,  6,  4,  5,  4, 10,  7,  4,  8,
        11,  4,  4,  8,  6,  4,  4, 11,  7,  4,  4,  6,  5,  9,  7,  8,  4,  4,
        12,  5,  4,  4,  6,  5,  4,  6,  4,  6,  5,  4, 12,  7,  9,  4,  7,  9,
         4,  7,  9,  4,  3

  x             = torch.tensor(x).long().contiguous()


NameError: name 'sengment_embedding' is not defined

In [22]:
def decode(x, tokenizer):
    results = []
    for line in x:
        decoded = ""
        for s in line:
            decoded += tokenizer.vocab.itos[s]
        results.append(decoded)
        
    return results

decode(target, SRC)

['<CLS>CC1=CC(=C(N1)C)C(=O)CSC2=NN=C(N2C)C3=CC=CC=C3Cl<SEP>',
 '<CLS>CCCCN(C1=C(N(C(=O)NC1=O)CCC)N)C(=O)C2=CN(N=N2)CC3=CC=CC=C3<SEP>',
 '<CLS>CC1=C(SC=N1)CCC(=O)NCC2CCN(C2)C3=CC(=O)N(N=C3)C<SEP>',
 '<CLS>COC(=O)C1=C(OC=C1)COC(=O)CCC2=CNC3=CC=CC=C32<SEP>',
 '<CLS>COC1=C(C=C(C=C1)F)C(=O)NC2=C(C(=C(C=C2)F)F)F<SEP>',
 '<CLS>CC(CC1=NC=CN=C1)NC2CCN(CC2)C3=CC=CC(=C3)C4=CSC=N4<SEP>',
 '<CLS>COC1=CC(=C(C=C1NCC(=O)NCCC2=CC=C(C=C2)F)OC)Cl<SEP>',
 '<CLS>C1CC(CN(C1)C(=O)C2=CC=CC3=CC=CC=C32)C(=O)N4CCNC(=O)C4<SEP>',
 '<CLS>CC1=CC(=O)C(=NN1C2=CC=CC=C2F)C(=O)N3CCNC(=O)C3<SEP>',
 '<CLS>COC1=CC=CC(=C1OC)C(=O)NC2CCCC(C2)C(F)(F)F<SEP>']