Skip to content

JAX release v0.3.17

Compare
Choose a tag to compare
@jakevdp jakevdp released this 31 Aug 18:23
· 7849 commits to main since this release
  • GitHub commits.
  • Bugs
    • Fix corner case issue in gradient of lax.pow with an exponent of zero (#12041)
  • Breaking changes
    • jax.checkpoint, also known as jax.remat, no longer supports the concrete option, following the previous version's deprecation; see JEP 11830.
  • Changes
    • Added jax.pure_callback that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit or jax.pmap).
  • Deprecations:
    • The deprecated DeviceArray.tile() method has been removed. Use jax.numpy.tile (#11944).
    • DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead.