# Pre-Training a 🤗 Transformers model on TPU with **Flax/JAX**

In this notebook, we will see how to pretrain one of the [🤗 Transformers](https://github.com/huggingface/transformers) models on TPU using [**Flax**](https://flax.readthedocs.io/en/latest/index.html).

The popular masked language modeling (MLM) objective, *cf.* with [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805), will be used as the pre-training objective.

As can be seen on [this benchmark](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.

[**Flax**](https://flax.readthedocs.io/en/latest/index.html) is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as `grad` and `pmap` (see the [Flax philosophy](https://flax.readthedocs.io/en/latest/philosophy.html)). For an introduction to Flax see the [Flax Basic Colab](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) or the list of curated [Flax examples](https://flax.readthedocs.io/en/latest/examples.html).

[**JAX**](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html).

In [1]:
import jax

jax.local_devices()

[StreamExecutorGpuDevice(id=0, process_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0),
 StreamExecutorGpuDevice(id=3, process_index=0)]

In this notebook, we will pre-train an [autoencoding model](https://huggingface.co/transformers/model_summary.html#autoencoding-models) on one of the languages of the  OSCAR corpus. [OSCAR](https://oscar-corpus.com/) is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the *goclassy* architecture.

Let's first select the language that our model should learn.
You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the "*deduplicated*" column on the official [OSCAR](https://oscar-corpus.com/) website.

Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official `run_mlm_flax.py` script offline that can be found [here](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#masked-language-modeling).

Here we select `ko` for Korean 🇰🇷.

In [2]:
language = "ko"

Next, we select the model architecture to be trained from scratch.
Here we choose [**`roberta-base`**](https://huggingface.co/roberta-base), but essentially any auto-encoding model that is available on the [**🤗 hub**](https://huggingface.co/models?filter=masked-lm,jax) in JAX/Flax can be used.

In [3]:
model_config = "roberta-base"

## 1. Defining the model configuration

To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory `"roberta-base-pretrained-ko"`:

In [4]:
model_dir = model_config + f"-pretrained-{language}"

and create it:

In [5]:
from pathlib import Path

Path(model_dir).mkdir(parents=True, exist_ok=True)

Next, we'll download the model configuration:

In [6]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_config)

 and save it to the directory:

In [7]:
config.save_pretrained(f"{model_dir}")

## 2. Training a tokenizer from scratch

One has to pre-process the raw text data to a format that is understandable by the model. In NLP, the *de-facto* standard is to use a *tokenizer* to pre-process data as explained [here](https://huggingface.co/transformers/preprocessing.html).

We can leverage the blazing-fast 🤗 Tokenizer library to train a [**ByteLevelBPETokenizer**](https://medium.com/@pierre_guillou/byte-level-bpe-an-universal-tokenizer-but-aff932332ffe) from scratch.

Let's import the necessary building blocks from `tokenizers` and the `load_dataset` function.

In [8]:
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer

We will store our tokenizer files and model files in a directory, called `model_dir`. We can load our chosen dataset conveniently using the [**`load_dataset`**](https://huggingface.co/docs/datasets/package_reference/loading_methods.html?highlight=load_dataset#datasets.load_dataset) function.

In [9]:
raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_{language}")



  0%|          | 0/1 [00:00<?, ?it/s]

In [10]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'text'],
        num_rows: 3675420
    })
})

In [11]:
raw_dataset.column_names, len(raw_dataset['train'])

({'train': ['id', 'text']}, 3675420)

Having imported the `ByteLevelBPETokenizer`, we instantiate it,

In [12]:
tokenizer = ByteLevelBPETokenizer()

define a training iterator,

In [13]:
def batch_iterator(batch_size=1000):
    for i in range(0, len(raw_dataset), batch_size):
        yield raw_dataset["train"][i: i + batch_size]["text"]

and train the tokenizer by defining `vocab_size` according to our model's configuration along with the `min_frequency` as well as some `special_tokens`:

Finally, we save the trained tokenizer in the model folder.

In [14]:
tokenizer.save(f"{model_dir}/tokenizer.json")

For more information on training tokenizers, see [this](https://huggingface.co/docs/tokenizers/python/v0.10.0/tutorials/python/training_from_memory.html) document.

## 3. Pre-processing the dataset

The trained tokenizer can now be used to pre-process the raw text data. Most auto-encoding models, such as [*BERT*](https://arxiv.org/abs/1810.04805) and [*RoBERTa*](https://arxiv.org/abs/1907.11692), are trained to handle sequences up to `512` tokens. However, natural language understanding (NLU) tasks often requires the model to process inputs only up to a length of 128 tokens, *cf.* [How to Train BERT with an Academic Budget](https://arxiv.org/abs/2104.07705).

Since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 128 here. The raw text data is pre-processed accordingly.

In [15]:
max_seq_length = 128

To cross-validate the model's performance during pre-training, we hold out 5% of the data as the validation set.

Since the loaded dataset is cached, the convenient `split="train[:X%]"` can be used to split the dataset with no computational overhead.

The first 95% percent will be used as the training data:

In [16]:
raw_dataset["train"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[5%:]")



and the final 5% as the validation data.

In [17]:
raw_dataset["validation"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[:5%]")



In [18]:
raw_dataset.column_names, len(raw_dataset['train']), len(raw_dataset['validation'])

({'train': ['id', 'text'], 'validation': ['id', 'text']}, 3491649, 183771)

For demonstration purposes, we will use only the first 10000 samples of the training data and the first 1000 samples of the validation data to not have to wait too much for each cell to be executed.

If you want to run the colab on the **full** dataset, please comment the following cell. Using the full dataset, the notebook will run for *ca.* 12 hours until loss convergence and give a final accuracy of around *50%*. Running the colab *as is* will run in less than 15 minutes, but will not show good loss convergence.

In [19]:
# these cells should be commented out to run on full dataset
raw_dataset["train"] = raw_dataset["train"].select(range(10000))
raw_dataset["validation"] = raw_dataset["validation"].select(range(1000))

Next, we load the previously trained `ByteLevelBPETokenizer` tokenizer to pre-process the raw text data:

In [20]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [21]:
tokenizer

PreTrainedTokenizerFast(name_or_path='roberta-base-pretrained-ko', vocab_size=0, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})

We can then write the function that will preprocess the raw text data. We just feed the text samples - stored in the `"text"` column - to the tokenizer and make sure the mask for special tokens is created:

In [22]:
def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

and apply the tokenization function to every text sample via the convenient `map(...)` function of Datasets. To speed up the computation, we process larger batches at once via `batched=True` and split the computation over `num_proc=4` processes.

**Note**: Running this command on the whole dataset might take up to 10 minutes ☕.

In [23]:
tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)

     



 



  

#3:   0%|          | 0/3 [00:00<?, ?ba/s]

#2:   0%|          | 0/3 [00:00<?, ?ba/s]

      

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

Following [RoBERTa: A Robustly Optimized BERT Pretraining Approach]( https://arxiv.org/abs/1907.11692), our model is pre-trained just with a masked language modeling (MLM) objective which is independent of whether the input sequence ends with a finished or unfinished sentence.

The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size `max_seq_length=128` each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples.

Let's define such a function to group the dataset into equally sized data samples:

In [24]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // max_seq_length) * max_seq_length
    result = {
        k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result

We pass `group_texts` to the `map(...)` function and set `batched=True` to make sure that the function is applied to a large batch of data samples.

**Note**: Running this function on the whole dataset might take up to 50 minutes 🕒.

In [25]:
tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)

      

#0:   0%|          | 0/3 [00:00<?, ?ba/s]

 

#1:   0%|          | 0/3 [00:00<?, ?ba/s]

 

#2:   0%|          | 0/3 [00:00<?, ?ba/s]

#3:   0%|          | 0/3 [00:00<?, ?ba/s]

        

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

Awesome, the data is now fully pre-processed and ready to be used for training 😎.