A toy example of training MNIST with JAX in a nix shell.
I'm dreaming of scientifically reproducible model development between multiple past/present/future collaborators in long-running R&D projects. Historically, I've been bit by past colleagues asking why they're getting nan
when running my old training code on a new system, and learning that simply pinning Python packages in a requirements.txt with pip freeze
is not enough because ML dependencies usually dynamically link to system-wide software such as CUDA/CuDNN, and different versions of those can result in different ML models.
Note: Every instruction that follows assumes
nix
is installed on your system. If it isn't, go and do that first at https://nixos.org/download.html
Install required Python dependencies into a nix package by running
nix-build
in the project root.
After nix-build
, the program is available in result/ and can be started as:
./result/bin/train
Tip: nix run github:carlthome/nix-jax-mnist
will build and run the program directly.
Start a development shell with
nix-shell
within which python
has the packages specified in shell.nix as well as a virtual environment for pip install
:ing additional dependencies when needed.
Thus you can simply run python train.py
as per usual.