filename: refactor_span_level_data_collator_v2.ipynb

In [1]:
%load_ext line_profiler
%load_ext memory_profiler

In [2]:
import sys
sys.path.append('../../thai2transformers')

In [3]:
import torch
import transformers
from transformers import AutoTokenizer , DataCollatorForLanguageModeling
import math
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
import numpy as np
import pandas as pd
import torch
import random
from bisect import bisect
from transformers.data.data_collator import DataCollatorForLanguageModeling, _collate_batch, tolist
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase

import glob, os
from torch.utils.data.dataloader import DataLoader

from torch.utils.data.sampler import RandomSampler, SequentialSampler


from thai2transformers.datasets import MLMDataset

In [4]:
transformers.__version__

'4.6.1'

In [5]:
tokenizer = AutoTokenizer.from_pretrained('airesearchth/wangchanberta-base-wiki-20210520-spm')

In [63]:
tokenizer

PreTrainedTokenizerFast(name_or_path='airesearchth/wangchanberta-base-wiki-20210520-spm', vocab_size=24005, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True), 'additional_special_tokens': ['<s>NOTUSED', '</s>NOTUSED', '▁']})

In [6]:
TRAIN_DATA_PATH = '../../dataset/split/thwiki-for-ddp_6.11.2020/train'

In [58]:
TRAIN_LARGE_DATA_PATH = '../../dataset/split/thwiki-for-ddp_concat_12.11.2020/train'

In [61]:
!head -2 $TRAIN_LARGE_DATA_PATH/train.txt


ม้าลายเบอร์เชลล์<_>เป็นม้าลายชนิดย่อยหรือสปีชีส์ย่อยของม้าลายธรรมดา<_>("E.<_>quagga")<_>ชนิดหนึ่ง
ม้าลายเบอร์เชลล์<_>เป็นม้าลายที่แพร่กระจายพันธุ์ในทวีปแอฟริกาตอนใต้<_>เช่น<_>บอตสวานา,<_>สวาซิแลนด์,<_>แอฟริกาใต้<_>เป็นต้น<_>เป็นม้าลายที่มีลายขนาดใหญ่สีดำพาดยาวสลับกับลายสีขาวจากหลังลงไปทั้งสองข้างของลำตัวจนถึงใต้ท้อง<_>มีพฤติกรรมและลักษณะนิสัยคล้ายกับม้าลายธรรมดาทั่วไป


In [7]:
%%time 

train_dataset = MLMDataset(tokenizer,
                           TRAIN_DATA_PATH,
                           510)


[INFO] Build features (parallel).

[INFO] Start groupping results.
[INFO] Done.
CPU times: user 316 ms, sys: 89.7 ms, total: 406 ms
Wall time: 2.31 s


In [709]:
%%time 

train_dataset_large = MLMDataset(tokenizer,
                           TRAIN_LARGE_DATA_PATH,
                           100)


[INFO] Build features (parallel).

[INFO] Start groupping results.
[INFO] Done.
CPU times: user 8.76 s, sys: 4.16 s, total: 12.9 s
Wall time: 6min 21s


In [8]:

SPECIAL_TOKEN_NAMES = ['bos_token', 'eos_token', 'sep_token', 'cls_token', 'pad_token']

@dataclass
class DataCollatorForSpanLevelMask(DataCollatorForLanguageModeling):
    """
    Data collator used for span-level masked language modeling
     
    adapted from NGramMaskGenerator class
    
    https://github.com/microsoft/DeBERTa/blob/11fa20141d9700ba2272b38f2d5fce33d981438b/DeBERTa/apps/tasks/mlm_task.py#L36
    and
    https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py

    """
    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    max_gram: int = 3
    keep_prob: float = 0.0
    mask_prob: float = 1.0
    max_preds_per_seq: int = None
    max_seq_len: int = 510

    def __init__(self, tokenizer, mlm=True, mlm_probability=0.15, *args, **kwargs):
        super().__init__(tokenizer, mlm=mlm, mlm_probability=mlm_probability)

        assert self.mask_prob + self.keep_prob <= 1, \
            f'The prob of using [MASK]({self.mask_prob}) and the prob of using original token({self.keep_prob}) should between [0,1]'

        if self.max_preds_per_seq is None:
            self.max_preds_per_seq = math.ceil(self.max_seq_len * self.mlm_probability / 10) * 10
            self.mask_window = int(1 / self.mlm_probability) # make ngrams per window sized context
        self.vocab_words = list(self.tokenizer.get_vocab().keys())
        self.vocab_mapping = self.tokenizer.get_vocab()
        
        self.special_tokens = [self.tokenizer.special_tokens_map[name] for name in  SPECIAL_TOKEN_NAMES]
#         print(' self.special_tokens', self.special_tokens)
        self.ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
        _pvals = 1. / np.arange(1, self.max_gram + 1)
        self.pvals = _pvals / _pvals.sum(keepdims=True)
        print('max_gram', self.max_gram)
    def _choice(self, rng, data, p):
        cul = np.cumsum(p)
        x = rng.random()*cul[-1]
        id = bisect(cul, x)
        return data[id]

    def _per_token_mask(self, idx, tokens, rng, mask_prob, keep_prob):
        label = tokens[idx]
        mask = self.tokenizer.mask_token
        rand = rng.random()
        if rand < mask_prob:
            new_label = mask
        elif rand < mask_prob + keep_prob:
            new_label = label
        else:
            new_label = rng.choice(self.vocab_words)

        tokens[idx] = new_label

        return label

    def _mask_tokens(self, tokens: List[str], rng=random, **kwargs):

        indices = [i for i in range(len(tokens)) if tokens[i] not in self.special_tokens]
#         print('debug: indices to be able to be masked', indices)
        
        unigrams = [ [idx] for idx in indices ]
        num_to_predict = min(self.max_preds_per_seq, max(1, int(round(len(tokens) * self.mlm_probability))))
           
        offset = 0
        mask_grams = np.array([False]*len(unigrams))
        while offset < len(unigrams):
            n = self._choice(rng, self.ngrams, p=self.pvals)
            ctx_size = min(n * self.mask_window, len(unigrams)-offset)
            m = rng.randint(0, ctx_size-1)
            s = offset + m
            e = min(offset + m + n, len(unigrams))
            offset = max(offset+ctx_size, e)
            mask_grams[s:e] = True

        target_labels = [None]*len(tokens)
        w_cnt = 0
        for m,word in zip(mask_grams, unigrams):
            if m:
                for idx in word:
                    label = self._per_token_mask(idx, tokens, rng, self.mask_prob, self.keep_prob)
                    target_labels[idx] = label
                    w_cnt += 1
                if w_cnt >= num_to_predict:
                    break

        target_labels = [self.vocab_mapping[x] if x else -100 for x in target_labels]
        return tokens, target_labels    


    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = []
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probabilityability`)
        # probability_matrix = torch.full(labels.shape, self.mlm_probabilityability)
        # if special_tokens_mask is None:
        #     special_tokens_mask = [
        #         self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        #     ]
        #     special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        # else:
        #     special_tokens_mask = special_tokens_mask.bool()

#         print('inputs', inputs.shape, inputs)
        inputs_masked = []
        
        for i, input in enumerate(inputs):
#             print('input',input)
            input_tokens = self.tokenizer.convert_ids_to_tokens(input)
            

            input_masked, _labels = self._mask_tokens(input_tokens)
#             print('DEBUG: input_masked', input_masked)
            input_masked_ids = self.tokenizer.convert_tokens_to_ids(input_masked)
            inputs_masked.append(input_masked_ids)
#             print('_labels, ', _labels)
#             print('inputs_masked, ', input_masked_ids)
            labels.append(_labels)
      
        return inputs_masked, labels


In [9]:
train_sampler = SequentialSampler(train_dataset)

In [34]:
BZ=32

data_collator_span_mlm =  DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              keep_prob=0.0,
                                              mask_prob=1.0,
                                              max_seq_len=510,
                                              pad_to_multiple_of=8)

data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

data_collator_subword_mlm =  DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              pad_to_multiple_of=8)

data_loader_subword_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=data_collator_subword_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

max_gram 3


In [11]:
%prun next(iter(data_loader_span_mlm))

 

         14407 function calls (14404 primitive calls) in 0.030 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.004    0.004    0.008    0.008 data_collator.py:195(_collate_batch)
        1    0.004    0.004    0.004    0.004 {method 'random_' of 'torch._C._TensorBase' objects}
     1320    0.003    0.000    0.003    0.000 tensor.py:468(<lambda>)
        8    0.003    0.000    0.008    0.001 tokenization_utils_fast.py:275(convert_ids_to_tokens)
        1    0.002    0.002    0.002    0.002 {method 'new_full' of 'torch._C._TensorBase' objects}
     1312    0.002    0.000    0.002    0.000 {method 'id_to_token' of 'tokenizers.Tokenizer' objects}
        8    0.002    0.000    0.006    0.001 <ipython-input-8-1161ad0796fc>:63(_mask_tokens)
        1    0.002    0.002    0.002    0.002 {method 'item' of 'torch._C._TensorBase' objects}
        1    0.001    0.001    0.001    0.001 {built-in method empty}
       20  

### Improved V2 Data Collator

In [12]:
.15*30

4.5

### Test lab zone 👩🏻‍🔬 

<br>

Idea 3: k-time masking projection


``` 
matrix = [[ . . . . . . ]
          [ . . . . . . ]
          [ . . . . . . ]
          [ . . . . . . ]
          [ . . . . . . ]]
