Skip to content

Commit

Permalink
Expose UnexpectedTracerError and add docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
LenaMartens committed Jul 27, 2021
1 parent e7f0307 commit 19ee7b2
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 26 deletions.
3 changes: 2 additions & 1 deletion docs/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ along with representative examples of how one might fix them.
.. autoclass:: ConcretizationTypeError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerIntegerConversionError
.. autoclass:: TracerIntegerConversionError
.. autoclass:: UnexpectedTracerError
3 changes: 2 additions & 1 deletion jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax._src.util import cache, safe_zip, safe_map, split_list
from jax.api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
from jax.core import raise_to_shaped
from jax.errors import UnexpectedTracerError
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
Expand Down Expand Up @@ -507,7 +508,7 @@ def _check_for_tracers(x):
"This behavior recently changed in JAX. "
"See https://github.com/google/jax/blob/main/docs/custom_vjp_update.md "
"for more information.")
raise core.UnexpectedTracerError(msg)
raise UnexpectedTracerError(msg)

@lu.transformation_with_aux
def _flatten_fwd(in_tree, *args):
Expand Down
102 changes: 102 additions & 0 deletions jax/_src/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,105 @@ class TracerIntegerConversionError(JAXTypeError):
def __init__(self, tracer: "core.Tracer"):
super().__init__(
f"The __index__() method was called on the JAX Tracer object {tracer}")


class UnexpectedTracerError(JAXTypeError):
"""
This error occurs when you use a JAX value which has leaked out of a function.
What does it mean to leak a value? If you use a JAX transform on a
function which saves a value to an outer scope through a side-effect, this
will leak a `Tracer`. When you then use this leaked value in a different
operation, an `UnexpectedTracerError` will be thrown.
To fix this, you need to return the value out of the transformed function
explictly.
Tracers are created when you transform a function, eg. with `jit`, `pmap`,
`vmap`, `eval_shape`, ... Intermediate values of these transformed values will
be Tracers, and should not escape this function through a side-effect.
Life-cycle of a leaked Tracer
Consider the following example of a transformed function which leaks a value
to an outer scope::
>>> from jax import jit
>>> import jax.numpy as jnp
>>> outs = []
>>> @jit # 1
... def side_effecting(x):
... y = x+1 # 3
... outs.append(y) # 4
>>> x = 1
>>> side_effecting(x) # 2
>>> outs[0]+1 # 5 # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
UnexpectedTracerError: Encountered an unexpected tracer.
In this example we leak a Traced value from an inner transformed scope to an
outer scope. We get an UnexpectedTracerError not when the value is leaked
but later, when the leaked value is used.
This example also demonstrates the life-cycle of a leaked Tracer:
1. A function is transformed (in this case, jitted)
2. The transformed function is called (kicking of an abstract trace of the
function and turning `x` into a Tracer)
3. The Tracer which will be leaked is created (an intermediate value of
a traced function is also a Tracer)
4. The Tracer is leaked (appended to a list in an outer scope, escaping
the function through a side-channel)
5. The leaked Tracer is used, and an UnexpectedTracerError is thrown.
The UnexpectedTracerError tries to point to these locations in your code by
including information about each stage. Respectively:
1. The name of the transformed function (`side_effecting`) and which
transform kicked of the trace (`jit`).
2. A reconstructed stack-trace of where the Tracer was created, which
includes where the transformed function was called. (`When the Tracer
was created, the final 5 stack frames (most recent last) excluding
JAX-internal frames were...`).
3. See the reconstructed stack-trace. This will point to the line of code
which created the leaked Tracer.
4. Currently not included in the error message, because this is difficult
to pin down! We can only tell you what the leaked value looks like
(what shape is has and where it was created) and what boundary it was
leaked over (the name of the transform and the name of the transformed
function).
5. The actual error stack-trace will point to where the value is used.
The error can be fixed by the returning the value out of the
transformed function::
>>> from jax import jit
>>> import jax.numpy as jnp
>>> outs = []
>>> @jit
... def not_side_effecting(x):
... y = x+1
... return y
>>> x = 1
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0]+1 # all good! no longer a leaked value.
DeviceArray(3, dtype=int32)
Leak checker
As discussed in point 2 and 3 above, we show a reconstructed stack-trace
because we only throw an error when the leaked Tracer is used, not when the
Tracer is leaked. We need to know the location where the Tracer was leaked
to fix the error. The leak checker is a debug option you can use to throw an
error as soon as a Tracer is leaked. (To be more exact, it will throw an
error when the transformed function from which the Tracer is leaked returns)
To enable the leak checker you can use the `JAX_CHECK_TRACER_LEAKS`
environment variable or the `with jax.checking_leaks()` context manager.
Note that this util is experimental and may have some false positives. It
works by disabling some JAX caches, so should only be used when debugging
as it will have a negative effect on performance.
"""

def __init__(self, msg: str):
super().__init__(msg)
4 changes: 1 addition & 3 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ._src import config as jax_config
from ._src.config import FLAGS, config
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError)
TracerIntegerConversionError, UnexpectedTracerError)
from . import linear_util as lu

from jax._src import source_info_util
Expand Down Expand Up @@ -460,8 +460,6 @@ def escaped_tracer_error(tracer, detail=None):
'manager.')
return UnexpectedTracerError(msg)

class UnexpectedTracerError(Exception): pass

