filename: refactor_span_level_data_collator_v3.ipynb

In [1]:
%load_ext line_profiler
%load_ext memory_profiler

In [2]:
import sys
sys.path.append('../../thai2transformers')

In [1]:
import torch
import transformers
from transformers import AutoTokenizer , DataCollatorForLanguageModeling
import math
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
import numpy as np
import pandas as pd
import torch
import random
from bisect import bisect
from transformers.data.data_collator import DataCollatorForLanguageModeling, _collate_batch, tolist
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase

import glob, os
from torch.utils.data.dataloader import DataLoader

from torch.utils.data.sampler import RandomSampler, SequentialSampler


from thai2transformers.datasets import MLMDataset

In [2]:
transformers.__version__

'4.6.1'

In [3]:
tokenizer = AutoTokenizer.from_pretrained('airesearchth/wangchanberta-base-wiki-20210520-spm')

In [4]:
tokenizer

PreTrainedTokenizerFast(name_or_path='airesearchth/wangchanberta-base-wiki-20210520-spm', vocab_size=24005, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True), 'additional_special_tokens': ['<s>NOTUSED', '</s>NOTUSED', '▁']})

In [5]:
TRAIN_DATA_PATH = '../../dataset/split/thwiki-for-ddp_6.11.2020/train'

In [6]:
TRAIN_LARGE_DATA_PATH = '../../dataset/split/thwiki-for-ddp_concat_12.11.2020/train'

In [7]:
!head -2 $TRAIN_LARGE_DATA_PATH/train.txt


ม้าลายเบอร์เชลล์<_>เป็นม้าลายชนิดย่อยหรือสปีชีส์ย่อยของม้าลายธรรมดา<_>("E.<_>quagga")<_>ชนิดหนึ่ง
ม้าลายเบอร์เชลล์<_>เป็นม้าลายที่แพร่กระจายพันธุ์ในทวีปแอฟริกาตอนใต้<_>เช่น<_>บอตสวานา,<_>สวาซิแลนด์,<_>แอฟริกาใต้<_>เป็นต้น<_>เป็นม้าลายที่มีลายขนาดใหญ่สีดำพาดยาวสลับกับลายสีขาวจากหลังลงไปทั้งสองข้างของลำตัวจนถึงใต้ท้อง<_>มีพฤติกรรมและลักษณะนิสัยคล้ายกับม้าลายธรรมดาทั่วไป


In [8]:
%%time 

train_dataset = MLMDataset(tokenizer,
                           TRAIN_DATA_PATH,
                           510)


[INFO] Build features (parallel).

[INFO] Start groupping results.
[INFO] Done.
CPU times: user 326 ms, sys: 112 ms, total: 437 ms
Wall time: 2.36 s


In [128]:
%%time 

train_dataset_large = MLMDataset(tokenizer,
                           TRAIN_LARGE_DATA_PATH,
                           100)


[INFO] Build features (parallel).


Process ForkPoolWorker-8:
Process ForkPoolWorker-9:
Process ForkPoolWorker-10:
Process ForkPoolWorker-12:
Process ForkPoolWorker-11:


KeyboardInterrupt: 

In [9]:

SPECIAL_TOKEN_NAMES = ['bos_token', 'eos_token', 'sep_token', 'cls_token', 'pad_token']

@dataclass
class DataCollatorForSpanLevelMask(DataCollatorForLanguageModeling):
    """
    Data collator used for span-level masked language modeling
     
    adapted from NGramMaskGenerator class
    
    https://github.com/microsoft/DeBERTa/blob/11fa20141d9700ba2272b38f2d5fce33d981438b/DeBERTa/apps/tasks/mlm_task.py#L36
    and
    https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py

    """
    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    max_gram: int = 3
    keep_prob: float = 0.0
    mask_prob: float = 1.0
    max_preds_per_seq: int = None
    max_seq_len: int = 510

    def __init__(self, tokenizer, mlm=True, mlm_probability=0.15, *args, **kwargs):
        super().__init__(tokenizer, mlm=mlm, mlm_probability=mlm_probability)

        assert self.mask_prob + self.keep_prob <= 1, \
            f'The prob of using [MASK]({self.mask_prob}) and the prob of using original token({self.keep_prob}) should between [0,1]'

        if self.max_preds_per_seq is None:
            self.max_preds_per_seq = math.ceil(self.max_seq_len * self.mlm_probability / 10) * 10
            self.mask_window = int(1 / self.mlm_probability) # make ngrams per window sized context
        self.vocab_words = list(self.tokenizer.get_vocab().keys())
        self.vocab_mapping = self.tokenizer.get_vocab()
        
        self.special_tokens = [self.tokenizer.special_tokens_map[name] for name in  SPECIAL_TOKEN_NAMES]
#         print(' self.special_tokens', self.special_tokens)
        self.ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
        _pvals = 1. / np.arange(1, self.max_gram + 1)
        self.pvals = _pvals / _pvals.sum(keepdims=True)
        print('max_gram', self.max_gram)
    def _choice(self, rng, data, p):
        cul = np.cumsum(p)
        x = rng.random()*cul[-1]
        id = bisect(cul, x)
        return data[id]

    def _per_token_mask(self, idx, tokens, rng, mask_prob, keep_prob):
        label = tokens[idx]
        mask = self.tokenizer.mask_token
        rand = rng.random()
        if rand < mask_prob:
            new_label = mask
        elif rand < mask_prob + keep_prob:
            new_label = label
        else:
            new_label = rng.choice(self.vocab_words)

        tokens[idx] = new_label

        return label

    def _mask_tokens(self, tokens: List[str], rng=random, **kwargs):

        indices = [i for i in range(len(tokens)) if tokens[i] not in self.special_tokens]
#         print('debug: indices to be able to be masked', indices)
        
        unigrams = [ [idx] for idx in indices ]
        num_to_predict = min(self.max_preds_per_seq, max(1, int(round(len(tokens) * self.mlm_probability))))
           
        offset = 0
        mask_grams = np.array([False]*len(unigrams))
        while offset < len(unigrams):
            n = self._choice(rng, self.ngrams, p=self.pvals)
            ctx_size = min(n * self.mask_window, len(unigrams)-offset)
            m = rng.randint(0, ctx_size-1)
            s = offset + m
            e = min(offset + m + n, len(unigrams))
            offset = max(offset+ctx_size, e)
            mask_grams[s:e] = True

        target_labels = [None]*len(tokens)
        w_cnt = 0
        for m,word in zip(mask_grams, unigrams):
            if m:
                for idx in word:
                    label = self._per_token_mask(idx, tokens, rng, self.mask_prob, self.keep_prob)
                    target_labels[idx] = label
                    w_cnt += 1
                if w_cnt >= num_to_predict:
                    break

        target_labels = [self.vocab_mapping[x] if x else -100 for x in target_labels]
        return tokens, target_labels    


    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = []
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probabilityability`)
        # probability_matrix = torch.full(labels.shape, self.mlm_probabilityability)
        # if special_tokens_mask is None:
        #     special_tokens_mask = [
        #         self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        #     ]
        #     special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        # else:
        #     special_tokens_mask = special_tokens_mask.bool()

#         print('inputs', inputs.shape, inputs)
        inputs_masked = []
        
        for i, input in enumerate(inputs):
#             print('input',input)
            input_tokens = self.tokenizer.convert_ids_to_tokens(input)
            

            input_masked, _labels = self._mask_tokens(input_tokens)
#             print('DEBUG: input_masked', input_masked)
            input_masked_ids = self.tokenizer.convert_tokens_to_ids(input_masked)
            inputs_masked.append(input_masked_ids)
#             print('_labels, ', _labels)
#             print('inputs_masked, ', input_masked_ids)
            labels.append(_labels)
      
        return inputs_masked, labels


In [10]:
train_sampler = SequentialSampler(train_dataset)

In [11]:
BZ=32

data_collator_span_mlm =  DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              keep_prob=0.0,
                                              mask_prob=1.0,
                                              max_seq_len=510,
                                              pad_to_multiple_of=8)

data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

data_collator_subword_mlm =  DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              pad_to_multiple_of=8)

data_loader_subword_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=data_collator_subword_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

max_gram 3


In [12]:
%prun next(iter(data_loader_span_mlm))

 

         63089 function calls (63086 primitive calls) in 0.073 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     8960    0.019    0.000    0.019    0.000 tensor.py:468(<lambda>)
       32    0.018    0.001    0.052    0.002 tokenization_utils_fast.py:275(convert_ids_to_tokens)
     8928    0.014    0.000    0.014    0.000 {method 'id_to_token' of 'tokenizers.Tokenizer' objects}
       33    0.003    0.000    0.009    0.000 tokenization_utils_fast.py:220(convert_tokens_to_ids)
     8929    0.003    0.000    0.003    0.000 {method 'token_to_id' of 'tokenizers.Tokenizer' objects}
       32    0.003    0.000    0.010    0.000 <ipython-input-9-1161ad0796fc>:63(_mask_tokens)
     8929    0.002    0.000    0.005    0.000 tokenization_utils_fast.py:242(_convert_token_to_id_with_added_voc)
    17952    0.002    0.000    0.002    0.000 {method 'append' of 'list' objects}
       32    0.001    0.000    0.001    0.000 <ipython-input

### Improved V3 Data Collator

### Test lab zone 👩🏻‍🔬 

<br>

Idea 3: k-time masking projection


``` 
matrix = [[ . . . . . . ]
          [ . . . . . . ]
          [ . . . . . . ]
          [ . . . . . . ]
          [ . . . . . . ]]
```

Project 1-subword mask

Suppose that MLM probability is 15% and the total tokens is 30, then project 1-subword mask 15% (Bernoulli distribution) or 4.5 tokens will be masked.

``` 
matrix = [[ . 1 . . . . ]
          [ . . 1 . . . ]
          [ . . . . . 1 ]
          [ . 1 . . . . ]
          [ . . . 1 . . ]]
```


In [16]:
pvals =  [0.5455, 0.2727, 0.1818]

mlm_probability = .2
probability_matrix = torch.full((5,8), mlm_probability * pvals[0])
print('\nprobability_matrix:\n', probability_matrix)
_masked_indices = torch.bernoulli(probability_matrix).bool()
print('\masked_indices:\n', _masked_indices)

base_indices = (_masked_indices == True).nonzero(as_tuple=True)
print('base_indices', base_indices)


probability_matrix:
 tensor([[0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091],
        [0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091],
        [0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091],
        [0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091],
        [0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091, 0.1091]])
\masked_indices:
 tensor([[ True, False, False, False, False,  True, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False,  True, False, False,  True],
        [False, False, False, False, False, False, False, False]])
base_indices (tensor([0, 0, 3, 3]), tensor([0, 5, 4, 7]))


In [17]:

# 2-subword
second_probabilty_matrix = torch.full((5,8), mlm_probability * pvals[1])
print('second_probabilty_matrix', second_probabilty_matrix)

second_masked_indices = torch.bernoulli(second_probabilty_matrix).bool()
print('second_masked_indices', second_masked_indices)
second_indices = (second_masked_indices == True).nonzero(as_tuple=True)

print('second_indices', second_indices)


second_probabilty_matrix tensor([[0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545]])
second_masked_indices tensor([[False,  True, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False,  True, False],
        [False, False, False, False, False, False, False, False]])
second_indices (tensor([0, 3]), tensor([1, 6]))


In [21]:

third_probabilty_matrix = torch.full((5,8), mlm_probability * pvals[2])
print('third_probabilty_matrix', third_probabilty_matrix)

third_masked_indices = torch.bernoulli(third_probabilty_matrix).bool()
print('third_masked_indices', third_masked_indices)
third_indices = (third_masked_indices == True).nonzero(as_tuple=False)

print('third_indices', third_indices)


third_probabilty_matrix tensor([[0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364],
        [0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364],
        [0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364],
        [0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364],
        [0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364, 0.0364]])
third_masked_indices tensor([[False, False, False, False, False, False, False, False],
        [False, False, False, False, False,  True, False, False],
        [ True, False, False, False, False, False, False, False],
        [False, False, False, False, False, False,  True, False],
        [False, False, False, False, False, False, False, False]])
third_indices tensor([[1, 5],
        [2, 0],
        [3, 6]])


__Wrapping up in 1 for loop__

In [101]:

def mask_tokens(inputs: torch.Tensor,
                special_tokens_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    
    labels = inputs.clone()
    labels_to_be_mask = torch.full(inputs.shape, 0.).bool()
    pvals =  [0.5455, 0.2727, 0.1818]
    K = len(pvals)
    mask_indices_by_span_len = [[] for i in range(K)]
    
    mlm_probability = .1
    
    for k in range(0, K):
       
            # 2-subword
        print(f'\n\n{k+1}-subword masking')
        
        probability_matrix = torch.full(inputs.shape, mlm_probability * pvals[k] / (k+1))
        
        
        print(f'\n{k+1}-subword masking probabilty_matrix', probability_matrix)

        _masked_indices = torch.bernoulli(probability_matrix).bool()
        print(f'\n{k+1}-subword masked_indices', _masked_indices)
        _indices_selected = (_masked_indices == True).nonzero(as_tuple=False)

        
        print(f'\n{k+1}-subword _indices_selected: ', _indices_selected)
        if _indices_selected == None:
            mask_indices_by_span_len[k] = []
        else:
            mask_indices_by_span_len[k] = _indices_selected
        
        mask_indices_by_span_len[k] = _indices_selected
    
    accum_indices = [[],[]]
    max_seq_len = inputs.shape[1] - 1

    for k in range(0, K):

        list_of_indices = mask_indices_by_span_len[k]

        if list_of_indices.shape == (0,):
            continue
        else:
            for j in range(k+1):
                max_indices = torch.full((list_of_indices.shape[0],), max_seq_len, dtype=torch.long)
                left, right = (list_of_indices[:, 0], \
                               torch.min(list_of_indices[:, 1] + j, max_indices))

                accum_indices[0].append(left)
                accum_indices[1].append(right)
    accum_indices[0] = list(filter(lambda x: x.shape != (0,), accum_indices[0]))
    accum_indices[1] = list(filter(lambda x: x.shape != (0,), accum_indices[1]))
    print('accum_indices', accum_indices)
    if len(accum_indices[0]) != 0: 
        accum_indices_flatten  = (torch.cat(accum_indices[0]), torch.cat(accum_indices[1]))
        labels_to_be_mask.index_put_(accum_indices_flatten, torch.tensor([1.]).bool())
#         labels_to_be_mask.masked_fill_(special_tokens_mask, value=0.0).bool()

    inputs[labels_to_be_mask] = 24004
    labels[~labels_to_be_mask] = -100  # We only compute loss on masked token

    return inputs, labels

In [102]:
inputs = (torch.rand((8,10))*100).long() 
inputs

tensor([[ 2,  8, 77, 62, 38,  7,  6,  4, 19, 17],
        [98, 15, 16, 36, 94, 51, 94, 59, 45, 92],
        [21, 49, 29, 29, 74, 17, 42, 75, 79, 97],
        [30, 46, 91,  1, 79, 92, 28, 69, 88, 85],
        [53, 41, 35, 86, 29, 93, 46, 15, 56, 90],
        [74,  0,  5, 41,  5, 57, 77, 24, 74, 91],
        [46, 41,  4, 94, 12, 43,  0, 14, 62, 63],
        [92, 91, 17, 54, 46, 58, 95, 39, 37, 92]])

In [104]:
mask_tokens(inputs)



1-subword masking

1-subword masking probabilty_matrix tensor([[0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545],
        [0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545, 0.0545,
         0.0545]])

1-subword masked_indices tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, 

(tensor([[    2,     8,    77,    62, 24004, 24004,     6,     4,    19,    17],
         [   98,    15,    16,    36,    94,    51,    94,    59,    45,    92],
         [   21,    49,    29,    29,    74,    17,    42,    75,    79,    97],
         [   30,    46,    91,     1, 24004,    92,    28,    69,    88, 24004],
         [24004, 24004, 24004,    86,    29,    93,    46,    15, 24004,    90],
         [   74,     0,     5,    41,     5,    57,    77,    24,    74, 24004],
         [   46,    41,     4,    94,    12,    43,     0,    14, 24004,    63],
         [   92,    91,    17, 24004, 24004,    58,    95,    39, 24004,    92]]),
 tensor([[-100, -100, -100, -100,   38,    7, -100, -100, -100, -100],
         [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
         [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
         [-100, -100, -100, -100,   79, -100, -100, -100, -100,   85],
         [-100, -100, -100, -100, -100, -100, -100, -100,   56, -1

### Garage Zone 🗺

In [136]:
import math
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
import numpy as np
import torch
from transformers.data.data_collator import DataCollatorForLanguageModeling
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

SPECIAL_TOKEN_NAMES = ['bos_token', 'eos_token',
                       'sep_token', 'cls_token', 'pad_token']


@dataclass
class ImprovedV3DataCollatorForSpanLevelMask(DataCollatorForLanguageModeling):
    """
    Data collator used for span-level masked language modeling
     
    adapted from NGramMaskGenerator class
    
    https://github.com/microsoft/DeBERTa/blob/11fa20141d9700ba2272b38f2d5fce33d981438b/DeBERTa/apps/tasks/mlm_task.py#L36
    and
    https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py

    """
    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    max_gram: int = 3
    pad_to_multiple_of: Optional[int] = None

    def __new__(cls, tokenizer, mlm, mlm_probability, pad_to_multiple_of, *args, **kwargs):
    
        obj = object.__new__(cls)
        DataCollatorForLanguageModeling.__init__(obj, tokenizer=tokenizer, mlm=mlm,
                                                 mlm_probability=mlm_probability,
                                                 pad_to_multiple_of=pad_to_multiple_of)
        return obj
    

    def __post_init__(self, *args, **kwargs):
        
        self.vocab_words = list(self.tokenizer.get_vocab().keys())
        self.vocab_mapping = self.tokenizer.get_vocab()
        
        self.special_token_ids = [ self.vocab_mapping[self.tokenizer.special_tokens_map[name]] for name in  SPECIAL_TOKEN_NAMES]
        self.ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
        _pvals = 1. / np.arange(1, self.max_gram + 1)
        self.pvals = torch.Tensor(_pvals / _pvals.sum(keepdims=True))

    def mask_tokens(self,inputs: torch.Tensor,
                special_tokens_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    
        labels = inputs.clone()
        labels_to_be_mask = torch.full(inputs.shape, 0.).bool()

        if special_tokens_mask is None:
            special_tokens_mask = sum(inputs==i for i in self.special_token_ids).bool()
        else:
            special_tokens_mask = special_tokens_mask.bool()
            
        K = len(pvals)
        mask_indices_by_span_len = [[] for i in range(K)]



        for k in range(0, K):

            probability_matrix = torch.full(inputs.shape, self.mlm_probability * self.pvals[k] )


            _masked_indices = torch.bernoulli(probability_matrix).bool()

            _indices_selected = (_masked_indices == True).nonzero(as_tuple=False)



            if _indices_selected == None:
                mask_indices_by_span_len[k] = []
            else:
                mask_indices_by_span_len[k] = _indices_selected

            mask_indices_by_span_len[k] = _indices_selected

        accum_indices = [[],[]]
        max_seq_len = inputs.shape[1] - 1

        for k in range(0, K):

            list_of_indices = mask_indices_by_span_len[k]

            if list_of_indices.shape == (0,):
                continue
            else:
                for j in range(k+1):
                    max_indices = torch.full((list_of_indices.shape[0],), max_seq_len, dtype=torch.long)
                    left, right = (list_of_indices[:, 0], \
                                   torch.min(list_of_indices[:, 1] + j, max_indices))

                    accum_indices[0].append(left)
                    accum_indices[1].append(right)
        accum_indices[0] = list(filter(lambda x: x.shape != (0,), accum_indices[0]))
        accum_indices[1] = list(filter(lambda x: x.shape != (0,), accum_indices[1]))
#         print('accum_indices', accum_indices)
        if len(accum_indices[0]) != 0: 
            accum_indices_flatten  = (torch.cat(accum_indices[0]), torch.cat(accum_indices[1]))
            labels_to_be_mask.index_put_(accum_indices_flatten, torch.tensor([1.]).bool())
            labels_to_be_mask.masked_fill_(special_tokens_mask, value=0.0).bool()

        inputs[labels_to_be_mask] = self.tokenizer.mask_token_id
        labels[~labels_to_be_mask] = -100  # We only compute loss on masked token

        return inputs, labels



------

In [137]:
imp_data_collator_span_mlm =  ImprovedV3DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.5,
                                              max_gram=3,
                                              pad_to_multiple_of=1)


text1 = """แผ่นดินไหวตาม หรือทับศัพท์ว่า อาฟเตอร์ช็อก (อังกฤษ: aftershock) เป็นแผ่นดินไหวขนาดเล็กที่เกิดขึ้นหลังจากแผ่นดินไหวขนาดใหญ่ที่มีก่อนหน้า"""
text2 = """อาฟเตอร์ยู (After You Dessert Café) เป็นร้านขนมหวานรสชาติละมุนที่ครองใจลูกค้าเป็นอย่างดี โดยมีเมนูของหวานอร่อยให้เลือกหลากหลายเมนู """
tokens = tokenizer.tokenize(text1)
print('text 1:', '|'.join(tokens), len(tokens))
tokens = tokenizer.tokenize(text2)
print('\ntext 2:','|'.join(tokens), len(tokens))


inputs_1 = tokenizer.encode_plus(text1, return_tensors='pt')['input_ids'].squeeze(0)
inputs_2 = tokenizer.encode_plus(text2, return_tensors='pt')['input_ids'].squeeze(0)

res = imp_data_collator_span_mlm((inputs_1, inputs_2))
print('\n\n')
# print('inputs 1 (before)', inputs_1)
# print('inputs 2 (before)', inputs_2)
print('\n\n\nres', res)

print('\ninputs 1 (after):\n', '|'.join(tokenizer.convert_ids_to_tokens(res['input_ids'][0])))
print('\n\ninputs 2 (after):\n','|'.join(tokenizer.convert_ids_to_tokens(res['input_ids'][1])))

text 1: ▁|แผ่นดินไหว|ตาม|▁|หรือ|ทับศัพท์|ว่า|▁|อ|า|ฟ|เตอร์|ช็อก|▁|(|อังกฤษ|:|▁|af|ter|sh|ock|)|▁|เป็น|แผ่นดินไหว|ขนาดเล็ก|ที่|เกิดขึ้น|หลังจาก|แผ่นดินไหว|ขนาดใหญ่|ที่มี|ก่อนหน้า 34

text 2: ▁|อ|า|ฟ|เตอร์|ยู|▁|(|After|▁|You|▁|Des|s|ert|▁|Ca|f|é|)|▁|เป็น|ร้าน|ขนมหวาน|รสชาติ|ละ|มุน|ที่|ครอง|ใจ|ลูกค้า|เป็น|อย่างดี|▁|โดยมี|เมนู|ของ|หวาน|อร่อย|ให้เลือก|หลากหลาย|เมนู|▁ 43






res {'input_ids': tensor([[    5, 24004,  2827,    83, 24004, 24004,  9962, 24004,     8, 24004,
             9,   265,   712, 13225, 24004, 24004, 24004, 24004, 24004, 24004,
         24004, 24004, 11661, 24004, 24004, 24004, 24004,   897,    12, 24004,
           134,  2827, 24004, 24004, 24004,     6,     1,     1,     1,     1,
             1,     1,     1,     1,     1],
        [    5, 24004, 24004, 24004, 24004, 24004, 24004, 24004,    15, 23639,
         24004, 24004,     8, 14929, 24004, 24004, 24004, 24004,   737,  2912,
         24004,     8,    14,  1665, 20046,  6886,   674, 24004, 24004, 24004,
          

In [119]:
BZ=32

# data_collator_span_mlm =  DataCollatorForSpanLevelMask(tokenizer=tokenizer,
#                                               mlm=True,
#                                               mlm_probability=0.15,
#                                               max_gram=3,
#                                               keep_prob=0.0,
#                                               mask_prob=1.0,

#                                               pad_to_multiple_of=8)

# data_loader_span_mlm = DataLoader(
#             train_dataset,
#             batch_size=BZ,
#             sampler=train_sampler,
#             collate_fn=data_collator_span_mlm,
#             drop_last=False,
#             num_workers=0,
#             pin_memory=True,
#         )

# data_collator_subword_mlm =  DataCollatorForLanguageModeling(tokenizer=tokenizer,
#                                               mlm=True,
#                                               mlm_probability=0.15,
#                                               pad_to_multiple_of=8)

# data_loader_subword_mlm = DataLoader(
#             train_dataset,
#             batch_size=BZ,
#             sampler=train_sampler,
#             collate_fn=data_collator_subword_mlm,
#             drop_last=False,
#             num_workers=0,
#             pin_memory=True,
#         )

imp_data_collator_span_mlm =  ImprovedV3DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              pad_to_multiple_of=8)

imp_data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=imp_data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

In [703]:
%%timeit
next(iter(data_loader_subword_mlm))

2.48 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [479]:
%%timeit
list(data_loader_subword_mlm)
print('.', end='')

.

KeyboardInterrupt: 

In [None]:
%%timeit
next(iter(data_loader_span_mlm))

In [None]:
%%timeit
list(data_loader_span_mlm)
print('.', end='')

In [None]:
%prun next(iter(data_loader_span_mlm))

In [120]:
%prun next(iter(imp_data_loader_span_mlm))

 

         304 function calls (301 primitive calls) in 0.003 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.001    0.001    0.002    0.002 <ipython-input-118-2a5455979ea3>:50(mask_tokens)
       11    0.001    0.000    0.001    0.000 tensor.py:25(wrapped)
        1    0.000    0.000    0.000    0.000 data_collator.py:195(_collate_batch)
        1    0.000    0.000    0.001    0.001 {built-in method builtins.sum}
        3    0.000    0.000    0.000    0.000 {built-in method bernoulli}
        7    0.000    0.000    0.000    0.000 {method 'bool' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'random_' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {method 'nonzero' of 'torch._C._TensorBase' objects}
       10    0.000    0.000    0.000    0.000 {built-in method full}
        1    0.000    0.000    0.000    0.000 {method 'item' of 'torch

In [121]:
# %prun next(iter(data_loader_subword_mlm))

In [122]:
%%timeit
next(iter(imp_data_loader_span_mlm))

1.38 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
%%timeit
list(imp_data_loader_span_mlm)
print('.', end='')

In [None]:
res = imp_data_collator_span_mlm((inputs_1, inputs_2))
print(res)

print(sum(torch.sum(res['input_ids'].eq(1), dim=1).detach().cpu().numpy()))

In [273]:
count_mask([res])

([14], [81])

In [190]:
res['input_ids'].size()

torch.Size([2, 48])

### Quality Assurance 🥽

In [138]:
def count_mask(results):
    mask_counts = []
    token_counts = []
    for item in results:
#         print(item['labels'])
        mask_count = sum(torch.sum(~(item['labels'].eq(-100)), dim=1).detach().cpu().numpy())
        special_tokens_count = sum(torch.sum(( item['input_ids'].eq(1) | item['input_ids'].eq(5)  |item['input_ids'].eq(6) ), dim=1).detach().cpu().numpy())

        _token_count = item['input_ids'].shape[0] * item['input_ids'].shape[1]
#         print('_token_count', _token_count)
        token_count = _token_count - special_tokens_count
#         print('token_count', token_count)
        token_counts.append(token_count)
        mask_counts.append(mask_count)
    return mask_counts, token_counts


def run_exp_masking_percentage(data_loader):
    
    result = list(data_loader)
    
    mask_counts, token_counts = count_mask(result)
#     print('mask_counts', mask_counts)
#     print('token_counts', token_counts)
    total_mask_tokens = sum(mask_counts)
    total_tokens = sum(token_counts)
    percentage = total_mask_tokens / total_tokens * 100
    print(f'total_mask_tokens: {total_mask_tokens}')
    print(f'total_tokens: {total_tokens}')
    print(f'masking percentage: {percentage:.3f}')
    return percentage

In [139]:
BZ=32
imp_data_collator_span_mlm =  ImprovedV3DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.3,
                                              max_gram=3,
                                              pad_to_multiple_of=8)

imp_data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=imp_data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

In [140]:
percentages = []
for i in range(10):
    print(f'\n\nexp {i+1}')
    percentages.append(run_exp_masking_percentage(data_loader=imp_data_loader_span_mlm))
    
percentages_df = pd.Series(percentages)
percentages_df.describe()



exp 1
total_mask_tokens: 652566
total_tokens: 1616963
masking percentage: 40.358


exp 2
total_mask_tokens: 654079
total_tokens: 1616963
masking percentage: 40.451


exp 3
total_mask_tokens: 651763
total_tokens: 1616963
masking percentage: 40.308


exp 4
total_mask_tokens: 651896
total_tokens: 1616963
masking percentage: 40.316


exp 5
total_mask_tokens: 651896
total_tokens: 1616963
masking percentage: 40.316


exp 6
total_mask_tokens: 651645
total_tokens: 1616963
masking percentage: 40.301


exp 7
total_mask_tokens: 653940
total_tokens: 1616963
masking percentage: 40.442


exp 8
total_mask_tokens: 652761
total_tokens: 1616963
masking percentage: 40.370


exp 9
total_mask_tokens: 653821
total_tokens: 1616963
masking percentage: 40.435


exp 10
total_mask_tokens: 650904
total_tokens: 1616963
masking percentage: 40.255


count    10.000000
mean     40.355104
std       0.068148
min      40.254724
25%      40.309905
50%      40.336792
75%      40.418736
max      40.451080
dtype: float64

In [127]:
BZ = 32
imp_data_collator_span_mlm =  ImprovedV3DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=.2,
                                              max_gram=5,

                                              pad_to_multiple_of=8)

imp_data_loader_span_mlm_large = DataLoader(
            train_dataset_large,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=imp_data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

NameError: name 'train_dataset_large' is not defined

In [1261]:
percentages = []
for i in range(10):
    print(f'\n\nexp {i+1}')
    percentages.append(run_exp_masking_percentage(data_loader=imp_data_loader_span_mlm_large))
    
percentages_df = pd.Series(percentages)
percentages_df.describe()



exp 1
total_mask_tokens: 217076
total_tokens: 983751
masking percentage: 22.066


exp 2
total_mask_tokens: 217684
total_tokens: 983751
masking percentage: 22.128


exp 3
total_mask_tokens: 217728
total_tokens: 983751
masking percentage: 22.132


exp 4
total_mask_tokens: 217454
total_tokens: 983751
masking percentage: 22.105


exp 5
total_mask_tokens: 216778
total_tokens: 983751
masking percentage: 22.036


exp 6
total_mask_tokens: 217589
total_tokens: 983751
masking percentage: 22.118


exp 7
total_mask_tokens: 217448
total_tokens: 983751
masking percentage: 22.104


exp 8
total_mask_tokens: 217073
total_tokens: 983751
masking percentage: 22.066


exp 9
total_mask_tokens: 216243
total_tokens: 983751
masking percentage: 21.981


exp 10
total_mask_tokens: 215945
total_tokens: 983751
masking percentage: 21.951


count    10.000000
mean     22.068776
std       0.062427
min      21.951185
25%      22.043358
50%      22.085060
75%      22.114870
max      22.132430
dtype: float64