Trax: Train Neural Nets with JAX
Trax: T2T Radically Simpler with JAX
Why? Because T2T has gotten too complex. We are simplifying the main code too, but we wanted to try a more radical step. So you can write code as in pure NumPy and debug directly. So you can easily pinpoint each line where things happen and understand each function. But we also want it to run fast on accelerators, and that's possible with JAX.
Status: preview; things work: models train, checkpoints are saved, TensorBoard has summaries, you can decode. But we are changing a lot every day for now. Please let us know what we should add, delete, keep, change. We plan to move the best parts into core JAX.
- Main library entrypoint:
See our example constructing language models from scratch in a GPU-backed colab notebook at Trax Demo
MLP on MNIST
python -m trax.trainer \ --dataset=mnist \ --model=MLP \ --config="train.train_steps=1000"
Resnet50 on Imagenet
python -m trax.trainer \ --config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin
TransformerDecoder on LM1B
python -m trax.trainer \ --config_file=transformer_lm1b_8gb.gin
Trax differs from T2T
- Configuration is done with
--config_fileas well as
--configfor file overrides.
- Models are defined with
models/. They are made gin-configurable in
- Datasets are simple iterators over batches. Datasets from
tensor2tensorare built-in and can be addressed by name.