Skip to content

Commit

Permalink
[export] Simplify the handling of shardings in Exported.
Browse files Browse the repository at this point in the history
Previously, Exported contained tuples of `XlaCompatibleSharding`
for the input and output shardings. These shardings contain references
to JAX devices, which is too much for exporting purposes and in fact
it gets in the way when we want to serialize the Exported.

We change Exported to carry `xla_client.HloSharding` instead, which
conveniently can be serialized to proto. We use the value `None` to
denote an unspecified sharding. We also add `nr_devices`
and then for exporting purposes we can construct actual
`XlaCompatibleSharding` when we need to.
  • Loading branch information
gnecula committed Dec 2, 2023
1 parent b822801 commit 3eb3e2d
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 98 deletions.
19 changes: 12 additions & 7 deletions jax/_src/internal_test_util/export_back_compat_test_util.py
Expand Up @@ -87,7 +87,6 @@ def func(...): ...

from jax._src import core
from jax._src import test_util as jtu
from jax._src.sharding_impls import UNSPECIFIED
from jax._src import xla_bridge as xb


Expand All @@ -104,6 +103,7 @@ class CompatTestData:
mlir_module_text: str
mlir_module_serialized: bytes
xla_call_module_version: int # The version of XlaCallModule to use for testing
nr_devices: int = 1


# The dummy_data is used for getting started for adding a new test and for
Expand Down Expand Up @@ -187,6 +187,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
will fail when the serialization changes. Otherwise, when checking old
serializations you can specify what custom calls are expected in the
current serialization.
nr_devices: the number of devices for which the data was serialized.
"""
if not isinstance(data, CompatTestData):
raise ValueError(f"Expecting data: CompatTestData but got {data}. "
Expand All @@ -202,7 +203,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
res_run_current = tuple(np.array(a) for a in res_run_current)
logging.info("Result of current version run is %s", res_run_current)

serialized, module_str, module_version = self.serialize(
serialized, module_str, module_version, nr_devices = self.serialize(
func, data,
polymorphic_shapes=polymorphic_shapes,
allow_unstable_custom_call_targets=allow_unstable_custom_call_targets)
Expand All @@ -225,6 +226,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
mlir_module_text=r\"\"\"\n{module_str}\"\"\",
mlir_module_serialized={serialized!r},
xla_call_module_version={module_version},
nr_devices={nr_devices},
) # End paste
"""
Expand Down Expand Up @@ -271,7 +273,7 @@ def serialize(self,
func: Callable, data: CompatTestData, *,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_unstable_custom_call_targets: Sequence[str] = ()
) -> tuple[bytes, str, int]:
) -> tuple[bytes, str, int, int]:
"""Serializes the test function.
Args:
Expand All @@ -281,7 +283,8 @@ def serialize(self,
custom call targets besides those known as stable.
Returns: a tuple with the (a) serialization, (b) the module contents as
a string (for debugging), and (c) the module serialization version.
a string (for debugging), (c) the module serialization version,
(d) the number of devices for which the module was serialized.
"""
# Use the native exporter, to make sure we get the proper serialization.
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
Expand All @@ -296,7 +299,8 @@ def serialize(self,
module_str = str(exported.mlir_module())
serialized = exported.mlir_module_serialized
module_version = exported.serialization_version
return serialized, module_str, module_version
nr_devices = exported.nr_devices
return serialized, module_str, module_version, nr_devices

def run_serialized(self, data: CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None):
Expand All @@ -321,14 +325,15 @@ def _get_vjp(_):
in_avals=tuple(in_avals),
out_tree=out_tree,
out_avals=tuple(out_avals),
in_shardings=(UNSPECIFIED,) * len(in_avals),
out_shardings=(UNSPECIFIED,) * len(out_avals),
in_shardings=(None,) * len(in_avals),
out_shardings=(None,) * len(out_avals),
lowering_platforms=(data.platform,),
ordered_effects=(),
unordered_effects=(),
disabled_checks=(),
mlir_module_serialized=data.mlir_module_serialized,
serialization_version=data.xla_call_module_version,
nr_devices=data.nr_devices,
module_kept_var_idx=tuple(range(len(in_avals))),
uses_shape_polymorphism=any(not core.is_constant_shape(a.shape)
for a in in_avals),
Expand Down
123 changes: 66 additions & 57 deletions jax/experimental/export/export.py
Expand Up @@ -122,7 +122,11 @@ def __hash__(self) -> int:
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9

Sharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]