class Tracer:
__array_priority__ = 1000
__slots__ = ['_trace', '__weakref__', '_line_info']
Expand Down
3 changes: 2 additions & 1 deletion jax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
ConcretizationTypeError,
NonConcreteBooleanIndexError,
TracerArrayConversionError,
TracerIntegerConversionError)
TracerIntegerConversionError,
UnexpectedTracerError)
3 changes: 2 additions & 1 deletion jax/experimental/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def loop_body(i, acc_arr):
from jax._src.lax import control_flow as lax_control_flow
from jax import tree_util
from jax import numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax._src.util import safe_map

Expand Down Expand Up @@ -415,7 +416,7 @@ def end_tracing_body(self):
in_tracers=in_tracers,
out_tracers=body_out_tracers,
trace=self.trace)
except core.UnexpectedTracerError as e:
except UnexpectedTracerError as e:
if "Tracer not among input tracers" in str(e):
raise ValueError("Body of cond_range or while_range should not use the "
"index variable returned by iterator.") from e
Expand Down
3 changes: 2 additions & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src.abstract_arrays import (make_shaped_array, array_types)
from ..core import (ConcreteArray, ShapedArray, AbstractToken,
Literal, pp_eqn_compact, raise_to_shaped, abstract_token)
from ..errors import UnexpectedTracerError
from jax._src.pprint_util import pp
from .._src.util import (partial, partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map)
Expand Down Expand Up @@ -682,7 +683,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
raise UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
Expand Down
29 changes: 15 additions & 14 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from jax import core, dtypes, lax
from jax._src import api
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters.sharded_jit import PartitionSpec as P
Expand Down Expand Up @@ -2193,13 +2194,13 @@ def helper_save_tracer(self, x):
def test_escaped_tracers_different_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError, "Encountered an unexpected tracer"):
UnexpectedTracerError, "Encountered an unexpected tracer"):
api.jit(lambda x: self._saved_tracer)(0.)

def test_escaped_tracers_cant_lift_sublevels(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer",
re.DOTALL)):
Expand All @@ -2208,7 +2209,7 @@ def test_escaped_tracers_cant_lift_sublevels(self):
def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
Expand All @@ -2220,7 +2221,7 @@ def func1(x):
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
UnexpectedTracerError,
re.compile("Encountered an unexpected tracer",
re.DOTALL)):
api.jit(func1)(2.)
Expand All @@ -2230,7 +2231,7 @@ def func1(x):
api.grad(self.helper_save_tracer)(0.)
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Can't lift",
re.DOTALL)):
api.grad(func1)(2.)
Expand All @@ -2242,7 +2243,7 @@ def func1(x):
return x + self._saved_tracer

with self.assertRaisesRegex(
core.UnexpectedTracerError,
UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer not among input tracers",
re.DOTALL)):
Expand All @@ -2265,7 +2266,7 @@ def f(x, c):
def g():
lax.scan(f, None, None, length=2)

with self.assertRaisesRegex(core.UnexpectedTracerError,
with self.assertRaisesRegex(UnexpectedTracerError,
"was created on line"):
g()

Expand All @@ -2279,24 +2280,24 @@ def f(_, __):

lax.scan(f, None, None, length=2) # leaked a tracer! (of level 1!)

with self.assertRaisesRegex(core.UnexpectedTracerError,
with self.assertRaisesRegex(UnexpectedTracerError,
"was created on line"):
# The following call will try and raise the ones array to the count tracer
# level, which is no longer live.
jax.jit(jnp.add)(jnp.ones(()), count)

def test_escaped_tracer_transform_name(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.jit(self.helper_save_tracer)(1)
_ = self._saved_tracer+1

with self.assertRaisesRegex(core.UnexpectedTracerError,
with self.assertRaisesRegex(UnexpectedTracerError,
"for pmap"):
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
_ = self._saved_tracer+1

with self.assertRaisesRegex(core.UnexpectedTracerError,
with self.assertRaisesRegex(UnexpectedTracerError,
"for eval_shape"):
jax.eval_shape(self.helper_save_tracer, 1)
_ = self._saved_tracer+1
Expand Down Expand Up @@ -3278,7 +3279,7 @@ def g():
api.remat(g)()
api.remat(g)()

with self.assertRaisesRegex(core.UnexpectedTracerError, "global state"):
with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
api.jit(f)()

def test_no_cse_widget_on_primals(self):
Expand Down Expand Up @@ -4715,9 +4716,9 @@ def f_rev(x, cos_y, g):
def g(x, y):
return f(x, y)

with self.assertRaisesRegex(core.UnexpectedTracerError, "custom_vjp"):
with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"):
_ = g(2, 3.)
with self.assertRaisesRegex(core.UnexpectedTracerError, "custom_vjp"):
with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"):
_ = api.grad(g, 1)(2., 3.)

def test_vmap_axes(self):
Expand Down
7 changes: 3 additions & 4 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import jax
from jax._src import api
from jax import core
from jax.errors import UnexpectedTracerError
from jax import lax
from jax import random
from jax import test_util as jtu
Expand Down Expand Up @@ -2722,17 +2723,15 @@ def cond_fun(val):
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)

def test_unexpected_tracer_error(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
"for while_loop"):
with self.assertRaisesRegex(UnexpectedTracerError, "for while_loop"):
lst = []
def side_effecting_body(val):
lst.append(val)
return val+1
lax.while_loop(lambda x: x < 2, side_effecting_body, 1)
lst[0] += 1

with self.assertRaisesRegex(core.UnexpectedTracerError,
"for scan"):
with self.assertRaisesRegex(UnexpectedTracerError, "for scan"):
lst = []
def side_effecting_scan(carry, val):
lst.append(val)
Expand Down

0 comments on commit 19ee7b2

Please sign in to comment.