In [9]:
import datasets
import json
import re
import random
import torch


In [2]:
# bookcorpus = datasets.load_dataset("bookcorpus")


In [7]:
# bookcorpus_2000 = bookcorpus['train']['text'][:2000]
# print()

In [11]:
# with open('bookcorpus_2000.json', 'w+') as file:
#     file.write(json.dumps(bookcorpus_2000))
# # end

In [11]:
with oepn('bookcorpus_2000.json', 'r') as file:
    bookcorpus_2000 = json.load(file)
# end

In [42]:
class Batch:
    DEVICE = 'cuda'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            if v is not None and type(v) is not bool:
                self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end

In [43]:
class Collator_BERT:
    def __init__(self, tokenizer, size_seq_max, need_masked):
        self.tokenizer = tokenizer  # 
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked
        
        index_special_token_2_id = {k:v for k,v in zip(tokenizer.all_special_tokens,tokenizer.all_special_ids)}
        
        self.id_pad = index_special_token_2_id['[PAD]']
        self.id_mask = index_special_token_2_id['[MASK]']
        self.id_cls = index_special_token_2_id['[CLS]']
        self.id_sep = index_special_token_2_id['[SEP]']
        self.id_eol = index_special_token_2_id['[EOL]']
        self.id_unk = index_special_token_2_id['[UNK]']
        
        self.regex_special_token = re.compile(r'([PAD]|[MASK]|[CLS]|[SEP]|[EOL]|[UNK])')
    # end
    

    def __call__(self, list_sequence_batch):
        list_sequence_tokenized = self.tokenizer.batch_encode_plus(list_sequence_batch, add_special_tokens=False)['input_ids']
        
        # Process I. 
        list_list_tokenized = []
        
        # batch initialized condition
        list_tokenized_cache = []
        len_tokenized_accumulated = 1 # add cls, no sep as sep will be treated as eol
        
        while list_sequence_tokenized:
            tokenized_poped = list_sequence_tokenized.pop(0)
            len_tokenized_current = len(tokenized_poped) + 1
            
            if len_tokenized_accumulated + len_tokenized_current > self.size_seq_max:
                list_list_tokenized.append(list_tokenized_cache)
                
                # clear
                list_tokenized_cache = []
                len_tokenized_accumulated = 1
            # end
            
            len_tokenized_accumulated += len_tokenized_current
            list_tokenized_cache.append(tokenized_poped)
        # end
        
        list_list_tokenized.append(list_tokenized_cache)
        
        
        # Process II. Merge list_tokenized
        list_tokenized_merged = []
        
        for list_tokenized in list_list_tokenized:
            tokenized_merged = [token for tokenized_padded in [tokenized + [self.id_eol] for tokenized in list_tokenized] for token in tokenized_padded]
            tokenized_merged = tokenized_merged[:-1]    # remove last eol token
            list_tokenized_merged.append(tokenized_merged)
        # end
        
        
        # Process III. Add begin and stop special token, same as jinyuj_transformers_quora.ipynb
        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        
        for tokenized_merged in list_tokenized_merged:
            tokens_input_encoder.append([self.id_cls] + tokenized_merged + [self.id_sep])
            tokens_input_decoder.append([self.id_cls] + tokenized_merged)
            tokens_label_decoder.append(tokenized_merged + [self.id_sep])
        # end
        
        
        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)
        
        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            segments_encoder=segments_encoder,
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label
        )
    # end


    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.id_pad
        id_mask = self.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq-1))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))
                
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.id_pad).unsqueeze(-2)    #(nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0
    # end

    
    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end

In [66]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset

train_source = bookcorpus_2000
batch_size = 32

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({'additional_special_tokens': ['[EOL]']})

collator = Collator_BERT(tokenizer, 128, need_masked=0)
dataloader_train = DataLoader(train_source, batch_size, shuffle=False, collate_fn=collator)

In [73]:
from tqdm import tqdm

In [75]:
for batch in tqdm(dataloader_train):
    pass
# end

100%|██████████| 63/63 [00:01<00:00, 38.84it/s]


In [68]:
tokenizer.batch_decode(batch()['ids_encoder'].cpu().tolist())

["[CLS] usually, he would be tearing around the living room, playing with his toys. [EOL] but just one look at a minion sent him practically catatonic. [EOL] that had been megan's plan when she got him dressed earlier. [EOL] he'd seen the movie almost by mistake, considering he was a little young for the pg cartoon, but with older cousins, along with her brothers, mason was often exposed to things that were older. [EOL] she liked to think being surrounded by adults and older kids was one reason why he was a such a good talker for his age. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]",
 "[CLS] ` ` aren't you being a good boy?'' [EOL] she said. [EOL] mason barely acknowledged her. [EOL] instead, his baby blues remained focused on the television. [EOL] since the movie was almost over, megan knew she better slip into the bedroom and finish getting ready. [EOL] each time she looked into mason's face, she was grateful that he looked nothing like his father. [EOL] his platinum 