Skip to content

Commit

Permalink
Add support for doing grad of pjit (similar to what jit supports)…
Browse files Browse the repository at this point in the history
…. Resolve in_shardings in `_pjit_call_impl` (that were UNSPECIFIED) before lowering to XLA. Then check if the device assignment is same across shardings in `lower_sharding_computation`.

PiperOrigin-RevId: 468251065
  • Loading branch information
yashk2810 authored and jax authors committed Aug 17, 2022
1 parent 539f13b commit c55e4dc
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 130 deletions.
171 changes: 81 additions & 90 deletions jax/experimental/pjit.py
Expand Up @@ -28,7 +28,7 @@
from jax import core
from jax import linear_util as lu
from jax import stages
from jax._src.api import _check_callable, _check_arg
from jax._src.api import _check_callable, _check_arg, devices
from jax._src.config import config
from jax._src import dispatch
from jax._src import source_info_util
Expand Down Expand Up @@ -323,8 +323,7 @@ def infer_params(*args, _global_avals=False, **kwargs):
donated_invars = (False,) * len(args_flat)

if config.jax_array:
in_shardings, out_shardings = _get_and_check_in_and_out_shardings(
args_flat, in_axis_resources, out_axis_resources, pjit_mesh, in_tree)
in_shardings, out_shardings = in_axis_resources, out_axis_resources
else:
in_shardings = tree_map(
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
Expand Down Expand Up @@ -389,6 +388,8 @@ def wrapped(*args, **kwargs):
def lower(*args, _global_avals=False, **kwargs):
(_, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, _global_avals=_global_avals, **kwargs)
if any(_is_unspecified(i) for i in params['in_shardings']):
raise ValueError("Please specify sharding on pjit's in_axis_resources.")
in_is_global = _calc_is_global_sequence(
params['in_positional_semantics'], params['in_shardings'])
lowering = _pjit_lower(
Expand Down Expand Up @@ -416,34 +417,6 @@ def hashable_pytree(pytree):
closure=(treedef, vals))


def _get_and_check_in_and_out_shardings(args_flat, pjit_in_shardings, out_shardings,
pjit_mesh, in_tree):
arg_in_shardings_flat = tuple(a.sharding if hasattr(a, 'sharding') else _UNSPECIFIED
for a in args_flat)
arg_ndims = tuple(a.ndim for a in args_flat)

if _is_unspecified(pjit_in_shardings):
# If pjit_in_shardings is unspecified, then arg_in_shardings cannot have
# unspecified in them.
for a in arg_in_shardings_flat:
if _is_unspecified(a):
raise ValueError('Please specify sharding either on the arg or on '
f'pjit. Found sharding {a} which is invalid.')
in_shardings_flat = arg_in_shardings_flat
else:
# This function is cached.
in_shardings_flat = _get_and_check_pjit_arg_shardings(
hashable_pytree(pjit_in_shardings), arg_in_shardings_flat, arg_ndims,
in_tree)

out_shardings_flat = tuple(tree_flatten(out_shardings)[0])
# Check if the device assignment is the same across inputs and outputs.
# This function is cached.
_check_array_device_assignment(pjit_mesh, in_shardings_flat + out_shardings_flat)

return tree_unflatten(in_tree, in_shardings_flat), out_shardings


@lru_cache(maxsize=4096)
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
if _is_unspecified_or_from_gda_or_auto(x):
Expand Down Expand Up @@ -527,7 +500,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
allow_uneven_sharding=False)
global_in_avals = local_in_avals
canonicalized_shardings = tuple(
i if _is_auto(i) else to_op_sharding_sharding(i, aval.ndim)
i if _is_auto(i) or _is_unspecified(i) else to_op_sharding_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
return tuple(global_in_avals), canonicalized_shardings

Expand Down Expand Up @@ -798,10 +771,42 @@ def _check_unique_resources(axis_resources, arg_name):
pjit_p.multiple_results = True


def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
arg_shardings = tuple(a.sharding if hasattr(a, 'sharding') else _UNSPECIFIED
for a in args)
arg_ndims = tuple(a.ndim if hasattr(a, 'ndim') else 0 for a in args)
da = _get_and_check_device_assignment(
it.chain(arg_shardings, pjit_in_shardings, out_shardings), pjit_mesh)

resolved_in_shardings = []
for arg_s, pjit_in_s, ndim in safe_zip(arg_shardings, pjit_in_shardings, arg_ndims):
if _is_unspecified(pjit_in_s):
if _is_unspecified(arg_s):
resolved_in_shardings.append(OpShardingSharding.get_replicated(da))
else:
resolved_in_shardings.append(to_op_sharding_sharding(arg_s, ndim))
else:
if not _is_unspecified(arg_s):
if not pxla.are_op_shardings_equal(
pjit_in_s._to_xla_op_sharding(ndim),
arg_s._to_xla_op_sharding(ndim)):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {pjit_in_s},\n'
f'arg sharding: {arg_s}')
resolved_in_shardings.append(pjit_in_s)

return tuple(resolved_in_shardings)


def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics):
if config.jax_array:
in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings,
resource_env.physical_mesh)

in_is_global = _calc_is_global_sequence(in_positional_semantics, in_shardings)
if config.jax_array and all(_is_unspecified(o) for o in out_shardings):
_allow_propagation_to_outputs = True
Expand Down Expand Up @@ -862,7 +867,7 @@ def _pjit_lower(
in_shardings,
out_shardings,
*args, **kwargs):
da = _get_device_assignment(it.chain(in_shardings, out_shardings))
da = _fast_path_get_device_assignment(it.chain(in_shardings, out_shardings))
in_shardings = SameDeviceAssignmentTuple(in_shardings, da)
out_shardings = SameDeviceAssignmentTuple(out_shardings, da)
return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs)
Expand Down Expand Up @@ -941,9 +946,9 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)

arg_shardings = [i._to_xla_op_sharding(aval.ndim)
arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(ctx.avals_in, in_shardings)]
result_shardings = [o._to_xla_op_sharding(aval.ndim)
result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(ctx.avals_out, out_shardings)]

sub_ctx = ctx.module_context.replace(
Expand Down Expand Up @@ -1056,7 +1061,6 @@ def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics):
da = _get_device_assignment(it.chain(in_shardings, out_shardings))
in_pvals = [t.pval for t in in_tracers]

known_ins = tuple(pv.is_known() for pv in in_pvals)
Expand All @@ -1070,35 +1074,44 @@ def _pjit_partial_eval(trace, *in_tracers,
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)

if config.jax_array:
residual_shardings = (_UNSPECIFIED,) * num_residuals
else:
# Using fast path to get the device assignment because mesh is used to
# create the shardings. So the device assignment is always the same.
da = _fast_path_get_device_assignment(it.chain(in_shardings, out_shardings))
residual_shardings = (OpShardingSharding.get_replicated(da),) * num_residuals
# Compute the known outputs
known_params = dict(
jaxpr=known_jaxpr,
in_shardings=keep_where(in_shardings, known_ins),
out_shardings=(
keep_where(out_shardings, known_outs) +
(OpShardingSharding.get_replicated(da),) * num_residuals),
keep_where(out_shardings, known_outs) + residual_shardings),
resource_env=resource_env,
donated_invars=keep_where(donated_invars, known_ins),
name=name,
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
out_positional_semantics=out_positional_semantics)

