Skip to content

Commit

Permalink
Include compile time along with executable in cache entry.
Browse files Browse the repository at this point in the history
In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit.

Testing: updated tests.
PiperOrigin-RevId: 549439628
  • Loading branch information
jax authors committed Jul 19, 2023
1 parent 5ae3ac2 commit 0c4c020
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 43 deletions.
83 changes: 61 additions & 22 deletions jax/_src/compilation_cache.py
Expand Up @@ -71,49 +71,59 @@ def initialize_cache(path):
logger.warning("Initialized persistent compilation cache at %s", path)


def get_executable(
def get_executable_and_time(
cache_key: str, compile_options, backend
) -> Optional[xla_client.LoadedExecutable]:
"""Returns the cached executable if present, or None otherwise."""
assert (
_cache is not None
), "initialize_cache must be called before you can call get_executable()"
serialized_executable = _cache.get(cache_key)
if not serialized_executable:
return None
) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]:
"""Returns the cached executable and its compilation time if present, or None
otherwise.
"""
assert _cache is not None, (
"initialize_cache must be called before you can call"
" get_executable_and_time()"
)
executable_and_time = _cache.get(cache_key)
if not executable_and_time:
return None, None
if zstandard:
decompressor = zstandard.ZstdDecompressor()
serialized_executable = decompressor.decompress(serialized_executable)
executable_and_time = decompressor.decompress(executable_and_time)
else:
serialized_executable = zlib.decompress(serialized_executable)
executable_and_time = zlib.decompress(executable_and_time)
serialized_executable, compile_time = extract_executable_and_time(
executable_and_time)
xla_executable_deserialized = backend.deserialize_executable(
serialized_executable, compile_options
)
return xla_executable_deserialized
serialized_executable, compile_options)
return xla_executable_deserialized, compile_time


def put_executable(
def put_executable_and_time(
cache_key: str,
module_name: str,
executable: xla_client.LoadedExecutable,
backend,
compile_time: int
) -> None:
"""Adds 'executable' to the cache, possibly evicting older entries."""
assert (
_cache is not None
), "initialize_cache must be called before you can call put_executable()"
"""Adds the 'executable' and its compilation time to the cache repository,
possibly evicting older entries.
"""
assert _cache is not None, (
"initialize_cache must be called before you can call"
"put_executable_and_time()"
)
logger.info(
"Writing %s to persistent compilation cache with key %s.",
module_name,
cache_key,
)
serialized_executable = backend.serialize_executable(executable)
executable_and_time = combine_executable_and_time(
serialized_executable, compile_time)
if zstandard:
compressor = zstandard.ZstdCompressor()
serialized_executable = compressor.compress(serialized_executable)
executable_and_time = compressor.compress(executable_and_time)
else:
serialized_executable = zlib.compress(serialized_executable)
_cache.put(cache_key, serialized_executable)
executable_and_time = zlib.compress(executable_and_time)
_cache.put(cache_key, executable_and_time)


def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
Expand Down Expand Up @@ -375,3 +385,32 @@ def reset_cache():
assert is_initialized()
logger.info("Resetting cache at %s.", _cache._path)
_cache = None


def combine_executable_and_time(
serialized_executable: bytes, compile_time: int
) -> bytes:
"""Given the serialized executable and the compilation time, produce a cache
entry in the format shown below.
The cache entry is of the form:
Byte: 0 1 2 3 4 ...
Content: compilation time serialized executable
(big-endian int)
"""
return int(compile_time).to_bytes(4, byteorder='big') + serialized_executable


def extract_executable_and_time(
exectuable_and_time: bytes
) -> tuple[bytes, int]:
"""Given the cache entry in the format shown below, extract the serialized
executable and the compilation time.
The cache entry 'executable_and_time' is of the form:
Byte: 0 1 2 3 4 ...
Content: compilation time serialized executable
(big-endian int)
"""
return exectuable_and_time[4:], int.from_bytes(
exectuable_and_time[:4], byteorder='big')
29 changes: 18 additions & 11 deletions jax/_src/dispatch.py
Expand Up @@ -500,11 +500,13 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend)

