Skip to content

Commit

Permalink
Enable Python callbacks on TFRT TPU backend
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 459415455
  • Loading branch information
sharadmv authored and jax authors committed Jul 7, 2022
1 parent 5d379bb commit 6274b9e
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 120 deletions.
5 changes: 3 additions & 2 deletions jax/_src/api.py
Expand Up @@ -891,7 +891,7 @@ def computation_maker(*args, **kwargs):
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in jaxpr.effects
if eff in core.ordered_effects]
m, _ = mlir.lower_jaxpr_to_module(
lowering_result = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
unordered_effects=unordered_effects,
Expand All @@ -906,7 +906,8 @@ def computation_maker(*args, **kwargs):
map(xla.sharding_to_proto, out_parts_flat)))
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(m), use_tuple_args=should_tuple,
mlir.module_to_string(lowering_result.module),
use_tuple_args=should_tuple,
return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
Expand Down
31 changes: 12 additions & 19 deletions jax/_src/debugging.py
Expand Up @@ -91,37 +91,30 @@ def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
raise ValueError("Transpose doesn't support debugging callbacks.")
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule

def _ordered_effect_lowering(ctx, token, *args, **params):
avals_in = [core.abstract_token, *ctx.avals_in]
avals_out = [core.abstract_token, *ctx.avals_out]
args = (token, *args)
def _callback(token, *flat_args):
out = debug_callback_p.impl(*flat_args, **params)
return (token, *out)
(token, *result), keepalive = mlir.emit_python_callback(
ctx.module_context.platform, _callback, list(args), avals_in, avals_out,
True)
return result, keepalive, token

def debug_callback_lowering(ctx, *args, effect, callback, **params):

def _callback(*flat_args):
return tuple(
debug_callback_p.impl(
*flat_args, effect=effect, callback=callback, **params))
if effect in core.ordered_effects:
token = ctx.tokens_in.get(effect)[0]
result, keepalive, token = _ordered_effect_lowering(ctx, token,
*args, effect=effect, callback=callback, **params)
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
else:
def _callback(*flat_args):
return tuple(debug_callback_p.impl(
*flat_args, effect=effect, callback=callback, **params))
result, keepalive = mlir.emit_python_callback(ctx.module_context.platform,
_callback, list(args), ctx.avals_in, ctx.avals_out, True)
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
platform="cpu")
if jaxlib.version >= (0, 3, 11):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="gpu")
if jaxlib.version >= (0, 3, 15):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="tpu")

def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args,
**kwargs):
Expand Down
77 changes: 43 additions & 34 deletions jax/_src/dispatch.py
Expand Up @@ -333,7 +333,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
name, None, True, None, None, None, jaxpr=jaxpr, consts=consts,
device=device, in_avals=abstract_args, out_avals=out_avals,
has_unordered_effects=False, ordered_effects=[],
kept_var_idx=kept_var_idx, keepalive=None)
kept_var_idx=kept_var_idx, keepalive=None, host_callbacks=[])

if not _on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
Expand Down Expand Up @@ -372,17 +372,20 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
module, keepalive = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, unordered_effects, ordered_effects,
backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack,
donated_invars)
lowering_result = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr,
unordered_effects, ordered_effects, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
return XlaComputation(
name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals,
has_unordered_effects=bool(unordered_effects),
ordered_effects=ordered_effects, kept_var_idx=kept_var_idx,
keepalive=keepalive)
keepalive=keepalive, host_callbacks=host_callbacks)


def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
Expand Down Expand Up @@ -751,7 +754,8 @@ def _execute_replicated(name: str, compiled: XlaExecutable,

def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
has_unordered_effects: bool,
ordered_effects: List[core.Effect], kept_var_idx, *args):
ordered_effects: List[core.Effect], kept_var_idx,
host_callbacks, *args):
env: Dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
Expand Down Expand Up @@ -818,9 +822,15 @@ def compile(self) -> XlaCompiledComputation:
return self._executable

@profiler.annotate_function
def backend_compile(backend, built_c, options):
def backend_compile(backend, built_c, options, host_callbacks):
# we use a separate function call to ensure that XLA compilation appears
# separately in Python profiling results
if host_callbacks:
return backend.compile(built_c, compile_options=options,
host_callbacks=host_callbacks)
# Some backends don't have `host_callbacks` option yet
# TODO(sharadmv): remove this fallback when all backends allow `compile`
# to take in `host_callbacks`
return backend.compile(built_c, compile_options=options)

# TODO(phawkins): update users.
Expand All @@ -838,7 +848,8 @@ def _dump_ir_to_file(name: str, ir: str):
name.write_text(ir)


