jax-ai-stackpackages:jax==0.8.0↗️ chex==0.1.91↗️ grain==0.2.13↗️ flax==0.12.0↗️ ml_dtypes==0.5.3optax==0.2.6↗️ orbax-checkpoint==0.11.26↗️ orbax-export==0.0.8↗️
jax-ai-stack[tfds]packages:tensorflow==2.20.0tensorflow_datasets==4.9.9
jax-ai-stack packages:
jax==0.8.0 chex==0.1.91 grain==0.2.13 flax==0.12.0 ml_dtypes==0.5.3optax==0.2.6 orbax-checkpoint==0.11.26 orbax-export==0.0.8 jax-ai-stack[tfds] packages:
tensorflow==2.20.0tensorflow_datasets==4.9.9