In [1]:
import torch
from transformers import BertTokenizerFast

fpath='data/tokenizer_model'
tokenizer = BertTokenizerFast.from_pretrained(fpath,
                                              strip_accents=False,
                                              lowercase=False)

In [2]:
class BERTLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_seq_len=256, masking_ratio=0.15, NSP_ratio=0.5):
        super(BERTLangaugeModelDataset, self).__init__()

        self.data          = data        
        self.tokenizer     = tokenizer
        self.vocab         = tokenizer.vocab
        self.max_seq_len   = max_seq_len
        self.masking_ratio = masking_ratio
        self.NSP_ratio     = NSP_ratio
        
        self.cls_token_id  = self.tokenizer.cls_token_id
        self.sep_token_id  = self.tokenizer.sep_token_id
        self.pad_token_id  = self.tokenizer.pad_token_id
        self.mask_token_id = self.tokenizer.mask_token_id
        
    def __getitem__(self, sent_1_idx):       
        sent_1 = self.tokenizer.encode(self.data[sent_1_idx])[1:-1]
        sent_2_idx = sent_1_idx + 1
        
        # NSP
        if torch.rand(1) >= self.NSP_ratio:
            sent_2 = self.tokenizer.encode(self.data[sent_1_idx + 1])[1:-1]
            is_next = torch.tensor(1)
        else:
            while sent_2_idx == sent_1_idx + 1:
                sent_2_idx = torch.randint(0, len(self.data), (1,))
            is_next = torch.tensor(0)

        sent_2 = self.tokenizer.encode(self.data[sent_2_idx])[1:-1]
        
        # if length of (sent 1 + sent 2) longer than threshold
        # CLS, SEP 1 and 2
        if len(sent_1) + len(sent_2) >= self.max_seq_len - 3:
            if len(sent_1) >= self.max_seq_len -3:
                sent_1 = sent_1[:int(self.max_seq_len/2)]
            
            sent_2 = sent_2[:self.max_seq_len - 3 - len(sent_1)]
        
        pad_length = self.max_seq_len - 3 - len(sent_1) - len(sent_2)
        target = torch.tensor([self.cls_token_id] + sent_1 + [self.sep_token_id] + sent_2 + [self.sep_token_id] + [self.pad_token_id] * pad_length).long().contiguous()        

        sengment_embedding = torch.zeros(target.size(0))
        sengment_embedding[(len(sent_1) + 2):] = 1
        
        # MLM
        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            self.masking(sent_1),
            torch.tensor([self.sep_token_id]),
            self.masking(sent_2),
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
                
        return train, target, sengment_embedding, is_next
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab
    
    
    def decode(self, x):
        return self.tokenizer.batch_decode(x)
    
    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_ratio) + 1]       
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x  

In [3]:
with open("data/petitions.txt", 'r') as f:
    data = f.readlines()

    
proced_data = [line.replace("\n", "") for line in data]
# proced_data = []
# for line in data:
#     proced_data.append(line.replace("\n", "").split(" "))

In [6]:
from tqdm.notebook import tqdm

dataset = BERTLangaugeModelDataset(data=proced_data, tokenizer=tokenizer)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)

for batch, (mlm_train, mlm_target, sengment_embedding, is_next) in tqdm(enumerate(data_loader), total=len(data_loader)):
    print(mlm_train.shape)
#     print(tokenizer.batch_decode(mlm_train))
#     print(tokenizer.batch_decode(mlm_target))
#     print(sengment_embedding)
#     print(is_next)
    break
    

  0%|          | 0/36320 [00:00<?, ?it/s]

torch.Size([5, 256])
