# Model training and evaluation

A complete example of model training and evaluation.

In [1]:
# Uncomment to emulate an arbitrary number of devices running on CPU.

# import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

In [2]:
import jax
from jax import numpy as jnp
import optax
from redex import combinator as cb
from flax_extra import random
from flax_extra.training import TrainLoop, TrainTask
from flax_extra.evaluation import EvalLoop, EvalTask
from flax_extra.checkpoint import (
    SummaryLogger,
    SummaryWriter,
    CheckpointFileReader,
    CheckpointFileWriter,
    LowestCheckpointFileReader,
    LowestCheckpointFileWriter,
)
from flax_extra.model import RNNLM

In [3]:
MAX_LENGTH = 256
BATCH_SIZE = 32
VOCAB_SIZE = 2 ** 8
D_MODEL = 128

model = RNNLM

collections = dict(
    init=["params","carry","dropout"],
    apply=["params","carry","dropout"],
)
config = dict(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=2,
)

def presudo_data_steam(shape, bounds, rnkey):
    minval, maxval = bounds
    while True:
        x = jax.random.uniform(
            key=rnkey,
            shape=shape,
            minval=minval,
            maxval=maxval,
        ).astype(jnp.int32)
        yield x, x

def categorical_cross_entropy(outputs, targets):
    n_categories = outputs.shape[-1]
    loss = optax.softmax_cross_entropy(
        outputs,
        jax.nn.one_hot(targets, n_categories),
    )
    return jnp.mean(loss)

rnkeyg = random.sequence(seed=0)
train_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))
eval_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))



In [4]:
train_loop = TrainLoop(
    # init=model(**config).init,
    init=CheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init),
    # init=LowestCheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init, metric="lnpp"),
    task=TrainTask(
        apply=model(**config | dict(deterministic=False)).apply,
        optimizer=optax.sgd(learning_rate=0.1, momentum=0.9),
        loss=categorical_cross_entropy,
        data=train_datag,
    ),
    collections=collections,
    mutable_collections=True,
    n_steps_per_checkpoint=5,
    rnkey=next(rnkeyg),
    n_steps=10,
)

process_checkpoint = cb.serial(
    EvalLoop(
        task=EvalTask(
            apply=model(**config).apply,
            metrics=dict(lnpp=categorical_cross_entropy),
            data=eval_datag,
        ),
        collections=collections,
        rnkey=next(rnkeyg),
        n_steps=1,
    ),
    SummaryLogger(),
    SummaryWriter(output_dir="/tmp/tensorboard"),
    CheckpointFileWriter(output_dir="/tmp/checkpoints"),
    LowestCheckpointFileWriter(output_dir="/tmp/checkpoints", metric="lnpp"),
)

for checkpoint in train_loop:
    _ = process_checkpoint(checkpoint)

Total model initialization time is 8.74 seconds.
The lowest value for the metric eval/lnpp is set to inf.
Total number of trainable weights: 328960 ~ 1.2 MB.

Step      1: Ran 1 train steps in 6.15 seconds
Step      1: train seconds_per_step | 6.14884496
Step      1: train gradients_l2norm | 0.00296191
Step      1: train   weights_l2norm | 14.01843071
Step      1: train             loss | 5.54912329
Step      1: eval              lnpp | 5.54803228

Step      6: Ran 5 train steps in 7.33 seconds
Step      6: train seconds_per_step | 1.46684761
Step      6: train gradients_l2norm | 0.00295720
Step      6: train   weights_l2norm | 14.01983833
Step      6: train             loss | 5.54826736
Step      6: eval              lnpp | 5.54797220

Step     10: Ran 4 train steps in 1.34 seconds
Step     10: train seconds_per_step | 0.33537143
Step     10: train gradients_l2norm | 0.00292075
Step     10: train   weights_l2norm | 14.02179623
Step     10: train             loss | 5.54690218
Step     