Skip to content

Commit

Permalink
Don't try to use persistent compilation cache when running CPU comput…
Browse files Browse the repository at this point in the history
…ations.
  • Loading branch information
skye committed Aug 6, 2021
1 parent 897b920 commit 8190286
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
15 changes: 7 additions & 8 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ def initialize_cache(path, max_cache_size_bytes=32 * 2**30):
_cache = FileSystemCache(path, max_cache_size_bytes)
logging.warning(f"Initialized persistent compilation cache at {path}")

def get_executable(xla_computation, compile_options) -> Optional[xla_client.Executable]:
def get_executable(xla_computation, compile_options, backend) -> Optional[xla_client.Executable]:
"""Returns the cached executable if present, or None otherwise."""
assert _cache is not None, "initialize_cache must be called before you can call get_executable()"
cache_key = get_cache_key(xla_computation, compile_options)
cache_key = get_cache_key(xla_computation, compile_options, backend)
xla_executable_serialized = _cache.get(cache_key)
if not xla_executable_serialized:
return None
backend = jax.lib.xla_bridge.get_backend()
# TODO(skye): xla_computation.get_hlo_module() is the unoptimized HLO but it should
#be optimized
xla_executable_deserialized = backend.deserialize_executable(
Expand All @@ -49,15 +48,15 @@ def get_executable(xla_computation, compile_options) -> Optional[xla_client.Exec
compile_options)
return xla_executable_deserialized

def put_executable(xla_computation, compile_options, executable: xla_client.Executable):
def put_executable(xla_computation, compile_options, executable: xla_client.Executable,
backend):
"""Adds 'executable' to the cache, possibly evicting older entries."""
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
cache_key = get_cache_key(xla_computation, compile_options)
backend = jax.lib.xla_bridge.get_backend()
cache_key = get_cache_key(xla_computation, compile_options, backend)
serialized_executable = backend.serialize_executable(executable)
_cache.put(cache_key, serialized_executable)

def get_cache_key(xla_computation, compile_options) -> str:
def get_cache_key(xla_computation, compile_options, backend) -> str:
"""Creates a hashed string to use as a key to the compilation cache.
get_cache_key takes in the xla_computation and compile_options of a program and hashes
Expand Down Expand Up @@ -89,7 +88,7 @@ def get_cache_key(xla_computation, compile_options) -> str:
hash_obj.update(bytes(jax.lib.version))
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}")
_hash_platform(hash_obj, jax.lib.xla_bridge.get_backend())
_hash_platform(hash_obj, backend)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing the backend: {hash_obj.digest().hex()}")
return hash_obj.digest().hex()
Expand Down
8 changes: 5 additions & 3 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@
def compile_or_get_cached(backend, computation, compile_options):
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
if cc.is_initialized():
cached_executable = cc.get_executable(computation, compile_options)
# Persistent compilation cache only implemented on TPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
if cc.is_initialized() and backend.platform == 'tpu':
cached_executable = cc.get_executable(computation, compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit')
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
cc.put_executable(computation, compile_options, compiled)
cc.put_executable(computation, compile_options, compiled, backend)
return compiled
return backend_compile(backend, computation, compile_options)

Expand Down
30 changes: 17 additions & 13 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,36 @@ def test_same_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
self.assertEqual(cc.get_cache_key(computation, compile_options),
cc.get_cache_key(computation, compile_options))
backend = jax.lib.xla_bridge.get_backend()
self.assertEqual(cc.get_cache_key(computation, compile_options, backend),
cc.get_cache_key(computation, compile_options, backend))

def test_different_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
compile_options_filled = self.filled_compile_options()
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled),
cc.get_cache_key(computation, compile_options_filled))
backend = jax.lib.xla_bridge.get_backend()
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled, backend),
cc.get_cache_key(computation, compile_options_filled, backend))

def test_different_computations(self):
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
compile_options = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
self.assertNotEqual(cc.get_cache_key(computation1, compile_options),
cc.get_cache_key(computation2, compile_options))
backend = jax.lib.xla_bridge.get_backend()
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
cc.get_cache_key(computation2, compile_options, backend))

def test_get_no_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
self.assertEqual(cc.get_executable(computation, compile_options), None)
backend = jax.lib.xla_bridge.get_backend()
self.assertEqual(cc.get_executable(computation, compile_options, backend), None)

def test_diff_executables(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -154,10 +158,10 @@ def test_diff_executables(self):
backend = jax.lib.xla_bridge.get_backend()
executable1 = backend.compile(computation1, compile_options)
executable2 = backend.compile(computation2, compile_options)
cc.put_executable(computation1, compile_options, executable1)
cc.put_executable(computation2, compile_options, executable2)
self.assertNotEqual(cc.get_executable(computation1, compile_options),
cc.get_executable(computation2, compile_options))
cc.put_executable(computation1, compile_options, executable1, backend)
cc.put_executable(computation2, compile_options, executable2, backend)
self.assertNotEqual(cc.get_executable(computation1, compile_options, backend),
cc.get_executable(computation2, compile_options, backend))

def test_put_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -167,8 +171,8 @@ def test_put_executable(self):
num_replicas=1, num_partitions=1)
backend = jax.lib.xla_bridge.get_backend()
executable = backend.compile(computation, compile_options)
cc.put_executable(computation, compile_options, executable)
deserialized_executable = cc.get_executable(computation, compile_options)
cc.put_executable(computation, compile_options, executable, backend)
deserialized_executable = cc.get_executable(computation, compile_options, backend)
inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
expected = jax.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
actual = jax.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
Expand Down

0 comments on commit 8190286

Please sign in to comment.