From 3eb3e2dcdb895ee88ce84d1688a07c994dfe5178 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 1 Dec 2023 12:24:21 +0200 Subject: [PATCH] [export] Simplify the handling of shardings in Exported. Previously, Exported contained tuples of `XlaCompatibleSharding` for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported. We change Exported to carry `xla_client.HloSharding` instead, which conveniently can be serialized to proto. We use the value `None` to denote an unspecified sharding. We also add `nr_devices` and then for exporting purposes we can construct actual `XlaCompatibleSharding` when we need to. --- .../export_back_compat_test_util.py | 19 ++- jax/experimental/export/export.py | 123 ++++++++++-------- jax/experimental/jax2tf/jax2tf.py | 76 ++++++----- .../jax2tf/tests/back_compat_test.py | 2 +- .../back_compat_testdata/tpu_Sharding.py | 1 + .../jax2tf/tests/back_compat_tf_test.py | 3 +- 6 files changed, 126 insertions(+), 98 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 00e1dc370e7a..99f6df4c2f59 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -87,7 +87,6 @@ def func(...): ... from jax._src import core from jax._src import test_util as jtu -from jax._src.sharding_impls import UNSPECIFIED from jax._src import xla_bridge as xb @@ -104,6 +103,7 @@ class CompatTestData: mlir_module_text: str mlir_module_serialized: bytes xla_call_module_version: int # The version of XlaCallModule to use for testing + nr_devices: int = 1 # The dummy_data is used for getting started for adding a new test and for @@ -187,6 +187,7 @@ def run_one_test(self, func: Callable[..., jax.Array], will fail when the serialization changes. Otherwise, when checking old serializations you can specify what custom calls are expected in the current serialization. + nr_devices: the number of devices for which the data was serialized. """ if not isinstance(data, CompatTestData): raise ValueError(f"Expecting data: CompatTestData but got {data}. " @@ -202,7 +203,7 @@ def run_one_test(self, func: Callable[..., jax.Array], res_run_current = tuple(np.array(a) for a in res_run_current) logging.info("Result of current version run is %s", res_run_current) - serialized, module_str, module_version = self.serialize( + serialized, module_str, module_version, nr_devices = self.serialize( func, data, polymorphic_shapes=polymorphic_shapes, allow_unstable_custom_call_targets=allow_unstable_custom_call_targets) @@ -225,6 +226,7 @@ def run_one_test(self, func: Callable[..., jax.Array], mlir_module_text=r\"\"\"\n{module_str}\"\"\", mlir_module_serialized={serialized!r}, xla_call_module_version={module_version}, + nr_devices={nr_devices}, ) # End paste """ @@ -271,7 +273,7 @@ def serialize(self, func: Callable, data: CompatTestData, *, polymorphic_shapes: Optional[Sequence[str]] = None, allow_unstable_custom_call_targets: Sequence[str] = () - ) -> tuple[bytes, str, int]: + ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: @@ -281,7 +283,8 @@ def serialize(self, custom call targets besides those known as stable. Returns: a tuple with the (a) serialization, (b) the module contents as - a string (for debugging), and (c) the module serialization version. + a string (for debugging), (c) the module serialization version, + (d) the number of devices for which the module was serialized. """ # Use the native exporter, to make sure we get the proper serialization. args_specs = export.args_specs(data.inputs, polymorphic_shapes) @@ -296,7 +299,8 @@ def serialize(self, module_str = str(exported.mlir_module()) serialized = exported.mlir_module_serialized module_version = exported.serialization_version - return serialized, module_str, module_version + nr_devices = exported.nr_devices + return serialized, module_str, module_version, nr_devices def run_serialized(self, data: CompatTestData, polymorphic_shapes: Optional[Sequence[str]] = None): @@ -321,14 +325,15 @@ def _get_vjp(_): in_avals=tuple(in_avals), out_tree=out_tree, out_avals=tuple(out_avals), - in_shardings=(UNSPECIFIED,) * len(in_avals), - out_shardings=(UNSPECIFIED,) * len(out_avals), + in_shardings=(None,) * len(in_avals), + out_shardings=(None,) * len(out_avals), lowering_platforms=(data.platform,), ordered_effects=(), unordered_effects=(), disabled_checks=(), mlir_module_serialized=data.mlir_module_serialized, serialization_version=data.xla_call_module_version, + nr_devices=data.nr_devices, module_kept_var_idx=tuple(range(len(in_avals))), uses_shape_polymorphism=any(not core.is_constant_shape(a.shape) for a in in_avals), diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index 2b4f808083d1..9ac5ec9bcb96 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -122,7 +122,11 @@ def __hash__(self) -> int: _VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 -Sharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] +# The values of input and output sharding from the lowering. +LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] + +# None means unspecified sharding +Sharding = Optional[xla_client.HloSharding] @dataclasses.dataclass(frozen=True) class Exported: @@ -140,9 +144,9 @@ class Exported: out_avals: the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in `in_avals. - in_shardings: the flattened input shardings. Only for the inputs that are - specified in `module_kept_var_idx`. + in_shardings: the flattened input shardings, as long as `in_avals`. out_shardings: the flattened output shardings, as long as `out_avals`. + nr_devices: the number of devices that the module has been lowered for. lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', 'cuda', 'rocm'. See below for the calling convention for when there are multiple lowering platforms. @@ -274,6 +278,7 @@ class Exported: in_shardings: tuple[Sharding, ...] out_shardings: tuple[Sharding, ...] + nr_devices: int lowering_platforms: tuple[str, ...] ordered_effects: tuple[effects.Effect, ...] unordered_effects: tuple[effects.Effect, ...] @@ -521,14 +526,31 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: ordered_effects = unordered_effects = () + nr_devices = len(lowering.compile_args["device_assignment"]) + def export_sharding(s: LoweringSharding, + aval: core.ShapedArray) -> Sharding: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + + all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], + module_kept_var_idx, + len(args_avals_flat)) + in_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(all_in_shardings, args_avals_flat)) + out_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) return Exported( fun_name=fun_name, in_tree=lowered.in_tree, out_tree=lowered.out_tree, in_avals=tuple(args_avals_flat), out_avals=tuple(out_avals_flat), - in_shardings=tuple(lowering.compile_args["in_shardings"]), - out_shardings=tuple(lowering.compile_args["out_shardings"]), + in_shardings=in_shardings, + out_shardings=out_shardings, + nr_devices=nr_devices, lowering_platforms=actual_lowering_platforms, ordered_effects=ordered_effects, unordered_effects=unordered_effects, @@ -921,16 +943,16 @@ def walk_operations(op): "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") raise ValueError(msg) -def expand_in_shardings(in_shardings: tuple[Sharding, ...], +def expand_in_shardings(in_shardings: Sequence[LoweringSharding], module_kept_var_idx: Sequence[int], - nr_inputs: int) -> tuple[Sharding, ...]: + nr_inputs: int) -> Sequence[LoweringSharding]: """Expands in_shardings with unspecified shardings for inputs not kept. Assumes in_shardings corresponds to module_kept_var_idx. """ assert len(in_shardings) == len(module_kept_var_idx) assert nr_inputs >= len(module_kept_var_idx) - all_in_shardings: list[Sharding] = [sharding_impls.UNSPECIFIED] * nr_inputs + all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings): all_in_shardings[idx] = in_s return tuple(all_in_shardings) @@ -938,37 +960,30 @@ def expand_in_shardings(in_shardings: tuple[Sharding, ...], # TODO(yashkatariya, necula): remove this function once we relax the checks # in the jit front-end. def canonical_shardings( + device_assignment: Sequence[jax.Device], in_shardings: Sequence[Sharding], out_shardings: Sequence[Sharding] ) -> tuple[Union[pxla.UnspecifiedValue, Sequence[sharding.XLACompatibleSharding]], Union[pxla.UnspecifiedValue, Sequence[sharding.XLACompatibleSharding]]]: - """Prepares canonical in_ and out_shardings for a jit invocation. + """Prepares canonical in_ and out_shardings for a pjit invocation. The pjit front-end is picky about what in- and out-shardings it accepts, e.g., if all are unspecified then the whole sharding should be the sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are replaced with the replicated sharding. - """ - # Prepare a replicated sharding, search in both the input and output shardings - specified_shardings = [ - s for s in itertools.chain(in_shardings, out_shardings) - if not sharding_impls.is_unspecified(s)] - if specified_shardings: - in_s = specified_shardings[0] # pjit will enforce that all have same devices - assert isinstance(in_s, sharding.XLACompatibleSharding) - replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) - else: - replicated_s = None + Returns: a pair with the canonicalized input and output shardings. + """ + replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment) def canonicalize( ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue, Sequence[sharding.XLACompatibleSharding]]: - if all(sharding_impls.is_unspecified(s) for s in ss): + if all(s is None for s in ss): return sharding_impls.UNSPECIFIED return tuple( - s if not sharding_impls.is_unspecified(s) else replicated_s + sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s for s in ss) return (canonicalize(in_shardings), canonicalize(out_shardings)) @@ -976,9 +991,9 @@ def _get_vjp_fun(primal_fun: Callable, *, in_tree: tree_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], - module_kept_var_idx: tuple[int, ...], in_shardings: tuple[Sharding, ...], out_shardings: tuple[Sharding, ...], + nr_devices: int, apply_jit: bool ) -> tuple[Callable, Sequence[core.AbstractValue]]: # Since jax.vjp does not handle kwargs, it is easier to do all the work @@ -1000,30 +1015,30 @@ def flattened_primal_fun_jax(*args_flat): itertools.chain(in_avals, map(lambda a: a.at_least_vspace(), out_avals))) - all_in_shardings = expand_in_shardings(in_shardings, - module_kept_var_idx, len(in_avals)) - vjp_in_shardings, vjp_out_shardings = canonical_shardings( - tuple(itertools.chain(all_in_shardings, out_shardings)), - all_in_shardings) - if apply_jit: + # Prepare a device assignment. For exporting purposes, all it matters + # is the number of devices. + device_assignment = jax.devices(jax.default_backend())[:nr_devices] + assert len(device_assignment) == nr_devices + vjp_in_shardings, vjp_out_shardings = canonical_shardings( + device_assignment, + tuple(itertools.chain(in_shardings, out_shardings)), + in_shardings) return pjit.pjit(fun_vjp_jax, in_shardings=vjp_in_shardings, out_shardings=vjp_out_shardings), vjp_in_avals else: - assert vjp_in_shardings == sharding_impls.UNSPECIFIED - assert vjp_out_shardings == sharding_impls.UNSPECIFIED return fun_vjp_jax, vjp_in_avals def _export_native_vjp(primal_fun, primal: Exported) -> Exported: # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun, in_tree=primal.in_tree, - module_kept_var_idx=primal.module_kept_var_idx, in_avals=primal.in_avals, in_shardings=primal.in_shardings, out_avals=primal.out_avals, out_shardings=primal.out_shardings, + nr_devices=primal.nr_devices, apply_jit=True) return export(fun_vjp_jax, lowering_platforms=primal.lowering_platforms, @@ -1154,13 +1169,24 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, if exported.uses_shape_polymorphism: ctx.module_context.shape_poly_state.uses_dim_vars = True + axis_context = ctx.module_context.axis_context + if isinstance(axis_context, sharding_impls.ShardingContext): + ctx_device_assignment = axis_context.device_assignment + elif isinstance(axis_context, sharding_impls.SPMDAxisContext): + ctx_device_assignment = list(axis_context.mesh.devices.flat) + else: + raise NotImplementedError(type(axis_context)) + if len(ctx_device_assignment) != exported.nr_devices: + raise NotImplementedError( + f"Exported module {exported.fun_name} was lowered for " + f"{exported.nr_devices} devices and is called in a context with " + f"{len(ctx_device_assignment)} devices" + ) + # Apply in_shardings - all_in_shardings = expand_in_shardings(exported.in_shardings, - exported.module_kept_var_idx, - len(args)) args = tuple( - wrap_with_sharding(ctx, exported, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(args, ctx.avals_in, all_in_shardings)) + wrap_with_sharding(ctx, x, x_aval, x_sharding) + for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings)) submodule = ir.Module.parse(exported.mlir_module()) symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called @@ -1251,7 +1277,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra exported.out_avals, ctx.avals_out)) # Apply out_shardings results = tuple( - wrap_with_sharding(ctx, exported, x, x_aval, x_sharding) + wrap_with_sharding(ctx, x, x_aval, x_sharding) for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings) ) return results @@ -1264,27 +1290,10 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra mlir.register_lowering(call_exported_p, _call_exported_lowering) def wrap_with_sharding(ctx: mlir.LoweringRuleContext, - exported: Exported, x: ir.Value, x_aval: core.AbstractValue, x_sharding: Sharding) -> ir.Value: - if sharding_impls.is_unspecified(x_sharding): + if x_sharding is None: return x - axis_context = ctx.module_context.axis_context - if isinstance(axis_context, sharding_impls.ShardingContext): - ctx_device_assignment = axis_context.device_assignment - elif isinstance(axis_context, sharding_impls.SPMDAxisContext): - ctx_device_assignment = list(axis_context.mesh.devices.flat) - else: - raise NotImplementedError(type(axis_context)) - assert isinstance(x_sharding, sharding_impls.XLACompatibleSharding) - sharding_device_assignment = x_sharding._device_assignment - if len(ctx_device_assignment) != len(sharding_device_assignment): - raise NotImplementedError( - f"Exported module {exported.fun_name} was lowered for " - f"{len(sharding_device_assignment)} devices and is called in a context with " - f"{len(ctx_device_assignment)} devices" - ) return mlir.wrap_with_sharding_op( - ctx, x, x_aval, - x_sharding._to_xla_hlo_sharding(x_aval.ndim).to_proto()) + ctx, x, x_aval, x_sharding.to_proto()) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 63756c2150c4..5dfad2df0096 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -512,13 +512,13 @@ def run_fun_tf(self, def get_vjp_fun(self) -> tuple[Callable, Sequence[core.AbstractValue]]: return export._get_vjp_fun(self.fun_jax, - in_tree=self.exported.in_tree, - module_kept_var_idx=self.exported.module_kept_var_idx, - in_avals=self.exported.in_avals, - in_shardings=self.exported.in_shardings, - out_avals=self.exported.out_avals, - out_shardings=self.exported.out_shardings, - apply_jit=True) + in_tree=self.exported.in_tree, + in_avals=self.exported.in_avals, + in_shardings=self.exported.in_shardings, + out_avals=self.exported.out_avals, + out_shardings=self.exported.out_shardings, + nr_devices=self.exported.nr_devices, + apply_jit=True) class GraphSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, @@ -584,13 +584,13 @@ def get_vjp_fun(self) -> tuple[Callable, # except we use unspecified shardings, and we do not apply a jit on the # VJP. This matches the older behavior of jax2tf for graph serialization. return export._get_vjp_fun(self.fun_jax, - in_tree=self.in_tree, - module_kept_var_idx=tuple(range(len(self.args_avals_flat))), - in_avals=self.args_avals_flat, - in_shardings=(sharding_impls.UNSPECIFIED,) * len(self.args_avals_flat), - out_avals=self.outs_avals, - out_shardings=(sharding_impls.UNSPECIFIED,) * len(self.outs_avals), - apply_jit=False) + in_tree=self.in_tree, + in_avals=self.args_avals_flat, + in_shardings=(None,) * len(self.args_avals_flat), + out_avals=self.outs_avals, + out_shardings=(None,) * len(self.outs_avals), + nr_devices=1, # Does not matter for unspecified shardings + apply_jit=False) def dtype_of_val(val: TfVal) -> DType: @@ -890,10 +890,13 @@ def _convert_value(val, aval): # Do not apply XlaSharding for REPLICATED, on inputs and outputs. # This is an agreed convention, and also improves usability under TF eager. # See b/255511660. - if exported.in_shardings is not None: - args_flat_tf = tuple( - map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), - kept_args_flat_tf, kept_args_avals, exported.in_shardings)) + kept_in_shardings = [] + for i in exported.module_kept_var_idx: + kept_in_shardings.append(exported.in_shardings[i]) + args_flat_tf = tuple( + map(partial(_shard_value, + skip_replicated_sharding=tf.executing_eagerly()), + kept_args_flat_tf, kept_in_shardings)) res = tfxla.call_module(args_flat_tf, **call_module_attrs) # TODO(b/278940799): Replace the TF v1 API with public TF2 API. # Add the custom call tf.function into the default graph, so those functions @@ -904,10 +907,9 @@ def _convert_value(val, aval): concrete_fn._inference_function ) - if exported.out_shardings is not None: - res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), - res, exported.out_avals, exported.out_shardings)) - + res = list(map(partial(_shard_value, + skip_replicated_sharding=tf.executing_eagerly()), + res, exported.out_shardings)) res = tuple(map(_convert_value, res, exported.out_avals)) return res @@ -3405,17 +3407,21 @@ def split_to_logical_devices(tensor: TfVal, return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) +def _xla_compatible_sharding_to_hlo_sharding( + s: sharding.XLACompatibleSharding, + aval: core.ShapedArray) -> Optional[xla_client.HloSharding]: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + def _shard_value(val: TfVal, - aval: core.ShapedArray, - sd: sharding.XLACompatibleSharding, *, + sd: Optional[xla_client.HloSharding], *, skip_replicated_sharding: bool) -> TfVal: """Apply sharding to a TfVal.""" - if sharding_impls.is_unspecified(sd): + if sd is None: return val - sharding_proto: xla_client.OpSharding = cast( - xla_client.OpSharding, sd._to_xla_hlo_sharding(aval.ndim).to_proto()) # type: ignore - + sharding_proto = sd.to_proto() if (skip_replicated_sharding and op_shardings.is_op_sharding_replicated(sharding_proto)): return val @@ -3465,17 +3471,21 @@ def _pjit(*args: TfVal, _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars # Apply sharding annotation to the arguments + in_hlo_shardings: Sequence[Optional[xla_client.HloSharding]] = map( + _xla_compatible_sharding_to_hlo_sharding, in_shardings, _in_avals) sharded_args: Sequence[TfVal] = tuple( map(partial(_shard_value, skip_replicated_sharding=not _thread_local_state.enable_xla), - args, _in_avals, in_shardings)) + args, in_hlo_shardings)) results = _interpret_jaxpr(jaxpr, *sharded_args, extra_name_stack=util.wrap_name(name, "pjit"), fresh_constant_cache=False) + out_hlo_shardings: Sequence[Optional[xla_client.HloSharding]] = map( + _xla_compatible_sharding_to_hlo_sharding, out_shardings, _out_aval) sharded_results: Sequence[TfVal] = tuple( map(partial(_shard_value, skip_replicated_sharding=not _thread_local_state.enable_xla), - results, _out_aval, out_shardings)) + results, out_hlo_shardings)) return tuple(sharded_results) @@ -3483,12 +3493,14 @@ def _pjit(*args: TfVal, def _pjit_sharding_constraint(arg: TfVal, *, - sharding: sharding.NamedSharding, + sharding: sharding.XLACompatibleSharding, resource_env: maps.ResourceEnv, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, **kwargs) -> TfVal: - return _shard_value(arg, _in_avals[0], sharding, skip_replicated_sharding=False) + hlo_sharding = _xla_compatible_sharding_to_hlo_sharding(sharding, _in_avals[0]) + return _shard_value(arg, hlo_sharding, + skip_replicated_sharding=False) tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index d30116b82372..d866e0c2b89e 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -585,7 +585,7 @@ def func(x): def test_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: - self.skipTest("Test runs only on TPU with at least 2 devices") + self.skipTest("Test runs only on TPU with at least 2 devices") # Must use exactly 2 devices for expected outputs from ppermute devices = jax.devices()[:2] diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py b/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py index 69072d23f03b..f2d8be3b958a 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py @@ -45,4 +45,5 @@ """, mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01\x1b\x05\x01\x05\x01\x03\x05\x03\x0b\x07\t\x0b\r\x0f\x03\x9d\x81\r\x01K\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0bS\x0b\x0b\x0b\x0b\x17\x0b\x13\x0b33\x0b\x0bS\x1b\x0b\x0b\x0f\x0b\x17SS\x13\x0b\x037\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x8f\x0b\x03\r\x17\x17\x07\x07\x17\x17\x02\xb6\x04\x1f\x1d1%\x05\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x1d3%\x05#\x03\x13\x05O\x07M\tY\x0bK\rQ\x0f[\x11K\x13K\x15K\x05%\x05'\x05)\x05+\x17'\x02\x02\x01\x05-\x03\x03\x19+\x05/\x03\x0b\x1d_\x1fS!k\x19q#s\x03\x0b\x1dU\x1fS!U\x19W#w\x051\x053\x03\x13\x05O\x07M\ty\x0bK\rQ\x0f]\x11K\x13K\x15K\x03\x059{;}\x055\x057\x1d?A\x059\x17'\x12\x05\x01\x03\x13\x05O\x07M\tY\x0bK\rQ\x0f]\x11K\x13K\x15K\x03\x13\x05O\x07M\t\x7f\x0bK\rQ\x0f[\x11K\x13K\x15K\x03\x03IW\x05;\x03\x01\x1d=\x0b\x03\x05\x01#\t\x03\x03u\x1d?\x1dA\x1dC\x1dE\x03\x03a\r\x05cegi\x1dG\x1dI\x1d\x1b\x1dK\x03\x03m\r\x03oM\x1dM\x1dO\x1dQ\r\x01\x1dS\x1dU\x13\x07\x05\x1f\x0bA\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dW)\x05\t\x11\x05)\x05\x05\x11\x05\t\x1d\x11\x03\x01\x03\x01)\x05\t\t\x07\x04\xd3\x05\x01\x11\x01)\x07\x03\x01\t\x05\x11\x01-\x05\x03\x05\x0b\x03\x01\x01\x0b\x07\x03G\x03\x01\x03\x01\x07\x04\x01\x03\x03\x05\x11\x03/\x05\x03\x11#\x03\x01\x01\x03\x07\x03\x1b\x03\x01\x03\x01\x03\x07\x17\x1b\x03\x01\x03\x03\x03\x07\x175\x03\x03\x03\x05\t\x07=7\x03\x03\x03\x07\x03\x07\x17C\x03\x03\x03\t\x03\x07\x17E\x03\x01\x03\x0b\x03\x07\x03\x1b\x03\x01\x03\r\x07\x04\x03\x03\x0f\x06\x03\x01\x05\x01\x00\x82\x13Y++\x11\x0f\x0b!\x1b\x11\x1b\x13'\x13\x11\x03\x0f\xa3)\x17\x9e\x02\x1e\x06\x19\x83\x1f\x15\x1d\x15\x13\x1f/!\x1d!)#\x1f\x19\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.sharding\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit_wrapped\x00jit(wrapped)/jit(main)/pjit[in_shardings=(GSPMDSharding({devices=[2,1]0,1}),) out_shardings=(GSPMDSharding({devices=[2,1]0,1}),) resource_env=ResourceEnv(Mesh(device_ids=array([0, 1]), axis_names=('a',)), ()) donated_invars=(False,) name=wrapped in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit(wrapped)/jit(main)/pjit(wrapped)/shard_map[mesh=Mesh(device_ids=array([0, 1]), axis_names=('a',)) in_names=({0: ('a',)},) out_names=({0: ('a',)},) check_rep=True]\x00channel_id\x00source_target_pairs\x00jit(wrapped)/jit(main)/pjit(wrapped)/ppermute[axis_name=a perm=((0, 1), (1, 0))]\x00callee\x00\x00wrapped\x00Sharding\x00{devices=[2,1]0,1}\x00{manual}\x00jax.arg_info\x00args[0]\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00SPMDFullToShardShape\x00SPMDShardToFullShape\x00", xla_call_module_version=4, + nr_devices=2, ) # End paste diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index d054d58060bf..7f7d13f0c92e 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -119,7 +119,8 @@ def serialize( options=tf.saved_model.SaveOptions(experimental_custom_gradients=False), ) serialized = serialize_directory(saved_model_dir) - return serialized, module_str, module_version + nr_devices = 1 + return serialized, module_str, module_version, nr_devices def run_serialized( self,