From 7cd81ca1a8118a821036c5daad073066da94515d Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 12 Aug 2022 19:48:40 -0700 Subject: [PATCH] Allow debug prints in staged out custom derivative functions PiperOrigin-RevId: 467344265 --- jax/_src/custom_derivatives.py | 20 +++++++---- jax/_src/debugging.py | 3 ++ jax/interpreters/partial_eval.py | 4 --- tests/debugging_primitives_test.py | 58 ++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 10 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index c02ab22f93c4..9f70b56b5ac3 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -15,8 +15,8 @@ from functools import update_wrapper, reduce, partial import inspect import operator as op -from typing import (Callable, Generic, Optional, Sequence, Tuple, List, TypeVar, - Any) +from typing import (Any, Callable, Generic, List, Optional, Sequence, Set, + Tuple, TypeVar) from jax import core from jax import linear_util as lu @@ -320,6 +320,8 @@ def _apply_todos(todos, outs): outs = map(core.full_lower, todos_list.pop()(outs)) return outs + +allowed_effects: Set[core.Effect] = set() custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') @@ -329,8 +331,11 @@ def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params): def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): del args, params - if fun_jaxpr.effects: - raise NotImplementedError('Effects not supported in `custom_jvp`.') + disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in + allowed_effects} + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `custom_jvp`: {disallowed_effects}') return fun_jaxpr.out_avals, fun_jaxpr.effects custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr') @@ -690,8 +695,11 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): return core.jaxpr_as_fun(fun_jaxpr)(*args) def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): - if fun_jaxpr.effects: - raise NotImplementedError('Effects not supported in `custom_vjp`.') + disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in + allowed_effects} + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `custom_vjp`: {disallowed_effects}') return fun_jaxpr.out_avals, fun_jaxpr.effects custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr') diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 39f30b00e04d..69057c11055d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -23,6 +23,7 @@ from jax import tree_util from jax import lax from jax._src import ad_checkpoint +from jax._src import custom_derivatives from jax._src import lib as jaxlib from jax._src import util from jax.interpreters import ad @@ -42,6 +43,8 @@ lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT) ad_checkpoint.remat_allowed_effects.add(DebugEffect.PRINT) ad_checkpoint.remat_allowed_effects.add(DebugEffect.ORDERED_PRINT) +custom_derivatives.allowed_effects.add(DebugEffect.PRINT) +custom_derivatives.allowed_effects.add(DebugEffect.ORDERED_PRINT) # `debug_callback_p` is the main primitive for staging out Python callbacks. debug_callback_p = core.Primitive('debug_callback') diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 807bc6229e24..695dc4f4e745 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1914,8 +1914,6 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers): in_avals = [t.aval for t in tracers] with core.new_sublevel(): fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) - if fun_jaxpr.effects: - raise NotImplementedError('Effects not supported in `custom_jvp`.') closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) main_ = ref(self.main) jvp_jaxpr_thunk = _memoize( @@ -1940,8 +1938,6 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): in_avals = [t.aval for t in tracers] with core.new_sublevel(): fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) - if fun_jaxpr.effects: - raise NotImplementedError('Effects not supported in `custom_vjp`.') closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) main_ = ref(self.main) fwd_jaxpr_thunk = _memoize( diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index cc8c5dc1e0f5..bc36cdb1cb1b 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -401,6 +401,64 @@ def policy(prim, *_, **params): # print. self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2) + @jtu.skip_on_devices(*disabled_backends) + def test_debug_print_in_staged_out_custom_jvp(self): + + @jax.jit + def f(x): + @jax.custom_jvp + def g(x): + debug_print("hello: {x}", x=x) + return x + def g_jvp(primals, tangents): + (x,), (t,) = primals, tangents + debug_print("goodbye: {x} {t}", x=x, t=t) + return x, t + g.defjvp(g_jvp) + return g(x) + + with capture_stdout() as output: + f(2.) + jax.effects_barrier() + self.assertEqual(output(), "hello: 2.0\n") + + with capture_stdout() as output: + jax.jvp(f, (2.,), (3.,)) + jax.effects_barrier() + self.assertEqual(output(), "goodbye: 2.0 3.0\n") + + @jtu.skip_on_devices(*disabled_backends) + def test_debug_print_in_staged_out_custom_vjp(self): + + @jax.jit + def f(x): + @jax.custom_vjp + def g(x): + debug_print("hello: {x}", x=x) + return x + def g_fwd(x): + debug_print("hello fwd: {x}", x=x) + return x, x + def g_bwd(x, g): + debug_print("hello bwd: {x} {g}", x=x, g=g) + return (g,) + g.defvjp(fwd=g_fwd, bwd=g_bwd) + return g(x) + + with capture_stdout() as output: + f(2.) + jax.effects_barrier() + self.assertEqual(output(), "hello: 2.0\n") + + with capture_stdout() as output: + _, f_vjp = jax.vjp(f, 2.) + jax.effects_barrier() + self.assertEqual(output(), "hello fwd: 2.0\n") + + with capture_stdout() as output: + f_vjp(3.0) + jax.effects_barrier() + self.assertEqual(output(), "hello bwd: 2.0 3.0\n") class DebugPrintControlFlowTest(jtu.JaxTestCase):