In [24]:
import datasets
from t5_tokenizer_model import SentencePieceUnigramTokenizer

vocab_size = 32000

# size of the input data
input_sentence_size = None

cache_dir="/workdir/.cache/huggingface/datasets"

# using a tiny bit of the dataset to train a tokenizer
ds = datasets.load_dataset("oscar",  
    name="unshuffled_deduplicated_no", 
    cache_dir=cache_dir, 
    split="train[:100]"
    )

len(ds)

Reusing dataset oscar (/workdir/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_no/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2)


100

In [47]:
import os
tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")

def batch_iterator(ds, input_sentence_size=None):
    if input_sentence_size==None:
        input_sentence_size = len(ds)
    batch_size = 100
    for i in range(0, input_sentence_size, batch_size):
        yield ds[i: i+batch_size]["text"]

tokenizer.train_from_iterator(
    iterator=batch_iterator(ds),
    vocab_size=vocab_size,
    show_progress=True
)

t5_config_dir = "/workdir/norwegian-t5-base/"
if not os.path.exists(t5_config_dir):
    os.makedirs(t5_config_dir)

tokenizer.save(os.path.join(t5_config_dir, "tokenizer.json"))





In [48]:
tokenizer.get_vocab_size()

6152

In [33]:

from transformers import T5Config

config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
config.save_pretrained(t5_config_dir)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# T5 experiments replicating
I should answer the following
questions when looking at the implementation:

- how span-masking is implemented <
- exactly how loss is computed.
    - implementation details of decoder
    - [do the most simple example](Dummy Examples)
- input and labeled of span-masked pre-training data
- modeling details / optimizer / learning rate scheduling
- how they implemented metrics logging. 


In [2]:
# set up
from transformers import AutoTokenizer
from transformers import T5Config
# use rust based tokenizer
cache_dir="/workdir/norwegian-t5-base"
tokenizer = AutoTokenizer.from_pretrained(
    cache_dir, 
    cache_dir=cache_dir,
    use_fast=True,
    use_auth_token=None
)

