diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 7929be95ce64..d76bc8727d95 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -105,64 +105,9 @@ def __repr__(self): AxisName = core.AxisName ResourceAxisName = AxisName # Different name just for documentation purposes Mesh = pxla.Mesh - -class _Loop(NamedTuple): - name: ResourceAxisName - length: int - -class ResourceEnv(NamedTuple): - physical_mesh: Mesh - loops: Tuple[_Loop, ...] - - def with_mesh(self, mesh: Mesh): - overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names)) - if overlap: - raise ValueError(f"Cannot update the mesh of the current resource " - f"environment. The new mesh shadows already defined axes " - f"{show_axes(overlap)}") - return self._replace(physical_mesh=mesh) - - def with_extra_loop(self, loop: _Loop): - if loop.name in self.resource_axes: - raise ValueError(f"Cannot extend the resource environment with loop named " - f"`{loop.name}`. An axis of this name is already defined!") - return self._replace(loops=self.loops + (loop,)) - - @property - def physical_resource_axes(self) -> Set[ResourceAxisName]: - return set(self.physical_mesh.axis_names) - - @property - def loop_resource_axes(self) -> Set[ResourceAxisName]: - return set(loop.name for loop in self.loops) - - @property - def resource_axes(self) -> Set[ResourceAxisName]: - return self.physical_resource_axes | self.loop_resource_axes - - @property - def shape(self): - shape = self.physical_mesh.shape - shape.update(self.loops) - return shape - - @property - def local_shape(self): - shape = self.physical_mesh.local_mesh.shape - shape.update(self.loops) - return shape - - def __repr__(self): - return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})" - -EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ()) - -class _ThreadResourcesLocalState(threading.local): - - def __init__(self): - self.env = EMPTY_ENV - -thread_resources = _ThreadResourcesLocalState() +ResourceEnv = pxla.ResourceEnv +EMPTY_ENV = pxla.EMPTY_ENV +thread_resources = pxla.thread_resources class SerialLoop: @@ -232,7 +177,7 @@ def serial_loop(name: ResourceAxisName, length: int): axis_resources={'i': 'l'})(x) """ old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) - thread_resources.env = old_env.with_extra_loop(_Loop(name, length)) + thread_resources.env = old_env.with_extra_loop(pxla._Loop(name, length)) try: yield finally: @@ -268,6 +213,7 @@ def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]): out_axes=['left', 'right', ...], axis_resources={'left': 'x', 'right': 'y'})(x, x.T) """ + # TODO(yashkatariya): Deprecate this context manager. old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) thread_resources.env = old_env.with_mesh(Mesh(np.asarray(devices, dtype=object), axis_names)) try: @@ -998,8 +944,6 @@ def _typecheck_xmap( return out_avals core.custom_typechecks[xmap_p] = _typecheck_xmap -def show_axes(axes): - return ", ".join(sorted([f"`{a}`" for a in axes])) def _resource_typing_xmap(avals, params, @@ -1014,7 +958,7 @@ def _resource_typing_xmap(avals, raise JAXTypeError( f"Detected disallowed xmap axis name shadowing at " f"{source_info_util.summarize(source_info)} " - f"(shadowed axes: {show_axes(overlap)})") + f"(shadowed axes: {pxla.show_axes(overlap)})") if resource_env.physical_mesh != params['resource_env'].physical_mesh: raise RuntimeError("Changing the physical mesh is not allowed inside xmap.") @@ -1042,9 +986,9 @@ def _resource_typing_xmap(avals, raise JAXTypeError( f"One of xmapped function ({params['name']}) outputs is broadcast " f"along axis `{baxis}` which is assigned to resources " - f"{show_axes(baxis_resources)}, but the output is already " - f"partitioned along {show_axes(overlap)}, because its " - f"named shape contains {show_axes(partitioning_axes)}") + f"{pxla.show_axes(baxis_resources)}, but the output is already " + f"partitioned along {pxla.show_axes(overlap)}, because its " + f"named shape contains {pxla.show_axes(partitioning_axes)}") pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index b40dc869cadf..53e6de385779 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -209,7 +209,7 @@ def infer_params(*args, **kwargs): f"was called with only {len(args)} positional arguments.") # Putting this outside of wrapped would make resources lexically scoped - resource_env = maps.thread_resources.env + resource_env = pxla.thread_resources.env mesh = resource_env.physical_mesh if mesh.empty: raise RuntimeError("pjit requires a non-empty mesh! Are you sure that " @@ -551,7 +551,7 @@ def _check_unique_resources(axis_resources, arg_name): if multiple_uses: raise ValueError(f"A single {arg_name} specification can map every mesh axis " f"to at most one positional dimension, but {arg_axis_resources.user_spec} " - f"has duplicate entries for {maps.show_axes(multiple_uses)}") + f"has duplicate entries for {pxla.show_axes(multiple_uses)}") def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape, flat_avals, flat_axis_resources): @@ -897,7 +897,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} " f"that uses one or more mesh axes already used by xmap to partition " f"a named axis appearing in its named_shape (both use mesh axes " - f"{maps.show_axes(overlap)})") + f"{pxla.show_axes(overlap)})") def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources): jaxpr = params["jaxpr"] @@ -925,7 +925,7 @@ def with_sharding_constraint(x, axis_resources): axis_resources_flat = tuple( flatten_axes("with_sharding_constraint axis_resources", tree, parsed_axis_resources)) - resource_env = maps.thread_resources.env + resource_env = pxla.thread_resources.env mesh = resource_env.physical_mesh _check_shapes_against_resources( "with_sharding_constraint arguments", diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 8835c85f463f..3d6f0881fda6 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -28,6 +28,8 @@ # This encoding is assumed by various parts of the system, e.g. generating # replica groups for collective operations. +from __future__ import annotations + from contextlib import contextmanager from collections import defaultdict, OrderedDict import dataclasses @@ -1787,6 +1789,9 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, # ------------------- xmap ------------------- class Mesh: + devices: np.ndarray + axis_names: Tuple[MeshAxisName, ...] + _old_env: ResourceEnv def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]): assert devices.ndim == len(axis_names) @@ -1814,6 +1819,16 @@ def __setattr__(self, name, value): raise RuntimeError("Cannot reassign attributes of immutable mesh objects") super().__setattr__(name, value) + def __enter__(self): + self._old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) + thread_resources.env = self._old_env.with_mesh( + Mesh(self.devices, self.axis_names)) + return thread_resources.env.physical_mesh + + def __exit__(self, exc_type, exc_value, traceback): + thread_resources.env = self._old_env + return False + @property def shape(self): return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape)) @@ -1885,6 +1900,72 @@ def global_to_local(self, axes: ArrayMapping, aval): tile_aval_nd(self.shape, axes, aval)) +ResourceAxisName = core.AxisName + +class _Loop(NamedTuple): + name: ResourceAxisName + length: int + + +def show_axes(axes): + return ", ".join(sorted([f"`{a}`" for a in axes])) + + +class ResourceEnv(NamedTuple): + physical_mesh: Mesh + loops: Tuple[_Loop, ...] + + def with_mesh(self, mesh: Mesh): + overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names)) + if overlap: + raise ValueError(f"Cannot update the mesh of the current resource " + f"environment. The new mesh shadows already defined axes " + f"{show_axes(overlap)}") + return self._replace(physical_mesh=mesh) + + def with_extra_loop(self, loop: _Loop): + if loop.name in self.resource_axes: + raise ValueError(f"Cannot extend the resource environment with loop named " + f"`{loop.name}`. An axis of this name is already defined!") + return self._replace(loops=self.loops + (loop,)) + + @property + def physical_resource_axes(self) -> Set[ResourceAxisName]: + return set(self.physical_mesh.axis_names) + + @property + def loop_resource_axes(self) -> Set[ResourceAxisName]: + return set(loop.name for loop in self.loops) + + @property + def resource_axes(self) -> Set[ResourceAxisName]: + return self.physical_resource_axes | self.loop_resource_axes + + @property + def shape(self): + shape = self.physical_mesh.shape + shape.update(self.loops) + return shape + + @property + def local_shape(self): + shape = self.physical_mesh.local_mesh.shape + shape.update(self.loops) + return shape + + def __repr__(self): + return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})" + +EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ()) + +class _ThreadResourcesLocalState(threading.local): + + def __init__(self): + self.env = EMPTY_ENV + +thread_resources = _ThreadResourcesLocalState() + + def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval): if aval is core.abstract_unit: return aval diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3ac811f271d2..7e4baae08c85 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -113,6 +113,25 @@ def f(x, y): self.assertAllClose(actual.device_buffers[0].to_py(), expected, check_dtypes=False) + def testBasic1DWithMeshContextManager(self): + @partial(pjit, + in_axis_resources=(P('x'), P('x')), + out_axis_resources=None) + def f(x, y): + return x + y + + shape = (8, 8) + x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + with jtu.create_global_mesh((2,), ('x')) as mesh: + actual = f(x, x + 1) + expected = x + (x + 1) + self.assertEqual(mesh, jtu.create_global_mesh((2,), ('x'))) + self.assertAllClose(actual, expected, check_dtypes=False) + self.assertIsInstance(actual, pxla.ShardedDeviceArray) + self.assertLen(actual.device_buffers, 2) + self.assertAllClose(actual.device_buffers[0].to_py(), expected, + check_dtypes=False) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testBasic2D(self): @partial(pjit, @@ -141,6 +160,35 @@ def f(x, y): self.assertAllClose(actual.device_buffers[3].to_py(), split1, check_dtypes=False) + def testBasic2DWithMeshContextManager(self): + @partial(pjit, + in_axis_resources=(P(None, 'x', 'y'), P('y')), + out_axis_resources=P('x')) + def f(x, y): + return x @ y + + x_shape = (8, 6, 4) + y_shape = (4, 2) + x = jnp.arange(np.prod(x_shape)).reshape(x_shape) + y = jnp.arange(np.prod(y_shape)).reshape(y_shape) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + with mesh: + actual = f(x, y) + expected = x @ y + self.assertAllClose(actual, expected, check_dtypes=False) + self.assertIsInstance(actual, pxla.ShardedDeviceArray) + self.assertLen(actual.device_buffers, 4) + + split0, split1 = np.split(expected, 2) + self.assertAllClose(actual.device_buffers[0].to_py(), split0, + check_dtypes=False) + self.assertAllClose(actual.device_buffers[1].to_py(), split0, + check_dtypes=False) + self.assertAllClose(actual.device_buffers[2].to_py(), split1, + check_dtypes=False) + self.assertAllClose(actual.device_buffers[3].to_py(), split1, + check_dtypes=False) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testTwoMeshAxisSharding(self): @partial(pjit, @@ -671,6 +719,41 @@ def f(x): 'in_axis_resources cannot be `pjit.FROM_GDA`.')): f(input_data) + def test_pjit_gda_single_output_with_mesh_context_manager(self): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_input_shape = (8, 2) + mesh_axes = P('x', 'y') + input_data = np.arange( + prod(global_input_shape)).reshape(global_input_shape) + def cb(index): + return input_data[index] + + gda_obj = global_device_array.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes, cb) + + with jax._src.config.parallel_functions_output_gda(True): + with global_mesh: + @partial(pjit, in_axis_resources=FROM_GDA, out_axis_resources=P('x', 'y')) + def f(x): + return x @ x.T + expected_matrix_mul = input_data @ input_data.T + + out = f(gda_obj) + self.assertIsInstance(out, global_device_array.GlobalDeviceArray) + self.assertEqual(out.shape, (8, 8)) + self.assertEqual(out.local_shards[0].data.shape, (2, 4)) + self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2}) + for s in out.local_shards: + self.assertArraysEqual(s.data, expected_matrix_mul[s.index]) + + out2 = f(out) + self.assertIsInstance(out2, global_device_array.GlobalDeviceArray) + + with self.assertRaisesRegex( + ValueError, ('For a non-GDA input, the corresponding resource in ' + 'in_axis_resources cannot be `pjit.FROM_GDA`.')): + f(input_data) + @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_multi_input_multi_output(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))