From 1ee750e79509942ed06b8f05bbd9eda3c821348e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 1 Mar 2023 10:04:59 -0800 Subject: [PATCH] Pass the `jaxpr` from `pjit` since there is no need to trace it again in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually. PiperOrigin-RevId: 513268085 --- jax/_src/interpreters/pxla.py | 54 +++++++++++++++++++++++------------ jax/_src/pjit.py | 14 ++------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 5ec8789a4658..4e511056cad3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2867,7 +2867,7 @@ def _get_and_check_device_assignment( @profiler.annotate_function def lower_sharding_computation( - fun: lu.WrappedFun, + fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr], api_name: str, fun_name: str, in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]], @@ -2889,11 +2889,19 @@ def lower_sharding_computation( # 1. Trace to jaxpr and preprocess/verify it name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) - with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec", - event=dispatch.JAXPR_TRACE_EVENT): - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( - fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name)) + if isinstance(fun_or_jaxpr, lu.WrappedFun): + with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " + "in {elapsed_time} sec", + event=dispatch.JAXPR_TRACE_EVENT): + jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( + fun_or_jaxpr, global_in_avals, + debug_info=pe.debug_info_final(fun_or_jaxpr, api_name)) + else: + assert isinstance(fun_or_jaxpr, core.ClosedJaxpr) + jaxpr = fun_or_jaxpr.jaxpr + global_out_avals = fun_or_jaxpr.out_avals + consts = fun_or_jaxpr.consts + kept_outputs = [True] * len(global_out_avals) if _is_unspecified(out_shardings): @@ -2927,10 +2935,9 @@ def lower_sharding_computation( log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logger.log(log_priority, - "Compiling %s (%d) for with global shapes and types %s. " + "Compiling %s for with global shapes and types %s. " "Argument mapping: %s.", - getattr(fun, '__name__', ''), id(fun), - global_in_avals, in_shardings) + fun_name, global_in_avals, in_shardings) if keep_unused: kept_var_idx = set(range(len(global_in_avals))) @@ -3089,7 +3096,7 @@ def lower_sharding_computation( @profiler.annotate_function def lower_mesh_computation( - fun: lu.WrappedFun, + fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr], api_name: str, fun_name: str, mesh: Mesh, @@ -3114,10 +3121,9 @@ def lower_mesh_computation( log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logger.log(log_priority, - "Compiling %s (%d) for %s mesh with global shapes and types %s. " + "Compiling %s for %s mesh with global shapes and types %s. " "Argument mapping: %s.", - getattr(fun, '__name__', ''), id(fun), - tuple(global_axis_sizes.items()), global_in_avals, + fun_name, tuple(global_axis_sizes.items()), global_in_avals, in_shardings) # 1. Trace to jaxpr and preprocess/verify it @@ -3134,10 +3140,11 @@ def lower_mesh_computation( raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}") assert not callable(out_shardings) assert not auto_spmd_lowering + assert isinstance(fun_or_jaxpr, lu.WrappedFun) # This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which # is why `.spec` can be accessed. - fun = tiling_transform( - fun, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore + fun_or_jaxpr = tiling_transform( + fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore [get_array_mapping(o.spec) for o in out_shardings]) # type: ignore in_jaxpr_avals = global_in_avals else: @@ -3148,11 +3155,20 @@ def lower_mesh_computation( in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore for aval, i in safe_zip(global_in_avals, in_shardings)] in_jaxpr_avals = in_tiled_avals + with core.extend_axis_env_nd(mesh.shape.items()): - with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec", - event=dispatch.JAXPR_TRACE_EVENT): - jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals) + if isinstance(fun_or_jaxpr, lu.WrappedFun): + with dispatch.log_elapsed_time( + f"Finished tracing + transforming {name_stack} in " + "{elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT): + jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final( + fun_or_jaxpr, in_jaxpr_avals) + else: + assert isinstance(fun_or_jaxpr, core.ClosedJaxpr) + jaxpr = fun_or_jaxpr.jaxpr + out_jaxpr_avals = fun_or_jaxpr.out_avals + consts = fun_or_jaxpr.consts + assert len(out_shardings) == len(out_jaxpr_avals) if spmd_lowering: global_out_avals = out_jaxpr_avals diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8d481975a9d1..1f277a4d9db6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1393,10 +1393,6 @@ def _pjit_lower_cached( if resource_env is not None: pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") - f = core.jaxpr_as_fun(jaxpr) - f.__name__ = name - fun = lu.wrap_init(f) - if resource_env is not None: mesh = resource_env.physical_mesh api_name = 'pjit' @@ -1427,18 +1423,14 @@ def _pjit_lower_cached( # For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path # because `xmap` only supports SPMDAxisContext right now. - if (any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap')): + if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'): return pxla.lower_mesh_computation( - fun, api_name, name, mesh, + jaxpr, api_name, name, mesh, in_shardings, out_shardings, donated_invars, True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global) else: - # Pass `in_is_global` here because this path is taken by both host local - # avals and global avals. - # TODO(yashkatariya): Don't set committed to True always. Infer that from - # the arguments just like dispatch.py in `sharded_lowering`. return pxla.lower_sharding_computation( - fun, api_name, name, in_shardings, out_shardings, donated_invars, + jaxpr, api_name, name, in_shardings, out_shardings, donated_invars, jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused, always_lower=always_lower, devices_from_context=(