# None means unspecified sharding
Sharding = Optional[xla_client.HloSharding]

@dataclasses.dataclass(frozen=True)
class Exported:
Expand All @@ -140,9 +144,9 @@ class Exported:
out_avals: the flat tuple of output abstract values. May contain dimension
expressions in the shapes, with dimension variables among those in
`in_avals.
in_shardings: the flattened input shardings. Only for the inputs that are
specified in `module_kept_var_idx`.
in_shardings: the flattened input shardings, as long as `in_avals`.
out_shardings: the flattened output shardings, as long as `out_avals`.
nr_devices: the number of devices that the module has been lowered for.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See below for the calling convention for when
there are multiple lowering platforms.
Expand Down Expand Up @@ -274,6 +278,7 @@ class Exported:

in_shardings: tuple[Sharding, ...]
out_shardings: tuple[Sharding, ...]
nr_devices: int
lowering_platforms: tuple[str, ...]
ordered_effects: tuple[effects.Effect, ...]
unordered_effects: tuple[effects.Effect, ...]
Expand Down Expand Up @@ -521,14 +526,31 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
ordered_effects = unordered_effects = ()

nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> Sharding:
if sharding_impls.is_unspecified(s):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]

all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,
len(args_avals_flat))
in_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(all_in_shardings, args_avals_flat))
out_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
return Exported(
fun_name=fun_name,
in_tree=lowered.in_tree,
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),
out_avals=tuple(out_avals_flat),
in_shardings=tuple(lowering.compile_args["in_shardings"]),
out_shardings=tuple(lowering.compile_args["out_shardings"]),
in_shardings=in_shardings,
out_shardings=out_shardings,
nr_devices=nr_devices,
lowering_platforms=actual_lowering_platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
Expand Down Expand Up @@ -921,64 +943,57 @@ def walk_operations(op):
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)

def expand_in_shardings(in_shardings: tuple[Sharding, ...],
def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
module_kept_var_idx: Sequence[int],
nr_inputs: int) -> tuple[Sharding, ...]:
nr_inputs: int) -> Sequence[LoweringSharding]:
"""Expands in_shardings with unspecified shardings for inputs not kept.
Assumes in_shardings corresponds to module_kept_var_idx.
"""
assert len(in_shardings) == len(module_kept_var_idx)
assert nr_inputs >= len(module_kept_var_idx)
all_in_shardings: list[Sharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
all_in_shardings[idx] = in_s
return tuple(all_in_shardings)

# TODO(yashkatariya, necula): remove this function once we relax the checks
# in the jit front-end.
def canonical_shardings(
device_assignment: Sequence[jax.Device],
in_shardings: Sequence[Sharding],
out_shardings: Sequence[Sharding]
) -> tuple[Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]],
Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]]]:
"""Prepares canonical in_ and out_shardings for a jit invocation.
"""Prepares canonical in_ and out_shardings for a pjit invocation.
The pjit front-end is picky about what in- and out-shardings it accepts,
e.g., if all are unspecified then the whole sharding should be the
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
replaced with the replicated sharding.
"""
# Prepare a replicated sharding, search in both the input and output shardings
specified_shardings = [
s for s in itertools.chain(in_shardings, out_shardings)
if not sharding_impls.is_unspecified(s)]
if specified_shardings:
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
else:
replicated_s = None
Returns: a pair with the canonicalized input and output shardings.
"""
replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment)
def canonicalize(
ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]]:
if all(sharding_impls.is_unspecified(s) for s in ss):
if all(s is None for s in ss):
return sharding_impls.UNSPECIFIED
return tuple(
s if not sharding_impls.is_unspecified(s) else replicated_s
sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s
for s in ss)
return (canonicalize(in_shardings), canonicalize(out_shardings))

def _get_vjp_fun(primal_fun: Callable, *,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
module_kept_var_idx: tuple[int, ...],
in_shardings: tuple[Sharding, ...],
out_shardings: tuple[Sharding, ...],
nr_devices: int,
apply_jit: bool
) -> tuple[Callable, Sequence[core.AbstractValue]]:
# Since jax.vjp does not handle kwargs, it is easier to do all the work
Expand All @@ -1000,30 +1015,30 @@ def flattened_primal_fun_jax(*args_flat):
itertools.chain(in_avals,
map(lambda a: a.at_least_vspace(), out_avals)))

all_in_shardings = expand_in_shardings(in_shardings,
module_kept_var_idx, len(in_avals))
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
tuple(itertools.chain(all_in_shardings, out_shardings)),
all_in_shardings)

if apply_jit:
# Prepare a device assignment. For exporting purposes, all it matters
# is the number of devices.
device_assignment = jax.devices(jax.default_backend())[:nr_devices]
assert len(device_assignment) == nr_devices
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
device_assignment,
tuple(itertools.chain(in_shardings, out_shardings)),
in_shardings)
return pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_out_shardings), vjp_in_avals
else:
assert vjp_in_shardings == sharding_impls.UNSPECIFIED
assert vjp_out_shardings == sharding_impls.UNSPECIFIED
return fun_vjp_jax, vjp_in_avals

def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
in_tree=primal.in_tree,
module_kept_var_idx=primal.module_kept_var_idx,
in_avals=primal.in_avals,
in_shardings=primal.in_shardings,
out_avals=primal.out_avals,
out_shardings=primal.out_shardings,
nr_devices=primal.nr_devices,
apply_jit=True)
return export(fun_vjp_jax,
lowering_platforms=primal.lowering_platforms,
Expand Down Expand Up @@ -1154,13 +1169,24 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True

axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
ctx_device_assignment = axis_context.device_assignment
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
ctx_device_assignment = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))
if len(ctx_device_assignment) != exported.nr_devices:
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{len(ctx_device_assignment)} devices"
)

# Apply in_shardings
all_in_shardings = expand_in_shardings(exported.in_shardings,
exported.module_kept_var_idx,
len(args))
args = tuple(
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, all_in_shardings))
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
submodule = ir.Module.parse(exported.mlir_module())
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
Expand Down Expand Up @@ -1251,7 +1277,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
exported.out_avals, ctx.avals_out))
# Apply out_shardings
results = tuple(
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
)
return results
Expand All @@ -1264,27 +1290,10 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
mlir.register_lowering(call_exported_p, _call_exported_lowering)

def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
exported: Exported,
x: ir.Value,
x_aval: core.AbstractValue,
x_sharding: Sharding) -> ir.Value:
if sharding_impls.is_unspecified(x_sharding):
if x_sharding is None:
return x
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
ctx_device_assignment = axis_context.device_assignment
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
ctx_device_assignment = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))
assert isinstance(x_sharding, sharding_impls.XLACompatibleSharding)
sharding_device_assignment = x_sharding._device_assignment
if len(ctx_device_assignment) != len(sharding_device_assignment):
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{len(sharding_device_assignment)} devices and is called in a context with "
f"{len(ctx_device_assignment)} devices"
)
return mlir.wrap_with_sharding_op(
ctx, x, x_aval,
x_sharding._to_xla_hlo_sharding(x_aval.ndim).to_proto())
ctx, x, x_aval, x_sharding.to_proto())

0 comments on commit 3eb3e2d

Please sign in to comment.