```

Project 1-subword mask

Suppose that MLM probability is 15% and the total tokens is 30, then project 1-subword mask 15% (Bernoulli distribution) or 4.5 tokens will be masked.

``` 
matrix = [[ . 1 . . . . ]
          [ . . 1 . . . ]
          [ . . . . . 1 ]
          [ . 1 . . . . ]
          [ . . . 1 . . ]]
```


Project 2-subword mask

Suppose that the 2-subword mask have 27.27 % from the masked tokens (from the following distribtuion [0.5455, 0.2727, 0.1818]), 1.3635 tokens will be assigned as 2-subword mask.


``` 
matrix = [[ . 1 . . . . ]
          [ . . 2 . . . ]
          [ . . . . . 1 ]
          [ . 1 . . . . ]
          [ . . . 2 . . ]]
```

Project 3-subword mask

Suppose that the 3-subword mask have 18.18 % from the masked tokens (from the following distribtuion [0.5455, 0.2727, 0.1818]), 0.9089 tokens will be assigned as 3-subword mask. This will assign masking only on 1-subword masked tokens. 


``` 
matrix = [[ . 1 . . . . ]
          [ . . 2 . . . ]
          [ . . . . . 1 ]
          [ . 3 . . . . ]
          [ . . . 2 . . ]]