if num_residuals:
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_shardings'])
compiled = _pjit_lower(
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global).compile(_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
_, out_op_shardings = _get_op_sharding_from_executable(compiled.xla_executable)
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
else:
residual_op_shardings = ()
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
known_params['out_shardings'] = (
keep_where(out_shardings, known_outs) + residual_shardings)
# Skip this for Arrays because Arrays support UNSPECIFIED in out_shardings. So
# there is no need to find out the residual shardings from XLA here.
if not config.jax_array:
if num_residuals:
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_shardings'])
compiled = _pjit_lower(
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global).compile(_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
_, out_op_shardings = _get_op_sharding_from_executable(compiled.xla_executable)
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
else:
residual_op_shardings = ()
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
known_params['out_shardings'] = (
keep_where(out_shardings, known_outs) + residual_shardings)

all_known_outs = pjit_p.bind(
*(pv.get_known() for pv in in_pvals if pv.is_known()),
Expand Down Expand Up @@ -1220,7 +1233,9 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")

for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
if hasattr(s, '_original_sharding'):
if _is_unspecified(s) or _is_auto(s):
continue
elif hasattr(s, '_original_sharding'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
parsed_pspec = parse_flatten_op_sharding(
Expand All @@ -1234,7 +1249,9 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r

what = "pjit output"
for aval, s in zip(jaxpr.out_avals, params['out_shardings']):
if hasattr(s, '_original_sharding'):
if _is_unspecified(s) or _is_auto(s):
continue
elif hasattr(s, '_original_sharding'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
parsed_pspec = parse_flatten_op_sharding(
Expand Down Expand Up @@ -1432,7 +1449,8 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics:
return maps._positional_semantics.val


def _get_device_assignment(shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
def _fast_path_get_device_assignment(
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
da = None
for i in shardings:
if _is_auto(i) or _is_unspecified(i):
Expand Down Expand Up @@ -1472,9 +1490,7 @@ def _gda_check_and_get_sharding(


@lru_cache(maxsize=4096)
def _check_array_device_assignment(pjit_mesh, shardings):
if not shardings:
return
def _get_and_check_device_assignment(shardings, pjit_mesh):
first_device_assignment = None
mesh_devices = list(pjit_mesh.devices.flat)
for i in shardings:
Expand All @@ -1500,36 +1516,11 @@ def _check_array_device_assignment(pjit_mesh, shardings):
raise ValueError("Pjit's devices and Array's devices should be equal. "
f"Got Pjit devices: {list(pjit_mesh.devices.flat)},\n "
f"Array devices: {arr_device_assignment}")

@lru_cache(maxsize=4096)
def _get_and_check_pjit_arg_shardings(pjit_in_shardings, arg_in_shardings_flat,
arg_ndims, in_tree):
pjit_in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, pjit_in_shardings(), tupled_args=True)

out = []
for pjit_sharding, arg_sharding, ndim in safe_zip(
pjit_in_shardings_flat, arg_in_shardings_flat, arg_ndims):
# If the sharding of the arg is not known, replace it with the sharding on
# pjit.
if _is_unspecified(arg_sharding):
out.append(pjit_sharding)
elif _is_auto(pjit_sharding):
raise ValueError('Passing sharding on pjit and on args while using the '
'auto spmd partitioner is not allowed. Please call the '
'compiled object on the inputs.')
else:
if not pxla.are_op_shardings_equal(
pjit_sharding._to_xla_op_sharding(ndim),
arg_sharding._to_xla_op_sharding(ndim)):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {pjit_sharding},\n'
f'arg sharding: {arg_sharding}')
out.append(pjit_sharding)

assert not any(_is_unspecified(o) for o in out)
return tuple(out)
if first_device_assignment is None and not pjit_mesh.empty:
return mesh_devices
if first_device_assignment is None:
return [config.jax_default_device or devices()[0]]
return first_device_assignment


def _maybe_check_pjit_gda_mesh(args, mesh):
Expand Down
7 changes: 4 additions & 3 deletions jax/interpreters/mlir.py
Expand Up @@ -889,7 +889,8 @@ def aval_to_types(aval):
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)
flat_args = [a if s is None else wrap_with_sharding_op(a, s)
for a, s in zip(flat_args, ir_arg_shardings)]

unflattened_args = util.unflatten(flat_args, map(len, input_types))
# We separate out the token inputs and the usual inputs. The token inputs
Expand Down Expand Up @@ -926,8 +927,8 @@ def aval_to_types(aval):
outs.append(out)
flat_outputs = util.flatten(outs)
if not use_sharding_annotations and ir_result_shardings is not None:
flat_outputs = map(wrap_with_sharding_op, flat_outputs,
ir_result_shardings)
flat_outputs = [o if s is None else wrap_with_sharding_op(o, s)
for o, s in zip(flat_outputs, ir_result_shardings)]

func_dialect.ReturnOp(flat_outputs)

Expand Down
23 changes: 15 additions & 8 deletions jax/interpreters/pxla.py
Expand Up @@ -2555,17 +2555,22 @@ def __reduce__(self):


def _get_backend_from_shardings(
shardings: Iterable[XLACompatibleSharding]) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
da = None
shardings: Iterable[Union[XLACompatibleSharding, _UnspecifiedValue]]
) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
from jax.experimental.sharding import XLACompatibleSharding

da: Optional[Sequence[xc.Device]] = None
first_sharding = None
for s in shardings:
if _is_unspecified(s):
continue
da = s._device_assignment
# pytype does not understand that _UNSPECIFIED is being skipped above.
da = s._device_assignment # type: ignore
first_sharding = s
break
assert len(da) > 0 # type: ignore
return xb.get_device_backend(da[0]), first_sharding # type: ignore
da = cast(Sequence[xc.Device], da)
assert len(da) > 0
return xb.get_device_backend(da[0]), cast(XLACompatibleSharding, first_sharding)


@profiler.annotate_function
Expand Down Expand Up @@ -2597,7 +2602,7 @@ def lower_sharding_computation(
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
assert len(out_shardings) == len(out_jaxpr_avals)
assert len(out_shardings) == len(out_jaxpr_avals), (len(out_shardings), len(out_jaxpr_avals))

global_out_avals = out_jaxpr_avals

Expand Down Expand Up @@ -3004,11 +3009,13 @@ def from_hlo(name: str,
assert mesh is not None
in_shardings, out_shardings = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
assert mesh is None
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
_, out_shardings_xla = _get_op_sharding_shardings_from_executable(
xla_executable, first_sharding._device_assignment,
len(global_in_avals), len(global_out_avals))
out_shardings = [x if _is_unspecified(o) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]

in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
Expand Down

0 comments on commit c55e4dc

Please sign in to comment.