Skip to content

Commit

Permalink
Enable receives in TPU callbacks and add tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 466103306
  • Loading branch information
sharadmv authored and jax authors committed Aug 8, 2022
1 parent 25dd62c commit 2d72dc8
Show file tree
Hide file tree
Showing 4 changed files with 587 additions and 53 deletions.
28 changes: 15 additions & 13 deletions jax/_src/debugging.py
Expand Up @@ -50,17 +50,15 @@
map, unsafe_map = util.safe_map, map

@debug_callback_p.def_impl
def debug_callback_impl(*flat_args, callback: Callable[..., Any],
effect: DebugEffect, in_tree: tree_util.PyTreeDef):
def debug_callback_impl(*args, callback: Callable[..., Any],
effect: DebugEffect):
del effect
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
out = callback(*args, **kwargs)
return tree_util.tree_leaves(out)
return callback(*args)

@debug_callback_p.def_effectful_abstract_eval
def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
effect: DebugEffect, in_tree: tree_util.PyTreeDef):
del flat_avals, callback, in_tree
effect: DebugEffect):
del flat_avals, callback
return [], {effect}

def debug_callback_batching_rule(args, dims, **params):
Expand All @@ -87,8 +85,8 @@ def debug_callback_jvp_rule(primals, tangents, **params):
ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule

def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
effect: DebugEffect, in_tree: tree_util.PyTreeDef):
del flat_args, callback, effect, in_tree
effect: DebugEffect):
del flat_args, callback, effect
raise ValueError("Transpose doesn't support debugging callbacks.")
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule

Expand Down Expand Up @@ -175,19 +173,23 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
of the computation are duplicated or dropped.
Args:
callback: A Python callable.
callback: A Python callable. Its return value will be ignored.
*args: The positional arguments to the callback.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this callback w.r.t.
other ordered callbacks.
**kwargs: The positional arguments to the callback.
**kwargs: The keyword arguments to the callback.
Returns:
The value of `callback(*args, **kwargs)`.
"""
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
return debug_callback_p.bind(*flat_args, callback=callback, effect=effect,
in_tree=in_tree)
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
callback(*args, **kwargs)
return []
return debug_callback_p.bind(*flat_args, callback=_flat_callback,
effect=effect)

class _DebugPrintFormatChecker(string.Formatter):

Expand Down
147 changes: 107 additions & 40 deletions jax/interpreters/mlir.py
Expand Up @@ -1419,13 +1419,91 @@ def receive_from_host(channel: int, token: mhlo.TokenType,
return token, result


def _emit_tpu_python_callback(backend: xb.XlaBackend, ctx: LoweringRuleContext,
callback, token: Optional[Any], operands: List[ir.Value],
operand_avals: List[core.ShapedArray],
operand_shapes: List[xc.Shape],
result_avals: List[core.ShapedArray],
result_shapes: List[xc.Shape], *,
sharding: Optional[xc.OpSharding] = None
) -> Tuple[List[ir.Value], Any, Any]:
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
_wrapped_callback = callback

send_channels = []
if not operand_avals:
# If there are no operands to the callback, we need to insert a dummy send
# op or the callback will never be triggered!
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
# MHLO builder.
callback_without_args = _wrapped_callback
def _wrapped_callback(*args): # pylint: disable=function-redefined
del args
return callback_without_args()
send_channel = ctx.module_context.new_channel()
dummy_send_aval = core.ShapedArray((1,), np.float32)
dummy_send_val = ir_constant(np.zeros(1, np.float32))
operand_shapes = [*operand_shapes,
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval,
callback.__name__, sharding=sharding)
send_channels.append(send_channel)
else:
for operand, operand_aval in zip(operands, operand_avals):
if any(s == 0 for s in operand_aval.shape):
raise NotImplementedError(
"Callbacks with zero-dimensional values not supported on TPU.")
channel = ctx.module_context.new_channel()
token = send_to_host(channel, token, operand, operand_aval,
callback.__name__, sharding=sharding)
send_channels.append(channel)

recv_channels = []
outputs = []
# `send-to-host`s can be interleaved by the transfer manager so we add in a
# dummy recv to sequence them (the recv can only happen after all the sends
# are done). We'd like to send back a 0-shaped array to avoid unnecessary
# copies but that currently doesn't work with the transfer
# manager as well.
# TODO(b/238239458): enable sending back a 0-dim array
# TODO(b/238239928): avoid interleaving sends in the transfer manager
if not result_avals:
callback_without_return_values = _wrapped_callback
def _wrapped_callback(*args): # pylint: disable=function-redefined
callback_without_return_values(*args)
return (np.zeros(1, np.float32),)
recv_channel = ctx.module_context.new_channel()
dummy_recv_aval = core.ShapedArray((1,), np.float32)
result_shapes = [*result_shapes,
xla.aval_to_xla_shapes(dummy_recv_aval)[0]]
token, _ = receive_from_host(recv_channel, token, dummy_recv_aval,
callback.__name__, sharding=sharding)
recv_channels.append(recv_channel)
else:
for result_aval in result_avals:
if any(s == 0 for s in result_aval.shape):
raise NotImplementedError(
"Callbacks with zero-dimensional values not supported on TPU.")
channel = ctx.module_context.new_channel()
assert isinstance(result_aval, core.ShapedArray)
token, out = receive_from_host(channel, token, result_aval,
callback.__name__, sharding=sharding)
outputs.append(out)
recv_channels.append(channel)
opaque = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels)
ctx.module_context.add_host_callback(opaque)
return outputs, token, opaque


def emit_python_callback(
ctx: LoweringRuleContext, callback, token: Optional[Any],
operands: List[ir.Value], operand_avals: List[core.AbstractValue],
result_avals: List[core.AbstractValue],
operands: List[ir.Value], operand_avals: List[core.ShapedArray],
result_avals: List[core.ShapedArray],
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None
) -> Tuple[List[ir.Value], Any, Any]:
"""Creates an MHLO `CustomCallOp` that calls back to the provided function."""
"""Emits MHLO that calls back to a provided Python function."""
platform = ctx.module_context.platform
if platform in {"tpu"} and jax._src.lib.version < (0, 3, 15):
raise ValueError(
Expand All @@ -1438,47 +1516,36 @@ def emit_python_callback(
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
operand_shapes = util.flatten(
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
if platform == "tpu":
if result_avals:
raise NotImplementedError(
"Callback with return values not supported on TPU.")
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
send_channels = []
for operand, operand_aval in zip(operands, operand_avals):
channel = ctx.module_context.new_channel()
token = send_to_host(channel, token, operand, operand_aval,
callback.__name__, sharding=sharding)
send_channels.append(channel)
recv_channels = []
recv_channel = ctx.module_context.new_channel()

# `send-to-host`s can be interleaved by the transfer manager so we add in a
# dummy recv to sequence them (the recv can only happen after all the sends
# are done). We'd like to send back a 0-shaped array to avoid unnecessary
# copies but that currently doesn't work with the transfer
# manager as well.
# TODO(b/238239458): enable sending back a 0-dim array
# TODO(b/238239928): avoid interleaving sends in the transfer manager
def _wrapped_callback(*args, **kwargs):
callback(*args, **kwargs)
return (np.zeros(1, np.float32),)
# First we apply checks to ensure output shapes and dtypes match the expected
# ones.
def _wrapped_callback(*args):
out_vals = callback(*args)
if len(out_vals) != len(result_avals):
raise RuntimeError(
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape))
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value {i}: "
"Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype))
return out_vals

dummy_recv_aval = core.ShapedArray((1,), np.float32)
result_shapes = [*result_shapes, xla.aval_to_xla_shapes(dummy_recv_aval)[0]]
token, _ = receive_from_host(recv_channel, token, dummy_recv_aval,
callback.__name__, sharding=sharding)
recv_channels.append(recv_channel)
opaque = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels)
ctx.module_context.add_host_callback(opaque)
return [], token, opaque
if platform == "tpu":
return _emit_tpu_python_callback(backend, ctx, _wrapped_callback, token,
operands, operand_avals, operand_shapes, result_avals, result_shapes,
sharding=sharding)
result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals])
wrapped_callback = callback
if token:

def wrapped_callback(token, *args): # type: ignore
return tuple((token, *callback(*args)))
callback_without_token = _wrapped_callback
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
return (token, *callback_without_token(*args))

operand_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
Expand All @@ -1489,7 +1556,7 @@ def wrapped_callback(token, *args): # type: ignore
operands = [token, *operands]
result_types = [token_type()[0], *result_types]
callback_descriptor, keepalive = (
backend.get_emit_python_callback_descriptor(wrapped_callback,
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
descriptor_operand = ir_constant(
Expand Down
5 changes: 5 additions & 0 deletions tests/BUILD
Expand Up @@ -873,6 +873,11 @@ jax_test(
],
)

jax_test(
name = "python_callback_test",
srcs = ["python_callback_test.py"],
)

jax_test(
name = "debugger_test",
srcs = ["debugger_test.py"],
Expand Down

0 comments on commit 2d72dc8

Please sign in to comment.