We plan for moving Jax AI Stack examples to NNX docs: - [ ] (Low prio) Part 2: Debug a variational autoencoder (VAE) - [x] Part 3: Train a diffusion model for image generation (@samanklesaria, https://github.com/google/flax/pull/5403) - [x] Visualize JAX model metrics with TensorBoard (@samanklesaria, https://github.com/google/flax/pull/5425) - [x] Introduction to Data Loaders on CPU with JAX / Introduction to Data Loaders on GPU with JAX (@samanklesaria #5454 ) - [x] JAX for PyTorch users / Porting a PyTorch model to JAX (@vfdev-5, https://github.com/google/flax/pull/5408) - [x] Train a miniGPT language model with JAX (@samanklesaria, https://github.com/google/flax/pull/5405) - [ ] Text classification with a transformer language model using JAX - [x] Machine Translation with encoder-decoder transformer model (@samanklesaria, https://github.com/google/flax/pull/5431) - [ ] Image segmentation with UNETR model - [ ] Image Captioning with Vision Transformer (ViT) model (@vfdev-5, ...) - [ ] Train a Vision Transformer (ViT) for image classification with JAX (@vfdev-5, https://github.com/google/flax/pull/5455) - [ ] Time series classification with CNN
We plan for moving Jax AI Stack examples to NNX docs: