Skip to content

frankroeder/tiny_diffusion_stories

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tiny Diffusion Stories

This project is part of the Google #TPUSprint program and showcases a JAX/Flax/NNX implementation of Masked Diffusion Language Models. This codebase borrows structure and snippets from the PyTorch implementation tiny-diffusion.

Why diffusion language models? Instead of generating predictions autoregressively, the model learns to recover masked tokens by denoising masked blocks in parallel.

Setup

uv sync

Train

uv run main.py --train

Trained model will be saved to:

weights/diffusion_checkpoint.pkl

Generate

uv run main.py --prompt "Once upon a time"

Uses checkpoint from last step if present.

Google Colab TPU Training

Open the Colab notebook.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages