From 231495166929be4a6ee3a0fd843858abeeca3694 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 7 Jul 2022 10:41:27 -0700 Subject: [PATCH] Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974 --- jax/experimental/jax2tf/jax2tf.py | 27 +- jax/experimental/pjit.py | 482 ++++++++++++++++++------------ jax/experimental/sharding.py | 93 +++++- tests/pjit_test.py | 59 ++-- 4 files changed, 418 insertions(+), 243 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 6ad70807d0be..0df4e839a8b2 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -17,7 +17,7 @@ import os import re import threading -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast import jax from jax import lax @@ -28,6 +28,7 @@ from jax import numpy as jnp from jax.experimental import maps from jax.experimental import pjit +from jax.experimental import sharding from jax.interpreters import ad from jax.interpreters import partial_eval from jax.interpreters import pxla @@ -2670,13 +2671,13 @@ def split_to_logical_devices(tensor: TfVal, return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) -def _shard_value(mesh: maps.Mesh, - val: TfVal, +def _shard_value(val: TfVal, aval: core.ShapedArray, - axis_resources: pjit.ParsedPartitionSpec) -> TfVal: + sd: sharding.XLACompatibleSharding) -> TfVal: """Apply sharding to a TfVal.""" - sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto( - aval, axis_resources, mesh) + sharding_proto: xla_client.OpSharding = cast( + xla_client.OpSharding, sd._to_xla_op_sharding(aval.ndim)) + # To use xla_sharding.py, we must have a xla_data_pb2.OpSharding. xla_sharding_proto: xla_data_pb2.OpSharding = ( xla_data_pb2.OpSharding( @@ -2691,8 +2692,8 @@ def _shard_value(mesh: maps.Mesh, def _pjit(*args: TfVal, jaxpr: core.ClosedJaxpr, - in_axis_resources: Sequence[pjit.ParsedPartitionSpec], - out_axis_resources: Sequence[pjit.ParsedPartitionSpec], + in_shardings: Sequence[sharding.XLACompatibleSharding], + out_shardings: Sequence[sharding.XLACompatibleSharding], resource_env: maps.ResourceEnv, donated_invars, name: str, @@ -2704,15 +2705,13 @@ def _pjit(*args: TfVal, if resource_env.physical_mesh.is_multi_process: raise NotImplementedError("jax2tf translation for pjit over multi-process " "meshes is not supported yet") - # TODO: add `name` to the name stack - shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh) # Apply sharding annotation to the arguments sharded_args: Sequence[TfVal] = tuple( - map(shard_value_for_mesh, args, _in_avals, in_axis_resources)) + map(_shard_value, args, _in_avals, in_shardings)) results = _interpret_jaxpr(jaxpr, *sharded_args, extra_name_stack=util.wrap_name(name, "pjit")) sharded_results: Sequence[TfVal] = tuple( - map(shard_value_for_mesh, results, _out_aval, out_axis_resources)) + map(_shard_value, results, _out_aval, out_shardings)) return tuple(sharded_results) @@ -2725,7 +2724,9 @@ def _pjit_sharding_constraint(arg: TfVal, *, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, **kwargs) -> TfVal: - return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources) + ms = sharding.MeshPspecSharding._from_parsed_pspec( + resource_env.physical_mesh, axis_resources) + return _shard_value(arg, _in_avals[0], ms) tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 3148a9f73f8e..86263e6214a3 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -15,13 +15,14 @@ from enum import IntEnum import numpy as np from collections import OrderedDict, Counter -from typing import Callable, Sequence, Tuple, Union, Optional, cast, List +from typing import Callable, Sequence, Tuple, Union, Optional, cast, List, Iterable import itertools as it from functools import partial from jax.experimental import maps from jax.experimental.global_device_array import GlobalDeviceArray as GDA from jax.experimental.array import Array +from jax.experimental import sharding from jax import core from jax import linear_util as lu from jax import stages @@ -257,14 +258,27 @@ def pjit(fun: Callable, # rather than raising an error. https://github.com/google/jax/issues/2367 in_axis_resources = tuple(in_axis_resources) - in_axis_resources, _, _, in_any_auto = _prepare_axis_resources( - in_axis_resources, "in_axis_resources") - out_axis_resources, _, _, _ = _prepare_axis_resources( - out_axis_resources, "out_axis_resources") - - # Duck type `UNSPECIFIED` with `FROM_GDA` to use that codepath for `Array`. - if config.jax_array and _is_unspecified(in_axis_resources): - in_axis_resources = FROM_GDA + in_any_auto: bool + if not config.jax_array: + in_axis_resources, _, _, in_any_auto = _prepare_axis_resources( + in_axis_resources, "in_axis_resources") + out_axis_resources, _, _, _ = _prepare_axis_resources( + out_axis_resources, "out_axis_resources") + else: + # TODO(yashkatariya): Relax this restriction once the transition to + # sharding instances finishes. + if not _is_unspecified(in_axis_resources): + raise ValueError('in_axis_resources should be empty for Array. The sharding ' + 'should be specified on the arguments as pjit follows ' + 'computation follows data semantics.') + # `out_axis_resources` should be instances of `Sharding` if it's not + # unspecified. For `AUTO` sharding, it can only be used with + # MeshPspecSharding. + if not _is_unspecified(out_axis_resources): + if not all(isinstance(s, sharding.Sharding) for s in tree_flatten(out_axis_resources)[0]): + raise ValueError('When `config.jax_array` flag is enabled, ' + 'out_axis_resources should contain instances of ' + '`Sharding`.') static_argnums = _ensure_index_tuple(static_argnums) donate_argnums = _ensure_index_tuple(donate_argnums) @@ -280,8 +294,8 @@ def infer_params(*args, _global_avals=False, **kwargs): # Putting this outside of wrapped would make resources lexically scoped resource_env = pxla.thread_resources.env - mesh = resource_env.physical_mesh - if mesh.empty: + pjit_mesh = resource_env.physical_mesh + if pjit_mesh.empty: raise RuntimeError("pjit requires a non-empty mesh! Are you sure that " "it's defined at the call site?") @@ -300,13 +314,19 @@ def infer_params(*args, _global_avals=False, **kwargs): if any(not isinstance(a, Array) for a in args_flat): raise ValueError('All arguments to pjit when `config.jax_array` is ' 'enabled should be `Array`s.') + # tree_map over `dyn_args` to preserve the pytree structure of args. + in_shardings = tree_map(lambda x: x.sharding, dyn_args) + out_shardings = out_axis_resources + # This function is cached which is an improvement over the old codepath. + _check_array_device_assignment(pjit_mesh, tuple(tree_flatten(in_shardings)[0])) + else: + in_shardings = tree_map(lambda x: _create_mesh_pspec_sharding(pjit_mesh, x), + in_axis_resources) + out_shardings = tree_map(lambda x: x if _is_unspecified(x) else + _create_mesh_pspec_sharding(pjit_mesh, x), out_axis_resources) + _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh) - # TODO(yashkatariya): Check device_set for Array instead of the mesh. - _maybe_check_pjit_gda_or_array_mesh(args_flat, mesh) - - # TODO(yashkatariya): Make sure you are not checking explicitly for `ShapedArray`. - # One possibility, is to only allow GDA and fully replicated inputs for AUTO. - if in_any_auto and not _global_avals: + if not config.jax_array and in_any_auto and not _global_avals: raise ValueError('Auto sharding is only enabled for global inputs. ' 'Please set `_global_avals=True` during `.lower()`. ' 'Use the compiled object to call the inputs.') @@ -323,22 +343,23 @@ def infer_params(*args, _global_avals=False, **kwargs): if config.jax_parallel_functions_output_gda or config.jax_array else maps._positional_semantics.val) - global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources( - mesh, local_in_avals, hashable_pytree(in_axis_resources), in_tree, - in_positional_semantics, tuple(isinstance(a, (GDA, Array)) for a in args_flat)) + global_in_avals, normalized_in_shardings_flat = _process_in_axis_resources( + hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, + tuple(isinstance(a, GDA) for a in args_flat)) - jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr( - flat_fun, mesh, global_in_avals, HashableFunction(out_tree, closure=()), - hashable_pytree(out_axis_resources)) + jaxpr, normalized_out_shardings_flat = _pjit_jaxpr( + flat_fun, hashable_pytree(out_shardings), global_in_avals, + HashableFunction(out_tree, closure=())) - canonicalized_in_axis_resources_flat = tree_map( - _maybe_replace_from_gda_with_pspec, - canonicalized_in_axis_resources_flat, tuple(args_flat)) + if not config.jax_array: + normalized_in_shardings_flat = tree_map( + _maybe_replace_from_gda_with_pspec, normalized_in_shardings_flat, + tuple(args_flat)) params = dict( jaxpr=jaxpr, - in_axis_resources=canonicalized_in_axis_resources_flat, - out_axis_resources=canonicalized_out_axis_resources_flat, + in_shardings=normalized_in_shardings_flat, + out_shardings=normalized_out_shardings_flat, resource_env=resource_env, donated_invars=donated_invars, name=getattr(flat_fun, '__name__', ''), @@ -356,13 +377,13 @@ def wrapped(*args, **kwargs): return tree_unflatten(out_tree, out) def lower(*args, _global_avals=False, **kwargs): - (args_flat, flat_local_in_avals, params, in_tree, out_tree, + (_, flat_local_in_avals, params, in_tree, out_tree, donate_argnums) = infer_params(*args, _global_avals=_global_avals, **kwargs) in_is_global = _calc_is_global_sequence( - params['in_positional_semantics'], params['in_axis_resources']) + params['in_positional_semantics'], params['in_shardings']) lowering = _pjit_lower( - params['jaxpr'], params['in_axis_resources'], - params['out_axis_resources'], params['resource_env'], + params['jaxpr'], params['in_shardings'], + params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], in_is_global) @@ -384,19 +405,28 @@ def hashable_pytree(pytree): return HashableFunction(lambda: tree_unflatten(treedef, vals), closure=(treedef, vals)) -def flatten_axis_resources(what, tree, axis_resources, tupled_args): + +def _create_mesh_pspec_sharding(mesh, x): + if _is_unspecified(x): + return x + if _is_from_gda(x): + return x + return sharding.MeshPspecSharding._from_parsed_pspec(mesh, x) + + +def flatten_axis_resources(what, tree, shardings, tupled_args): try: - return tuple(flatten_axes(what, tree, axis_resources, tupled_args=tupled_args)) + return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args)) except ValueError: pass # Raise a tree prefix error below # Tree leaves are always valid prefixes, so if there was a prefix error as # assumed here, axis_resources must not be a leaf. - assert not treedef_is_leaf(tree_structure(axis_resources)) + assert not treedef_is_leaf(tree_structure(shardings)) # Check the type directly rather than using isinstance because of namedtuples. - if tupled_args and (type(axis_resources) is not tuple or - len(axis_resources) != len(tree.children())): + if tupled_args and (type(shardings) is not tuple or + len(shardings) != len(tree.children())): # We know axis_resources is meant to be a tuple corresponding to the args # tuple, but while it is a non-leaf pytree, either it wasn't a tuple or it # wasn't the right length. @@ -406,17 +436,17 @@ def flatten_axis_resources(what, tree, axis_resources, tupled_args): f"a tuple of length equal to the number of positional arguments.") # If `tree` represents an args tuple, then `axis_resources` must be a tuple. # TODO(mattjj,apaszke): disable implicit list casts, remove 'or list' below - if type(axis_resources) is not tuple: - msg += f" But {what} is not a tuple: got {type(axis_resources)} instead." - elif len(axis_resources) != len(tree.children()): + if type(shardings) is not tuple: + msg += f" But {what} is not a tuple: got {type(shardings)} instead." + elif len(shardings) != len(tree.children()): msg += (f" But {what} is the wrong length: got a tuple or list of length " - f"{len(axis_resources)} for an args tuple of length " + f"{len(shardings)} for an args tuple of length " f"{len(tree.children())}.") # As an extra hint, let's check if the user just forgot to wrap - # in_axis_resources in a singleton tuple. + # shardings in a singleton tuple. if len(tree.children()) == 1: - try: flatten_axes(what, tree, (axis_resources,)) + try: flatten_axes(what, tree, (shardings,)) except ValueError: pass # That's not the issue. else: msg += (f" Given the corresponding argument being " @@ -425,8 +455,11 @@ def flatten_axis_resources(what, tree, axis_resources, tupled_args): raise ValueError(msg) - # Replace axis_resources with unparsed versions to avoid revealing internal details - axis_tree = tree_map(lambda parsed: parsed.user_spec, axis_resources) + if config.jax_array: + axis_tree = shardings + else: + # Replace axis_resources with unparsed versions to avoid revealing internal details + axis_tree = tree_map(lambda parsed: parsed.spec, shardings) # Because ecause we only have the `tree` treedef and not the full pytree here, # we construct a dummy tree to compare against. Revise this in callers? @@ -444,12 +477,27 @@ def __repr__(self): return "pytree leaf" @cache() -def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk, +def _process_in_axis_resources(in_shardings_thunk, local_in_avals, in_tree, in_positional_semantics, is_gda): - in_axis_resources_flat = flatten_axis_resources( - "pjit in_axis_resources", in_tree, - in_axis_resources_thunk(), tupled_args=True) - canonicalized_in_axis_resources_flat = tree_map(_create_cpspec, in_axis_resources_flat) + in_shardings_flat = flatten_axis_resources( + "pjit in_axis_resources", in_tree, in_shardings_thunk(), tupled_args=True) + + # Fork here because the `Array` path is very simple and doesn't need all the + # complexity below. + if config.jax_array: + for aval, i in safe_zip(local_in_avals, in_shardings_flat): + pjit_check_aval_sharding(i, aval, "pjit arguments", + allow_uneven_sharding=False) + global_in_avals = local_in_avals + return tuple(global_in_avals), tuple(i.normalize() for i in in_shardings_flat) + + if not local_in_avals: + assert not in_shardings_flat + return (), () + + in_axis_resources_flat = tuple(i if _is_from_gda(i) else i._parsed_pspec + for i in in_shardings_flat) + # This check should be above local_to_global call below otherwise if # `FROM_GDA` is passed to any input other than GDA, a ugly error message # will be raised because get_array_mapping (in local_to_global) of a @@ -462,26 +510,33 @@ def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk, # Use canonicalized in_axis_resources here because we want to treat P(None) # and None (for example) as equivalent. if all( - (not _is_from_gda(p) and not _is_auto(p) and p.partitions == ()) or ips == maps._PositionalSemantics.GLOBAL - for p, ips in safe_zip(canonicalized_in_axis_resources_flat, in_positional_semantics)): + (not _is_from_gda(p) and not _is_auto(p) and + CanonicalizedParsedPartitionSpec(p).partitions == ()) or + ips == maps._PositionalSemantics.GLOBAL + for p, ips in safe_zip(in_axis_resources_flat, in_positional_semantics)): # Shapes should be checked against non canonicalized in_axis_resources. # For example, partitions of () and ((),) are not equivalent, since the # first one is a valid spec for a scalar value, while the second is not! - _check_shapes_against_resources( - "pjit arguments", mesh.is_multi_process, mesh.shape, local_in_avals, - in_axis_resources_flat, allow_uneven_sharding=False) + for aval, i in safe_zip(local_in_avals, in_shardings_flat): + if _is_from_gda(i): continue + pjit_check_aval_sharding(i, aval, "pjit arguments", + allow_uneven_sharding=False) else: - _check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape, - local_in_avals, in_axis_resources_flat, - allow_uneven_sharding=False) + for aval, i in safe_zip(local_in_avals, in_shardings_flat): + if _is_from_gda(i): continue + pjit_check_aval_sharding(i, aval, "pjit arguments", + allow_uneven_sharding=False, local=True) + + normalized_in_shardings_flat = tuple( + i if _is_from_gda(i) else i.normalize() for i in in_shardings_flat) - global_in_avals = local_to_global(in_positional_semantics, mesh, - local_in_avals, canonicalized_in_axis_resources_flat) - return tuple(global_in_avals), canonicalized_in_axis_resources_flat + global_in_avals = local_to_global(in_positional_semantics, + local_in_avals, normalized_in_shardings_flat) + return tuple(global_in_avals), normalized_in_shardings_flat @lu.cache -def _pjit_jaxpr(fun, mesh, global_in_avals, out_tree, out_axis_resources_thunk): +def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree): prev_positional_val = maps._positional_semantics.val try: maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL @@ -492,15 +547,34 @@ def _pjit_jaxpr(fun, mesh, global_in_avals, out_tree, out_axis_resources_thunk): maps._positional_semantics.val = prev_positional_val jaxpr = core.ClosedJaxpr(jaxpr, consts) - out_axis_resources_flat = flatten_axis_resources( - "pjit out_axis_resources", out_tree(), - out_axis_resources_thunk(), tupled_args=False) - _check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape, - global_out_avals, out_axis_resources_flat, - allow_uneven_sharding=False) - canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat) + out_shardings_flat = flatten_axis_resources( + "pjit out_axis_resources", out_tree(), out_shardings_thunk(), tupled_args=False) + + for aval, o in safe_zip(global_out_avals, out_shardings_flat): + if _is_unspecified(o): continue + pjit_check_aval_sharding(o, aval, "pjit outputs", allow_uneven_sharding=False) + + normalized_out_shardings_flat = tuple( + o if _is_unspecified(o) else o.normalize() + for o in out_shardings_flat + ) + # lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple - return _ListWithW([jaxpr, canonicalized_out_axis_resources_flat]) + return _ListWithW([jaxpr, normalized_out_shardings_flat]) + + +# TODO(yashkatariya): Replace this with shape check against sharding which +# uses OpSharding.tile_assignment_dimension. +def pjit_check_aval_sharding( + sharding: sharding.MeshPspecSharding, aval, what_aval: str, + allow_uneven_sharding: bool, local: bool = False): + if local: + m = sharding.mesh.local_mesh + else: + m = sharding.mesh + _check_shapes_against_resources( + what_aval, m.is_multi_process, m.shape, [aval], + [sharding._parsed_pspec], allow_uneven_sharding) class SpecSync(IntEnum): @@ -690,17 +764,17 @@ def _check_shapes_against_resources(what: str, is_global_shape: bool, def _pjit_call_impl(*args, jaxpr, - in_axis_resources, out_axis_resources, - resource_env, donated_invars, name, + in_shardings, out_shardings, resource_env, + donated_invars, name, in_positional_semantics, out_positional_semantics): - in_is_global = _calc_is_global_sequence(in_positional_semantics, in_axis_resources) - if config.jax_array and all(_is_unspecified(o) for o in out_axis_resources): + in_is_global = _calc_is_global_sequence(in_positional_semantics, in_shardings) + if config.jax_array and all(_is_unspecified(o) for o in out_shardings): _allow_propagation_to_outputs = True else: _allow_propagation_to_outputs = False compiled = _pjit_lower( - jaxpr, in_axis_resources, out_axis_resources, - resource_env, donated_invars, name, in_is_global).compile( + jaxpr, in_shardings, out_shardings, resource_env, + donated_invars, name, in_is_global).compile( _allow_propagation_to_outputs=_allow_propagation_to_outputs) # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.jax_enable_checks: @@ -714,9 +788,8 @@ def _pjit_call_impl(*args, jaxpr, if fingerprint is not None: fingerprint = fingerprint.hex() distributed_debug_log(("Running pjit'd function", name), - ("mesh", resource_env.physical_mesh), - ("in_axis_resources", in_axis_resources), - ("out_axis_resources", out_axis_resources), + ("in_shardings", in_shardings), + ("out_shardings", out_shardings), ("abstract args", list(map(xla.abstractify, args))), ("fingerprint", fingerprint)) return compiled.unsafe_call(*args) @@ -725,57 +798,53 @@ def _pjit_call_impl(*args, jaxpr, @weakref_lru_cache def _pjit_lower( jaxpr: core.ClosedJaxpr, - in_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...], - out_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...], + in_shardings, + out_shardings, resource_env, donated_invars, name: str, in_is_global: Sequence[bool]): - # in_axis_resources and out_axis_resources are canonicalized to avoid + # in_shardings and out_shardings are canonicalized to avoid # recompilation (since pjit_lower is cached) if its compiled with `None` but # in the next call `P(None)` is passed. Those are the same thing so should be # treat as equivalent and pjit_lower's cache shouldn't be invalidated. - in_axes = [get_array_mapping(axes) for axes in in_axis_resources] - out_axes = [get_array_mapping(axes) for axes in out_axis_resources] + in_axes = [get_array_mapping(i._parsed_pspec) for i in in_shardings] + out_axes = [get_array_mapping(o if _is_unspecified(o) else o._parsed_pspec) + for o in out_shardings] pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") f = core.jaxpr_as_fun(jaxpr) f.__name__ = name fun = lu.wrap_init(f) + # TODO(yashkatariya): Move `lower_mesh_computation` to use the `Sharding` + # interface too. return pxla.lower_mesh_computation( fun, 'pjit', name, resource_env.physical_mesh, in_axes, out_axes, donated_invars, True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global) -def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env, +def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, out_positional_semantics, **_): if jaxpr.effects: raise NotImplementedError('Effects not supported in `pjit`.') - return global_to_local(out_positional_semantics, resource_env.physical_mesh, - jaxpr.out_avals, out_axis_resources), jaxpr.effects + return global_to_local(out_positional_semantics, jaxpr.out_avals, + out_shardings), jaxpr.effects pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) -def _pjit_lowering(ctx, *args, name, jaxpr, in_axis_resources, - out_axis_resources, resource_env, donated_invars, +def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, + out_shardings, resource_env, donated_invars, in_positional_semantics, out_positional_semantics): if not isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext): raise RuntimeError("Nesting pjit() inside jit() is not allowed.") - # TODO: use manual_axes! - mesh = resource_env.physical_mesh + output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) - arg_shardings = [] - for i, (aval, axis_resources) in enumerate( - safe_zip(ctx.avals_in, in_axis_resources)): - arg_shardings.append(get_aval_sharding_proto(aval, axis_resources, mesh)) - - result_shardings = [ - get_aval_sharding_proto(aval, axis_resources, mesh) - for aval, axis_resources in safe_zip( - ctx.avals_out, out_axis_resources) - ] + arg_shardings = [i._to_xla_op_sharding(aval.ndim) + for aval, i in safe_zip(ctx.avals_in, in_shardings)] + result_shardings = [o._to_xla_op_sharding(aval.ndim) + for aval, o in safe_zip(ctx.avals_out, out_shardings)] sub_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(ctx.module_context.name_stack, @@ -798,7 +867,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_axis_resources, def _pjit_batcher(insert_axis, axis_size, axis_name, main_type, vals_in, dims_in, - jaxpr, in_axis_resources, out_axis_resources, + jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): # batch_jaxpr expects all batching dimensions to be equal to 0 @@ -810,17 +879,17 @@ def _pjit_batcher(insert_axis, instantiate=False, axis_name=axis_name, main_type=main_type) new_parts = (axis_name,) if insert_axis else () - in_axis_resources = tuple( - spec.insert_axis_partitions(0, new_parts) if is_mapped else spec - for is_mapped, spec in zip(is_mapped_in, in_axis_resources)) - out_axis_resources = tuple( - spec.insert_axis_partitions(0, new_parts) if is_mapped else spec - for is_mapped, spec in zip(is_mapped_out, out_axis_resources)) + in_shardings = tuple( + _pjit_batcher_for_sharding(i, 0, new_parts) if is_mapped else i + for is_mapped, i in zip(is_mapped_in, in_shardings)) + out_shardings = tuple( + _pjit_batcher_for_sharding(o, 0, new_parts) if is_mapped else o + for is_mapped, o in zip(is_mapped_out, out_shardings)) vals_out = pjit_p.bind( *vals_in, jaxpr=new_jaxpr, - in_axis_resources=in_axis_resources, - out_axis_resources=out_axis_resources, + in_shardings=in_shardings, + out_shardings=out_shardings, resource_env=resource_env, donated_invars=donated_invars, name=name, @@ -831,9 +900,18 @@ def _pjit_batcher(insert_axis, batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False) pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True) +# TODO(yashkatariya, apaszke): Remove this and replace this with `VmapSharding`. +def _pjit_batcher_for_sharding(s, dim, val): + parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, val) + # Use `_from_parsed_pspec` because after you `insert_axis_partitions` the + # `sync` attribute changes. To make sure we preserve that, we need to pass + # that parsed partition spec when created the sharding instance. + # Inferring the `PartitiionSpec` from that is easy as done in the classmethod. + return sharding.MeshPspecSharding._from_parsed_pspec(s.mesh, parsed_pspec).normalize() + def _pjit_jvp(primals_in, tangents_in, - jaxpr, in_axis_resources, out_axis_resources, + jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] @@ -847,8 +925,8 @@ def _filter_zeros(is_nz_l, l): outputs = pjit_p.bind( *primals_in, *_filter_zeros_in(tangents_in), jaxpr=jaxpr_jvp, - in_axis_resources=(*in_axis_resources, *_filter_zeros_in(in_axis_resources)), - out_axis_resources=(*out_axis_resources, *_filter_zeros_out(out_axis_resources)), + in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)), + out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)), resource_env=resource_env, donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), name=wrap_name(name, 'jvp'), @@ -864,7 +942,7 @@ def _filter_zeros(is_nz_l, l): def _pjit_partial_eval(trace, *in_tracers, - jaxpr, in_axis_resources, out_axis_resources, + jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): # XXX: At the moment all residuals get fully replicated, which is extremely @@ -886,9 +964,13 @@ def keep_where(l, should_keep): # Compute the known outputs known_params = dict( jaxpr=known_jaxpr, - in_axis_resources=keep_where(in_axis_resources, known_ins), - out_axis_resources=(keep_where(out_axis_resources, known_outs) + - (REPLICATED,) * num_residuals), + in_shardings=keep_where(in_shardings, known_ins), + # TODO(yashkatariya): Remove the `MeshPspecSharding` creation here. + # This is done like this because all output + # shardings are that even in the `Array` codepath. + out_shardings=( + keep_where(out_shardings, known_outs) + + (sharding.MeshPspecSharding(mesh, pxla.PartitionSpec(None)),) * num_residuals), resource_env=resource_env, donated_invars=keep_where(donated_invars, known_ins), name=name, @@ -897,10 +979,10 @@ def keep_where(l, should_keep): if num_residuals: in_is_global = _calc_is_global_sequence( - known_params['in_positional_semantics'], known_params['in_axis_resources']) + known_params['in_positional_semantics'], known_params['in_shardings']) compiled = _pjit_lower( - known_params["jaxpr"], known_params["in_axis_resources"], - known_params["out_axis_resources"], known_params["resource_env"], + known_params["jaxpr"], known_params["in_shardings"], + known_params["out_shardings"], known_params["resource_env"], known_params["donated_invars"], known_params["name"], in_is_global).compile(_allow_propagation_to_outputs=True, _allow_compile_replicated=False) @@ -908,8 +990,10 @@ def keep_where(l, should_keep): residual_specs = tuple(output_ppspec[-num_residuals:]) else: residual_specs = () - known_params['out_axis_resources'] = ( - keep_where(out_axis_resources, known_outs) + residual_specs) + residual_sharding = tuple(sharding.MeshPspecSharding._from_parsed_pspec(mesh, r) + for r in residual_specs) + known_params['out_shardings'] = ( + keep_where(out_shardings, known_outs) + residual_sharding) all_known_outs = pjit_p.bind( *(pv.get_known() for pv in in_pvals if pv.is_known()), @@ -930,8 +1014,8 @@ def keep_where(l, should_keep): # Prepare unknown tracers unknown_params = dict( jaxpr=unknown_jaxpr, - in_axis_resources=(keep_where(in_axis_resources, unknown_ins) + residual_specs), - out_axis_resources=keep_where(out_axis_resources, unknown_outs), + in_shardings=(keep_where(in_shardings, unknown_ins) + residual_sharding), + out_shardings=keep_where(out_shardings, unknown_outs), resource_env=resource_env, donated_invars=(keep_where(donated_invars, unknown_ins) + (False,) * num_residuals), @@ -943,8 +1027,8 @@ def keep_where(l, should_keep): unknown_tracers_out = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in global_to_local(unknown_params["out_positional_semantics"], - mesh, unknown_jaxpr.out_avals, - unknown_params["out_axis_resources"]) + unknown_jaxpr.out_avals, + unknown_params["out_shardings"]) ] eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), unknown_tracers_out, @@ -958,48 +1042,48 @@ def keep_where(l, should_keep): def _pjit_transpose(reduce_axes, cts_in, *primals_in, - jaxpr, in_axis_resources, out_axis_resources, + jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): - mesh = resource_env.physical_mesh - def prune_type(ty, xs, maybe_zeros): - return tuple(x for x, mz in zip(xs, maybe_zeros) if not type(mz) is ty) + return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) body = lu.wrap_init(ad.closed_backward_pass) body = lu.hashable_partial(body, jaxpr, reduce_axes, False) primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) - transpose_in_axis_resources = ( - *prune_type(ad.UndefinedPrimal, in_axis_resources, primals_in), - *prune_type(ad.Zero, out_axis_resources, cts_in) + transpose_in_shardings = ( + *prune_type(ad.UndefinedPrimal, in_shardings, primals_in), + *prune_type(ad.Zero, out_shardings, cts_in) ) transpose_in_positional_semantics = ( *prune_type(ad.UndefinedPrimal, in_positional_semantics, primals_in), *prune_type(ad.Zero, (out_positional_semantics,) * len(cts_in), cts_in) ) - global_cts_in_avals = local_to_global( - transpose_in_positional_semantics, - mesh, - [core.raise_to_shaped(core.get_aval(ct)) for ct in primals_and_nz_cts_in], - transpose_in_axis_resources) + global_cts_in_avals = [core.raise_to_shaped(core.get_aval(ct)) + for ct in primals_and_nz_cts_in] + if not config.jax_array: + global_cts_in_avals = local_to_global( + transpose_in_positional_semantics, global_cts_in_avals, + transpose_in_shardings) + transpose_jaxpr, global_cts_out_avals, consts = pe.trace_to_jaxpr_dynamic( body, global_cts_in_avals) # TODO(apaszke): Creating ClosedJaxpr by hand will break compilation cache! transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts) del consts cts_out_treedef = cts_out_treedef_thunk() - transpose_out_axis_resources = prune_type( + transpose_out_shardings = prune_type( ad.Zero, - in_axis_resources, + in_shardings, tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) nz_cts_out = pjit_p.bind( *primals_and_nz_cts_in, jaxpr=transpose_jaxpr, - in_axis_resources=transpose_in_axis_resources, - out_axis_resources=transpose_out_axis_resources, + in_shardings=transpose_in_shardings, + out_shardings=transpose_out_shardings, resource_env=resource_env, donated_invars=(False,) * len(primals_and_nz_cts_in), name=name, @@ -1028,15 +1112,15 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r what = "pjit input" if resource_env.physical_mesh != params['resource_env'].physical_mesh: raise RuntimeError("Changing the physical mesh is not allowed inside pjit.") - for aval, pos_axis_resources in zip(jaxpr.in_avals, params['in_axis_resources']): - _check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources) + for aval, s in zip(jaxpr.in_avals, params['in_shardings']): + _check_resources_against_named_axes(what, aval, s._parsed_pspec, named_axis_resources) pxla.resource_typecheck( jaxpr.jaxpr, resource_env, named_axis_resources, lambda: (f"a pjit'ed function {params['name']} " f"(pjit called at {source_info_util.summarize(source_info)})")) what = "pjit output" - for aval, pos_axis_resources in zip(jaxpr.out_avals, params['out_axis_resources']): - _check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources) + for aval, s in zip(jaxpr.out_avals, params['out_shardings']): + _check_resources_against_named_axes(what, aval, s._parsed_pspec, named_axis_resources) pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit @@ -1148,73 +1232,88 @@ def get_unconstrained_dims(axis_resources: ParsedPartitionSpec): return {i for i, axes in enumerate(axis_resources) if axes is None} -def global_to_local(positional_semantics, mesh, avals, axes): +def global_to_local(positional_semantics, avals, shardings): + if config.jax_array: + return avals if isinstance(positional_semantics, maps._PositionalSemantics): - positional_semantics = [positional_semantics] * len(axes) + positional_semantics = [positional_semantics] * len(shardings) return [ - aval if ps == maps._PositionalSemantics.GLOBAL or aval_axes.partitions == () else mesh._global_to_local( - get_array_mapping(aval_axes), aval) - for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics) + aval if ps == maps._PositionalSemantics.GLOBAL or + s._parsed_pspec.partitions == () else s.mesh._global_to_local( + get_array_mapping(s._parsed_pspec), aval) + for aval, s, ps in safe_zip(avals, shardings, positional_semantics) ] -def local_to_global(positional_semantics, mesh, avals, axes): +def local_to_global(positional_semantics, avals, shardings): + if config.jax_array: + return avals return [ - aval if ps == maps._PositionalSemantics.GLOBAL or aval_axes.partitions == () else mesh._local_to_global( - get_array_mapping(aval_axes), aval) - for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics) + aval if ps == maps._PositionalSemantics.GLOBAL or + s._parsed_pspec.partitions == () else s.mesh._local_to_global( + get_array_mapping(s._parsed_pspec), aval) + for aval, s, ps in safe_zip(avals, shardings, positional_semantics) ] -def _calc_is_global_sequence(in_positional_semantics, in_axis_resources): +def _calc_is_global_sequence(in_positional_semantics, in_shardings): + if config.jax_array: + return (True,) * len(in_positional_semantics) return tuple( - ips == maps._PositionalSemantics.GLOBAL or p.partitions == () - for ips, p in safe_zip(in_positional_semantics, in_axis_resources)) + ips == maps._PositionalSemantics.GLOBAL or i._parsed_pspec.partitions == () + for ips, i in safe_zip(in_positional_semantics, in_shardings)) def _get_in_positional_semantics(arg) -> maps._PositionalSemantics: if isinstance(arg, GDA): return maps._PositionalSemantics.GLOBAL return maps._positional_semantics.val -def _create_cpspec(x): - return x if _is_unspecified_or_from_gda_or_auto(x) else CanonicalizedParsedPartitionSpec(x) def _maybe_replace_from_gda_with_pspec( - in_axis_resources_flat: Union[CanonicalizedParsedPartitionSpec, _AUTOAxisResource], - arg) -> Union[CanonicalizedParsedPartitionSpec, _AUTOAxisResource]: - if isinstance(arg, (GDA, Array)): - # TODO(yashkatariya): Use `TypeGuard` on `_is_auto` when it is supported. - # Don't use `_is_auto` here to satisfy pytype and mypy. - if isinstance(in_axis_resources_flat, _AUTOAxisResource): - return in_axis_resources_flat - - arr_flavor = 'GDA' if isinstance(arg, GDA) else 'Array' - # TODO(yashkatariya): Don't use `spec` from `MeshPspecSharding`. Write a - # sharding inference handler that will work with any sharding. - gda_or_array_cpspec = CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec.from_user_input( - arg.mesh_axes if arr_flavor == 'GDA' else arg.sharding.spec, # type: ignore # union-attr - arg_name=f"{arr_flavor} spec")) - if (not _is_from_gda(in_axis_resources_flat) and - in_axis_resources_flat != gda_or_array_cpspec): + in_sharding_flat, arg) -> sharding.MeshPspecSharding: + if isinstance(arg, GDA): + gda_cpspec = CanonicalizedParsedPartitionSpec( + ParsedPartitionSpec.from_user_input(arg.mesh_axes, arg_name="GDA spec")) + if (not _is_from_gda(in_sharding_flat) and + in_sharding_flat._parsed_pspec != gda_cpspec): raise ValueError( - f"Got an input {arr_flavor} to pjit with different partitioning than specified in " + f"Got an input GDA to pjit with different partitioning than specified in " "the in_axis_resources argument to pjit. The partitioning must match, or " "use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. " - "Leave in_axis_resources empty for Array. " - f"Got {arr_flavor} spec: {gda_or_array_cpspec.user_spec} and " - f"pjit spec: {in_axis_resources_flat.user_spec} for {arr_flavor}: {arg}") - return gda_or_array_cpspec - return in_axis_resources_flat + f"Got GDA spec: {gda_cpspec.user_spec} and " + f"pjit spec: {in_sharding_flat.spec} for GDA: {arg}") + return sharding.MeshPspecSharding._from_parsed_pspec(arg.mesh, gda_cpspec) + return in_sharding_flat + + +@cache() +def _check_array_device_assignment(pjit_mesh, in_shardings): + if not in_shardings: + return + first_device_assignment = in_shardings[0]._device_assignment() + mesh_devices = list(pjit_mesh.devices.flat) + for i in in_shardings: + arr_device_assignment = i._device_assignment() + if pjit_mesh.empty: + # If mesh is empty, then check if all devices across shardings are + # equal + if first_device_assignment != arr_device_assignment: + raise ValueError("Devices of all `Array` inputs should be the same. " + f"Got array devices: {first_device_assignment},\n " + f"another array devices: {arr_device_assignment}") + else: + # If mesh is not empty, then check devices of all shardings against the + # mesh devices. + if mesh_devices != arr_device_assignment: + raise ValueError("Pjit's devices and Array's devices should be equal. " + f"Got Pjit devices: {list(pjit_mesh.devices.flat)},\n " + f"Array devices: {arr_device_assignment}") -def _maybe_check_pjit_gda_or_array_mesh(args, mesh): +def _maybe_check_pjit_gda_mesh(args, mesh): for x in args: if isinstance(x, GDA) and x.mesh != mesh: raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit " f"mesh: {mesh},\n GDA mesh: {x.mesh}") - if isinstance(x, Array) and x.sharding.mesh != mesh: - raise ValueError("Pjit's mesh and Array's mesh should be equal. Got Pjit " - f"mesh: {mesh},\n Array mesh: {x.sharding.mesh}") # -------------------- XLA OpSharding to PartitionSpec -------------------- # Note that OpSharding is more expressive than PartitionSpecs, so it's not @@ -1350,18 +1449,19 @@ def parse_flatten_op_sharding(op_sharding: xc.OpSharding, else: raise AssertionError("Unhandled OpSharding type. Please open a bug report!") +_get_single_pspec = lambda p: pxla.array_mapping_to_axis_resources( + cast(pxla.ArrayMapping, get_array_mapping(p))) def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]: - return [pxla.array_mapping_to_axis_resources(cast(pxla.ArrayMapping, get_array_mapping(p))) - for p in ppspec] + return [_get_single_pspec(p) for p in ppspec] def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]: input_op_shardings: Sequence[xc.OpSharding] = executable.hlo_modules()[0].spmd_parameters_shardings output_op_sharding: xc.OpSharding = executable.hlo_modules()[0].spmd_output_sharding in_ppspec: List[ParsedPartitionSpec] = [] - for sharding in input_op_shardings: - in_ppspec.extend(parse_flatten_op_sharding(sharding, mesh)) + for s in input_op_shardings: + in_ppspec.extend(parse_flatten_op_sharding(s, mesh)) out_ppspec = parse_flatten_op_sharding(output_op_sharding, mesh) return in_ppspec, out_ppspec diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index 438fcf4237c9..459fc95b90c2 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -14,12 +14,12 @@ import abc from collections import Counter -from typing import Sequence, Tuple, Optional, Mapping, Dict, Set +from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union from jax._src.util import cache, safe_zip from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax.interpreters import pxla +from jax.interpreters import pxla, mlir import numpy as np @@ -62,6 +62,10 @@ class XLACompatibleSharding(Sharding): def _device_assignment(self) -> XLADeviceAssignment: raise NotImplementedError('Subclasses should implement this method.') + @abc.abstractmethod + def normalize(self): + raise NotImplementedError('Subclasses should implement this method.') + @abc.abstractmethod def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: raise NotImplementedError('Subclasses should implement this method.') @@ -72,16 +76,57 @@ def _addressable_device_assignment(self) -> XLADeviceAssignment: return [d for d in self._device_assignment() if d.process_index == process_index] @abc.abstractmethod - def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: + def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]: raise NotImplementedError('Subclasses should implement this method.') class MeshPspecSharding(XLACompatibleSharding): - def __init__(self, mesh: pxla.Mesh, spec: pxla.PartitionSpec): + def __init__( + self, mesh: pxla.Mesh, + spec: Union[pxla.PartitionSpec, pxla._AUTOAxisResource], + _parsed_pspec = None): + self.mesh = mesh self.spec = spec + # This split exists because you can pass `_parsed_pspec` that has been + # modified from the original. For example: Adding extra dimension to + # axis_resources for vmap handlers. In such cases you need to preserve the + # `sync` attribute of parsed pspecs. + # PartitionSpec is inferred from the parsed pspec in this case. + # TODO(yaskatariya): Remove this and replace this with a normalized + # representation of Parsed Pspec + if _parsed_pspec is None: + from jax.experimental import pjit + self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources( + self.spec, "MeshPspecSharding spec") + else: + self._parsed_pspec = _parsed_pspec + + def __repr__(self): + return f'MeshPspecSharding(\n mesh={self.mesh},\n partition_spec={self.spec})' + + def __hash__(self): + return hash((self.mesh, self.spec)) + + def __eq__(self, other): + return self.mesh == other.mesh and self.spec == other.spec + + def normalize(self): + from jax.experimental import pjit + cp = (self._parsed_pspec if pjit._is_auto(self._parsed_pspec) else + pjit.CanonicalizedParsedPartitionSpec(self._parsed_pspec)) + return MeshPspecSharding._from_parsed_pspec(self.mesh, cp) + + @classmethod + def _from_parsed_pspec(cls, mesh, parsed_pspec): + from jax.experimental import pjit + parsed_pspec, spec = ((parsed_pspec, parsed_pspec) + if pjit._is_auto(parsed_pspec) else + (parsed_pspec, pjit._get_single_pspec(parsed_pspec))) + return cls(mesh, spec, parsed_pspec) + @pxla.maybe_cached_property def device_set(self) -> Set[Device]: return set(self.mesh.devices.flat) @@ -94,6 +139,11 @@ def devices_indices_map( # TODO(yashkatariya): Remove this when utilities are moved to pxla.py. from jax.experimental import global_device_array + if pxla._is_auto(self.spec): + raise ValueError('Getting indices when the sharding is not known is not ' + 'possible. Please get the sharding from XLA and then ' + 'create the Array with that sharding.') + # `get_shard_indices` is cached. return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec) @@ -114,16 +164,27 @@ def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: def _device_assignment(self) -> XLADeviceAssignment: return list(self.mesh.devices.flat) - def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: + @cache() + def _to_xla_op_sharding( + self, num_dimensions: int, + axis_ctx: Optional[mlir.SPMDAxisContext] = None) -> Optional[xc.OpSharding]: from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources + assert not pxla._is_auto(self.spec) + parsed_spec, _, _, _ = _prepare_axis_resources(self.spec, "spec") array_mapping = get_array_mapping(parsed_spec) # TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding # since we don't really need sharding spec. sharding_spec = pxla.new_mesh_sharding_specs( self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping) - return sharding_spec.sharding_proto() + # Used in `with_sharding_constraint`. + special_axes = {} + if axis_ctx is not None: + axis_names = self.mesh.axis_names + for manual_axis in axis_ctx.manual_axes: + special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL + return sharding_spec.sharding_proto(special_axes=special_axes) class SingleDeviceSharding(XLACompatibleSharding): @@ -131,6 +192,20 @@ class SingleDeviceSharding(XLACompatibleSharding): def __init__(self, device: Device): self._device = device + def __repr__(self): + return f"SingleDeviceSharding(device={self._device})" + + def __hash__(self): + return hash(self._device) + + def __eq__(self, other): + if not isinstance(other, SingleDeviceSharding): + return False + return self._device == other._device + + def normalize(self): + return SingleDeviceSharding(self._device) + @pxla.maybe_cached_property def device_set(self) -> Set[Device]: return {self._device} @@ -150,7 +225,8 @@ def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: def _device_assignment(self) -> XLADeviceAssignment: return [self._device] - def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: + @cache() + def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]: proto = xc.OpSharding() proto.type = xc.OpSharding.Type.REPLICATED return proto @@ -163,6 +239,9 @@ def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec): # The sharding spec should be pmap's sharding spec. self.sharding_spec = sharding_spec + def normalize(self): + return PmapSharding(self.devices, self.sharding_spec) + @pxla.maybe_cached_property def device_set(self) -> Set[Device]: return set(self.devices.flat) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b0129ed4a8f0..1a95a53a72ff 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -505,10 +505,10 @@ def testVmapModifiesAxisResources(self): jaxpr = jax.make_jaxpr(jax.vmap(h, in_axes=(None, 0)))(x, y).jaxpr eqn = jaxpr.eqns[0] self.assertIs(eqn.primitive, pjit_p) - x_sync, y_sync = (spec.sync for spec in eqn.params['in_axis_resources']) + x_sync, y_sync = (s._parsed_pspec.sync for s in eqn.params['in_shardings']) self.assertEqual(x_sync, SpecSync.IN_SYNC) self.assertEqual(y_sync, SpecSync.DIM_PERMUTE) - x_sync, y_sync, z_sync = (spec.sync for spec in eqn.params['out_axis_resources']) + x_sync, y_sync, z_sync = (s._parsed_pspec.sync for s in eqn.params['out_shardings']) self.assertEqual(x_sync, SpecSync.DIM_PERMUTE) self.assertEqual(y_sync, SpecSync.IN_SYNC) self.assertEqual(z_sync, SpecSync.DIM_PERMUTE) @@ -546,9 +546,9 @@ def testShardingInXMap(self): def _test_rule(*args, **kwargs): nonlocal test_rule_called test_rule_called = True - in_axis_resources = kwargs['in_axis_resources'] - self.assertEqual(len(in_axis_resources), 1) - self.assertIn(('y',), in_axis_resources[0].partitions) + in_shardings = kwargs['in_shardings'] + self.assertEqual(len(in_shardings), 1) + self.assertIn(('y',), in_shardings[0]._parsed_pspec.partitions) return rule(*args, **kwargs) try: mlir._lowerings[pjit_p] = _test_rule @@ -1116,9 +1116,8 @@ def cb(index): "Got an input GDA to pjit with different partitioning than specified " 'in the in_axis_resources argument to pjit. The partitioning must match, or ' 'use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. ' - 'Leave in_axis_resources empty for Array. ' "Got GDA spec: PartitionSpec('x',) and " - "pjit spec: PartitionSpec('x', 'y') " + "pjit spec: PartitionSpec(('x',), ('y',)) " 'for GDA: GlobalDeviceArray(shape=(8, 2), dtype=float32)'): @partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y')) def f(x): @@ -1378,7 +1377,8 @@ def test_pjit_array_single_output(self, out_axis_resources, shard_shape): with jax._src.config.jax_array(True): with global_mesh: - f = pjit(lambda x: x @ x.T, out_axis_resources=out_axis_resources) + f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding( + global_mesh, out_axis_resources)) expected_matrix_mul = input_data @ input_data.T out = f(input_array) @@ -1398,7 +1398,8 @@ def test_non_array_input_error(self): with jax._src.config.jax_array(True): with global_mesh: f = pjit(lambda x: x, - out_axis_resources=P('x', 'y')) + out_axis_resources=MeshPspecSharding( + global_mesh, P('x', 'y'))) with self.assertRaisesRegex( ValueError, ('All arguments to pjit when `config.jax_array` is ' 'enabled should be `Array`s.')): @@ -1480,34 +1481,28 @@ def f(tree): for s in out4.addressable_shards: self.assertArraysEqual(s.data._arrays[0], input_data) - def test_in_axis_resources_mismatch_error(self): - global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - mesh_axes = P('x', 'y') - - input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes) - + def test_in_axis_resources_error(self): with jax._src.config.jax_array(True): - with global_mesh: - f = pjit(lambda x: x, in_axis_resources=P('x')) - with self.assertRaisesRegex( + with self.assertRaisesRegex( ValueError, - ('Got an input Array to pjit with different partitioning ' - 'than specified in the in_axis_resources argument to pjit')): - f(input_array) - - def test_in_axis_resources_same_as_array_sharding(self): - global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - mesh_axes = P('x', 'y') - - input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes) + ('in_axis_resources should be empty for Array. The sharding ' + 'should be specified on the arguments as pjit follows ' + 'computation follows data semantics.')): + pjit(lambda x: x, in_axis_resources=P('x')) + def test_out_axis_resources_error(self): with jax._src.config.jax_array(True): - with global_mesh: - out = pjit(lambda x: x, in_axis_resources=P('x' ,'y'))(input_array) - self.assertIsInstance(out, array.Array) + with self.assertRaisesRegex( + ValueError, + ('When `config.jax_array` flag is enabled, ' + 'out_axis_resources should contain instances of `Sharding`.')): + pjit(lambda x: x, out_axis_resources=P('x')) + def test_no_input_output(self): + with jax._src.config.jax_array(True): + def f(): + pass + pjit(f) def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")