config = T5Config.from_pretrained(
    "/workdir/norwegian-t5-base/",
    cache_dir=cache_dir,
    vocab_size = len(tokenizer),
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Span-masked language modeling

Spans of the input sequence are masked by so-called sentinel tokens (unique mask tokens) and the output sequence
is formed as a concatenation of the same sentinel tokens and the real token that has been masked out. For example, if the input sequence is 

`The bad dog ruined my sleep`

We can mask out `bad dog` and ask the model to predict it. The sequence we will feed to the encoder is 

`The <extra_id_0> ruined my sleep`

The label we use to compute the loss is 

`<extra_id_0> bad dog <extra_id_1>`

T5-like span masked language models fuse the consecutively masked tokens to a single sentinel token.


In [3]:
tokenizer("<extra_id_0>", add_special_tokens=False)

{'input_ids': [6152], 'attention_mask': [1]}

In [4]:
print("pad token id: ", config.pad_token_id)
print("decoder start token id: ", config.decoder_start_token_id)


pad token id:  0
decoder start token id:  0


The following parameters are needed to create span-masked 

- tokenizer: PreTrainedTokenizerBase
A pretrained tokenizer with all the extra id stuff

- noise_density: float = 0.15 (data_args.mlm_probability,)
The probablity of mask out a token

- mean_noise_span_length: float (data_args.mean_noise_span_length,)
Average size of the masked span

- input_length: int
Maximum input sequence length. Defined as 
```
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
```

- target_length: int
Target sequence length. This quantity depends on the `max_seq_length` and is computed
by `def compute_input_and_target_lengths`


- pad_token_id: int (model.config.pad_token_id,)
id for the pad token. 0 for T5

- decoder_start_token_id: int (model.config.decoder_start_token_id)
start token for sequence feed into decoder. 0 for T5. 


### Input type to `FlaxDataCollatorForT5MLM`
Span-mask is implemeted in `FlaxDataCollatorForT5MLM`, the input to the `__call__` method is 
a `List[Dict[str, np.ndarray]]`, i.e. it is a batch of input data, each of them has the signiture:
```
{
    "input_ids": [..., token ids,...],
    "masks": np.array,
}
```
The instance of `FlaxDataCollatorForT5MLM` is refered as
`data_collator` in the code. I can check what kind of input is fed into `data_collator` at line 895. 
The `tokenized_dataset` object from which we generate the batch of data is a standard interface in HuggingFace

### Hugging Face Dataset
The `tokenized_dataset` is defined as follwing. This pattern is the same for many language model training usages
in Hugging Face
```python
# 557
datasets = load_dataset(
    data_args.dataset_name,
    data_args.dataset_config_name,
    cache_dir=model_args.cache_dir,
    use_auth_token=True if model_args.use_auth_token else None,
)

# line 667
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=data_args.preprocessing_num_workers,
    remove_columns=column_names,
    load_from_cache_file=not data_args.overwrite_cache,
)

# line 706
tokenized_datasets = tokenized_datasets.map(
    group_texts, # concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
    batched=True,
    num_proc=data_args.preprocessing_num_workers,
    load_from_cache_file=not data_args.overwrite_cache,
)
```
Now, we know how the input of the `__call__` method of `FlaxDataCollatorForT5MLM` look like, let's dive into
its implementation details:

```python
# list of dicts to dict of batched tensors of the same key, 
# BatchEncoding: https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.BatchEncoding
# a wrapper of input data
batch = BatchEncoding(
    {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
)

input_ids = batch["input_ids"]
batch_size, expandend_input_length = input_ids.shape
```
Some interesting stuff happens below: given the length of a input sequence, we are deciding the indices
of the tokens to be masked. Note that for span-mask language modeling, the masked tokens need to be 
locally-connected (locally-contiguous)

```python
mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
labels_mask = ~mask_indices
```

Let's look at how `def random_spans_noise_mask` is being implemented. Note that the implementation is a clone
from [google's origal implementation](https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682).


### Some auxilliary functions I will need to know

#### np.stack
If we have a list of $m$ arrays $A_0, \cdots, A_{m-1}$ of shape $(x_0, x_1, \cdots, x_n)$, then the result of the stack is an array 
with dimension $n+2$. If we stack then along axis 0, then the result would be $X$ such that
$$
X[i,...] = A_i
$$
If we stack along axis 1, then the result array $X$ would satisfy:
$$
X[:,i,...] = A_i
$$


In [39]:
import numpy as np

class DataConfig:
    noise_density : float = 0.15 # mask language probaility
    mean_noise_span_length : float = 3.0 # lenght of the noise span

def random_span_noise_mask(data_config: DataConfig, length:int):
    """
    Noise mask consisting of random spans of noise tokens.
    The number of noise tokens and the number of noise spans and non-noise spans
    are determined deterministically as follows:
    num_noise_tokens = round(length * noise_density)
    num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
    Spans alternate between non-noise and noise, beginning with non-noise.
    Subject to the above restrictions, all masks are equally likely.

    Args:
        length: an int32 scalar (length of the incoming token sequence)
        noise_density: a float - approximate density of output mask
        mean_noise_span_length: a number

    Returns:
        a boolean tensor with shape [length] 
    """
    assert length > 1
    orig_length = length 

    num_noise_tokens = int(np.round(length * data_config.noise_density))

    # we want the number of noise and non-noise token to be positive
    num_noise_tokens = min(max(num_noise_tokens, 1), length-1)

    # np.round(0.5) = 0; np.round(0.51) = 1.0
    num_noise_spans = int(np.round(num_noise_spans/ data_config.mean_noise_span_length))

    # want positive number of spans
    num_noise_spans = max(num_noise_spans, 1)
    num_nonnoise_tokens = length - num_noise_tokens

    # pick length of noise and non-noise spans
    def _random_segmentation(num_items: int, num_segments: int):
        mask_indices = np.arange(num_items - 1) < (num_segments - 1)
        np.random.shuffle(mask_indices)
        first_in_segment = np.pad(mask_indices, [[1, 0]])
        segment_id = np.cumsum(first_in_segment)
        # count length of sub segments assuming that list is sorted
        _, segment_length = np.unique(segment_id, return_counts=True)
        return segment_length
    
    noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
    nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)

num_items = 10; num_segments = 4
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)

print("mask indices: ", mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]]) # before_1 = 1, after_1 = 0

print("first_in_segment: ", first_in_segment)

segment_id = np.cumsum(first_in_segment) # cumulative sum
print("segment id: ", segment_id)

# count the length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
print("segment_length: ", segment_length)


# pick length of noise and non-noise spans
def _random_segmentation(num_items: int, num_segments: int):
    mask_indices = np.arange(num_items - 1) < (num_segments - 1)
    np.random.shuffle(mask_indices)
    first_in_segment = np.pad(mask_indices, [[1, 0]])
    segment_id = np.cumsum(first_in_segment)
    # count length of sub segments assuming that list is sorted
    _, segment_length = np.unique(segment_id, return_counts=True)
    return segment_length

np.random.seed(10)

length=100; num_noise_tokens = 15; num_nonnoise_tokens = 85; num_noise_spans = 5
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
print("noise_span_lengths: ", noise_span_lengths)
print("nonnoise span length: ", nonnoise_span_lengths)

# span, nonspan, span, nonspan,...
interleaved_span_lengths = np.reshape(
    np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
)
print("interleved_span_lengths: ", interleaved_span_lengths)

span_starts = np.cumsum(interleaved_span_lengths)[:-1]
print("span_starts", span_starts)

span_start_indicator = np.zeros((length, ), dtype=np.int8)
span_start_indicator[span_starts] = True

span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
is_noise = is_noise[:length]
print("is_noise: ", is_noise)
# np.stack means to stack arrays along a prescribed axis
# np.stack([[0, 1],[0, 1]], axis=0) = 


mask indices:  [False  True False  True False  True False False False]
first_in_segment:  [False False  True False  True False  True False False False]
segment id:  [0 0 1 1 2 2 3 3 3 3]
segment_length:  [2 2 2 4]
noise_span_lengths:  [1 5 5 1 3]
nonnoise span length:  [ 3 10  4 11 57]
interleved_span_lengths:  [ 3  1 10  5  4  5 11  1 57  3]
span_starts [ 3  4 14 19 23 28 39 40 97]
is_noise:  [False False False  True False False False False False False False False
 False False  True  True  True  True  True False False False False  True
  True  True  True  True False False False False False False False False
 False False False  True False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False  True  True 

I am confused here, isn't noise suppose to mean the masked token? Then why so many tokens in the sequence is said to be noise? I guess something happens later corrects it?

In [40]:
import tensorflow as tf

span_start_indicator = tf.math.unsorted_segment_sum(
    tf.ones_like(span_starts), span_starts, length
)
span_num = tf.cumsum(span_start_indicator)
is_noise = tf.equal(span_num % 2, 1)
print(is_noise[:length])

tf.Tensor(
[False  True  True  True False False False False False  True  True  True
  True  True  True  True  True  True  True False False False False False
  True  True  True  True False  True  True  True  True  True  True  True
  True  True  True  True False False False  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True], shape=(100,), dtype=bool)


## Inspect the output from the `FlaxDataCollatorForT5MLM`. 
One thing I can do is to simply look at the output from this data collator given 1 paragraph of English text.
I can use the `t5-base` tokenizer and write a paragraph myself and directly feed it into the collator.

  from .autonotebook import tqdm as notebook_tqdm


expanded_inputs_length:  10
target_length: 4


In [35]:
from datasets import load_dataset
import datasets
from itertools import chain
#datasets.list_datasets()

from transformers import AutoTokenizer
import run_t5_mlm_flax
import imp
imp.reload(run_t5_mlm_flax)

from run_t5_mlm_flax import (
    FlaxDataCollatorForT5MLM,
    compute_input_and_target_lengths
)

tokenizer = AutoTokenizer.from_pretrained(
    "t5-small", max_seq_length=512, model_max_length=512)

max_seq_length = 100

expanded_inputs_length, target_length=compute_input_and_target_lengths(max_seq_length, 0.15, 3)

print("expanded_inputs_length: ", expanded_inputs_length)
print("target_length:", target_length)
data_collator = FlaxDataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3,
    input_length=max_seq_length,
    target_length=target_length,
    pad_token_id=0,
    decoder_start_token_id=0
)


ds = load_dataset(
    "wikitext",
    "wikitext-2-v1"
)

text_column_name=ds["train"].column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name], return_attention_mask=False)

tokenized_ds = ds.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=ds["train"].column_names[0],
    load_from_cache_file=True
)

expanded_inputs_length, targets_length = compute_input_and_target_lengths(
    inputs_length=max_seq_length,
    noise_density=0.15,
    mean_noise_span_length=3
)


# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= expanded_inputs_length:
        total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
        for k, t in concatenated_examples.items()
    }
    return result

tokenized_ds = tokenized_ds.map(
    group_texts,
    batched=True,
    num_proc=4,
    load_from_cache_file=True
)


expanded_inputs_length:  110
target_length: 22


