Skip to content

Commit

Permalink
Default-enable the Jax persistent compilation cache.
Browse files Browse the repository at this point in the history
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
  • Loading branch information
jax authors committed Nov 27, 2023
1 parent 5274ca9 commit b9b5410
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 38 deletions.
108 changes: 72 additions & 36 deletions jax/_src/compilation_cache.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import threading
from typing import Optional
import zlib

Expand All @@ -24,9 +25,9 @@
except ImportError:
zstandard = None

from jax._src import path as pathlib
from jax._src import cache_key
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.config import config
from jax._src.gfile_cache import GFileCache
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
Expand All @@ -36,30 +37,59 @@

_cache: Optional[CacheInterface] = None

_cache_initialized: bool = False

def initialize_cache(path):
"""Creates a global cache object.
_cache_initialized_mutex = threading.Lock()

Should only be called once per process.

Will throw an assertion error if called a second time with a different path.
def get_file_cache(path: str) -> CacheInterface:
return GFileCache(path)

Only works for GPU and TPU backend as the CPU backend don't
implement yet the serialization API.

Args:
path: path for the cache directory.
def initialize_cache(path) -> None:
"""
global _cache
if _cache is not None and _cache._path == pathlib.Path(path):
logger.warning("Cache already previously initialized at %s", _cache._path)
return
Set the path. To take effect, should be called prior to any calls to
get_executable_and_time() and put_executable_and_time().
"""
config.update("jax_compilation_cache_dir", path)


def _is_cache_enabled() -> bool:
return config.jax_enable_compilation_cache


def _initialize_cache() -> None:
# Attempt to initialize the cache at most once.
global _cache_initialized
with _cache_initialized_mutex:
if _cache_initialized:
logger.info("_initialize_cache: cache has already been initialized!")
return
_cache_initialized = True

# Nothing to do if the cache is disabled.
if not _is_cache_enabled():
logger.warning("_initialize_cache: cache is disabled!")
return

assert (
_cache is None
), f"The cache path has already been initialized to {_cache._path}"
_cache = GFileCache(path)
logger.warning("Initialized persistent compilation cache at %s", path)
global _cache
assert _cache is None, "The cache has already been initialized!"
path: str = config.jax_compilation_cache_dir
# If the path is not set, the cache will not be enabled.
if not path:
return

_cache = get_file_cache(path)
logger.warning("Initialized persistent compilation cache at %s", path)


def _get_cache() -> Optional[CacheInterface]:
# TODO(b/289098047): consider making this an API and changing the callers of
# get_executable_and_time() and put_executable_and_time() to call get_cache()
# and passing the result to them.
if _cache is None:
_initialize_cache() # initialization is done at most once; see above
return _cache


def get_executable_and_time(
Expand All @@ -68,11 +98,11 @@ def get_executable_and_time(
"""Returns the cached executable and its compilation time if present, or None
otherwise.
"""
assert _cache is not None, (
"initialize_cache must be called before you can call"
" get_executable_and_time()"
)
executable_and_time = _cache.get(cache_key)
cache = _get_cache()
if cache is None:
logger.info("get_executable_and_time: cache is disabled/not initialized")
return None, None
executable_and_time = cache.get(cache_key)
if not executable_and_time:
return None, None
if zstandard:
Expand All @@ -94,13 +124,13 @@ def put_executable_and_time(
backend,
compile_time: int
) -> None:
"""Adds the 'executable' and its compilation time to the cache repository,
possibly evicting older entries.
"""Adds the 'executable' and its compilation time to the cache, possibly
evicting older entries.
"""
assert _cache is not None, (
"initialize_cache must be called before you can call"
"put_executable_and_time()"
)
cache = _get_cache()
if cache is None:
logger.info("put_executable_and_time: cache is disabled/not initialized")
return
logger.info(
"Writing %s to persistent compilation cache with key %s.",
module_name,
Expand All @@ -114,7 +144,7 @@ def put_executable_and_time(
executable_and_time = compressor.compress(executable_and_time)
else:
executable_and_time = zlib.compress(executable_and_time)
_cache.put(cache_key, executable_and_time)
cache.put(cache_key, executable_and_time)


def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
Expand All @@ -124,17 +154,23 @@ def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
produce_original_cache_key)


def is_initialized():
"""Return True is there is a cache initialized.
def is_initialized() -> bool:
"""
Return whether the cache is enabled. Initialization can be deferred, so
initialized status is not checked. The name is retained for backwards
compatibility.
"""
return _cache is not None
return _is_cache_enabled()


def reset_cache():
def reset_cache() -> None:
"""Get back to pristine, uninitialized state."""
global _cache
assert is_initialized()
logger.info("Resetting cache at %s.", _cache._path)
global _cache_initialized
logger.info("Resetting cache at %s.",
_cache._path if _cache is not None else "<empty>")
_cache = None
_cache_initialized = False


def combine_executable_and_time(
Expand Down
18 changes: 18 additions & 0 deletions jax/_src/config.py
Expand Up @@ -975,6 +975,24 @@ def _update_jax_memories_thread_local(val):
"deployed, this flag and the original cache-key generation algorithm "
"will be removed.")

enable_compilation_cache = define_bool_state(
name='jax_enable_compilation_cache',
default=True,
help=('If set to False, the compilation cache will be disabled regardless '
'of whether initialize_cache() was called. If set to True, the '
'path could be set to a default value or via a call to '
'initialize_cache().'),
)

compilation_cache_dir = define_string_state(
name='jax_compilation_cache_dir',
default=None,
help=('Path for the cache. '
'Precedence: '
'1. A call to compilation_cache.initialize_cache(). '
'2. The value of this flag set in the command line or by default.'),
)

default_dtype_bits = define_enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
Expand Down
1 change: 1 addition & 0 deletions jax/_src/test_util.py
Expand Up @@ -940,6 +940,7 @@ def setUpClass(cls):
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack
stack.enter_context(config.enable_compilation_cache(True))
stack.enter_context(config.raise_persistent_cache_errors(True))
stack.enter_context(config.persistent_cache_min_compile_time_secs(0))

Expand Down
43 changes: 41 additions & 2 deletions tests/compilation_cache_test.py
Expand Up @@ -59,6 +59,7 @@ def increment_event_count(event):


@jtu.with_config(
jax_enable_compilation_cache=True,
jax_raise_persistent_cache_errors=True,
jax_persistent_cache_min_compile_time_secs=0,
)
Expand Down Expand Up @@ -245,7 +246,7 @@ def test_cache_write_warning(self):

with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._cache.__class__, "put") as mock_put,
mock.patch.object(cc._get_cache().__class__, "put") as mock_put,
warnings.catch_warnings(record=True) as w,
):
mock_put.side_effect = RuntimeError("test error")
Expand All @@ -266,7 +267,7 @@ def test_cache_read_warning(self):

with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._cache.__class__, "get") as mock_get,
mock.patch.object(cc._get_cache().__class__, "get") as mock_get,
warnings.catch_warnings(record=True) as w,
):
mock_get.side_effect = RuntimeError("test error")
Expand Down Expand Up @@ -429,5 +430,43 @@ def test_cache_hits_metric(self, use_original):
- previous_counts["/jax/compilation_cache/cache_hits"],
1)


@jtu.with_config(
jax_enable_compilation_cache=False,
jax_persistent_cache_min_compile_time_secs=0,
)
class CompilationCacheDisabledTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()

# Reset cache if already initialized by JaxTestCase
if cc.is_initialized():
cc.reset_cache()

def tearDown(self):
if cc.is_initialized():
cc.reset_cache()
super().tearDown()

# If the cache is disabled, there should be no files in the cache directory.
# A call to initialize_cache() does not affect this.
def test_jit(self):
# Sequence of flag settings for config.jax_enable_compilation_cache:
# 1. Flag is disabled by @jtu.with_config() above.
# 2. Flag is enabled by JaxTestCase for some test configs
# (see test_util.py).
# We need the flag disabled for this test, so disable it below.
with (
tempfile.TemporaryDirectory() as tmpdir,
config.enable_compilation_cache(False),
):
cc.initialize_cache(tmpdir)
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 0)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b9b5410

Please sign in to comment.