diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 87f0fc3ee195..d11669e72281 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -71,49 +71,59 @@ def initialize_cache(path): logger.warning("Initialized persistent compilation cache at %s", path) -def get_executable( +def get_executable_and_time( cache_key: str, compile_options, backend -) -> Optional[xla_client.LoadedExecutable]: - """Returns the cached executable if present, or None otherwise.""" - assert ( - _cache is not None - ), "initialize_cache must be called before you can call get_executable()" - serialized_executable = _cache.get(cache_key) - if not serialized_executable: - return None +) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]: + """Returns the cached executable and its compilation time if present, or None + otherwise. + """ + assert _cache is not None, ( + "initialize_cache must be called before you can call" + " get_executable_and_time()" + ) + executable_and_time = _cache.get(cache_key) + if not executable_and_time: + return None, None if zstandard: decompressor = zstandard.ZstdDecompressor() - serialized_executable = decompressor.decompress(serialized_executable) + executable_and_time = decompressor.decompress(executable_and_time) else: - serialized_executable = zlib.decompress(serialized_executable) + executable_and_time = zlib.decompress(executable_and_time) + serialized_executable, compile_time = extract_executable_and_time( + executable_and_time) xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, compile_options - ) - return xla_executable_deserialized + serialized_executable, compile_options) + return xla_executable_deserialized, compile_time -def put_executable( +def put_executable_and_time( cache_key: str, module_name: str, executable: xla_client.LoadedExecutable, backend, + compile_time: int ) -> None: - """Adds 'executable' to the cache, possibly evicting older entries.""" - assert ( - _cache is not None - ), "initialize_cache must be called before you can call put_executable()" + """Adds the 'executable' and its compilation time to the cache repository, + possibly evicting older entries. + """ + assert _cache is not None, ( + "initialize_cache must be called before you can call" + "put_executable_and_time()" + ) logger.info( "Writing %s to persistent compilation cache with key %s.", module_name, cache_key, ) serialized_executable = backend.serialize_executable(executable) + executable_and_time = combine_executable_and_time( + serialized_executable, compile_time) if zstandard: compressor = zstandard.ZstdCompressor() - serialized_executable = compressor.compress(serialized_executable) + executable_and_time = compressor.compress(executable_and_time) else: - serialized_executable = zlib.compress(serialized_executable) - _cache.put(cache_key, serialized_executable) + executable_and_time = zlib.compress(executable_and_time) + _cache.put(cache_key, executable_and_time) def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): @@ -375,3 +385,32 @@ def reset_cache(): assert is_initialized() logger.info("Resetting cache at %s.", _cache._path) _cache = None + + +def combine_executable_and_time( + serialized_executable: bytes, compile_time: int +) -> bytes: + """Given the serialized executable and the compilation time, produce a cache + entry in the format shown below. + + The cache entry is of the form: + Byte: 0 1 2 3 4 ... + Content: compilation time serialized executable + (big-endian int) + """ + return int(compile_time).to_bytes(4, byteorder='big') + serialized_executable + + +def extract_executable_and_time( + exectuable_and_time: bytes +) -> tuple[bytes, int]: + """Given the cache entry in the format shown below, extract the serialized + executable and the compilation time. + + The cache entry 'executable_and_time' is of the form: + Byte: 0 1 2 3 4 ... + Content: compilation time serialized executable + (big-endian int) + """ + return exectuable_and_time[4:], int.from_bytes( + exectuable_and_time[:4], byteorder='big') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6a1363c96e3b..fc6f27d7eeef 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -500,11 +500,13 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray, cache_key = compilation_cache.get_cache_key( computation, devices, compile_options, backend) - cached_executable = _cache_read(module_name, cache_key, compile_options, - backend) - if cached_executable is not None: + executable, compile_time_retrieved = _cache_read( + module_name, cache_key, compile_options, backend) + if executable is not None: + # TODO(b/289098047): Will instrument a metric which uses the 'compile_time' + # to measure the savings due to the cache hit. logger.info("Persistent compilation cache hit for '%s'", module_name) - return cached_executable + return executable else: start_time = time.monotonic() executable = backend_compile(backend, computation, @@ -517,17 +519,20 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray, def _cache_read( module_name: str, cache_key: str, compile_options, backend -) -> Optional[xc.LoadedExecutable]: - """Looks up `computation` in the persistent compilation cache.""" +) -> tuple[Optional[xc.LoadedExecutable], Optional[int]]: + """Looks up the `computation` and it's compilation time in the persistent + compilation cache repository. + """ try: - return compilation_cache.get_executable(cache_key, compile_options, backend) + return compilation_cache.get_executable_and_time( + cache_key, compile_options, backend) except Exception as ex: if config.jax_raise_persistent_cache_errors: raise warnings.warn( f"Error reading persistent compilation cache entry for " f"'{module_name}': {type(ex).__name__}: {ex}") - return None + return None, None def _cache_write(cache_key: str, @@ -535,7 +540,9 @@ def _cache_write(cache_key: str, module_name: str, backend: Backend, executable: xc.LoadedExecutable, host_callbacks: list[Any]): - """Writes `serialized_computation` to the persistent compilation cache.""" + """Writes the `serialized_computation` and its compilation time to the + persistent compilation cache repository. + """ if host_callbacks: logger.info( "Not writing persistent cache entry for '%s' because it uses host " @@ -557,8 +564,8 @@ def _cache_write(cache_key: str, compile_time_secs) try: - compilation_cache.put_executable(cache_key, module_name, executable, - backend) + compilation_cache.put_executable_and_time( + cache_key, module_name, executable, backend, int(compile_time_secs)) except Exception as ex: if config.jax_raise_persistent_cache_errors: raise diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 25b6e752cf1b..57f109ae3018 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -46,6 +46,8 @@ config.parse_flags_with_absl() FLAGS = config.FLAGS +FAKE_COMPILE_TIME = 10 + @jtu.with_config( jax_raise_persistent_cache_errors=True, @@ -272,9 +274,10 @@ def test_get_no_executable(self): ) backend = xla_bridge.get_backend() key = cc.get_cache_key(computation, devices, compile_options, backend) - self.assertEqual( - cc.get_executable(key, compile_options, backend), None - ) + executable, compile_time = cc.get_executable_and_time( + key, compile_options, backend) + self.assertIsNone(executable) + self.assertIsNone(compile_time) def test_diff_executables(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -287,11 +290,13 @@ def test_diff_executables(self): backend = xla_bridge.get_backend() executable1 = backend.compile(computation1, compile_options) executable2 = backend.compile(computation2, compile_options) - cc.put_executable("key1", "computation1", executable1, backend) - cc.put_executable("key2", "computation2", executable2, backend) + cc.put_executable_and_time( + "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) + cc.put_executable_and_time( + "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) self.assertNotEqual( - cc.get_executable("key1", compile_options, backend), - cc.get_executable("key2", compile_options, backend), + cc.get_executable_and_time("key1", compile_options, backend)[0], + cc.get_executable_and_time("key2", compile_options, backend)[0] ) def test_put_executable(self): @@ -309,8 +314,10 @@ def test_put_executable(self): backend = xla_bridge.get_backend() executable = backend.compile(str(computation), compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) - cc.put_executable(key, "alambda", executable, backend) - deserialized_executable = cc.get_executable(key, compile_options, backend) + cc.put_executable_and_time( + key, "alambda", executable, backend, FAKE_COMPILE_TIME) + executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( + key, compile_options, backend) inputs_to_executable = ( np.array(1, dtype=np.int32), np.array(2, dtype=np.int32), @@ -319,9 +326,10 @@ def test_put_executable(self): executable, inputs_to_executable, backend ) actual = xla_client.execute_with_python_values( - deserialized_executable, inputs_to_executable, backend + executable_retrieved, inputs_to_executable, backend ) self.assertEqual(expected, actual) + self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) def test_pmap(self): with tempfile.TemporaryDirectory() as tmpdir: