diff --git a/jax/core.py b/jax/core.py index 7104cb342d81..4c755f1ee5b7 100644 --- a/jax/core.py +++ b/jax/core.py @@ -503,7 +503,7 @@ def escaped_tracer_error(tracer, detail=None): f'with shape {tracer.shape} and dtype {tracer.dtype} to escape.\n' 'JAX transformations require that functions explicitly return their ' 'outputs, and disallow saving intermediate values to global state.') - dbg = getattr(tracer._trace.main, 'debug_info', None) + dbg = getattr(tracer, '_debug_info', None) if dbg is not None: msg += ('\nThe function being traced when the value leaked was ' f'{dbg.func_src_info} traced for {dbg.traced_for}.') diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 067a97906782..688fb7efb0a0 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1526,11 +1526,13 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval'] + __slots__ = ['aval', '_debug_info'] def __init__(self, trace, aval, line_info=None): self._trace = trace self._line_info = line_info + # Needed for UnexpectedTracerError. + self._debug_info = self._trace.frame.debug_info self.aval = aval def full_lower(self): @@ -1547,29 +1549,25 @@ def _origin_msg(self): f"{source_info_util.summarize(self._line_info)}") else: invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) - dbg = self._trace.main.debug_info + dbg = self._debug_info if dbg is None: return "" + + origin = (f"The error occurred while tracing the function {dbg.func_src_info} " + f"for {dbg.traced_for}. ") if invar_pos: - origin = (f"While tracing the function {dbg.func_src_info} " - f"for {dbg.traced_for}, " - "this concrete value was not available in Python because it " - f"depends on the value{'s' if len(invar_pos) > 1 else ''} " - f"of {dbg.arg_info(invar_pos)}.") + origin += ("This concrete value was not available in Python because it " + f"depends on the value{'s' if len(invar_pos) > 1 else ''} " + f"of {dbg.arg_info(invar_pos)}.") elif progenitor_eqns: msts = [" operation " f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n" f" from line {source_info_util.summarize(eqn.source_info)}" for eqn in progenitor_eqns[:5]] # show at most 5 - origin = (f"While tracing the function {dbg.func_src_info} " - f"for {dbg.traced_for}, " - "this value became a tracer due to JAX operations on these lines:" - "\n\n" + "\n\n".join(msts)) + origin += ("This value became a tracer due to JAX operations on these lines:" + "\n\n" + "\n\n".join(msts)) if len(progenitor_eqns) > 5: origin += "\n\n(Additional originating lines are not shown.)" - else: - origin = (f"The error occurred while tracing the function {dbg.func_src_info} " - f"for {dbg.traced_for}.") return "\n" + origin def _assert_live(self) -> None: @@ -1591,6 +1589,7 @@ class JaxprStackFrame: eqns: List[JaxprEqn] invars: List[Var] effects: core.Effects + debug_info: Optional[DebugInfo] def __init__(self): self.gensym = core.gensym() @@ -1601,6 +1600,7 @@ def __init__(self): self.eqns = [] # cleared when we pop frame from main self.invars = [] self.effects = set() + self.debug_info = None def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) @@ -1821,10 +1821,13 @@ def process_call(self, call_primitive, f, explicit_tracers, params): # TODO(mattjj): check in_tracers are consistent with f.in_type annotation with core.new_sublevel(): if config.jax_check_tracer_leaks or not config.jax_experimental_subjaxpr_lowering_cache: - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main) + # TODO(lenamartens): Make call_primitive name -> API function name mapping. + # (currently this will display eg. 'xla_call' instead of `jit`) + dbg = debug_info_final(f, call_primitive.name) + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg) else: jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized( - f, self.main).val + f, self.main, call_primitive.name).val if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers) source_info = source_info_util.current() @@ -1862,7 +1865,7 @@ def process_map(self, map_primitive, f, tracers, params): with core.extend_axis_env(axis_name, axis_size, None): # type: ignore with core.new_sublevel(): jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( - f, self.main, reduced_in_avals) + f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name)) ordered_effects = jaxpr.effects & core.ordered_effects if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2069,19 +2072,20 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, *, keep_inputs: Optional[List[bool]] = None): with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore - main.debug_info = debug_info # type: ignore main.jaxpr_stack = () # type: ignore jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs) + fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) del main, fun return jaxpr, out_avals, consts def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, in_avals: Sequence[AbstractValue], *, - keep_inputs: Optional[Sequence[bool]] = None): + keep_inputs: Optional[Sequence[bool]] = None, + debug_info: Optional[DebugInfo] = None): keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() + frame.debug_info = debug_info with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): trace = DynamicJaxprTrace(main, core.cur_sublevel()) in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) @@ -2099,17 +2103,18 @@ def trace_to_jaxpr_dynamic2( fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None ) -> Tuple[Jaxpr, OutputType, List[Any]]: with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore - main.debug_info = debug_info # type: ignore main.jaxpr_stack = () # type: ignore - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main) + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) del main, fun return jaxpr, out_type, consts def trace_to_subjaxpr_dynamic2( - fun: lu.WrappedFun, main: core.MainTrace - ) -> Tuple[Jaxpr, OutputType, List[Any]]: + fun: lu.WrappedFun, main: core.MainTrace, + debug_info: Optional[DebugInfo] = None +) -> Tuple[Jaxpr, OutputType, List[Any]]: in_avals, keep_inputs = unzip2(fun.in_type) frame = JaxprStackFrame() + frame.debug_info = debug_info with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): trace = DynamicJaxprTrace(main, core.cur_sublevel()) in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) @@ -2123,8 +2128,10 @@ def trace_to_subjaxpr_dynamic2( @lu.cache def trace_to_subjaxpr_dynamic2_memoized(fun: lu.WrappedFun, - main: core.MainTrace): - return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main)) + main: core.MainTrace, + traced_for: str): + dbg = debug_info_final(fun, traced_for) + return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main, dbg)) class WrapperForWeakRef: @@ -2148,11 +2155,10 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None, keep_inputs: Optional[Sequence[bool]] = None): with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore - main.debug_info = debug_info # type: ignore main.jaxpr_stack = () # type: ignore with core.new_sublevel(): jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs) + fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) del fun, main return jaxpr, out_avals, consts @@ -2161,10 +2167,9 @@ def trace_to_jaxpr_final2( fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None ) -> Tuple[Jaxpr, OutputType, List[Any]]: with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore - main.debug_info = debug_info # type: ignore main.jaxpr_stack = () # type: ignore with core.new_sublevel(): - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main) + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) del fun, main return jaxpr, out_type, consts diff --git a/tests/api_test.py b/tests/api_test.py index 6583da55898e..b7fdb68ad823 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3233,6 +3233,20 @@ def f(): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f() + def test_concrete_error_with_nested_call(self): + @jax.jit + def f(x, y): + if y: + return x + + @jax.jit + def g(x): + return f(x, True) + + msg = r"on the value of the argument 'y'" + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + g(1) + def test_xla_computation_zeros_doesnt_device_put(self): with jtu.count_device_put() as count: api.xla_computation(lambda: jnp.zeros(3))()