Skip to content

Commit

Permalink
Merge pull request #4342 from google:improve-tracer-error
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 333841912
  • Loading branch information
jax authors committed Sep 26, 2020
2 parents aaa5724 + 23a25da commit 5b3cbc5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 11 deletions.
2 changes: 0 additions & 2 deletions jax/core.py
Expand Up @@ -846,8 +846,6 @@ def raise_concretization_error(val: Tracer, context=""):
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
+ context + "\n\n"
+ val._origin_msg() + "\n\n"
+ "You can use transformation parameters such as `static_argnums` for "
"`jit` to avoid tracing particular arguments of transformed functions.\n\n"
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
f"Encountered tracer value: {val}")
raise ConcretizationTypeError(msg)
Expand Down
33 changes: 24 additions & 9 deletions jax/interpreters/partial_eval.py
Expand Up @@ -800,14 +800,24 @@ def _contents(self):
return ()

def _origin_msg(self):
progenitor_eqns = self._trace.frame.find_progenitors(self)
msgs = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
if msgs:
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
if invar_pos:
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this concrete value was not available in Python because it "
"depends on the value of the arguments to "
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
"and the computation of these values is being staged out "
"(that is, delayed rather than executed eagerly).\n\n"
"You can use transformation parameters such as `static_argnums` "
"for `jit` to avoid tracing particular arguments of transformed "
"functions, though at the cost of more recompiles.")
elif progenitor_eqns:
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msgs))
"\n\n" + "\n\n".join(msts))
else:
origin = ("The error occured while tracing the function "
f"{self._trace.main.source_info}.")
Expand All @@ -820,7 +830,7 @@ def _assert_live(self) -> None:

class JaxprStackFrame:
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns']
'tracers', 'eqns', 'invars']

def __init__(self):
self.newvar = core.gensym()
Expand All @@ -829,6 +839,7 @@ def __init__(self):
self.constvar_to_val = {}
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main
self.invars = []

def to_jaxpr(self, in_tracers, out_tracers):
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
Expand All @@ -850,7 +861,10 @@ def find_progenitors(self, tracer):
if produced:
active_vars.difference_update(produced)
active_vars.update(eqn.invars)
return [eqn for eqn in self.eqns if set(eqn.invars) & active_vars]
invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
constvars = active_vars & set(self.constvar_to_val)
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns

def _inline_literals(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
Expand Down Expand Up @@ -890,7 +904,8 @@ def frame(self):
def new_arg(self, aval):
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(aval)
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
self.frame.invars.append(var)
return tracer

def new_const(self, val):
Expand Down
27 changes: 27 additions & 0 deletions tests/api_test.py
Expand Up @@ -1797,6 +1797,33 @@ def f():

f() # doesn't crash

def test_concrete_error_because_arg(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")

@jax.jit
def f(x, y):
if x > y:
return x
else:
return y

msg = r"at flattened positions \[0, 1\]"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)

def test_concrete_error_because_const(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")

@jax.jit
def f():
assert jnp.add(1, 1) > 0

msg = "on these lines"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

def test_xla_computation_zeros_doesnt_device_put(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
Expand Down

0 comments on commit 5b3cbc5

Please sign in to comment.