Skip to content

Commit

Permalink
Add tracers to LeakChecker error, and filter out false positives this…
Browse files Browse the repository at this point in the history
… way.

If we can't find any hanging tracers in the gc.get_referrers chain, is it
really a leak? Probably not!
  • Loading branch information
LenaMartens committed Jul 29, 2021
1 parent b6e25fa commit 2190734
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
24 changes: 21 additions & 3 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering
import gc
import itertools as it
from weakref import ref
import threading
Expand Down Expand Up @@ -736,6 +737,17 @@ def reset_trace_state() -> bool:
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]

def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
It's possible there's none! eg. there's some cases where JAX itself holds a
reference to `x` inside of a lambda closure, and no tracers were leaked
by the user. In this case an empty list is returned.
"""
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers

@contextmanager
def new_main(trace_type: Type[Trace],
dynamic: bool = False,
Expand All @@ -761,7 +773,9 @@ def new_main(trace_type: Type[Trace],
t = ref(main)
del main
if t() is not None:
raise Exception(f'Leaked trace {t()}')
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')

@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
Expand All @@ -782,7 +796,9 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
t = ref(main)
del main
if t() is not None:
raise Exception('Leaked trace {}'.format(t()))
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')

@contextmanager
def eval_context():
Expand All @@ -802,7 +818,9 @@ def new_sublevel() -> Generator[None, None, None]:
t = ref(sublevel)
del sublevel
if t() is not None:
raise Exception(f'Leaked sublevel {t()}.')
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked sublevel {t()}. Leaked tracer(s): {leaked_tracers}.')

def full_lower(val):
if isinstance(val, Tracer):
Expand Down
17 changes: 17 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,23 @@ def g(x):
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
f(3)

def test_leak_checker_avoids_false_positive_custom_jvp(self):
# see https://github.com/google/jax/issues/5636
with jax.checking_leaks():
@api.custom_jvp
def t(y):
return y

def t_jvp(p, t):
pass

t.defjvp(t_jvp)

@jit
def s(y):
return t(y)
s(3) # doesn't crash

def test_default_backend(self):
first_local_device = api.local_devices()[0]
self.assertEqual(first_local_device.platform, api.default_backend())
Expand Down

0 comments on commit 2190734

Please sign in to comment.