Skip to content

Commit

Permalink
Fix race condition for weakref destructor by catching rare exceptions.
Browse files Browse the repository at this point in the history
  • Loading branch information
pschuh committed Apr 1, 2022
1 parent d9403f6 commit df1c478
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
34 changes: 32 additions & 2 deletions jax/_src/util.py
Expand Up @@ -222,12 +222,28 @@ def wrapper(*args, **kwargs):
CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])

def weakref_lru_cache(call: Callable, maxsize=2048):
"""
Least recently used cache decorator with weakref support.
The cache will take a weakref to the first argument of the wrapped function
and strong refs to all subsequent operations. In all other respects it should
behave similar to `functools.lru_cache`.
"""
cache: Dict[Any, Any] = {}
hits = misses = 0
lock = threading.Lock()

def remove_key(tctx, args, kwargs, weak_arg):
del cache[(weak_arg, tctx, args, kwargs)]
k = (weak_arg, tctx, args, kwargs)
try:
# This has a chance to race with the iteration in next(iter(cache)),
# but we cannot lock because GC can get triggered synchronously inside
# a critical section and will not relinquish control until the callback
# has finished. This would lead to a deadlock between this weakref
# cleanup function and any function below which locks.
del cache[k]
except KeyError:
pass

def wrapped(weak_arg, *args, **kwargs):
nonlocal hits, misses
Expand All @@ -250,8 +266,22 @@ def wrapped(weak_arg, *args, **kwargs):
result = call(weak_arg, *args, **kwargs)
with lock:
cache[k] = result
num_errors = 0
while len(cache) > maxsize:
del cache[next(iter(cache))]
try:
del_k = next(iter(cache))
# This happens if a weakref callback happens between iter and
# next. Just ignore the error. WeakKeyDictionary handles this
# by deferring the deletes, but that has a chance at leaking,
# and this solution is easier.
except RuntimeError:
num_errors += 1
if num_errors > len(cache):
# This must be some other problem.
raise
else:
continue
del cache[del_k]
return result

def cache_info():
Expand Down
16 changes: 16 additions & 0 deletions tests/util_test.py
Expand Up @@ -18,6 +18,7 @@
from jax._src import test_util as jtu

from jax.config import config
from jax._src.util import weakref_lru_cache
config.parse_flags_with_absl()
FLAGS = config.FLAGS

Expand Down Expand Up @@ -63,6 +64,21 @@ def kw_to_positional(factor, *args, **kwargs):
self.assertEqual(dict(three=6, four=8), scaled_kwargs)
self.assertEqual(2, out_thunk())

def test_weakref_lru_cache(self):
@weakref_lru_cache
def example_cached_fn(key):
return object()

class Key:
def __init__(self):
# Make a GC loop.
self.ref_loop = [self]

stable_keys = [Key() for _ in range(2049)]
for i in range(10000):
example_cached_fn(stable_keys[i % len(stable_keys)])
example_cached_fn(Key())


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit df1c478

Please sign in to comment.