diff --git a/jax/_src/api.py b/jax/_src/api.py index d941df54e8dd..049d2f65cf6e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -891,7 +891,7 @@ def computation_maker(*args, **kwargs): if eff not in core.ordered_effects] ordered_effects = [eff for eff in jaxpr.effects if eff in core.ordered_effects] - m, _ = mlir.lower_jaxpr_to_module( + lowering_result = mlir.lower_jaxpr_to_module( f"xla_computation_{fun_name}", core.ClosedJaxpr(jaxpr, consts), unordered_effects=unordered_effects, @@ -906,7 +906,8 @@ def computation_maker(*args, **kwargs): map(xla.sharding_to_proto, out_parts_flat))) should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100) built = xc._xla.mlir.mlir_module_to_xla_computation( - mlir.module_to_string(m), use_tuple_args=should_tuple, + mlir.module_to_string(lowering_result.module), + use_tuple_args=should_tuple, return_tuple=True) out_shapes_flat = [ ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index f907dd66e9ee..3fbd02b7a60e 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -91,30 +91,20 @@ def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], raise ValueError("Transpose doesn't support debugging callbacks.") ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule -def _ordered_effect_lowering(ctx, token, *args, **params): - avals_in = [core.abstract_token, *ctx.avals_in] - avals_out = [core.abstract_token, *ctx.avals_out] - args = (token, *args) - def _callback(token, *flat_args): - out = debug_callback_p.impl(*flat_args, **params) - return (token, *out) - (token, *result), keepalive = mlir.emit_python_callback( - ctx.module_context.platform, _callback, list(args), avals_in, avals_out, - True) - return result, keepalive, token - def debug_callback_lowering(ctx, *args, effect, callback, **params): + + def _callback(*flat_args): + return tuple( + debug_callback_p.impl( + *flat_args, effect=effect, callback=callback, **params)) if effect in core.ordered_effects: token = ctx.tokens_in.get(effect)[0] - result, keepalive, token = _ordered_effect_lowering(ctx, token, - *args, effect=effect, callback=callback, **params) + result, token, keepalive = mlir.emit_python_callback( + ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True) ctx.set_tokens_out(mlir.TokenSet({effect: (token,)})) else: - def _callback(*flat_args): - return tuple(debug_callback_p.impl( - *flat_args, effect=effect, callback=callback, **params)) - result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, - _callback, list(args), ctx.avals_in, ctx.avals_out, True) + result, token, keepalive = mlir.emit_python_callback( + ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True) ctx.module_context.add_keepalive(keepalive) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, @@ -122,6 +112,9 @@ def _callback(*flat_args): if jaxlib.version >= (0, 3, 11): mlir.register_lowering( debug_callback_p, debug_callback_lowering, platform="gpu") +if jaxlib.version >= (0, 3, 15): + mlir.register_lowering( + debug_callback_p, debug_callback_lowering, platform="tpu") def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args, **kwargs): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 5e5d2fc1031b..0a876eddcefb 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -333,7 +333,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, name, None, True, None, None, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=False, ordered_effects=[], - kept_var_idx=kept_var_idx, keepalive=None) + kept_var_idx=kept_var_idx, keepalive=None, host_callbacks=[]) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG @@ -372,17 +372,20 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, if eff not in core.ordered_effects] ordered_effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects] - module, keepalive = mlir.lower_jaxpr_to_module( - module_name, closed_jaxpr, unordered_effects, ordered_effects, - backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, - donated_invars) + lowering_result = mlir.lower_jaxpr_to_module( + module_name, closed_jaxpr, + unordered_effects, ordered_effects, backend.platform, + mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) + module, keepalive, host_callbacks = ( + lowering_result.module, lowering_result.keepalive, + lowering_result.host_callbacks) return XlaComputation( name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=bool(unordered_effects), ordered_effects=ordered_effects, kept_var_idx=kept_var_idx, - keepalive=keepalive) + keepalive=keepalive, host_callbacks=host_callbacks) def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool: @@ -751,7 +754,8 @@ def _execute_replicated(name: str, compiled: XlaExecutable, def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, has_unordered_effects: bool, - ordered_effects: List[core.Effect], kept_var_idx, *args): + ordered_effects: List[core.Effect], kept_var_idx, + host_callbacks, *args): env: Dict[core.Var, Any] = {} pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx) map(env.setdefault, jaxpr.invars, pruned_args) @@ -818,9 +822,15 @@ def compile(self) -> XlaCompiledComputation: return self._executable @profiler.annotate_function -def backend_compile(backend, built_c, options): +def backend_compile(backend, built_c, options, host_callbacks): # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results + if host_callbacks: + return backend.compile(built_c, compile_options=options, + host_callbacks=host_callbacks) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` return backend.compile(built_c, compile_options=options) # TODO(phawkins): update users. @@ -838,7 +848,8 @@ def _dump_ir_to_file(name: str, ir: str): name.write_text(ir) -def compile_or_get_cached(backend, computation, compile_options): +def compile_or_get_cached(backend, computation, compile_options, + host_callbacks): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc @@ -861,7 +872,8 @@ def compile_or_get_cached(backend, computation, compile_options): logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: - compiled = backend_compile(backend, computation, compile_options) + compiled = backend_compile(backend, computation, compile_options, + host_callbacks) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled @@ -875,7 +887,7 @@ def compile_or_get_cached(backend, computation, compile_options): assert isinstance(computation, str) ir_str = computation _dump_ir_to_file(module_name, ir_str) - return backend_compile(backend, computation, compile_options) + return backend_compile(backend, computation, compile_options, host_callbacks) class XlaCompiledComputation(stages.XlaExecutable): @@ -890,21 +902,17 @@ def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call, self.unsafe_call.keepalive = keepalive @staticmethod - def from_xla_computation( - name: str, - xla_computation: Optional[ir.Module], - in_type: Optional[pe.InputType], - out_type: Optional[pe.OutputType], - nreps: int, - device: Optional[Device], - backend: Backend, - tuple_args: bool, - in_avals: Sequence[core.AbstractValue], - out_avals: Sequence[core.AbstractValue], - has_unordered_effects: bool, - ordered_effects: List[core.Effect], - kept_var_idx: Set[int], - keepalive: Optional[Any]) -> XlaCompiledComputation: + def from_xla_computation(name: str, xla_computation: Optional[ir.Module], + in_type: Optional[pe.InputType], + out_type: Optional[pe.OutputType], nreps: int, + device: Optional[Device], backend: Backend, + tuple_args: bool, + in_avals: Sequence[core.AbstractValue], + out_avals: Sequence[core.AbstractValue], + has_unordered_effects: bool, + ordered_effects: List[core.Effect], + kept_var_idx: Set[int], keepalive: Optional[Any], + host_callbacks: List[Any]) -> XlaCompiledComputation: sticky_device = device input_handler = _input_handler(backend, in_type, out_type) result_handler = _result_handler(backend, sticky_device, out_type) @@ -914,7 +922,8 @@ def from_xla_computation( options.parameter_is_tupled_arguments = tuple_args with log_elapsed_time(f"Finished XLA compilation of {name} " "in {elapsed_time} sec"): - compiled = compile_or_get_cached(backend, xla_computation, options) + compiled = compile_or_get_cached(backend, xla_computation, options, + host_callbacks) buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals] if ordered_effects or has_unordered_effects: num_output_tokens = len(ordered_effects) + has_unordered_effects @@ -937,15 +946,15 @@ def xla_executable(self): return self._xla_executable @staticmethod - def from_trivial_jaxpr( - jaxpr, consts, device, in_avals, out_avals, has_unordered_effects, - ordered_effects, kept_var_idx, keepalive: Optional[Any] - ) -> XlaCompiledComputation: + def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals, + has_unordered_effects, ordered_effects, kept_var_idx, + keepalive: Optional[Any], + host_callbacks: List[Any]) -> XlaCompiledComputation: assert keepalive is None result_handlers = map(partial(aval_to_result_handler, device), out_avals) - unsafe_call = partial(_execute_trivial, jaxpr, device, consts, - out_avals, result_handlers, has_unordered_effects, - ordered_effects, kept_var_idx) + unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals, + result_handlers, has_unordered_effects, + ordered_effects, kept_var_idx, host_callbacks) return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call, keepalive) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b6b27751f21f..9e36cbeb8686 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1159,16 +1159,16 @@ def _outside_call_lowering( # inside pmap, but does not work when we just execute on a single device, # because in such executions we always get replica_id == 0. replica_id = mhlo.ReplicaIdOp() - callback_operands = [current_token, replica_id, *args_to_outfeed] + callback_operands = [replica_id, *args_to_outfeed] callback_operand_avals = [ - core.abstract_token, core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]] + core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]] if identity: - callback_flat_results_aval = [core.abstract_token] + callback_flat_results_aval = [] else: - callback_flat_results_aval = [core.abstract_token, *flat_results_aval] + callback_flat_results_aval = [*flat_results_aval] def wrapped_callback(*args): - token, replica_id, *arrays = args + replica_id, *arrays = args result_arrays = _outside_call_run_callback( arrays, xb.local_devices()[replica_id], @@ -1180,13 +1180,13 @@ def wrapped_callback(*args): if identity: # For identity, we do not pass the any results back to the device result_arrays = () - return (token,) + result_arrays + return result_arrays - results, keep_alive = mlir.emit_python_callback(platform, wrapped_callback, - callback_operands, callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type] + results, next_token, keep_alive = mlir.emit_python_callback(ctx, + wrapped_callback, current_token, callback_operands, + callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type] has_side_effect=True) _callback_handler_data.keep_alives.append(keep_alive) - next_token, *results = results # We must put the two tokens at the end if identity: results = list(args_to_outfeed) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 3e57b82eb5cc..d064d1640628 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -24,8 +24,8 @@ import itertools import re import typing -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, - Type, Union, FrozenSet) +from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional, + Sequence, Set, Tuple, Type, Union, FrozenSet) from typing_extensions import Protocol import warnings @@ -372,6 +372,8 @@ class ModuleContext: axis_context: AxisContext name_stack: NameStack keepalives: List[Any] + channel_iterator: Iterator[int] + host_callbacks: List[Any] # Cached primitive lowerings. cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp] @@ -386,11 +388,14 @@ def __init__( axis_context: AxisContext, name_stack: NameStack, keepalives: List[Any], + channel_iterator: Iterator[int], + host_callbacks: List[Any], context: Optional[ir.Context] = None, module: Optional[ir.Module] = None, ip: Optional[ir.InsertionPoint] = None, symbol_table: Optional[ir.SymbolTable] = None, - cached_primitive_lowerings: Optional[Dict[Any, func_dialect.FuncOp]] = None): + cached_primitive_lowerings: Optional[Dict[Any, + func_dialect.FuncOp]] = None): assert platform is not None self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) @@ -401,7 +406,15 @@ def __init__( self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings) + self.channel_iterator = channel_iterator self.keepalives = keepalives + self.host_callbacks = host_callbacks + + def new_channel(self) -> int: + return next(self.channel_iterator) + + def add_host_callback(self, host_callback: Any) -> None: + self.host_callbacks.append(host_callback) def add_keepalive(self, keepalive: Any) -> None: self.keepalives.append(keepalive) @@ -493,17 +506,25 @@ def sharded_aval(aval: core.ShapedArray, return aval.update(tuple(sharded_shape)) +class LoweringResult(NamedTuple): + module: ir.Module + keepalive: Optional[Any] + host_callbacks: List[Any] + + def lower_jaxpr_to_module( - module_name: str, jaxpr: core.ClosedJaxpr, + module_name: str, + jaxpr: core.ClosedJaxpr, unordered_effects: List[core.Effect], ordered_effects: List[core.Effect], platform: str, axis_context: AxisContext, - name_stack: NameStack, donated_args: Sequence[bool], + name_stack: NameStack, + donated_args: Sequence[bool], replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None - ) -> Tuple[ir.Module, Optional[Any]]: +) -> LoweringResult: """Lowers a top-level jaxpr to an MHLO module. Handles the quirks of the argument/return value passing conventions of the @@ -540,9 +561,13 @@ def lower_jaxpr_to_module( msg = f"Donation is not implemented for {platform}.\n{msg}" warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}") + # MHLO channels need to start at 1 + channel_iter = itertools.count(1) # Create a keepalives list that will be mutated during the lowering. keepalives: List[Any] = [] - ctx = ModuleContext(platform, axis_context, name_stack, keepalives) + host_callbacks: List[Any] = [] + ctx = ModuleContext(platform, axis_context, name_stack, keepalives, + channel_iter, host_callbacks) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. @@ -565,7 +590,7 @@ def lower_jaxpr_to_module( input_output_aliases=input_output_aliases) ctx.module.operation.verify() - return ctx.module, ctx.keepalives + return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks) def module_to_string(module: ir.Module) -> str: output = io.StringIO() @@ -600,7 +625,7 @@ def _set_up_aliases(avals_in, avals_out, donated_args): def token_type() -> Sequence[ir.Type]: return [mhlo.TokenType.get()] -def token() -> Token: +def create_token() -> Token: return wrap_singleton_ir_values( mhlo.CreateTokenOp(mhlo.TokenType.get()).result) @@ -625,7 +650,7 @@ def get(self, effect: core.Effect) -> Token: @classmethod def create(cls, effects: Sequence[core.Effect]) -> TokenSet: """Creates a `TokenSet` corresponding to a list of `core.Effect`s.""" - tokens = [token() for _ in effects] + tokens = [create_token() for _ in effects] return TokenSet(zip(effects, tokens)) def items(self) -> Sequence[Tuple[core.Effect, Token]]: @@ -1274,30 +1299,145 @@ def fallback(ctx: LoweringRuleContext, *args, **params): register_lowering(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp) -def emit_python_callback(platform, callback, - operands: List[ir.Value], - operand_avals: List[core.AbstractValue], - result_avals: List[core.AbstractValue], - has_side_effect: bool) -> Tuple[List[ir.Value], Any]: +SEND_TO_HOST_TYPE = 2 +RECV_FROM_HOST_TYPE = 3 + +_dtype_to_xla_type_string_map = { + np.dtype("bool"): "pred", + np.dtype("float16"): "f16", + np.dtype("float32"): "f32", + np.dtype("float64"): "f64", + np.dtype("int8"): "s8", + np.dtype("uint8"): "u8", + np.dtype("int16"): "s16", + np.dtype("uint16"): "u16", + np.dtype("int32"): "s32", + np.dtype("uint32"): "u32", + np.dtype("int64"): "s64", + np.dtype("uint64"): "u64", + dtypes._bfloat16_dtype: "bf16", + np.dtype("complex64"): "c64", + np.dtype("complex128"): "c128", +} + +def _dtype_to_xla_type_string(dtype: np.dtype) -> str: + if dtype not in _dtype_to_xla_type_string_map: + raise NotImplementedError(dtype) + return _dtype_to_xla_type_string_map[dtype] + +def send_to_host(channel: int, token: mhlo.TokenType, operand: Any, + aval: core.ShapedArray, name: str) -> ir.Value: + channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) + send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) + dtype_str = _dtype_to_xla_type_string(aval.dtype) + if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: + raise NotImplementedError("64-bit types not supported.") + send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), + _xla_host_transfer_original_type=ir.StringAttr.get(dtype_str), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) + return send_op.result + + +def receive_from_host(channel: int, token: mhlo.TokenType, + out_aval: core.ShapedArray, name: str) -> ir.Value: + channel_handle = mhlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) + recv_op = mhlo.RecvOp([aval_to_ir_type(out_aval), + mhlo.TokenType.get()], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) + dtype_str = _dtype_to_xla_type_string(out_aval.dtype) + if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: + raise NotImplementedError("64-bit types not supported.") + recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), + _xla_host_transfer_original_type=ir.StringAttr.get(dtype_str), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) + # Token should be at the end of the results + result, token = recv_op.results + return token, result + + +def emit_python_callback( + ctx: LoweringRuleContext, callback, token: Optional[Any], + operands: List[ir.Value], operand_avals: List[core.AbstractValue], + result_avals: List[core.AbstractValue], + has_side_effect: bool) -> Tuple[List[ir.Value], Any, Any]: """Creates an MHLO `CustomCallOp` that calls back to the provided function.""" + platform = ctx.module_context.platform if platform in {"cuda", "rocm"} and jax._src.lib.version < (0, 3, 11): raise ValueError( "`EmitPythonCallback` on CUDA only supported on jaxlib >= 0.3.11") - if platform not in {"cpu", "cuda", "rocm"}: + if platform in {"tpu"} and jax._src.lib.version < (0, 3, 15): + raise ValueError( + "`EmitPythonCallback` on TPU only supported on jaxlib >= 0.3.15") + if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( - "`EmitPythonCallback` only supported on CPU, CUDA, and ROCM backends.") + f"`EmitPythonCallback` not supported on {platform} backend.") backend = xb.get_backend(platform) result_shapes = util.flatten( [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) operand_shapes = util.flatten( [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) - callback_descriptor, keepalive = backend.get_emit_python_callback_descriptor( - callback, operand_shapes, result_shapes) + if platform == "tpu": + if result_avals: + raise NotImplementedError( + "Callback with return values not supported on TPU.") + token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result + send_channels = [] + for operand, operand_aval in zip(operands, operand_avals): + channel = ctx.module_context.new_channel() + token = send_to_host(channel, token, operand, operand_aval, + callback.__name__) + send_channels.append(channel) + recv_channels = [] + recv_channel = ctx.module_context.new_channel() + + # `send-to-host`s can be interleaved by the transfer manager so we add in a + # dummy recv to sequence them (the recv can only happen after all the sends + # are done). We'd like to send back a 0-shaped array to avoid unnecessary + # copies but that currently doesn't work with the transfer + # manager as well. + # TODO(b/238239458): enable sending back a 0-dim array + # TODO(b/238239928): avoid interleaving sends in the transfer manager + def _wrapped_callback(*args, **kwargs): + callback(*args, **kwargs) + return (np.zeros(1, np.float32),) + + dummy_recv_aval = core.ShapedArray((1,), np.float32) + result_shapes = [*result_shapes, xla.aval_to_xla_shapes(dummy_recv_aval)[0]] + token, _ = receive_from_host(recv_channel, token, dummy_recv_aval, + callback.__name__) + recv_channels.append(recv_channel) + opaque = backend.make_python_callback_from_host_send_and_recv( + _wrapped_callback, operand_shapes, result_shapes, send_channels, + recv_channels) + ctx.module_context.add_host_callback(opaque) + return [], token, opaque + result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals]) + wrapped_callback = callback + if token: + + def wrapped_callback(token, *args): # type: ignore + return tuple((token, *callback(*args))) + + operand_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes + ] + result_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes + ] + operands = [token, *operands] + result_types = [token_type()[0], *result_types] + callback_descriptor, keepalive = ( + backend.get_emit_python_callback_descriptor(wrapped_callback, + operand_shapes, + result_shapes)) descriptor_operand = ir_constant( callback_descriptor, canonicalize_types=False) callback_operands = [descriptor_operand, *operands] - result_types = util.flatten( - [aval_to_ir_types(aval) for aval in result_avals]) result_type = ir.TupleType.get_tuple(result_types) call_target_name = ("xla_python_gpu_callback" if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") @@ -1315,7 +1455,9 @@ def emit_python_callback(platform, callback, mhlo.GetTupleElementOp(result, i32_attr(i)).result for i in range(len(result_types)) ] - return results, keepalive + if token: + token, *results = results + return results, token, keepalive # Lax ops missing MLIR lowerings. # # TODO(b/203775215): these are missing from the cHLO dialect. Either add diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 77adf452895c..99011e91fad1 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1089,16 +1089,19 @@ def lower_parallel_callable( raise ValueError("Ordered effects not supported in `pmap`.") unordered_effects = [eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects] - module, keepalive = mlir.lower_jaxpr_to_module( + lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, [], backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars, replicated_args=replicated_args, arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts), result_shardings=_shardings_to_mlir_shardings(parts.out_parts)) + module, keepalive, host_callbacks = ( + lowering_result.module, lowering_result.keepalive, + lowering_result.host_callbacks) return PmapComputation(module, pci=pci, replicas=replicas, parts=parts, shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, - keepalive=keepalive) + keepalive=keepalive, host_callbacks=host_callbacks) class PmapComputation(stages.XlaLowering): @@ -1152,6 +1155,7 @@ def from_hlo(xla_computation, shards: ShardInfo, tuple_args: bool, unordered_effects: List[core.Effect], + host_callbacks: List[Any], keepalive: Any): devices = pci.devices if devices is None: @@ -1270,7 +1274,7 @@ def from_hlo(xla_computation, with dispatch.log_elapsed_time( f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec"): compiled = dispatch.compile_or_get_cached( - pci.backend, xla_computation, compile_options) + pci.backend, xla_computation, compile_options, host_callbacks) handle_args = InputsHandler( compiled.local_devices(), input_sharding_specs, input_indices) execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args, @@ -2350,17 +2354,19 @@ def lower_mesh_computation( raise ValueError("Ordered effects not supported in mesh computations.") unordered_effects = [eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects] - module, keepalive = mlir.lower_jaxpr_to_module( + lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, [], backend.platform, axis_ctx, name_stack, donated_invars, replicated_args=replicated_args, arg_shardings=in_partitions, result_shardings=out_partitions) - + module, keepalive, host_callbacks = ( + lowering_result.module, lowering_result.keepalive, + lowering_result.host_callbacks) return MeshComputation( str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes, spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_global=in_is_global, auto_spmd_lowering=auto_spmd_lowering, - unordered_effects=unordered_effects, + unordered_effects=unordered_effects, host_callbacks=host_callbacks, keepalive=keepalive) @@ -2471,6 +2477,7 @@ def from_hlo(name: str, _allow_propagation_to_outputs: bool, _allow_compile_replicated: bool, unordered_effects: List[core.Effect], + host_callbacks: List[Any], keepalive: Any) -> MeshExecutable: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) @@ -2510,7 +2517,8 @@ def from_hlo(name: str, else: with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " "in {elapsed_time} sec"): - xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options) + xla_executable = dispatch.compile_or_get_cached( + backend, computation, compile_options, host_callbacks) if auto_spmd_lowering or (out_axes and all(_is_unspecified(o) for o in out_axes)): in_axes, out_axes = _get_array_mapping_from_executable(xla_executable, mesh) diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index 21494fc5c0fd..fc97970403e2 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -141,7 +141,7 @@ def _sharded_callable( if eff not in core.ordered_effects] ordered_effects = [eff for eff in jaxpr.effects if eff in core.ordered_effects] - module, _ = mlir.lower_jaxpr_to_module( + lowering_result = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, @@ -151,6 +151,8 @@ def _sharded_callable( donated_args=[False]*len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) + module, host_callbacks = (lowering_result.module, + lowering_result.host_callbacks) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) @@ -166,7 +168,8 @@ def _sharded_callable( compiled = dispatch.backend_compile( xb.get_backend(), built, - xb.get_compile_options(nrep, nparts, device_assignment)) + xb.get_compile_options(nrep, nparts, device_assignment), + host_callbacks) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) diff --git a/tests/BUILD b/tests/BUILD index 6b38531b986c..fad640a0482b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -843,16 +843,28 @@ jax_test( jax_test( name = "jaxpr_effects_test", srcs = ["jaxpr_effects_test.py"], + enable_configs = [ + "gpu", + "cpu", + ], ) jax_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], + enable_configs = [ + "gpu", + "cpu", + ], ) jax_test( name = "debugger_test", srcs = ["debugger_test.py"], + enable_configs = [ + "gpu", + "cpu", + ], ) exports_files( diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 9349d0421567..26925d539b38 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -51,9 +51,13 @@ def tearDownModule(): # TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum # version is >= 0.3.11 -disabled_backends = ["tpu"] +# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum +# version is >= 0.3.15 +disabled_backends = [] if jaxlib.version < (0, 3, 11): disabled_backends.append("gpu") +if jaxlib.version < (0, 3, 15): + disabled_backends.append("tpu") class CliDebuggerTest(jtu.JaxTestCase): @@ -67,6 +71,7 @@ def f(x): return y with self.assertRaises(SystemExit): f(2.) + jax.effects_barrier() @jtu.skip_on_devices(*disabled_backends) def test_debugger_can_continue(self): @@ -77,6 +82,7 @@ def f(x): debugger.breakpoint(stdin=stdin, stdout=stdout) return y f(2.) + jax.effects_barrier() expected = _format_multiline(r""" Entering jaxdb: (jaxdb) """) @@ -95,6 +101,7 @@ def f(x): (jaxdb) DeviceArray(2., dtype=float32) (jaxdb) """) f(jnp.array(2., jnp.float32)) + jax.effects_barrier() self.assertEqual(stdout.getvalue(), expected) @jtu.skip_on_devices(*disabled_backends) @@ -111,6 +118,7 @@ def f(x): (jaxdb) array(2., dtype=float32) (jaxdb) """) f(jnp.array(2., jnp.float32)) + jax.effects_barrier() self.assertEqual(stdout.getvalue(), expected) @jtu.skip_on_devices(*disabled_backends) @@ -127,6 +135,7 @@ def f(x): (jaxdb) (array(2., dtype=float32), array(3., dtype=float32)) (jaxdb) """) f(jnp.array(2., jnp.float32)) + jax.effects_barrier() self.assertEqual(stdout.getvalue(), expected) @jtu.skip_on_devices(*disabled_backends) @@ -139,6 +148,7 @@ def f(x): debugger.breakpoint(stdin=stdin, stdout=stdout) return y f(2.) + jax.effects_barrier() expected = _format_multiline(r""" Entering jaxdb: \(jaxdb\) > .*debugger_test\.py\([0-9]+\) @@ -165,6 +175,7 @@ def f(x): \(jaxdb\) Traceback:.* """) f(2.) + jax.effects_barrier() self.assertRegex(stdout.getvalue(), expected) @jtu.skip_on_devices(*disabled_backends) @@ -203,6 +214,7 @@ def f\(x\): .* \(jaxdb\) """) g(jnp.array(2., jnp.float32)) + jax.effects_barrier() self.assertRegex(stdout.getvalue(), expected) @jtu.skip_on_devices(*disabled_backends) @@ -232,10 +244,14 @@ def g(x): @jtu.skip_on_devices(*disabled_backends) def test_debugger_works_with_vmap(self): stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"]) + # On TPU, the breakpoints can be reordered inside of vmap but can be fixed + # by ordering sends. + # TODO(sharadmv): change back to ordered = False when sends are ordered + ordered = jax.default_backend() == "tpu" def f(x): y = x + 1. - debugger.breakpoint(stdin=stdin, stdout=stdout) + debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered) return 2. * y @jax.jit @@ -250,6 +266,7 @@ def g(x): (jaxdb) array(2., dtype=float32) (jaxdb) """) g(jnp.arange(2., dtype=jnp.float32)) + jax.effects_barrier() self.assertEqual(stdout.getvalue(), expected) @jtu.skip_on_devices(*disabled_backends) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 909a4e7fbcd3..292682566f80 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import collections import functools import io import textwrap @@ -59,9 +60,13 @@ def tearDownModule(): # TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum # version is >= 0.3.11 -disabled_backends = ["tpu"] +# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum +# version is >= 0.3.15 +disabled_backends = [] if jaxlib.version < (0, 3, 11): disabled_backends.append("gpu") +if jaxlib.version < (0, 3, 15): + disabled_backends.append("tpu") class DebugPrintTest(jtu.JaxTestCase): @@ -193,6 +198,13 @@ def f(x): class DebugPrintControlFlowTest(jtu.JaxTestCase): + def _assertLinesEqual(self, text1, text2): + + def _count(lines): + return collections.Counter(lines) + + self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n"))) + @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name="_ordered" if ordered else "", ordered=ordered) for ordered in [False, True])) @@ -220,19 +232,29 @@ def _body(carry, x): def test_can_print_inside_for_loop(self, ordered): def f(x): def _body(i, x): + debug_print("i: {i}", i=i, ordered=ordered) debug_print("x: {x}", x=x, ordered=ordered) return x + 1 return lax.fori_loop(0, 5, _body, x) with capture_stdout() as output: f(2) jax.effects_barrier() - self.assertEqual(output(), _format_multiline(""" + expected = _format_multiline(""" + i: 0 x: 2 + i: 1 x: 3 + i: 2 x: 4 + i: 3 x: 5 + i: 4 x: 6 - """)) + """) + if ordered: + self.assertEqual(output(), expected) + else: + self._assertLinesEqual(output(), expected) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name="_ordered" if ordered else "", ordered=ordered) @@ -333,6 +355,7 @@ def b3(x): return lax.switch(x, (b1, b2, b3), x) with capture_stdout() as output: f(0) + jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" b1: 0 """)) @@ -352,7 +375,11 @@ def b3(x): class DebugPrintParallelTest(jtu.JaxTestCase): def _assertLinesEqual(self, text1, text2): - self.assertSetEqual(set(text1.split("\n")), set(text2.split("\n"))) + + def _count(lines): + return collections.Counter(lines) + + self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n"))) @jtu.skip_on_devices(*disabled_backends) def test_ordered_print_not_supported_in_pmap(self): @@ -394,12 +421,35 @@ def f(x): debug_print("{}", x, ordered=False) f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', axis_resources={'a': 'dev'}) - with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']): + with maps.Mesh(np.array(jax.devices()), ['dev']): with capture_stdout() as output: f(jnp.arange(40)) jax.effects_barrier() lines = [f"{i}\n" for i in range(40)] self._assertLinesEqual(output(), "".join(lines)) + @jtu.skip_on_devices(*disabled_backends) + def test_unordered_print_works_in_pmap_of_while(self): + + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices.") + + @jax.pmap + def f(x): + def cond(x): + return x < 3 + def body(x): + debug_print("hello: {}", x, ordered=False) + return x + 1 + return lax.while_loop(cond, body, x) + + with capture_stdout() as output: + f(jnp.arange(2)) + jax.effects_barrier() + + self._assertLinesEqual( + output(), "hello: 0\nhello: 1\nhello: 2\n" + "hello: 1\nhello: 2\n") + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 53f751d95f2b..139bb83172bf 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -61,9 +61,13 @@ def _(*, effect): # TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum # version is >= 0.3.11 -disabled_backends = ['tpu'] +# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum +# version is >= 0.3.15 +disabled_backends = [] if jaxlib.version < (0, 3, 11): disabled_backends.append('gpu') +if jaxlib.version < (0, 3, 15): + disabled_backends.append('tpu') def trivial_effect_lowering(ctx, *, effect): @@ -108,23 +112,16 @@ def _(*avals, callback, out_avals, effect): def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect): del out_avals + token_in = None if effect in core.ordered_effects: - def _token_callback(token, *args): - out = callback(*args) - flat_out = jax.tree_util.tree_leaves(out) - return (token, *flat_out) token_in = ctx.tokens_in.get(effect)[0] - (token_out, *out_op), keep_alive = mlir.emit_python_callback( - ctx.module_context.platform, _token_callback, - [token_in, *args], [core.abstract_token, *ctx.avals_in], - [core.abstract_token, *ctx.avals_out], True) + + out_op, token_out, keep_alive = mlir.emit_python_callback( + ctx, callback, token_in, list(args), list(ctx.avals_in), + list(ctx.avals_out), True) + if token_out: ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out}))) - else: - out_op, keep_alive = mlir.emit_python_callback( - ctx.module_context.platform, callback, - list(args), list(ctx.avals_in), - list(ctx.avals_out), True) ctx.module_context.add_keepalive(keep_alive) return out_op @@ -297,6 +294,7 @@ def _effect_lowering(ctx, *, effect): ctx.set_tokens_out(ctx.tokens_in) return [] mlir.register_lowering(effect_p, _effect_lowering) + jax.effects_barrier() dispatch.runtime_tokens.clear() def tearDown(self): @@ -590,10 +588,11 @@ def f(x): return callback_p.bind(x, callback=log_value, effect='log', out_avals=[]) f(2.) + jax.effects_barrier() self.assertListEqual(log, [2.]) f(3.) + jax.effects_barrier() self.assertListEqual(log, [2., 3.]) - dispatch.runtime_tokens.block_until_ready() @jtu.skip_on_devices(*disabled_backends) def test_ordered_effect_remains_ordered_across_multiple_devices(self): @@ -623,15 +622,14 @@ def g(x): g(3.) f(jnp.ones((500, 500))) g(3.) - dispatch.runtime_tokens.block_until_ready() + jax.effects_barrier() x_, y_ = float(jnp.log(1.25e8)), 3. expected_log = [x_, y_, x_, y_, x_, y_] self.assertListEqual(log, expected_log) + @jtu.skip_on_devices("tpu") @jtu.skip_on_devices(*disabled_backends) def test_different_threads_get_different_tokens(self): - # TODO(sharadmv): enable this test on GPU and TPU when backends are - # supported if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") tokens = []