cached_executable = _cache_read(module_name, cache_key, compile_options,
backend)
if cached_executable is not None:
executable, compile_time_retrieved = _cache_read(
module_name, cache_key, compile_options, backend)
if executable is not None:
# TODO(b/289098047): Will instrument a metric which uses the 'compile_time'
# to measure the savings due to the cache hit.
logger.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable
return executable
else:
start_time = time.monotonic()
executable = backend_compile(backend, computation,
Expand All @@ -517,25 +519,30 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,

def _cache_read(
module_name: str, cache_key: str, compile_options, backend
) -> Optional[xc.LoadedExecutable]:
"""Looks up `computation` in the persistent compilation cache."""
) -> tuple[Optional[xc.LoadedExecutable], Optional[int]]:
"""Looks up the `computation` and it's compilation time in the persistent
compilation cache repository.
"""
try:
return compilation_cache.get_executable(cache_key, compile_options, backend)
return compilation_cache.get_executable_and_time(
cache_key, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
return None
return None, None


def _cache_write(cache_key: str,
compile_time_secs: float,
module_name: str,
backend: Backend, executable: xc.LoadedExecutable,
host_callbacks: list[Any]):
"""Writes `serialized_computation` to the persistent compilation cache."""
"""Writes the `serialized_computation` and its compilation time to the
persistent compilation cache repository.
"""
if host_callbacks:
logger.info(
"Not writing persistent cache entry for '%s' because it uses host "
Expand All @@ -557,8 +564,8 @@ def _cache_write(cache_key: str,
compile_time_secs)

try:
compilation_cache.put_executable(cache_key, module_name, executable,
backend)
compilation_cache.put_executable_and_time(
cache_key, module_name, executable, backend, int(compile_time_secs))
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
Expand Down
28 changes: 18 additions & 10 deletions tests/compilation_cache_test.py
Expand Up @@ -46,6 +46,8 @@
config.parse_flags_with_absl()
FLAGS = config.FLAGS

FAKE_COMPILE_TIME = 10


@jtu.with_config(
jax_raise_persistent_cache_errors=True,
Expand Down Expand Up @@ -272,9 +274,10 @@ def test_get_no_executable(self):
)
backend = xla_bridge.get_backend()
key = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertEqual(
cc.get_executable(key, compile_options, backend), None
)
executable, compile_time = cc.get_executable_and_time(
key, compile_options, backend)
self.assertIsNone(executable)
self.assertIsNone(compile_time)

def test_diff_executables(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -287,11 +290,13 @@ def test_diff_executables(self):
backend = xla_bridge.get_backend()
executable1 = backend.compile(computation1, compile_options)
executable2 = backend.compile(computation2, compile_options)
cc.put_executable("key1", "computation1", executable1, backend)
cc.put_executable("key2", "computation2", executable2, backend)
cc.put_executable_and_time(
"key1", "computation1", executable1, backend, FAKE_COMPILE_TIME)
cc.put_executable_and_time(
"key2", "computation2", executable2, backend, FAKE_COMPILE_TIME)
self.assertNotEqual(
cc.get_executable("key1", compile_options, backend),
cc.get_executable("key2", compile_options, backend),
cc.get_executable_and_time("key1", compile_options, backend)[0],
cc.get_executable_and_time("key2", compile_options, backend)[0]
)

def test_put_executable(self):
Expand All @@ -309,8 +314,10 @@ def test_put_executable(self):
backend = xla_bridge.get_backend()
executable = backend.compile(str(computation), compile_options)
key = cc.get_cache_key(computation, devices, compile_options, backend)
cc.put_executable(key, "alambda", executable, backend)
deserialized_executable = cc.get_executable(key, compile_options, backend)
cc.put_executable_and_time(
key, "alambda", executable, backend, FAKE_COMPILE_TIME)
executable_retrieved, compile_time_retrieved = cc.get_executable_and_time(
key, compile_options, backend)
inputs_to_executable = (
np.array(1, dtype=np.int32),
np.array(2, dtype=np.int32),
Expand All @@ -319,9 +326,10 @@ def test_put_executable(self):
executable, inputs_to_executable, backend
)
actual = xla_client.execute_with_python_values(
deserialized_executable, inputs_to_executable, backend
executable_retrieved, inputs_to_executable, backend
)
self.assertEqual(expected, actual)
self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved)

def test_pmap(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit 0c4c020

Please sign in to comment.