# AirIO Train with WMT Example

This notebook demonstrates training with T5X on the [WMT](https://www.tensorflow.org/datasets/catalog/wmt19_translate) dataset. It performs the following actions:

* Creates an AirIO task that handles:
  * Loading the dataset
  * Mapping raw data to a format suitable for training
  * Tokenizing the text using SeqIO's [`SentencePieceVocabulary`](https://github.com/google/seqio/blob/main/seqio/vocabularies.py)
* Defines a small T5 1.1 model
* Defines a function for training

The training function is called, performing 3 steps of training.

## Imports and constants

In [None]:
import dataclasses
import functools
import tempfile

import airio
from airio import examples
from seqio import vocabularies
from t5x import adafactor
from t5x import gin_utils
from t5x import models
from t5x import partitioning
from t5x import train as train_lib
from t5x import trainer
from t5x import utils
from t5x.examples.t5 import network

In [None]:
_DEFAULT_EXTRA_IDS = 100
_DEFAULT_SPM_PATH = "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model"
_DEFAULT_VOCAB = vocabularies.SentencePieceVocabulary(
    _DEFAULT_SPM_PATH, _DEFAULT_EXTRA_IDS
)
_EVAL_STEPS = 2
_SOURCE_SEQUENCE_LENGTH = 32
_TOTAL_STEPS = 3
_WORKDIR = tempfile.mkdtemp()

## Define a small T5 1.1 model

In [None]:
def get_t5_model(**config_overrides) -> models.EncoderDecoderModel:
  """Returns a small T5 1.1 model."""
  tiny_config = network.T5Config(
      vocab_size=32128,
      dtype="bfloat16",
      emb_dim=8,
      num_heads=4,
      num_encoder_layers=2,
      num_decoder_layers=2,
      head_dim=3,
      mlp_dim=16,
      mlp_activations=("gelu", "linear"),
      dropout_rate=0.0,
      logits_via_embedding=False,
  )
  tiny_config = dataclasses.replace(tiny_config, **config_overrides)
  return models.EncoderDecoderModel(
      module=network.Transformer(tiny_config),
      input_vocabulary=_DEFAULT_VOCAB,
      output_vocabulary=_DEFAULT_VOCAB,
      optimizer_def=adafactor.Adafactor(
          decay_rate=0.8,
          step_offset=0,
          logical_factor_rules=adafactor.standard_logical_factor_rules(),
      ),
  )


## Define a function for training

In [None]:
def create_train_fn(task: airio.dataset_providers.Task):
  """Returns a callable function for training."""
  train_dataset_cfg = utils.DatasetConfig(
      mixture_or_task_name=task,
      task_feature_lengths={
          "inputs": _SOURCE_SEQUENCE_LENGTH,
          "targets": _SOURCE_SEQUENCE_LENGTH
      },
      split="train",
      batch_size=8,
      shuffle=False,
      pack=False,
      use_cached=False,
      seed=0,
  )
  eval_dataset_cfg = utils.DatasetConfig(
      mixture_or_task_name=task,
      task_feature_lengths={
          "inputs": _SOURCE_SEQUENCE_LENGTH,
          "targets": _SOURCE_SEQUENCE_LENGTH
      },
      split="validation",
      batch_size=8,
      shuffle=False,
      pack=False,
      use_cached=False,
      seed=0,
  )
  partitioner = partitioning.PjitPartitioner(num_partitions=4)
  trainer_cls = functools.partial(
      trainer.Trainer,
      learning_rate_fn=utils.create_learning_rate_scheduler(
          factors="constant * rsqrt_decay",
          base_learning_rate=1.0,
          warmup_steps=1000,
      ),
      num_microbatches=None,
  )
  restore_cfg = None
  ckpt_cfg = utils.CheckpointConfig(
      save=utils.SaveCheckpointConfig(
          dtype="float32",
          period=4,
          checkpoint_steps=[0, 1, 2, 3, 4, 80, 97, 100],
      ),
      restore=restore_cfg,
  )
  return functools.partial(
      train_lib.train,
      model=get_t5_model(),
      train_dataset_cfg=train_dataset_cfg,
      train_eval_dataset_cfg=eval_dataset_cfg,
      infer_eval_dataset_cfg=None,
      checkpoint_cfg=ckpt_cfg,
      partitioner=partitioner,
      trainer_cls=trainer_cls,
      total_steps=_TOTAL_STEPS,
      eval_steps=_EVAL_STEPS,
      eval_period=1000,
      random_seed=0,
      summarize_config_fn=gin_utils.summarize_gin_config,
      use_orbax=False,
      gc_period=4,
  )


## Create a Task.

Create an AirIO task that handles dataset loading, raw data mapping, and tokenization.

In [None]:
wmt_task = examples.tasks.get_wmt_19_ende_v003_task()

## Create a training function

In [None]:
train_fn = create_train_fn(wmt_task)

## Run training

In [None]:
step, _ = train_fn(model_dir=_WORKDIR)
print(f"step: {step}")

## Visualize training with TensorBoard

Load the extension.

In [None]:
%load_ext tensorboard

Launch the UI. Note the metrics for `train` and `training_eval`.



In [None]:
%tensorboard --logdir=$_WORKDIR --port=0