Skip to content

Commit

Permalink
Enable batching rule for debug_print
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
sharadmv and mattjj committed May 5, 2022
1 parent 24eb7d8 commit c1a8d7f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
27 changes: 22 additions & 5 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

from jax import core
from jax import tree_util
from jax import lax
from jax._src import util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.lax import control_flow as lcf
import jax.numpy as jnp

DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])

Expand All @@ -37,6 +40,8 @@
debug_callback_p = core.Primitive('debug_callback')
debug_callback_p.multiple_results = True

map, unsafe_map = util.safe_map, map

@debug_callback_p.def_impl
def debug_callback_impl(*flat_args, callback: Callable[..., Any],
effect: DebugEffect, in_tree: tree_util.PyTreeDef):
Expand All @@ -51,11 +56,23 @@ def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
del flat_avals, callback, in_tree
return [], {effect}

def debug_callback_batching_rule(*flat_args, callback: Callable[..., Any],
effect: DebugEffect, in_tree: tree_util.PyTreeDef):
del flat_args, callback, effect, in_tree
# TODO(sharadmv): implement batching rule
raise NotImplementedError('Batching not supported for `debug_callback`.')
def debug_callback_batching_rule(args, dims, **params):
"""Unrolls the debug callback across the mapped axis."""
axis_size = next(x.shape[i] for x, i in zip(args, dims)
if i is not None)
# TODO(sharadmv): implement in terms of rolled loop unstead of
# unrolled.
def get_arg_at_dim(i, dim, arg):
if dim is batching.not_mapped:
# Broadcast unmapped argument
return arg
return lax.index_in_dim(arg, i, axis=dim, keepdims=False)
outs = []
for i in range(axis_size):
args_idx = map(functools.partial(get_arg_at_dim, i), dims, args)
outs.append(debug_callback_p.bind(*args_idx, **params))
outs = [jnp.stack(xs) for xs in zip(*outs)]
return outs, (0,) * len(outs)
batching.primitive_batchers[debug_callback_p] = debug_callback_batching_rule

def debug_callback_jvp_rule(*flat_args, callback: Callable[..., Any],
Expand Down
49 changes: 49 additions & 0 deletions tests/debugging_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import io
import textwrap
from unittest import mock
Expand All @@ -27,6 +28,7 @@
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
import jax.numpy as jnp
import numpy as np

config.parse_flags_with_absl()

Expand Down Expand Up @@ -87,6 +89,53 @@ def f(x):
f(2)
self.assertEqual(output(), "x: 2\n")

@jtu.skip_on_devices("tpu", "gpu")
def test_can_stage_out_ordered_print_with_pytree(self):
@jax.jit
def f(x):
struct = dict(foo=x)
debug_print('x: {}', struct, ordered=True)
with capture_stdout() as output:
f(np.array(2, np.int32))
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")

class DebugPrintTransformationTest(jtu.JaxTestCase):

def test_debug_print_batching(self):
@jax.vmap
def f(x):
debug_print('hello: {}', x)
with capture_stdout() as output:
f(jnp.arange(2))
self.assertEqual(output(), "hello: 0\nhello: 1\n")

def test_debug_print_batching_with_diff_axes(self):
@functools.partial(jax.vmap, in_axes=(0, 1))
def f(x, y):
debug_print('hello: {} {}', x, y)
with capture_stdout() as output:
f(jnp.arange(2), jnp.arange(2)[None])
self.assertEqual(output(), "hello: 0 [0]\nhello: 1 [1]\n")

def tested_debug_print_with_nested_vmap(self):
def f(x):
debug_print('hello: {}', x)
# Call with
# [[0, 1],
# [2, 3],
# [4, 5]]
with capture_stdout() as output:
# Should print over 0-axis then 1-axis
jax.vmap(jax.vmap(f))(jnp.arange(6).reshape((3, 2)))
self.assertEqual(
output(),
"hello: 0\nhello: 2\nhello: 4\nhello: 1\nhello: 3\nhello: 5\n")
with capture_stdout() as output:
# Should print over 1-axis then 0-axis
jax.vmap(jax.vmap(f, in_axes=0), in_axes=1)(jnp.arange(6).reshape((3, 2)))
self.assertEqual(
output(),
"hello: 0\nhello: 1\nhello: 2\nhello: 3\nhello: 4\nhello: 5\n")

class DebugPrintControlFlowTest(jtu.JaxTestCase):

Expand Down

0 comments on commit c1a8d7f

Please sign in to comment.