def compile_or_get_cached(backend, computation, compile_options):
def compile_or_get_cached(backend, computation, compile_options,
host_callbacks):
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc

Expand All @@ -861,7 +872,8 @@ def compile_or_get_cached(backend, computation, compile_options):
logging.info('Persistent compilation cache hit for %s.', module_name)
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
compiled = backend_compile(backend, computation, compile_options,
host_callbacks)
cc.put_executable(module_name, computation, compile_options, compiled,
backend)
return compiled
Expand All @@ -875,7 +887,7 @@ def compile_or_get_cached(backend, computation, compile_options):
assert isinstance(computation, str)
ir_str = computation
_dump_ir_to_file(module_name, ir_str)
return backend_compile(backend, computation, compile_options)
return backend_compile(backend, computation, compile_options, host_callbacks)


class XlaCompiledComputation(stages.XlaExecutable):
Expand All @@ -890,21 +902,17 @@ def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
self.unsafe_call.keepalive = keepalive

@staticmethod
def from_xla_computation(
name: str,
xla_computation: Optional[ir.Module],
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType],
nreps: int,
device: Optional[Device],
backend: Backend,
tuple_args: bool,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx: Set[int],
keepalive: Optional[Any]) -> XlaCompiledComputation:
def from_xla_computation(name: str, xla_computation: Optional[ir.Module],
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType], nreps: int,
device: Optional[Device], backend: Backend,
tuple_args: bool,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx: Set[int], keepalive: Optional[Any],
host_callbacks: List[Any]) -> XlaCompiledComputation:
sticky_device = device
input_handler = _input_handler(backend, in_type, out_type)
result_handler = _result_handler(backend, sticky_device, out_type)
Expand All @@ -914,7 +922,8 @@ def from_xla_computation(
options.parameter_is_tupled_arguments = tuple_args
with log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
compiled = compile_or_get_cached(backend, xla_computation, options)
compiled = compile_or_get_cached(backend, xla_computation, options,
host_callbacks)
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
if ordered_effects or has_unordered_effects:
num_output_tokens = len(ordered_effects) + has_unordered_effects
Expand All @@ -937,15 +946,15 @@ def xla_executable(self):
return self._xla_executable

@staticmethod
def from_trivial_jaxpr(
jaxpr, consts, device, in_avals, out_avals, has_unordered_effects,
ordered_effects, kept_var_idx, keepalive: Optional[Any]
) -> XlaCompiledComputation:
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
has_unordered_effects, ordered_effects, kept_var_idx,
keepalive: Optional[Any],
host_callbacks: List[Any]) -> XlaCompiledComputation:
assert keepalive is None
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, has_unordered_effects,
ordered_effects, kept_var_idx)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
result_handlers, has_unordered_effects,
ordered_effects, kept_var_idx, host_callbacks)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
keepalive)

Expand Down
18 changes: 9 additions & 9 deletions jax/experimental/host_callback.py
Expand Up @@ -1159,16 +1159,16 @@ def _outside_call_lowering(
# inside pmap, but does not work when we just execute on a single device,
# because in such executions we always get replica_id == 0.
replica_id = mhlo.ReplicaIdOp()
callback_operands = [current_token, replica_id, *args_to_outfeed]
callback_operands = [replica_id, *args_to_outfeed]
callback_operand_avals = [
core.abstract_token, core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]]
core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]]
if identity:
callback_flat_results_aval = [core.abstract_token]
callback_flat_results_aval = []
else:
callback_flat_results_aval = [core.abstract_token, *flat_results_aval]
callback_flat_results_aval = [*flat_results_aval]

def wrapped_callback(*args):
token, replica_id, *arrays = args
replica_id, *arrays = args
result_arrays = _outside_call_run_callback(
arrays,
xb.local_devices()[replica_id],
Expand All @@ -1180,13 +1180,13 @@ def wrapped_callback(*args):
if identity:
# For identity, we do not pass the any results back to the device
result_arrays = ()
return (token,) + result_arrays
return result_arrays

results, keep_alive = mlir.emit_python_callback(platform, wrapped_callback,
callback_operands, callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type]
results, next_token, keep_alive = mlir.emit_python_callback(ctx,
wrapped_callback, current_token, callback_operands,
callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type]
has_side_effect=True)
_callback_handler_data.keep_alives.append(keep_alive)
next_token, *results = results
# We must put the two tokens at the end
if identity:
results = list(args_to_outfeed)
Expand Down

0 comments on commit 6274b9e

Please sign in to comment.