JAX release v0.4.13
NOTE: This is the last JAX release that will include Python 3.8 support
-
Changes
jax.jitnow allowsNoneto be passed toin_shardingsand
out_shardings. The semantics are as follows:- For in_shardings, JAX will mark is as replicated but this behavior
can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- For in_shardings, JAX will mark is as replicated but this behavior
jax.experimental.pjit.pjitalso allowsNoneto be passed to
in_shardingsandout_shardings. The semantics are as follows:- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.- For in_shardings, JAX will mark is as replicated but this behavior
can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- For in_shardings, JAX will mark is as replicated but this behavior
- If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
- If the mesh context manager is not provided, JAX has the freedom to
- Executable.cost_analysis() works on Cloud TPU
- Added a warning if a non-allowlisted
jaxlibplugin is in use. - Added
jax.tree_util.tree_leaves_with_path.
-
Bug fixes
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is namedcudnn89instead ofcudnn88.
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
-
Deprecations
- The
native_serialization_strict_checksparameter to
{func}jax.experimental.jax2tf.convertis deprecated in favor of the
newnative_serializaation_disabled_checks({jax-issue}#16347).
- The