```

In [13]:
pvals =  [0.5455, 0.2727, 0.1818]

mlm_probability = .2
probability_matrix = torch.full((5,8), mlm_probability)
print('\nprobability_matrix:\n', probability_matrix)
_masked_indices = torch.bernoulli(probability_matrix).bool()
print('\masked_indices:\n', masked_indices)

base_indices = (_masked_indices == True).nonzero(as_tuple=False)


probability_matrix:
 tensor([[0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])


NameError: name 'masked_indices' is not defined

In [None]:

print('base_indices', base_indices)
# 2-subword
second_probabilty_matrix = torch.full((1, base_indices.shape[0]), pvals[1])
print('second_probabilty_matrix', second_probabilty_matrix)

second_masked_indices = torch.bernoulli(second_probabilty_matrix).bool()
print('second_masked_indices', second_masked_indices)
second_indices = (second_masked_indices == True).nonzero(as_tuple=False)

print('second_indices', second_indices)


In [None]:
print('base_indices', base_indices)
print('second_indices', second_indices)

def filter_indices(base_indices, to_be_filtered_indices):      

    to_filter_indices = to_be_filtered_indices[:,1].tolist()
    keep_indices = list(set(range(base_indices.shape[0])).difference(set(to_filter_indices)))
#     print('\n  debug:filter_indices:to_filter_indices', to_filter_indices)
#     print('\n  debug:filter_indices:keep_indices', keep_indices)
#     base_indices_filtered = [ item for i, item in enumerate(base_indices) if i not in to_filter_indices ]
    base_indices_filtered = torch.index_select(base_indices, dim=0, index=torch.LongTensor(keep_indices))

#     print('\n  debug:filter_indices:base_indices_selected', base_indices_selected)

    if len(to_filter_indices) == 0:
        return base_indices_filtered, None

    base_indices_selected = torch.index_select(base_indices, dim=0, index=torch.LongTensor(to_filter_indices))
    return base_indices_filtered, base_indices_selected

base_indices_filtered, selected_indices = filter_indices(base_indices, second_indices)
print('\n\nbase_indices_filtered', base_indices_filtered)
print('selected_indices', selected_indices)


In [None]:
print('base_indices_filtered', base_indices_filtered)

third_probabilty_matrix = torch.full((1, len(base_indices_filtered), pvals[2])
print('third_probabilty_matrix', third_probabilty_matrix)

third_masked_indices = torch.bernoulli(third_probabilty_matrix).bool()
print('third_masked_indices', third_masked_indices)
third_indices = (third_masked_indices == True).nonzero(as_tuple=False)

print('third_indices', third_indices)


__Wrapping up in 1 for loop__

In [None]:
def filter_indices(base_indices, to_be_filtered_indices):      
    indices = to_be_filtered_indices[:,1].tolist()
#     print('to be filtered', indices)
    base_indices_filtered = [ item for i, item in enumerate(base_indices) if i not in indices ]
    base_indices_selected = [ item for i, item in enumerate(base_indices) if i in indices ]
    if len(base_indices_selected) == 0:
        return torch.stack(base_indices_filtered, dim=0), None
    return torch.stack(base_indices_filtered, dim=0), torch.stack(base_indices_selected, dim=0)

def mask_tokens(inputs: torch.Tensor,
                special_tokens_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    
    labels = inputs.clone()
    labels_to_be_mask = torch.full(inputs.shape, 0.).bool()
    pvals =  [0.5455, 0.2727, 0.1818]
    K = len(pvals)
    mask_indices_by_span_len = [[] for i in range(K)]
    
    mlm_probability = .2
    probability_matrix = torch.full(inputs.shape, mlm_probability)
    base_masked_indices = torch.bernoulli(probability_matrix).bool()
    print('\nprobability_matrix:\n', masked_indices)
    print('\nbase_masked_indices:\n', base_masked_indices)
    
    base_indices = (base_masked_indices == True).nonzero(as_tuple=False)
    print('\nbase_indices:\n', base_indices)
    
    _filter_base_indices = base_indices.clone()
    for k in range(1, K):
       
            # 2-subword
        print(f'\n\n{k+1}-subword masking')
        
        _probabilty_matrix = torch.full((1, _filter_base_indices.shape[0]), pvals[k])
        print(f'\n{k+1}-subword masking probabilty_matrix', _probabilty_matrix)

        _masked_indices = torch.bernoulli(second_probabilty_matrix).bool()
        print(f'\n{k+1}-subword masked_indices', _masked_indices)
        to_be_filtered_indices = (_masked_indices == True).nonzero(as_tuple=False)

        _filter_base_indices, _indices_selected = filter_indices(_filter_base_indices, to_be_filtered_indices)
        
        print(f'\n{k+1}-subword _indices_selected: ', _indices_selected)
        if _indices_selected == None:
            mask_indices_by_span_len[k] = []
        else:
            mask_indices_by_span_len[k] = _indices_selected

    mask_indices_by_span_len[0] = _filter_base_indices
    
    # Applying label
    for k in range(0, K):
        mask_indices_by_span_len_at_k = mask_indices_by_span_len[k]
        print(f'\n\n Perform masking, k={k}')
        for indices in mask_indices_by_span_len_at_k:
            print('indices', indices)
            labels_to_be_mask[indices[0], indices[1]:indices[1]+k+1] = True
        
        print('labels_to_be_mask', labels_to_be_mask)
        print('')
    
    inputs[labels_to_be_mask] = 24000 # mask token id

    labels[~labels_to_be_mask] = -100  # We only compute loss on masked toke
    return inputs, labels

In [None]:
inputs = (torch.rand((8,10))*100).long() 
inputs

In [None]:
mask_tokens(inputs)

### Garage Zone 🗺

In [1321]:
import math
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
import numpy as np
import torch
from transformers.data.data_collator import DataCollatorForLanguageModeling
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

SPECIAL_TOKEN_NAMES = ['bos_token', 'eos_token',
                       'sep_token', 'cls_token', 'pad_token']


@dataclass
class ImprovedV2DataCollatorForSpanLevelMask(DataCollatorForLanguageModeling):
    """
    Data collator used for span-level masked language modeling
     
    adapted from NGramMaskGenerator class
    
    https://github.com/microsoft/DeBERTa/blob/11fa20141d9700ba2272b38f2d5fce33d981438b/DeBERTa/apps/tasks/mlm_task.py#L36
    and
    https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py

    """
    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    max_gram: int = 3
    pad_to_multiple_of: Optional[int] = None

    def __new__(cls, tokenizer, mlm, mlm_probability, pad_to_multiple_of, *args, **kwargs):
    
        obj = object.__new__(cls)
        DataCollatorForLanguageModeling.__init__(obj, tokenizer=tokenizer, mlm=mlm,
                                                 mlm_probability=mlm_probability,
                                                 pad_to_multiple_of=pad_to_multiple_of)
        return obj
    

    def __post_init__(self, *args, **kwargs):
        
        self.vocab_words = list(self.tokenizer.get_vocab().keys())
        self.vocab_mapping = self.tokenizer.get_vocab()
        
        self.special_token_ids = [ self.vocab_mapping[self.tokenizer.special_tokens_map[name]] for name in  SPECIAL_TOKEN_NAMES]
        self.ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
        _pvals = 1. / np.arange(1, self.max_gram + 1)
        _pvals_np =  _pvals / _pvals.sum(keepdims=True)
        self.pvals = torch.Tensor(_pvals / _pvals.sum(keepdims=True))
        
        pvals_adjusted = [_pvals_np[0]]
        for k in range(1, len(_pvals)):
            pvals_adjusted.append(_pvals_np[k] / sum(_pvals_np[k:]))

        self.pvals_adjusted = torch.Tensor(pvals_adjusted)
        print('pvals', self.pvals)
        print('pvals_adjusted', self.pvals_adjusted)
#         mlm_probability_at_k = []
#         remaining_mlm_probability = self.mlm_probability
#         print('_pvals', _pvals,     self.pvals)
#         for i in range(0, len(_pvals)):
#             _mlm_probability = remaining_mlm_probability * self.pvals[i].detach().cpu().numpy()
#             mlm_probability_at_k.append(_mlm_probability)
#             remaining_mlm_probability -=  _mlm_probability
            
# #         self.base_mlm_probability = self.mlm_probability + self.mlm_probability * (1/sum(range(1, self.max_gram + 1))/self.max_gram)
#         self.base_mlm_probability = sum(mlm_probability_at_k)
#         print('base_mlm_probability', self.base_mlm_probability)


    def filter_indices(self, base_indices, to_be_filtered_indices):      

        to_filter_indices = to_be_filtered_indices[:,1].tolist()
        
        keep_indices = list(set(range(base_indices.shape[0])).difference(set(to_filter_indices)))

        base_indices_filtered = torch.index_select(base_indices, dim=0, index=torch.LongTensor(keep_indices))

        if len(to_filter_indices) == 0:
            return base_indices_filtered, None
        
        base_indices_selected = torch.index_select(base_indices, dim=0, index=torch.LongTensor(to_filter_indices))

        return base_indices_filtered, base_indices_selected


    def mask_tokens(self, inputs: torch.Tensor,
                    special_tokens_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:

        labels = inputs.clone()
        labels_to_be_mask = torch.full(inputs.shape, 0., dtype=torch.bool)
           
        if special_tokens_mask is None:
            special_tokens_mask = sum(inputs==i for i in self.special_token_ids).bool()
        else:
            special_tokens_mask = special_tokens_mask.bool()

        K = len(self.pvals)
        mask_indices_by_span_len = [[] for i in range(K)]

#         probability_matrix = torch.full(inputs.shape, self.base_mlm_probability, dtype=torch.float)
        
        probability_matrix = torch.full(inputs.shape, self.mlm_probability, dtype=torch.float)

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        base_masked_indices = torch.bernoulli(probability_matrix).bool()

        base_indices = (base_masked_indices == True).nonzero(as_tuple=False)

#         print('self.pvals[k]', self.pvals)
        _filter_base_indices = base_indices.clone()
        for k in range(1, K):
         
            _probabilty_matrix = torch.full((1, _filter_base_indices.shape[0]), self.mlm_probability * self.pvals_adjusted[k], dtype=torch.float)

            _masked_indices = torch.bernoulli(_probabilty_matrix).bool()

            to_be_filtered_indices = (_masked_indices == True).nonzero(as_tuple=False)

            _filter_base_indices, _indices_selected = self.filter_indices(_filter_base_indices, to_be_filtered_indices)

            if _indices_selected == None:
                mask_indices_by_span_len[k] = torch.LongTensor([])
            else:
                mask_indices_by_span_len[k] = _indices_selected
        
        mask_indices_by_span_len[0] = _filter_base_indices
        # Applying span-level masking
        accum_indices = [[],[]]
        max_seq_len = inputs.shape[1] - 1

        for k in range(0, K):

            list_of_indices = mask_indices_by_span_len[k]
            
            if list_of_indices.shape == (0,):
                continue
            else:
                for j in range(k+1):
                    max_indices = torch.full((list_of_indices.shape[0],), max_seq_len, dtype=torch.long)
                    left, right = (list_of_indices[:, 0], \
                                   torch.min(list_of_indices[:, 1] + j, max_indices))

                    accum_indices[0].append(left)
                    accum_indices[1].append(right)
        accum_indices[0] = list(filter(lambda x: x.shape != (0,), accum_indices[0]))
        accum_indices[1] = list(filter(lambda x: x.shape != (0,), accum_indices[1]))
        
        if len(accum_indices[0]) != 0: 
            accum_indices_flatten  = (torch.cat(accum_indices[0]), torch.cat(accum_indices[1]))
            labels_to_be_mask.index_put_(accum_indices_flatten, torch.tensor([1.]).bool())
            labels_to_be_mask.masked_fill_(special_tokens_mask, value=0.0).bool()

        inputs[labels_to_be_mask] = self.tokenizer.mask_token_id
        labels[~labels_to_be_mask] = -100  # We only compute loss on masked token

        return inputs, labels

In [1322]:
pvals=[0.5455, 0.2727, 0.1818]
base_prob =.15

b1 = base_prob * pvals[0]
print('b1', b1)
b2 = (base_prob - (b1)) * (pvals[1])
print('b2', b2)
b3 = (base_prob - (b1) - b2) * (pvals[2])
print('b3', b3)
# b4 = (base_prob - (b1) - b2- b3) * (pvals[3])
# print('b4', b4)
print('\nsum:', sum([b1,b2,b3]))

# print(b1 )
# print(b2 )
# print(b3 )

b1 0.081825
b2 0.0185913225
b3 0.0090143125695

sum: 0.10943063506949999


In [1323]:
base_prob + (base_prob * 1 / (sum(range(1, 3+1))) / 3 )

0.15833333333333333

------

In [1324]:
imp_data_collator_span_mlm =  ImprovedV2DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              pad_to_multiple_of=1)


text1 = """แผ่นดินไหวตาม หรือทับศัพท์ว่า อาฟเตอร์ช็อก (อังกฤษ: aftershock) เป็นแผ่นดินไหวขนาดเล็กที่เกิดขึ้นหลังจากแผ่นดินไหวขนาดใหญ่ที่มีก่อนหน้า"""
text2 = """อาฟเตอร์ยู (After You Dessert Café) เป็นร้านขนมหวานรสชาติละมุนที่ครองใจลูกค้าเป็นอย่างดี โดยมีเมนูของหวานอร่อยให้เลือกหลากหลายเมนู """
tokens = tokenizer.tokenize(text1)
print('text 1:', '|'.join(tokens), len(tokens))
tokens = tokenizer.tokenize(text2)
print('\ntext 2:','|'.join(tokens), len(tokens))


inputs_1 = tokenizer.encode_plus(text1, return_tensors='pt')['input_ids'].squeeze(0)
inputs_2 = tokenizer.encode_plus(text2, return_tensors='pt')['input_ids'].squeeze(0)

res = imp_data_collator_span_mlm((inputs_1, inputs_2))
print('\n\n')
# print('inputs 1 (before)', inputs_1)
# print('inputs 2 (before)', inputs_2)
print('\n\n\nres', res)

print('\ninputs 1 (after):\n', '|'.join(tokenizer.convert_ids_to_tokens(res['input_ids'][0])))
print('\n\ninputs 2 (after):\n','|'.join(tokenizer.convert_ids_to_tokens(res['input_ids'][1])))

pvals tensor([0.5455, 0.2727, 0.1818])
pvals_adjusted tensor([0.5455, 0.6000, 1.0000])
pvals tensor([0.5455, 0.2727, 0.1818])
pvals_adjusted tensor([0.5455, 0.6000, 1.0000])
text 1: ▁|แผ่นดินไหว|ตาม|▁|หรือ|ทับศัพท์|ว่า|▁|อ|า|ฟ|เตอร์|ช็อก|▁|(|อังกฤษ|:|▁|af|ter|sh|ock|)|▁|เป็น|แผ่นดินไหว|ขนาดเล็ก|ที่|เกิดขึ้น|หลังจาก|แผ่นดินไหว|ขนาดใหญ่|ที่มี|ก่อนหน้า 34

text 2: ▁|อ|า|ฟ|เตอร์|ยู|▁|(|After|▁|You|▁|Des|s|ert|▁|Ca|f|é|)|▁|เป็น|ร้าน|ขนมหวาน|รสชาติ|ละ|มุน|ที่|ครอง|ใจ|ลูกค้า|เป็น|อย่างดี|▁|โดยมี|เมนู|ของ|หวาน|อร่อย|ให้เลือก|หลากหลาย|เมนู|▁ 43






res {'input_ids': tensor([[    5,     8, 24004, 24004,     8,    28,  9962,    38,     8,    52,
             9,   265,   712, 13225,     8, 24004,   236,    94,     8, 14051,
          2414,  2711, 11661,    18,     8,    14,  2827,   897,    12,   319,
           134,  2827,   604,    93,   862,     6,     1,     1,     1,     1,
             1,     1,     1,     1,     1],
        [    5,     8,    52, 24004, 24004, 24004,   406,     8, 24004, 2

In [None]:
# [0.5455, 0.2727, 0.1818]

# 15

0.15 * (0.5455 + (0.2727 /2) + (0.1813/3) )

In [905]:
BZ=32

# data_collator_span_mlm =  DataCollatorForSpanLevelMask(tokenizer=tokenizer,
#                                               mlm=True,
#                                               mlm_probability=0.15,
#                                               max_gram=3,
#                                               keep_prob=0.0,
#                                               mask_prob=1.0,

#                                               pad_to_multiple_of=8)

# data_loader_span_mlm = DataLoader(
#             train_dataset,
#             batch_size=BZ,
#             sampler=train_sampler,
#             collate_fn=data_collator_span_mlm,
#             drop_last=False,
#             num_workers=0,
#             pin_memory=True,
#         )

# data_collator_subword_mlm =  DataCollatorForLanguageModeling(tokenizer=tokenizer,
#                                               mlm=True,
#                                               mlm_probability=0.15,
#                                               pad_to_multiple_of=8)

# data_loader_subword_mlm = DataLoader(
#             train_dataset,
#             batch_size=BZ,
#             sampler=train_sampler,
#             collate_fn=data_collator_subword_mlm,
#             drop_last=False,
#             num_workers=0,
#             pin_memory=True,
#         )

imp_data_collator_span_mlm =  ImprovedV2DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              pad_to_multiple_of=8)

imp_data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=imp_data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

In [703]:
%%timeit
next(iter(data_loader_subword_mlm))

2.48 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [479]:
%%timeit
list(data_loader_subword_mlm)
print('.', end='')

.

KeyboardInterrupt: 

In [None]:
%%timeit
next(iter(data_loader_span_mlm))

In [None]:
%%timeit
list(data_loader_span_mlm)
print('.', end='')

In [None]:
%prun next(iter(data_loader_span_mlm))

In [810]:
%prun next(iter(imp_data_loader_span_mlm))

 

         334 function calls (331 primitive calls) in 0.004 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       12    0.001    0.000    0.001    0.000 tensor.py:25(wrapped)
        1    0.001    0.001    0.003    0.003 <ipython-input-807-b34b762e1e90>:66(mask_tokens)
        1    0.000    0.000    0.001    0.001 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 data_collator.py:195(_collate_batch)
        3    0.000    0.000    0.001    0.000 <ipython-input-807-b34b762e1e90>:50(filter_indices)
        6    0.000    0.000    0.000    0.000 {built-in method index_select}
        4    0.000    0.000    0.000    0.000 {built-in method bernoulli}
       11    0.000    0.000    0.000    0.000 {built-in method full}
        1    0.000    0.000    0.000    0.000 {method 'item' of 'torch._C._TensorBase' objects}
        4    0.000    0.000    0.000    0.000 {method 'nonzero' of 'torch._C._TensorBase' obj

In [705]:
%prun next(iter(data_loader_subword_mlm))

 

         1873 function calls (1870 primitive calls) in 0.005 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       32    0.001    0.000    0.001    0.000 tokenization_utils_base.py:3128(<listcomp>)
        1    0.000    0.000    0.004    0.004 data_collator.py:361(mask_tokens)
        1    0.000    0.000    0.001    0.001 data_collator.py:195(_collate_batch)
        1    0.000    0.000    0.000    0.000 {method 'clone' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {built-in method bernoulli}
       32    0.000    0.000    0.000    0.000 tokenization_utils_base.py:1225(all_special_tokens_extended)
        1    0.000    0.000    0.000    0.000 {method 'tolist' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'new_full' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'random_' of 'torch._C._TensorBase' objects}


In [706]:
%%timeit
next(iter(imp_data_loader_span_mlm))

1.22 ms ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
%%timeit
list(imp_data_loader_span_mlm)
print('.', end='')

In [None]:
res = imp_data_collator_span_mlm((inputs_1, inputs_2))
print(res)

print(sum(torch.sum(res['input_ids'].eq(1), dim=1).detach().cpu().numpy()))

In [273]:
count_mask([res])

([14], [81])

In [190]:
res['input_ids'].size()

torch.Size([2, 48])

### Quality Assurance 🥽

In [1325]:
def count_mask(results):
    mask_counts = []
    token_counts = []
    mask_hits = []
    for item in results:
#         print(item['labels'])
        mask_count = sum(torch.sum(~(item['labels'].eq(-100)), dim=1).detach().cpu().numpy())
    
    
        mask_hit = (item['labels'] != -100).nonzero(as_tuple=False).detach().cpu().numpy().tolist()    
                    
        special_tokens_count = sum(torch.sum(( item['input_ids'].eq(1) | item['input_ids'].eq(5)  |item['input_ids'].eq(6) ), dim=1).detach().cpu().numpy())

        _token_count = item['input_ids'].shape[0] * item['input_ids'].shape[1]
#         print('_token_count', _token_count)
        token_count = _token_count - special_tokens_count
#         print('token_count', token_count)
        token_counts.append(token_count)

        mask_hits.extend(mask_hit)
        mask_counts.append(mask_count)
    return mask_counts, token_counts, mask_hits


def run_exp_masking_percentage(data_loader):
    
    result = list(data_loader)
    
    mask_counts, token_counts, mask_hits = count_mask(result)
#     print('mask_counts', mask_counts)
#     print('token_counts', token_counts)
#     print('mask_hits', mask_hits)
    total_mask_tokens = sum(mask_counts)
    total_tokens = sum(token_counts)
    percentage = total_mask_tokens / total_tokens * 100
    print(f'total_mask_tokens: {total_mask_tokens}')
    print(f'total_tokens: {total_tokens}')
    print(f'masking percentage: {percentage:.3f}')
    return percentage , mask_counts, token_counts, mask_hits

In [1326]:
BZ=32
imp_data_collator_span_mlm =  ImprovedV2DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              pad_to_multiple_of=8)

imp_data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=imp_data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

pvals tensor([0.5455, 0.2727, 0.1818])
pvals_adjusted tensor([0.5455, 0.6000, 1.0000])
pvals tensor([0.5455, 0.2727, 0.1818])
pvals_adjusted tensor([0.5455, 0.6000, 1.0000])


In [1327]:
percentages = []
for i in range(10):
    print(f'\n\nexp {i+1}')
    percentages.append(run_exp_masking_percentage(data_loader=imp_data_loader_span_mlm))
    
percentages_df = pd.Series(percentages)
percentages_df.describe()



exp 1
total_mask_tokens: 314803
total_tokens: 1616963
masking percentage: 19.469


exp 2
total_mask_tokens: 315103
total_tokens: 1616963
masking percentage: 19.487


exp 3
total_mask_tokens: 315982
total_tokens: 1616963
masking percentage: 19.542


exp 4
total_mask_tokens: 315273
total_tokens: 1616963
masking percentage: 19.498


exp 5
total_mask_tokens: 315064
total_tokens: 1616963
masking percentage: 19.485


exp 6
total_mask_tokens: 315770
total_tokens: 1616963
masking percentage: 19.529


exp 7
total_mask_tokens: 316043
total_tokens: 1616963
masking percentage: 19.545


exp 8
total_mask_tokens: 314959
total_tokens: 1616963
masking percentage: 19.478


exp 9
total_mask_tokens: 314729
total_tokens: 1616963
masking percentage: 19.464


exp 10
total_mask_tokens: 314649
total_tokens: 1616963
masking percentage: 19.459


count    10.000000
mean     19.495653
std       0.031985
min      19.459258
25%      19.471194
50%      19.486129
75%      19.520901
max      19.545469
dtype: float64

In [1296]:
BZ = 32
imp_data_collator_span_mlm =  ImprovedV2DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=.2,
                                              max_gram=5,

                                              pad_to_multiple_of=8)

imp_data_loader_span_mlm_large = DataLoader(
            train_dataset_large,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=imp_data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

pvals tensor([0.5455, 0.2727, 0.1818])
pvals_adjusted tensor([0.5455, 0.6000, 1.0000])
pvals tensor([0.4380, 0.2190, 0.1460, 0.1095, 0.0876])
pvals_adjusted tensor([0.4380, 0.3896, 0.4255, 0.5556, 1.0000])


In [1297]:
percentages = []
for i in range(10):
    print(f'\n\nexp {i+1}')
    percentages.append(run_exp_masking_percentage(data_loader=imp_data_loader_span_mlm_large))
    
percentages_df = pd.Series(percentages)
percentages_df.describe()



exp 1
total_mask_tokens: 214482
total_tokens: 983751
masking percentage: 21.802


exp 2
total_mask_tokens: 214163
total_tokens: 983751
masking percentage: 21.770


exp 3
total_mask_tokens: 214427
total_tokens: 983751
masking percentage: 21.797


exp 4
total_mask_tokens: 214047
total_tokens: 983751
masking percentage: 21.758


exp 5
total_mask_tokens: 214765
total_tokens: 983751
masking percentage: 21.831


exp 6
total_mask_tokens: 216120
total_tokens: 983751
masking percentage: 21.969


exp 7
total_mask_tokens: 215169
total_tokens: 983751
masking percentage: 21.872


exp 8
total_mask_tokens: 213513
total_tokens: 983751
masking percentage: 21.704


exp 9
total_mask_tokens: 214384
total_tokens: 983751
masking percentage: 21.793


exp 10
total_mask_tokens: 214852
total_tokens: 983751
masking percentage: 21.840


count    10.000000
mean     21.813670
std       0.071824
min      21.703968
25%      21.775658
50%      21.799673
75%      21.837869
max      21.968974
dtype: float64