This project is part of the Google #TPUSprint program and showcases a JAX/Flax/NNX implementation of Masked Diffusion Language Models. This codebase borrows structure and snippets from the PyTorch implementation tiny-diffusion.
Why diffusion language models? Instead of generating predictions autoregressively, the model learns to recover masked tokens by denoising masked blocks in parallel.
uv syncuv run main.py --trainTrained model will be saved to:
weights/diffusion_checkpoint.pkl
uv run main.py --prompt "Once upon a time"Uses checkpoint from last step if present.
Open the Colab notebook.