Skip to content

Commit

Permalink
Allow pjit.AUTO to be used with jax.jit. This introduces an API chang…
Browse files Browse the repository at this point in the history
…e which requires a mesh to be provided to pjit.AUTO(mesh).

`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
  • Loading branch information
yashk2810 authored and jax authors committed May 21, 2023
1 parent e0b5003 commit b71829f
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 129 deletions.
100 changes: 47 additions & 53 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -62,7 +62,7 @@
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTOAxisResource, UnspecifiedValue, UNSPECIFIED,
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
)
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
Expand Down Expand Up @@ -1693,7 +1693,7 @@ class TileManual:

def check_if_any_auto(
shardings: Iterable[Union[sharding_impls.XLACompatibleSharding,
AUTOAxisResource, UnspecifiedValue]]) -> bool:
AUTO, UnspecifiedValue]]) -> bool:
for s in shardings:
if is_auto(s):
return True
Expand Down Expand Up @@ -1755,8 +1755,7 @@ class DeviceAssignmentMismatchError(Exception):


ShardingInfo = Tuple[
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue,
AUTOAxisResource],
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports


Expand All @@ -1775,13 +1774,14 @@ def _get_and_check_device_assignment(
devices = tuple(devices)

for i, s_type, source_info in shardings:
if is_auto(i) or is_unspecified(i):
if is_unspecified(i):
continue
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
# skipped.

if first_sharding_info is None:
first_sharding_info = (i._device_assignment, s_type, source_info) # type: ignore
arr_device_assignment = i._device_assignment # type: ignore
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
else (i._device_assignment, s_type, source_info)) # type: ignore
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
Expand Down Expand Up @@ -1815,7 +1815,7 @@ def wrapped(f, *args, **kwargs):

@cache_wrap
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars):
keep_unused, donated_invars, auto_spmd_lowering):
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))

if isinstance(fun_or_jaxpr, lu.WrappedFun):
Expand All @@ -1830,7 +1830,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts

if (keep_unused or
if (keep_unused or auto_spmd_lowering or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
Expand Down Expand Up @@ -2006,10 +2006,14 @@ def lower_sharding_computation(
the singleton UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
auto_spmd_lowering = (
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore

(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars)
donated_invars, auto_spmd_lowering)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)

Expand Down Expand Up @@ -2091,14 +2095,13 @@ def lower_sharding_computation(
module,
False,
donated_invars,
mesh=None,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
out_shardings=out_shardings,
spmd_lowering=True,
tuple_args=tuple_args,
auto_spmd_lowering=False,
auto_spmd_lowering=auto_spmd_lowering,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=host_callbacks,
Expand All @@ -2112,7 +2115,7 @@ def lower_sharding_computation(


def _to_logical_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTO]
) -> Optional[sharding_impls.XLACompatibleSharding]:
if is_unspecified(sharding) or is_auto(sharding):
return None
Expand All @@ -2131,9 +2134,9 @@ def lower_mesh_computation(
api_name: str,
fun_name: str,
mesh: Mesh,
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource]],
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource,
UnspecifiedValue]],
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO]],
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO,
UnspecifiedValue]],
donated_invars: Sequence[bool],
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
Expand All @@ -2143,11 +2146,6 @@ def lower_mesh_computation(
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))

auto_spmd_lowering = check_if_any_auto((*in_shardings, *out_shardings))

if auto_spmd_lowering and not spmd_lowering:
raise ValueError('Enable spmd_lowering to use auto spmd lowering.')

global_axis_sizes = mesh.shape

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
Expand All @@ -2171,7 +2169,6 @@ def lower_mesh_computation(
else:
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_shardings)
assert not auto_spmd_lowering
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
# is why `.spec` can be accessed.
Expand All @@ -2181,7 +2178,6 @@ def lower_mesh_computation(
in_jaxpr_avals = global_in_avals
else:
assert isinstance(tiling_method, TileVectorize)
assert not auto_spmd_lowering
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
# why `.spec` can be accessed.
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
Expand Down Expand Up @@ -2274,14 +2270,13 @@ def lower_mesh_computation(
lowering_result.module,
False,
donated_invars,
mesh=mesh,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
out_shardings=out_shardings,
spmd_lowering=spmd_lowering,
tuple_args=tuple_args,
auto_spmd_lowering=auto_spmd_lowering,
auto_spmd_lowering=False,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=lowering_result.host_callbacks,
Expand Down Expand Up @@ -2501,26 +2496,20 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da

dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
np.arange(len(device_assignment))
)
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
else:
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
np.arange(len(device_assignment))
)
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
else:
num_replicas, num_partitions = dev.size, 1
num_replicas, num_partitions = dev.size, 1

if pmap_nreps > 1:
# In `jit` device_assignment is set to None when num_replicas > 1. Do
Expand Down Expand Up @@ -2610,14 +2599,11 @@ def load(self) -> MeshExecutable:
@staticmethod
def from_hlo(name: str,
hlo: ir.Module,
# TODO(yashkatariya): Remove `mesh` from here once AUTO can work
# without mesh.
mesh: Optional[Mesh],
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource]],
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource,
UnspecifiedValue]],
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO]],
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO,
UnspecifiedValue]],
spmd_lowering: bool,
tuple_args: bool,
auto_spmd_lowering: bool,
Expand All @@ -2641,6 +2627,14 @@ def from_hlo(name: str,
device_assignment, _DeviceAssignment) else tuple(device_assignment)
del device_assignment
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)

mesh = None
if auto_spmd_lowering:
for i in it.chain.from_iterable([in_shardings, out_shardings]):
if is_auto(i):
mesh = i.mesh # type: ignore
break

xla_executable, compile_options = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
Expand All @@ -2661,7 +2655,7 @@ def from_hlo(name: str,
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if is_auto(o) else (o, False)
Expand Down
62 changes: 26 additions & 36 deletions jax/_src/pjit.py
Expand Up @@ -54,7 +54,7 @@
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
AUTOAxisResource, UNSPECIFIED, UnspecifiedValue,
AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.traceback_util import api_boundary
Expand All @@ -72,10 +72,10 @@

traceback_util.register_exclusion(__file__)

PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTO]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -342,13 +342,22 @@ def lower(*args, **kwargs):
donate_argnums) = infer_params_fn(*args, **kwargs)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'], always_lower=True,
lowering_platform=_experimental_lowering_platform)
try:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'], always_lower=True,
lowering_platform=_experimental_lowering_platform)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
arg_names = _get_arg_names(fun, in_tree, args_flat)
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
raise ValueError(msg) from None

if kwargs:
args_kwargs_in_tree = in_tree
Expand Down Expand Up @@ -1210,29 +1219,9 @@ def _pjit_lower_cached(
mesh = None
api_name = 'jit'

# Convert to `NamedSharding` when `jax_array` is not enabled. This is
# because GDA/SDA/DA are dependent on mesh for generating outputs.
# NamedSharding is required for host-local inputs too.
any_auto = pxla.check_if_any_auto(it.chain(in_shardings, out_shardings))
if any_auto:
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
Tuple[MeshShardingMinusUnspecified, ...], tuple(
NamedSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
if isinstance(i, GSPMDSharding) else i
for i in in_shardings
))
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
Tuple[MeshSharding, ...], tuple(
NamedSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
if isinstance(o, GSPMDSharding) else o
for o in out_shardings
))

# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
if dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
return pxla.lower_mesh_computation(
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
Expand Down Expand Up @@ -1929,10 +1918,11 @@ def _fast_path_get_device_assignment(
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
da = None
for i in shardings:
if is_auto(i) or is_unspecified(i):
if is_unspecified(i):
continue
da = i._device_assignment # type: ignore
break
if is_auto(i):
return i.mesh._flat_devices_tuple # type: ignore
return i._device_assignment # type: ignore
return da


Expand Down
17 changes: 9 additions & 8 deletions jax/_src/sharding_impls.py
Expand Up @@ -719,12 +719,14 @@ def get_replicated(cls, device_assignment):
return cls(tuple(device_assignment), proto)


class AUTOAxisResource:
pass
AUTO = AUTOAxisResource()
class AUTO:

def __init__(self, mesh: mesh_lib.Mesh):
self.mesh = mesh


def is_auto(x):
return isinstance(x, AUTOAxisResource)
return isinstance(x, AUTO)


class UnspecifiedValue:
Expand Down Expand Up @@ -757,8 +759,7 @@ def is_unspecified_or_auto(x):
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
UnspecifiedValue]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]

def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
if not array_mapping:
Expand All @@ -779,11 +780,11 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
return PartitionSpec(*partitions)

def get_array_mapping(
axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue]
axis_resources: Union[ParsedPartitionSpec, AUTO, UnspecifiedValue]
) -> ArrayMappingOrAutoOrUnspecified:
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
# Don't use `is_auto` here to satisfy pytype and mypy.
if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)):
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)
Expand Down

0 comments on commit b71829f

Please sign in to comment.