Skip to content

Commit

Permalink
First attempt to enable auto-sharding. This CL adds support for GDA (…
Browse files Browse the repository at this point in the history
…no SDA support yet).

An example of using auto sharding with GDA:

```
f = pjit(lambda x: x, in_axis_resources=pjit.AUTO, out_axis_resources=pjit.AUTO)

sharding_info = pjit.get_sharding_from_xla(f, mesh, [(8, 2)], [np.int32])

inputs = [GlobalDeviceArray.from_callback(shape, mesh, ip, cb) for ip in sharding_info.in_pspec]

# Use the compiled function (which was compiled in get_sharding_from_xla)
out = sharding_info.compiled(*inputs) # Recommended way!
# OR
out = f(*inputs)
```
PiperOrigin-RevId: 438708483
  • Loading branch information
yashk2810 authored and jax authors committed Apr 1, 2022
1 parent 3184dd6 commit 8ca8f74
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 51 deletions.
21 changes: 20 additions & 1 deletion jax/_src/test_util.py
Expand Up @@ -19,7 +19,7 @@
import re
import os
import textwrap
from typing import Dict, List, Generator, Sequence, Tuple, Union
from typing import Dict, List, Generator, Sequence, Tuple, Union, NamedTuple
import unittest
import warnings
import zlib
Expand All @@ -30,6 +30,7 @@
import numpy as np
import numpy.random as npr

from jax import stages
from jax._src import api
from jax import core
from jax._src import dtypes as _dtypes
Expand All @@ -42,6 +43,8 @@
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.maps import Mesh
from jax.interpreters.sharded_jit import PartitionSpec
from jax.experimental import pjit


FLAGS = flags.FLAGS
Expand Down Expand Up @@ -1165,6 +1168,22 @@ def __get__(self, obj, cls):
return self._value


class _XLAShardingInfo(NamedTuple):
in_pspec: Tuple[PartitionSpec]
out_pspec: Tuple[PartitionSpec]
compiled: stages.Compiled


def compile_and_get_sharding(pjitted_fn, mesh, global_inputs):
# TODO(yashkatariya): Check if the pjitted_fn comes from pjit.
inputs = [core.ShapedArray(i.shape, i.dtype) for i in global_inputs]
compiled = pjitted_fn.lower(*inputs, _global_avals=True).compile()
in_sharding, out_sharding = pjit._get_sharding_from_executable(
compiled.runtime_executable(), mesh)
return _XLAShardingInfo(in_pspec=in_sharding, out_pspec=out_sharding,
compiled=compiled)


