Skip to content

Commit

Permalink
[jax2tf] Bump the default JAX serialization version to 7.
Browse files Browse the repository at this point in the history
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
  • Loading branch information
gnecula authored and jax authors committed Aug 12, 2023
1 parent 580b860 commit cf4e1d4
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).

* Breaking changes:
* jax2tf now uses native serialization by default. See
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,8 @@ def update_thread_local_jit_state(**kw):
# Note: bump the default serialization version at least one month after
# we update XlaCallModule to support the new version, so that serialized
# modules are forward compatible with deployed versions of XlaCallModule.
# Version 6 of XlaCallModule is supported since June 7th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 6),
# Version 7 of XlaCallModule is supported since July 12th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 7),
help=(
'The version number to use for native serialization. This must be '
'within the range of versions supported by the tf.XlaCallModule '
Expand Down
10 changes: 6 additions & 4 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,14 @@ We list here a history of the serialization version numbers:
for some specialized use cases. Used in JAX from May 3rd, 2023
(cl/529106145).
* Version 6 adds support for the `disabled_checks` attribute. This version
mandates a non-empty `platforms` attribute.
Used in JAX since June 13th, 2023 (JAX 0.4.13).
mandates a non-empty `platforms` attribute. Supported by XlaCallModule
since June 7th, 2023 and available in JAX since
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
Available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522) and
available in JAX serialization since July 20th, 2023 (JAX 0.4.14).
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
attribute and enables the shape refinement pass only when the
attribute is present. Supported by XlaCallModule since July 21st, 2023
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,9 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol):

self.run_one_test(func, data,
polymorphic_shapes=("_, b",),
check_results=check_top_k_results)
check_results=check_top_k_results,
# TODO(necula): now also includes shape_assertion
compare_with_current=False)


if __name__ == "__main__":
Expand Down

0 comments on commit cf4e1d4

Please sign in to comment.