# Model training and evaluation

In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

In [2]:
import jax
from jax import numpy as jnp
import optax
from redex import combinator as cb
from flax_extra import random
from flax_extra.training import TrainLoop, TrainTask
from flax_extra.evaluation import EvalLoop, EvalTask
from flax_extra.checkpoint import (
    SummaryLogger,
    SummaryWriter,
    CheckpointFileReader,
    CheckpointFileWriter,
    LowestCheckpointFileReader,
    LowestCheckpointFileWriter,
)
from flax_extra.model import RNNLM

In [3]:
MAX_LENGTH = 256
BATCH_SIZE = 32
VOCAB_SIZE = 2 ** 8
D_MODEL = 128

model = RNNLM

collections = dict(
    init=["params","carry","dropout"],
    apply=["params","carry","dropout"],
)
config = dict(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=2,
)

def presudo_data_steam(shape, bounds, rnkey):
    minval, maxval = bounds
    while True:
        x = jax.random.uniform(
            key=rnkey,
            shape=shape,
            minval=minval,
            maxval=maxval,
        ).astype(jnp.int32)
        yield x, x

def categorical_cross_entropy(outputs, targets):
    n_categories = outputs.shape[-1]
    loss = optax.softmax_cross_entropy(
        outputs,
        jax.nn.one_hot(targets, n_categories),
    )
    return jnp.mean(loss)

rnkeyg = random.sequence(seed=0)
train_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))
eval_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))



In [4]:
train_loop = TrainLoop(
    # init=model(**config).init,
    # init=CheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init),
    init=LowestCheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init, metric="lnpp"),
    task=TrainTask(
        apply=model(**config | dict(deterministic=False)).apply,
        optimizer=optax.sgd(learning_rate=0.1, momentum=0.9),
        loss=categorical_cross_entropy,
        data=train_datag,
    ),
    collections=collections,
    mutable_collections=True,
    n_steps_per_checkpoint=3,
    rnkey=next(rnkeyg),
    n_steps=6,
)

process_checkpoint = cb.serial(
    EvalLoop(
        task=EvalTask(
            apply=model(**config).apply,
            metrics=dict(lnpp=categorical_cross_entropy),
            data=eval_datag,
        ),
        collections=collections,
        rnkey=next(rnkeyg),
        n_steps=2,
    ),
    SummaryLogger(),
    SummaryWriter(output_dir="/tmp/tensorboard"),
    CheckpointFileWriter(output_dir="/tmp/checkpoints"),
    LowestCheckpointFileWriter(output_dir="/tmp/checkpoints", metric="lnpp"),
)

for checkpoint in train_loop:
    _ = process_checkpoint(checkpoint)

A checkpoint was loaded in 0.01 seconds.
Total model initialization time is 10.42 seconds.
The lowest value for the metric eval/lnpp is set to 5.54797554.
Stop training, already reached the total training steps 6.