class _LazyDtypes:
"""A class that unifies lists of supported dtypes.
Expand Down
107 changes: 90 additions & 17 deletions jax/experimental/pjit.py
Expand Up @@ -15,7 +15,7 @@
from enum import IntEnum
import numpy as np
from collections import OrderedDict, Counter
from typing import Callable, Sequence, Tuple, Union, Optional
from typing import Callable, Sequence, Tuple, Union, Optional, cast, List
import itertools as it
from functools import partial

Expand Down Expand Up @@ -60,6 +60,10 @@ def _is_from_gda(x):
# doesn't cause an error to avoid user confusion.
return isinstance(x, type(FROM_GDA))

_AUTOAxisResource = pxla._AUTOAxisResource
AUTO = pxla.AUTO
_is_auto = pxla._is_auto


# TODO(yashkatariya): Add pjit microbenchmarks.
def pjit(fun: Callable,
Expand Down Expand Up @@ -245,9 +249,11 @@ def infer_params(*args, _global_avals=False, **kwargs):
global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources(
mesh, local_in_avals, hashable_pytree(in_axis_resources), in_tree,
in_positional_semantics, tuple(isinstance(a, GDA) for a in args_flat))

jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, global_in_avals, HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources))

canonicalized_in_axis_resources_flat = tree_map(
_maybe_replace_from_gda_with_pspec,
canonicalized_in_axis_resources_flat, tuple(args_flat))
Expand Down Expand Up @@ -379,7 +385,7 @@ def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk,
# Use canonicalized in_axis_resources here because we want to treat P(None)
# and None (for example) as equivalent.
if all(
(not _is_from_gda(p) and p.partitions == ()) or ips == maps._PositionalSemantics.GLOBAL
(not _is_from_gda(p) and not _is_auto(p) and p.partitions == ()) or ips == maps._PositionalSemantics.GLOBAL
for p, ips in safe_zip(canonicalized_in_axis_resources_flat, in_positional_semantics)):
# Shapes should be checked against non canonicalized in_axis_resources.
# For example, partitions of () and ((),) are not equivalent, since the
Expand Down Expand Up @@ -539,14 +545,19 @@ def _prepare_axis_resources(axis_resources,
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
entries = [
# TODO(yashkatariya): Allow AUTO and other specification together after
# auto sharding pass supports user specified annotations.
all_auto = pxla._check_if_all_or_none_auto(entries, arg_name)
if not all_auto:
entries = [
entry if _is_from_gda(entry) else ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims)
for entry in entries
]
_check_unique_resources(entries, arg_name)
]
_check_unique_resources(entries, arg_name)
return tree_unflatten(treedef, entries), entries, treedef


def _check_resources_mismatch(in_axis_resources_flat, is_gda):
if not is_gda and _is_from_gda(in_axis_resources_flat):
raise ValueError('For a non-GDA input, the corresponding resource in '
Expand All @@ -573,6 +584,8 @@ def _check_shapes_against_resources(what: str, is_global_shape: bool,
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
if _is_from_gda(aval_axis_resources):
continue
if _is_auto(aval_axis_resources):
continue
shape = aval.shape
if len(shape) < len(aval_axis_resources):
raise ValueError(f"One of {what} was given the resource assignment "
Expand Down Expand Up @@ -608,6 +621,12 @@ def _pjit_call_impl(*args, jaxpr,
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, in_is_global).compile()
# Check the GDA sharding and the sharding returned by the auto spmd partitoner
# only if auto_spmd_lowering is enabled.
# TODO(yashkatariya): Move this check to `def call()` method of MeshExecutable.
if compiled._auto_spmd_lowering:
in_pspec, _ = _get_sharding_from_executable(compiled.xla_executable, resource_env.physical_mesh)
_check_gda_xla_sharding_match(args, in_pspec)
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled.unsafe_call(*args)
Expand Down Expand Up @@ -791,16 +810,14 @@ def keep_where(l, should_keep):
if num_residuals:
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_axis_resources'])
executable = _pjit_lower(
compiled = _pjit_lower(
known_params["jaxpr"], known_params["in_axis_resources"],
known_params["out_axis_resources"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global).compile(_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
output_op_sharding = \
executable.xla_executable.hlo_modules()[0].spmd_output_sharding
output_sharding_specs = parse_op_sharding(output_op_sharding, mesh)
residual_specs = tuple(output_sharding_specs[-num_residuals:])
_, output_ppspec = _get_ppspec_from_executable(compiled.xla_executable, mesh)
residual_specs = tuple(output_ppspec[-num_residuals:])
else:
residual_specs = ()
known_params['out_axis_resources'] = (
Expand Down Expand Up @@ -1024,7 +1041,11 @@ def _resource_typing_sharding_constraint(avals, params, source_info, resource_en

# -------------------- helpers --------------------

def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
def get_array_mapping(axis_resources: Union[ParsedPartitionSpec, _AUTOAxisResource]) -> pxla.ArrayMappingOrAuto:
# TODO(yashkatariya): Use `TypeGuard` on `_is_auto` when it is supported.
# Don't use `_is_auto` here to satisfy pytype and mypy.
if isinstance(axis_resources, _AUTOAxisResource):
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)
if axes is not None for axis in axes)
Expand Down Expand Up @@ -1072,6 +1093,25 @@ def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))

def _check_gda_xla_sharding_match(args, in_pspec):
for arg, ip in safe_zip(args, in_pspec):
if not isinstance(arg, GDA):
continue

gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(
arg.mesh_axes, arg_name="GDA mesh_axes"))
in_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(ip, arg_name="auto sharding pspec"))
if in_cpspec != gda_cpspec:
raise ValueError(
"GDA sharding does not match the sharding returned by auto spmd "
"partitioner. Did you create the GDA with the input sharding "
"returned by XLA? If yes, please file a bug. "
f"Got GDA spec: {gda_cpspec.user_spec} and "
f"auto sharding spec: {in_cpspec.user_spec} for GDA: {arg}")


def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSemantics:
if isinstance(arg, GDA):
return maps._PositionalSemantics.GLOBAL
Expand All @@ -1080,11 +1120,16 @@ def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSem
return maps._positional_semantics.val

def _create_cpspec(x):
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)
return x if _is_from_gda(x) or _is_auto(x) else CanonicalizedParsedPartitionSpec(x)

def _maybe_replace_from_gda_with_pspec(
in_axis_resources_flat: CanonicalizedParsedPartitionSpec, arg) -> CanonicalizedParsedPartitionSpec:
in_axis_resources_flat: Union[CanonicalizedParsedPartitionSpec, _AUTOAxisResource],
arg) -> Union[CanonicalizedParsedPartitionSpec, _AUTOAxisResource]:
if isinstance(arg, GDA):
# TODO(yashkatariya): Use `TypeGuard` on `_is_auto` when it is supported.
# Don't use `_is_auto` here to satisfy pytype and mypy.
if isinstance(in_axis_resources_flat, _AUTOAxisResource):
return in_axis_resources_flat
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(
arg.mesh_axes, arg_name="GDA mesh_axes"))
Expand Down Expand Up @@ -1205,11 +1250,15 @@ def explode_superdims(sizes, dims):
final_dims += reversed(new_dims)
return final_dims

def parse_op_sharding(op_sharding, mesh):
def parse_flatten_op_sharding(op_sharding: xc.OpSharding,
mesh: pxla.Mesh) -> Sequence[ParsedPartitionSpec]:
if op_sharding.type == xc.OpSharding.Type.TUPLE:
return [parse_op_sharding(s, mesh) for s in op_sharding.tuple_shardings]
out: List[ParsedPartitionSpec] = []
for s in op_sharding.tuple_shardings:
out.extend(parse_flatten_op_sharding(s, mesh))
return out
elif op_sharding.type == xc.OpSharding.Type.REPLICATED:
return REPLICATED
return [REPLICATED]
elif op_sharding.type == xc.OpSharding.Type.OTHER:
mesh_shape = mesh.shape
mesh_axis_order = unflatten_array(mesh.shape, op_sharding.tile_assignment_devices)
Expand All @@ -1233,6 +1282,30 @@ def parse_op_sharding(op_sharding, mesh):
raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!")
if replicate_on_last_tile_dim:
partitions = partitions[:-1]
return ParsedPartitionSpec('<internally generated spec>', partitions)
return [ParsedPartitionSpec('<internally generated spec>', partitions)]
else:
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")


def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
return [pxla.array_mapping_to_axis_resources(cast(pxla.ArrayMapping, get_array_mapping(p)))
for p in ppspec]


def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
input_op_shardings: Sequence[xc.OpSharding] = executable.hlo_modules()[0].spmd_parameters_shardings
output_op_sharding: xc.OpSharding = executable.hlo_modules()[0].spmd_output_sharding
in_ppspec: List[ParsedPartitionSpec] = []
for sharding in input_op_shardings:
in_ppspec.extend(parse_flatten_op_sharding(sharding, mesh))
out_ppspec = parse_flatten_op_sharding(output_op_sharding, mesh)
return in_ppspec, out_ppspec


def _get_sharding_from_executable(
executable, mesh: pxla.Mesh
) -> Tuple[Tuple[PartitionSpec, ...], Tuple[PartitionSpec, ...]]:
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
out_partition_spec = _get_partition_spec(out_ppspec)
in_partition_spec = _get_partition_spec(in_ppspec)
return tuple(in_partition_spec), tuple(out_partition_spec)

0 comments on commit 8ca8f74

Please sign in to comment.