JAX v0.4.29
-
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.pip install jax[cuda12]). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.exportAPI. It is not possible anymore to use
from jax.experimental.export import export, and instead you should use
from jax.experimental import export.
The removed functionality has been deprecated since 0.4.24.
- We anticipate that this will be the last release of JAX and jaxlib
-
Deprecations
jax.sharding.XLACompatibleShardingis deprecated. Please use
jax.sharding.Sharding.jax.experimental.Exported.in_shardingshas been renamed as
jax.experimental.Exported.in_shardings_hlo. Same forout_shardings.
The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core:non_negative_dim,DimSize,Shape - from {mod}
jax.lax:tie_in - from {mod}
jax.nn:normalize - from {mod}
jax.interpreters.xla:backend_specific_translations,
translations,register_translation,xla_destructure,
TranslationRule,TranslationContext,XlaOp.
- from {mod}
- The
tolargument of {func}jax.numpy.linalg.matrix_rankis being
deprecated and will soon be removed. Usertolinstead. - The
rcondargument of {func}jax.numpy.linalg.pinvis being
deprecated and will soon be removed. Usertolinstead. - The deprecated
jax.configsubmodule has been removed. To configure JAX
useimport jaxand then reference the config object viajax.config. - {mod}
jax.randomAPIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}jax.vmapin such cases.
-
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jaxto construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in theExportedobjects.
- Added {func}