Skip to content

v0.5.0: JAX/Flax and TPU support

Compare
Choose a tag to compare
@anton-l anton-l released this 13 Oct 17:54
· 3291 commits to main since this release
0679d09

🌾 JAX/Flax integration for super fast Stable Diffusion on TPUs.

We added JAX support for Stable Diffusion! You can now run Stable Diffusion on Colab TPUs (and GPUs too!) for faster inference.

Check out this TPU-ready colab for a Stable Diffusion pipeline: Open In Colab
And a detailed blog post on Stable Diffusion and parallelism in JAX / Flax 🤗 https://huggingface.co/blog/stable_diffusion_jax

The most used models, schedulers and pipelines have been ported to JAX/Flax, namely:

  • Models: FlaxAutoencoderKL, FlaxUNet2DConditionModel
  • Schedulers: FlaxDDIMScheduler, FlaxDDIMScheduler, FlaxPNDMScheduler
  • Pipelines: FlaxStableDiffusionPipeline

Changelog:

🔥 DeepSpeed low-memory training

Thanks to the 🤗 accelerate integration with DeepSpeed, a few of our training examples became even more optimized in terms of VRAM and speed:

✏️ Changelog