Skip to content

Commit

Permalink
Fix ConcretizationError in nested calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
LenaMartens committed Jul 26, 2022
1 parent ec435c7 commit 53dfe35
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
2 changes: 1 addition & 1 deletion jax/core.py
Expand Up @@ -503,7 +503,7 @@ def escaped_tracer_error(tracer, detail=None):
f'with shape {tracer.shape} and dtype {tracer.dtype} to escape.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer._trace.main, 'debug_info', None)
dbg = getattr(tracer, '_debug_info', None)
if dbg is not None:
msg += ('\nThe function being traced when the value leaked was '
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
Expand Down
65 changes: 35 additions & 30 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1526,11 +1526,13 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
return move_binders_to_front(closed_jaxpr, map(op.not_, to_move))

class DynamicJaxprTracer(core.Tracer):
__slots__ = ['aval']
__slots__ = ['aval', '_debug_info']

def __init__(self, trace, aval, line_info=None):
self._trace = trace
self._line_info = line_info
# Needed for UnexpectedTracerError.
self._debug_info = self._trace.frame.debug_info
self.aval = aval

def full_lower(self):
Expand All @@ -1547,29 +1549,25 @@ def _origin_msg(self):
f"{source_info_util.summarize(self._line_info)}")
else:
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
dbg = self._trace.main.debug_info
dbg = self._debug_info
if dbg is None:
return ""

origin = (f"The error occurred while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}. ")
if invar_pos:
origin = (f"While tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}, "
"this concrete value was not available in Python because it "
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {dbg.arg_info(invar_pos)}.")
origin += ("This concrete value was not available in Python because it "
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {dbg.arg_info(invar_pos)}.")
elif progenitor_eqns:
msts = [" operation "
f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns[:5]] # show at most 5
origin = (f"While tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}, "
"this value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msts))
origin += ("This value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msts))
if len(progenitor_eqns) > 5:
origin += "\n\n(Additional originating lines are not shown.)"
else:
origin = (f"The error occurred while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}.")
return "\n" + origin

def _assert_live(self) -> None:
Expand All @@ -1591,6 +1589,7 @@ class JaxprStackFrame:
eqns: List[JaxprEqn]
invars: List[Var]
effects: core.Effects
debug_info: Optional[DebugInfo]

def __init__(self):
self.gensym = core.gensym()
Expand All @@ -1601,6 +1600,7 @@ def __init__(self):
self.eqns = [] # cleared when we pop frame from main
self.invars = []
self.effects = set()
self.debug_info = None

def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
Expand Down Expand Up @@ -1821,10 +1821,13 @@ def process_call(self, call_primitive, f, explicit_tracers, params):
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
with core.new_sublevel():
if config.jax_check_tracer_leaks or not config.jax_experimental_subjaxpr_lowering_cache:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
# TODO(lenamartens): Make call_primitive name -> API function name mapping.
# (currently this will display eg. 'xla_call' instead of `jit`)
dbg = debug_info_final(f, call_primitive.name)
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
else:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized(
f, self.main).val
f, self.main, call_primitive.name).val
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *in_tracers)
source_info = source_info_util.current()
Expand Down Expand Up @@ -1862,7 +1865,7 @@ def process_map(self, map_primitive, f, tracers, params):
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = jaxpr.effects & core.ordered_effects
if ordered_effects:
raise ValueError("Ordered effects not supported for "
Expand Down Expand Up @@ -2069,19 +2072,20 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
*,
keep_inputs: Optional[List[bool]] = None):
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs)
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
del main, fun
return jaxpr, out_avals, consts

def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[Sequence[bool]] = None):
keep_inputs: Optional[Sequence[bool]] = None,
debug_info: Optional[DebugInfo] = None):
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs

frame = JaxprStackFrame()
frame.debug_info = debug_info
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
Expand All @@ -2099,17 +2103,18 @@ def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
) -> Tuple[Jaxpr, OutputType, List[Any]]:
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
del main, fun
return jaxpr, out_type, consts

def trace_to_subjaxpr_dynamic2(
fun: lu.WrappedFun, main: core.MainTrace
) -> Tuple[Jaxpr, OutputType, List[Any]]:
fun: lu.WrappedFun, main: core.MainTrace,
debug_info: Optional[DebugInfo] = None
) -> Tuple[Jaxpr, OutputType, List[Any]]:
in_avals, keep_inputs = unzip2(fun.in_type)
frame = JaxprStackFrame()
frame.debug_info = debug_info
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
Expand All @@ -2123,8 +2128,10 @@ def trace_to_subjaxpr_dynamic2(

@lu.cache
def trace_to_subjaxpr_dynamic2_memoized(fun: lu.WrappedFun,
main: core.MainTrace):
return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main))
main: core.MainTrace,
traced_for: str):
dbg = debug_info_final(fun, traced_for)
return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main, dbg))


class WrapperForWeakRef:
Expand All @@ -2148,11 +2155,10 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
debug_info: Optional[DebugInfo] = None,
keep_inputs: Optional[Sequence[bool]] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs)
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
del fun, main
return jaxpr, out_avals, consts

Expand All @@ -2161,10 +2167,9 @@ def trace_to_jaxpr_final2(
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
) -> Tuple[Jaxpr, OutputType, List[Any]]:
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
del fun, main
return jaxpr, out_type, consts

Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Expand Up @@ -3233,6 +3233,20 @@ def f():
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

def test_concrete_error_with_nested_call(self):
@jax.jit
def f(x, y):
if y:
return x

@jax.jit
def g(x):
return f(x, True)

msg = r"on the value of the argument 'y'"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)

def test_xla_computation_zeros_doesnt_device_put(self):
with jtu.count_device_put() as count:
api.xla_computation(lambda: jnp.zeros(3))()
Expand Down

0 comments on commit 53dfe35

Please sign in to comment.