Skip to content

JAX release v0.3.8

Compare
Choose a tag to compare
@mattjj mattjj released this 30 Apr 03:09
· 9352 commits to main since this release
  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted.
      Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
      for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}#10452)
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of
      {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.