From 744f6b4ee802c7475285d5fa07666847a01b8705 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 16 May 2022 12:07:15 -0700 Subject: [PATCH] Update xla_client._version and add missing version checks to JAX PiperOrigin-RevId: 449021408 --- jax/experimental/compilation_cache/compilation_cache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 62b9828dcfd4..3077de6bd8ac 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -113,7 +113,11 @@ def _hash_computation(hash_obj, xla_computation): hash_obj.update(scrubbed_hlo) def _hash_compile_options(hash_obj, compile_options_obj): - assert len(dir(compile_options_obj)) == 32, ( + if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11 + expected_num_compile_options = 32 + else: + expected_num_compile_options = 31 + assert len(dir(compile_options_obj)) == expected_num_compile_options, ( f"Unexpected number of CompileOption fields: " f"{len(dir(compile_options_obj))}. This likely: means that an extra " f"field was added, and this function needs to be updated.") @@ -126,7 +130,8 @@ def _hash_compile_options(hash_obj, compile_options_obj): _hash_bool(hash_obj, compile_options_obj.tuple_arguments) _hash_int(hash_obj, compile_options_obj.num_replicas) _hash_int(hash_obj, compile_options_obj.num_partitions) - _hash_int(hash_obj, compile_options_obj.profile_version) + if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11 + _hash_int(hash_obj, compile_options_obj.profile_version) if compile_options_obj.device_assignment is not None: hash_obj.update(compile_options_obj.device_assignment.serialize())