Trax: Train Neural Nets with JAX

train tracks

Trax: T2T Radically Simpler with JAX

Why? Because T2T has gotten too complex. We are simplifying the main code too, but we wanted to try a more radical step. So you can write code as in pure NumPy and debug directly. So you can easily pinpoint each line where things happen and understand each function. But we also want it to run fast on accelerators, and that's possible with JAX.

Status: preview; things work: models train, checkpoints are saved, TensorBoard has summaries, you can decode. But we are changing a lot every day for now. Please let us know what we should add, delete, keep, change. We plan to move the best parts into core JAX.


  • Script:
  • Main library entrypoint: trax.train


Example Colab

See our example constructing language models from scratch in a GPU-backed colab notebook at Trax Demo


python -m trax.trainer \
  --dataset=mnist \
  --model=MLP \

Resnet50 on Imagenet

python -m trax.trainer \

TransformerDecoder on LM1B

python -m trax.trainer \

How Trax differs from T2T

  • Configuration is done with gin. takes --config_file as well as --config for file overrides.
  • Models are defined with stax in models/. They are made gin-configurable in models/
  • Datasets are simple iterators over batches. Datasets from tensorflow/datasets and tensor2tensor are built-in and can be addressed by name.
