Skip to content

JAX release v0.2.26

Compare
Choose a tag to compare
@yashk2810 yashk2810 released this 08 Dec 19:20
· 10787 commits to main since this release
  • Bug fixes:

  • Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).

  • jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).