In [1]:
from transformers import AutoTokenizer
import random
import torch
from torch.utils.data import DataLoader, Dataset
import os
import re
from tqdm.notebook import tqdm
import copy



class Batch:

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            if v is not None and type(v) is not bool:
                self.kwargs[k] = v.cuda()
            # end
        # end
        
    # end

    def __call__(self):
        return self.kwargs
    # end
# end



class Collator_Base:

    def __init__(self, tokenizer, size_seq_max, need_masked=0.3):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked

        index_special_token_2_id = {k: v for k, v in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids)}

        self.id_pad = index_special_token_2_id['[PAD]']
        self.id_mask = index_special_token_2_id['[MASK]']
        self.id_cls = index_special_token_2_id['[CLS]']
        self.id_sep = index_special_token_2_id['[SEP]']
        self.id_unk = index_special_token_2_id['[UNK]']
        
        self.regex_special_token = re.compile(r'\[(PAD|MASK|CLS|SEP|EOL|UNK)\]')
        
        self.index_randtoken_start = 999
        self.index_randtoken_end = 30521
    # end

    def _preprocess(self, line):
        line = re.sub(self.regex_special_token, r'<\1>', line)
        line = re.sub(r'''('|"|`){2}''', '', line)
        line = re.sub(r'\.{2,3}', '', line)
        line = re.sub(r' {2,}', ' ', line)
        line = line.lstrip().rstrip()
        return line
    # end
    
    def _get_random_tokens(self):
        return random.randint(self.index_randtoken_start, self.index_randtoken_end)
    # end

    
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False,
                      need_masked=0):  # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        
        sequences = copy.deepcopy(sequences)
        
        id_pad = self.id_pad
        id_mask = self.id_mask

        sequences_masked_padded = []
        labels_padded = []

        for sequence in sequences:

            len_seq = len(sequence)
            label = copy.deepcopy(sequence)

            if need_masked:
                indexs_masked = list(range(1, len_seq - 1))  # 0 = cls, -1 = sep
                random.shuffle(indexs_masked)
                anchor_mask_all = round(need_masked * (len_seq - 2)) or 1
                anchor_mask_replace = int(anchor_mask_all / 2)

                if anchor_mask_replace:  # not 0
                    indexs_replaced = indexs_masked[:anchor_mask_replace]
                    for index_replaced in indexs_replaced:
                        sequence[index_replaced] = self._get_random_tokens()
                    # end
                # end

                indexs_masked = indexs_masked[anchor_mask_replace:anchor_mask_all]
            # end


            count_pad = size_seq_max - len_seq
            
            label = torch.LongTensor(label)
            label_padded = torch.cat((label, torch.LongTensor([id_pad] * count_pad)))
            labels_padded.append(label_padded)

            if need_masked:

                sequence_masked = torch.LongTensor(sequence)
                sequence_masked.index_fill_(0, torch.LongTensor(indexs_masked), id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))

                sequences_masked_padded.append(sequence_masked_padded)
            # end
        #   # end for

        inputs = torch.stack(labels_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.id_pad).unsqueeze(-2)  # (nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs  # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end


class Collator_BERT_Encoded_254(Collator_Base):

    def __call__(self, list_tokenized_merged):
        
        len_tokenized_accumulated = 2  # add cls and sep
        list_tokenized_merged = [tokenized_merged[:self.size_seq_max - len_tokenized_accumulated] for tokenized_merged in list_tokenized_merged]

        # Process III. Add begin and stop special token, same as jinyuj_transformers_quora.ipynb
        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []

        for tokenized_merged in list_tokenized_merged:
            tokens_input_encoder.append([self.id_cls] + tokenized_merged + [self.id_sep])
            tokens_input_decoder.append([self.id_cls] + tokenized_merged)
            tokens_label_decoder.append(tokenized_merged + [self.id_sep])
        # end

        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder,
                                                                                             self.size_seq_max,
                                                                                             need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max,
                                                                                need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)

        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            segments_encoder=segments_encoder,
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label
        )
    # end
# end

In [2]:
class SimpleEncodedDataset(torch.utils.data.Dataset):

    # info_file_rows = {'path_file': 1,000,000,...}
    def __init__(self, folder_dataset_base, info_file_rows, split=0.001):
        self.folder_dataset_base = folder_dataset_base
        self.list_tokenized_eval = []
        self.dict_filename_loaded = {filename: False for filename, num_rows in info_file_rows.items()}
        self.list_corpus_idx_filename_train = []

        for filename, num_lines in info_file_rows.items():
            idxs_eval = list(range(num_lines))
            random.shuffle(idxs_eval)
            idxs_eval = idxs_eval[:round(len(idxs_eval) * split)]

            for idx_eval in idxs_eval:
                self.list_tokenized_eval.append((idx_eval, filename))
            # end

            set_idxs_eval = set(idxs_eval)
            for idx_train in range(num_lines):
                if idx_train in set_idxs_eval:
                    continue
                # end

                self.list_corpus_idx_filename_train.append((idx_train, filename))
            # end
        # end

        self.is_train = True
        self.rows_cached = []
        self.filename_cached = None
    # end


    def __getitem__(self, idx):  # should not have problem now
        # if eval, use all cached eval tokenized
        if not self.is_train:
            return self.list_tokenized_eval[idx]
        # end

        # if train
        idxs_in_file, filename_current = self.list_corpus_idx_filename_train[idx]

        # if file not fully used
        if filename_current != self.filename_cached:

            # load new file
            print('switch from {} to {}'.format(self.filename_cached, filename_current))
            path_file = os.path.join(self.folder_dataset_base, filename_current)
            with open(path_file, 'r') as file:  # update rows_cached
                self.rows_cached = file.read().splitlines()
            # end

            self.filename_cached = filename_current

            if not self.dict_filename_loaded[filename_current]:
                for id_list_eval, tokenized_eval in enumerate(self.list_tokenized_eval):
                    if type(tokenized_eval) is tuple:
                        if tokenized_eval[1] == filename_current:
                            self.list_tokenized_eval[id_list_eval] = self._fransfer_one_line_to_tokenized(self.rows_cached[tokenized_eval[0]])
                        # end
                    # end
                # end
                self.dict_filename_loaded[filename_current] = True
            # end
        # end

        return self._fransfer_one_line_to_tokenized(self.rows_cached[idxs_in_file])
    # end

    def __len__(self):
        if self.is_train:
            return len(self.list_corpus_idx_filename_train)
        else:
            return len(self.list_tokenized_eval)
        # end
    # end

    def _fransfer_one_line_to_tokenized(self, str_line):
        tokenized = [int(t) for t in str_line.split(', ')]
        return tokenized
    # end

    def train(self):
        self.is_train = True
    # end

    def eval(self):
        self.is_train = False
    # end
# end

In [3]:
GPUS = [1]
torch.cuda.set_device(GPUS[0])

# source
seq_max = 256
batch_size = 2
len_dataset = 22345


tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 
collator = Collator_BERT_Encoded_254(tokenizer, seq_max)


folder_dataset = 'bookcorpus_merged_254_20'
filenames_dataset = sorted([f for f in os.listdir(folder_dataset) if f[0] != '.'], key=lambda name: int(name.split('.')[0]))
# list_size_per_file = [10000, 10000, 2345]
list_size_per_file = [20, 20, 13]

info_filename_rows = {k:v for k,v in zip(filenames_dataset, list_size_per_file)}

In [4]:
source = SimpleEncodedDataset(folder_dataset, info_filename_rows)
dataloader_train = DataLoader(source, batch_size*len(GPUS), shuffle=False, collate_fn=collator)
dataloader_eval = DataLoader(source, 1, shuffle=False, collate_fn=collator)

In [12]:
source.train()
for batch in tqdm(dataloader_train):
    info_batch = batch()
# end


  0%|          | 0/11162 [00:00<?, ?it/s]

switch from 2.encode to 0.encode
switch from 0.encode to 1.encode
switch from 1.encode to 2.encode


In [13]:
info_batch['ids_encoder']

tensor([[  101,  2002,   103, 17733,  1012,  1045,  2074,  2729,  2055, 16214,
          1012,  1045,  4737,  2055,  2017,  1012,  2823, 22225, 14901,  2007,
          2129,  2000, 19155,  2009, 14061, 23979,  2032,  1012,  2823,  1045,
          2514, 26440,  2017,  2079,  1050,   103,  1056, 21793,   103,   103,
          2115,  2925,  1012,   103,  4426,  1012,  2823,  1045,  2079,  1050,
          1005, 12593,  1012,  1045,  3112,  2222,  2196,  3815,  2000,  2505,
           103, 10272,   103,  2002,  5015,  2061,  3043, 24454,  2755,  1010,
           103,  3480,  1012,  2079,  1050,  1005,   103, 27521,  1210,  1010,
          3841,   103,  2009,   103,  1055,  2025,  2995,  6227,  2054,   103,
          1045,  2204,   103,  1010,   103, 18890,  2054, 24908,  1045,   103,
          2008,  1005,  1055,  6783, 19927, 16652,  1045,   103,  1050,  1005,
          1056,   103,  2019,   103,  1012,  2002,  2106,  1050,  1005, 17595,
          2428,  8225,   103,  7570, 27982,  2008,  