Skip to content

Commit

Permalink
Pass the jaxpr from pjit since there is no need to trace it again…
Browse files Browse the repository at this point in the history
… in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually.

PiperOrigin-RevId: 513268085
  • Loading branch information
yashk2810 authored and jax authors committed Mar 1, 2023
1 parent ed491b3 commit 1ee750e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
54 changes: 35 additions & 19 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -2867,7 +2867,7 @@ def _get_and_check_device_assignment(

@profiler.annotate_function
def lower_sharding_computation(
fun: lu.WrappedFun,
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]],
Expand All @@ -2889,11 +2889,19 @@ def lower_sharding_computation(
# 1. Trace to jaxpr and preprocess/verify it
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))

with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, global_in_avals,
debug_info=pe.debug_info_final(fun_or_jaxpr, api_name))
else:
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts

kept_outputs = [True] * len(global_out_avals)

if _is_unspecified(out_shardings):
Expand Down Expand Up @@ -2927,10 +2935,9 @@ def lower_sharding_computation(

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s (%d) for with global shapes and types %s. "
"Compiling %s for with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),
global_in_avals, in_shardings)
fun_name, global_in_avals, in_shardings)

if keep_unused:
kept_var_idx = set(range(len(global_in_avals)))
Expand Down Expand Up @@ -3089,7 +3096,7 @@ def lower_sharding_computation(

@profiler.annotate_function
def lower_mesh_computation(
fun: lu.WrappedFun,
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
mesh: Mesh,
Expand All @@ -3114,10 +3121,9 @@ def lower_mesh_computation(

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s (%d) for %s mesh with global shapes and types %s. "
"Compiling %s for %s mesh with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),
tuple(global_axis_sizes.items()), global_in_avals,
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
in_shardings)

# 1. Trace to jaxpr and preprocess/verify it
Expand All @@ -3134,10 +3140,11 @@ def lower_mesh_computation(
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_shardings)
assert not auto_spmd_lowering
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
# is why `.spec` can be accessed.
fun = tiling_transform(
fun, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
fun_or_jaxpr = tiling_transform(
fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
[get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
in_jaxpr_avals = global_in_avals
else:
Expand All @@ -3148,11 +3155,20 @@ def lower_mesh_computation(
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
for aval, i in safe_zip(global_in_avals, in_shardings)]
in_jaxpr_avals = in_tiled_avals

with core.extend_axis_env_nd(mesh.shape.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(
f"Finished tracing + transforming {name_stack} in "
"{elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, in_jaxpr_avals)
else:
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts

assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals
Expand Down
14 changes: 3 additions & 11 deletions jax/_src/pjit.py
Expand Up @@ -1393,10 +1393,6 @@ def _pjit_lower_cached(
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")

f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)

if resource_env is not None:
mesh = resource_env.physical_mesh
api_name = 'pjit'
Expand Down Expand Up @@ -1427,18 +1423,14 @@ def _pjit_lower_cached(

# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
if (any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap')):
if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
return pxla.lower_mesh_computation(
fun, api_name, name, mesh,
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
else:
# Pass `in_is_global` here because this path is taken by both host local
# avals and global avals.
# TODO(yashkatariya): Don't set committed to True always. Infer that from
# the arguments just like dispatch.py in `sharded_lowering`.
return pxla.lower_sharding_computation(
fun, api_name, name, in_shardings, out_shardings, donated_invars,
jaxpr, api_name, name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
always_lower=always_lower,
devices_from_context=(
Expand Down

0 comments on commit 1ee750e

Please sign in to comment.