100%|██████████| 3/3 [00:00<00:00, 1147.45it/s]
#0:   0%|          | 0/2 [00:00<?, ?ba/s]
[A

[A[AToken indices sequence length is longer than the specified maximum sequence length for this model (557 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (515 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (607 > 512). Running this sequence through the model will result in indexing errors


#0: 100%|██████████| 2/2 [00:00<00:00,  8.28ba/s]
#1: 100%|██████████| 2/2 [00:00<00:00,  8.67ba/s]
#3: 100%|██████████| 2/2 [00:00<00:00,  9.34ba/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (559 > 512). Running this sequence through the model will result in indexing errors

#2: 100%|█████████

In [21]:
for i in range(10, 20):
    print("index:%s\t %s" % (i,  tokenizer.decode(tokenized_ds["train"][i]["input_ids"])))

index:10	 .</s> It met with positive sales in Japan, and was praised by both Japanese and western critics. After release, it received 
index:11	 downloadable content, along with an expanded edition in November of that year. It was also adapted into manga and an original video animation series. Due
index:12	 to low sales of Valkyria Chronicles II, Valkyria Chronicles III was not localized, but a fan
index:13	 translation compatible with the game's expanded edition was released in 2014. Media.Vision would return to the franchise with the development of Valkyri
index:14	 a : Azure Revolution for the PlayStation 4.</s></s> = = Gameplay = =</s></s> As with previous<unk> Chronicles games, Valky
index:15	 ria Chronicles III is a tactical role @-@ playing game where players take control of a military unit and take part in missions against enemy
index:16	 forces. Stories are told through comic book @-@ like panels with animated character portraits, with characters speaking partially through voi

In [36]:
examples = [tokenized_ds["train"][i] for i in [17, 18, 19]]
masked_examples = data_collator(examples)

for i in range(3):
    m_input = tokenizer.decode(masked_examples["input_ids"][i])
    m_label = tokenizer.decode(masked_examples["labels"][i])
    print("masked input: ", m_input)
    print("masked label: ", m_label)
    print()


noise_span_lengths:  [2 4 1 5 4]
nonnoise span length:  [ 3  6 22 55  8]
interleved_span_lengths:  [ 3  2  6  4 22  1 55  5  8  4]
span_starts [  3   5  11  15  37  38  93  98 106]
span_start_indicator:  [0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0]
span_num:  [0 0 0 1 1 2 2 2 2 2 2 3 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 5 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6
 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 7 7 7 7 7 8 8 8 8 8 8 8 8 9 9 9 9]
is_noise,  [False False False  True  True False False False False False False  True
  True  True  True False False False False False False False False False
 False False False False False False False False False False False False
 False  True False False False False False False False False False False
 False False False False False Fal

The masked input and output does seem to be what I expected. To make the logic of span-mask creation more transparent, I can excute the 
logics in `__call__` line by line and print out shape of each tensor

In [37]:
from transformers import BatchEncoding
import numpy as np

batch = BatchEncoding(
    {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
)


input_ids = batch["input_ids"]
batch_size, expandend_input_length = input_ids.shape

mask_indices = np.asarray([data_collator.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
labels_mask = ~mask_indices
print("mask indices: ", mask_indices)

noise_span_lengths:  [3 2 2 3 6]
nonnoise span length:  [43  5 15 10 21]
interleved_span_lengths:  [43  3  5  2 15  2 10  3 21  6]
span_starts [ 43  46  51  53  68  70  80  83 104]
span_start_indicator:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 1 0 0 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0
 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0]
span_num:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 1 1 1 2 2 2 2 2 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5 5 6 6 6 6
 6 6 6 6 6 6 7 7 7 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 9 9 9 9 9 9]
is_noise,  [False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False  True  True  True False False
 False False False  True  True Fal

In [38]:
data_collator.random_spans_noise_mask(100)

noise_span_lengths:  [6 1 3 3 2]
nonnoise span length:  [ 8 22 16 29 10]
interleved_span_lengths:  [ 8  6 22  1 16  3 29  3 10  2]
span_starts [ 8 14 36 37 53 56 85 88 98]
span_start_indicator:  [0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 1 0]
span_num:  [0 0 0 0 0 0 0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5 5 5 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6
 6 6 6 6 6 6 6 6 6 6 6 7 7 7 8 8 8 8 8 8 8 8 8 8 9 9]
is_noise,  [False False False False False False False False  True  True  True  True
  True  True False False False False False False False False False False
 False False False False False False False False False False False False
  True False False False False False False False False False False False
 False False False False False  True  True  True False False False False
 False Fal

array([False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
        True, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False,  True,  True,  True, False, False,
       False, False, False, False, False, False, False, False,  True,
        True])