Skip to content

Commit

Permalink
Allow debug prints in staged out custom derivative functions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 467344265
  • Loading branch information
sharadmv authored and jax authors committed Aug 13, 2022
1 parent d42e1b8 commit 7cd81ca
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 10 deletions.
20 changes: 14 additions & 6 deletions jax/_src/custom_derivatives.py
Expand Up @@ -15,8 +15,8 @@
from functools import update_wrapper, reduce, partial
import inspect
import operator as op
from typing import (Callable, Generic, Optional, Sequence, Tuple, List, TypeVar,
Any)
from typing import (Any, Callable, Generic, List, Optional, Sequence, Set,
Tuple, TypeVar)

from jax import core
from jax import linear_util as lu
Expand Down Expand Up @@ -320,6 +320,8 @@ def _apply_todos(todos, outs):
outs = map(core.full_lower, todos_list.pop()(outs))
return outs


allowed_effects: Set[core.Effect] = set()
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')


Expand All @@ -329,8 +331,11 @@ def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):

def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
del args, params
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_jvp`.')
disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in
allowed_effects}
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
return fun_jaxpr.out_avals, fun_jaxpr.effects

custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
Expand Down Expand Up @@ -690,8 +695,11 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
return core.jaxpr_as_fun(fun_jaxpr)(*args)

def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_vjp`.')
disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in
allowed_effects}
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_vjp`: {disallowed_effects}')
return fun_jaxpr.out_avals, fun_jaxpr.effects

custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr')
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/debugging.py
Expand Up @@ -23,6 +23,7 @@
from jax import tree_util
from jax import lax
from jax._src import ad_checkpoint
from jax._src import custom_derivatives
from jax._src import lib as jaxlib
from jax._src import util
from jax.interpreters import ad
Expand All @@ -42,6 +43,8 @@
lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT)
ad_checkpoint.remat_allowed_effects.add(DebugEffect.PRINT)
ad_checkpoint.remat_allowed_effects.add(DebugEffect.ORDERED_PRINT)
custom_derivatives.allowed_effects.add(DebugEffect.PRINT)
custom_derivatives.allowed_effects.add(DebugEffect.ORDERED_PRINT)

# `debug_callback_p` is the main primitive for staging out Python callbacks.
debug_callback_p = core.Primitive('debug_callback')
Expand Down
4 changes: 0 additions & 4 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1914,8 +1914,6 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_jvp`.')
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
jvp_jaxpr_thunk = _memoize(
Expand All @@ -1940,8 +1938,6 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
if fun_jaxpr.effects:
raise NotImplementedError('Effects not supported in `custom_vjp`.')
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
fwd_jaxpr_thunk = _memoize(
Expand Down
58 changes: 58 additions & 0 deletions tests/debugging_primitives_test.py
Expand Up @@ -401,6 +401,64 @@ def policy(prim, *_, **params):
# print.
self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)

@jtu.skip_on_devices(*disabled_backends)
def test_debug_print_in_staged_out_custom_jvp(self):

@jax.jit
def f(x):
@jax.custom_jvp
def g(x):
debug_print("hello: {x}", x=x)
return x
def g_jvp(primals, tangents):
(x,), (t,) = primals, tangents
debug_print("goodbye: {x} {t}", x=x, t=t)
return x, t
g.defjvp(g_jvp)
return g(x)

with capture_stdout() as output:
f(2.)
jax.effects_barrier()
self.assertEqual(output(), "hello: 2.0\n")

with capture_stdout() as output:
jax.jvp(f, (2.,), (3.,))
jax.effects_barrier()
self.assertEqual(output(), "goodbye: 2.0 3.0\n")

@jtu.skip_on_devices(*disabled_backends)
def test_debug_print_in_staged_out_custom_vjp(self):

@jax.jit
def f(x):
@jax.custom_vjp
def g(x):
debug_print("hello: {x}", x=x)
return x
def g_fwd(x):
debug_print("hello fwd: {x}", x=x)
return x, x
def g_bwd(x, g):
debug_print("hello bwd: {x} {g}", x=x, g=g)
return (g,)
g.defvjp(fwd=g_fwd, bwd=g_bwd)
return g(x)

with capture_stdout() as output:
f(2.)
jax.effects_barrier()
self.assertEqual(output(), "hello: 2.0\n")

with capture_stdout() as output:
_, f_vjp = jax.vjp(f, 2.)
jax.effects_barrier()
self.assertEqual(output(), "hello fwd: 2.0\n")

with capture_stdout() as output:
f_vjp(3.0)
jax.effects_barrier()
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")

class DebugPrintControlFlowTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 7cd81ca

Please sign in to comment.