diff --git a/jax/_src/util.py b/jax/_src/util.py index d4dff69a52d4..ae11cd9b4c80 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -222,12 +222,28 @@ def wrapper(*args, **kwargs): CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) def weakref_lru_cache(call: Callable, maxsize=2048): + """ + Least recently used cache decorator with weakref support. + + The cache will take a weakref to the first argument of the wrapped function + and strong refs to all subsequent operations. In all other respects it should + behave similar to `functools.lru_cache`. + """ cache: Dict[Any, Any] = {} hits = misses = 0 lock = threading.Lock() def remove_key(tctx, args, kwargs, weak_arg): - del cache[(weak_arg, tctx, args, kwargs)] + k = (weak_arg, tctx, args, kwargs) + try: + # This has a chance to race with the iteration in next(iter(cache)), + # but we cannot lock because GC can get triggered synchronously inside + # a critical section and will not relinquish control until the callback + # has finished. This would lead to a deadlock between this weakref + # cleanup function and any function below which locks. + del cache[k] + except KeyError: + pass def wrapped(weak_arg, *args, **kwargs): nonlocal hits, misses @@ -250,8 +266,22 @@ def wrapped(weak_arg, *args, **kwargs): result = call(weak_arg, *args, **kwargs) with lock: cache[k] = result + num_errors = 0 while len(cache) > maxsize: - del cache[next(iter(cache))] + try: + del_k = next(iter(cache)) + # This happens if a weakref callback happens between iter and + # next. Just ignore the error. WeakKeyDictionary handles this + # by deferring the deletes, but that has a chance at leaking, + # and this solution is easier. + except RuntimeError: + num_errors += 1 + if num_errors > len(cache): + # This must be some other problem. + raise + else: + continue + del cache[del_k] return result def cache_info(): diff --git a/tests/util_test.py b/tests/util_test.py index a639f83b33ee..83fe55559652 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -18,6 +18,7 @@ from jax._src import test_util as jtu from jax.config import config +from jax._src.util import weakref_lru_cache config.parse_flags_with_absl() FLAGS = config.FLAGS @@ -63,6 +64,21 @@ def kw_to_positional(factor, *args, **kwargs): self.assertEqual(dict(three=6, four=8), scaled_kwargs) self.assertEqual(2, out_thunk()) + def test_weakref_lru_cache(self): + @weakref_lru_cache + def example_cached_fn(key): + return object() + + class Key: + def __init__(self): + # Make a GC loop. + self.ref_loop = [self] + + stable_keys = [Key() for _ in range(2049)] + for i in range(10000): + example_cached_fn(stable_keys[i % len(stable_keys)]) + example_cached_fn(Key()) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())