Skip to content

Commit

Permalink
Merge mesh and Mesh. Make Mesh a context manager + class so tha…
Browse files Browse the repository at this point in the history
…t it can be used in the following ways:

```
global_mesh = Mesh(devices, axis_names)
with global_mesh:
  ...

OR

with Mesh(devices, axis_names) as global_mesh:
  ...

OR

global_mesh = Mesh(devices, axis_names)
with global_mesh as m:
  ...
```
PiperOrigin-RevId: 429201126
  • Loading branch information
yashk2810 authored and jax authors committed Feb 17, 2022
1 parent e25259e commit a83695a
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 69 deletions.
74 changes: 9 additions & 65 deletions jax/experimental/maps.py
Expand Up @@ -105,64 +105,9 @@ def __repr__(self):
AxisName = core.AxisName
ResourceAxisName = AxisName # Different name just for documentation purposes
Mesh = pxla.Mesh

class _Loop(NamedTuple):
name: ResourceAxisName
length: int

class ResourceEnv(NamedTuple):
physical_mesh: Mesh
loops: Tuple[_Loop, ...]

def with_mesh(self, mesh: Mesh):
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
if overlap:
raise ValueError(f"Cannot update the mesh of the current resource "
f"environment. The new mesh shadows already defined axes "
f"{show_axes(overlap)}")
return self._replace(physical_mesh=mesh)

def with_extra_loop(self, loop: _Loop):
if loop.name in self.resource_axes:
raise ValueError(f"Cannot extend the resource environment with loop named "
f"`{loop.name}`. An axis of this name is already defined!")
return self._replace(loops=self.loops + (loop,))

@property
def physical_resource_axes(self) -> Set[ResourceAxisName]:
return set(self.physical_mesh.axis_names)

@property
def loop_resource_axes(self) -> Set[ResourceAxisName]:
return set(loop.name for loop in self.loops)

@property
def resource_axes(self) -> Set[ResourceAxisName]:
return self.physical_resource_axes | self.loop_resource_axes

@property
def shape(self):
shape = self.physical_mesh.shape
shape.update(self.loops)
return shape

@property
def local_shape(self):
shape = self.physical_mesh.local_mesh.shape
shape.update(self.loops)
return shape

def __repr__(self):
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"

EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())

class _ThreadResourcesLocalState(threading.local):

def __init__(self):
self.env = EMPTY_ENV

thread_resources = _ThreadResourcesLocalState()
ResourceEnv = pxla.ResourceEnv
EMPTY_ENV = pxla.EMPTY_ENV
thread_resources = pxla.thread_resources


class SerialLoop:
Expand Down Expand Up @@ -232,7 +177,7 @@ def serial_loop(name: ResourceAxisName, length: int):
axis_resources={'i': 'l'})(x)
"""
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_extra_loop(_Loop(name, length))
thread_resources.env = old_env.with_extra_loop(pxla._Loop(name, length))
try:
yield
finally:
Expand Down Expand Up @@ -268,6 +213,7 @@ def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
out_axes=['left', 'right', ...],
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
"""
# TODO(yashkatariya): Deprecate this context manager.
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_mesh(Mesh(np.asarray(devices, dtype=object), axis_names))
try:
Expand Down Expand Up @@ -998,8 +944,6 @@ def _typecheck_xmap(
return out_avals
core.custom_typechecks[xmap_p] = _typecheck_xmap

def show_axes(axes):
return ", ".join(sorted([f"`{a}`" for a in axes]))

def _resource_typing_xmap(avals,
params,
Expand All @@ -1014,7 +958,7 @@ def _resource_typing_xmap(avals,
raise JAXTypeError(
f"Detected disallowed xmap axis name shadowing at "
f"{source_info_util.summarize(source_info)} "
f"(shadowed axes: {show_axes(overlap)})")
f"(shadowed axes: {pxla.show_axes(overlap)})")

if resource_env.physical_mesh != params['resource_env'].physical_mesh:
raise RuntimeError("Changing the physical mesh is not allowed inside xmap.")
Expand Down Expand Up @@ -1042,9 +986,9 @@ def _resource_typing_xmap(avals,
raise JAXTypeError(
f"One of xmapped function ({params['name']}) outputs is broadcast "
f"along axis `{baxis}` which is assigned to resources "
f"{show_axes(baxis_resources)}, but the output is already "
f"partitioned along {show_axes(overlap)}, because its "
f"named shape contains {show_axes(partitioning_axes)}")
f"{pxla.show_axes(baxis_resources)}, but the output is already "
f"partitioned along {pxla.show_axes(overlap)}, because its "
f"named shape contains {pxla.show_axes(partitioning_axes)}")
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap


Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/pjit.py
Expand Up @@ -209,7 +209,7 @@ def infer_params(*args, **kwargs):
f"was called with only {len(args)} positional arguments.")

# Putting this outside of wrapped would make resources lexically scoped
resource_env = maps.thread_resources.env
resource_env = pxla.thread_resources.env
mesh = resource_env.physical_mesh
if mesh.empty:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
Expand Down Expand Up @@ -551,7 +551,7 @@ def _check_unique_resources(axis_resources, arg_name):
if multiple_uses:
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {maps.show_axes(multiple_uses)}")
f"has duplicate entries for {pxla.show_axes(multiple_uses)}")

def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape,
flat_avals, flat_axis_resources):
Expand Down Expand Up @@ -897,7 +897,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax
f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} "
f"that uses one or more mesh axes already used by xmap to partition "
f"a named axis appearing in its named_shape (both use mesh axes "
f"{maps.show_axes(overlap)})")
f"{pxla.show_axes(overlap)})")

def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources):
jaxpr = params["jaxpr"]
Expand Down Expand Up @@ -925,7 +925,7 @@ def with_sharding_constraint(x, axis_resources):
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint axis_resources",
tree, parsed_axis_resources))
resource_env = maps.thread_resources.env
resource_env = pxla.thread_resources.env
mesh = resource_env.physical_mesh
_check_shapes_against_resources(
"with_sharding_constraint arguments",
Expand Down
81 changes: 81 additions & 0 deletions jax/interpreters/pxla.py
Expand Up @@ -28,6 +28,8 @@
# This encoding is assumed by various parts of the system, e.g. generating
# replica groups for collective operations.

from __future__ import annotations

from contextlib import contextmanager
from collections import defaultdict, OrderedDict
import dataclasses
Expand Down Expand Up @@ -1787,6 +1789,9 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
# ------------------- xmap -------------------

class Mesh:
devices: np.ndarray
axis_names: Tuple[MeshAxisName, ...]
_old_env: ResourceEnv

def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
assert devices.ndim == len(axis_names)
Expand Down Expand Up @@ -1814,6 +1819,16 @@ def __setattr__(self, name, value):
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
super().__setattr__(name, value)

def __enter__(self):
self._old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = self._old_env.with_mesh(
Mesh(self.devices, self.axis_names))
return thread_resources.env.physical_mesh

def __exit__(self, exc_type, exc_value, traceback):
thread_resources.env = self._old_env
return False

@property
def shape(self):
return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))
Expand Down Expand Up @@ -1885,6 +1900,72 @@ def global_to_local(self, axes: ArrayMapping, aval):
tile_aval_nd(self.shape, axes, aval))


ResourceAxisName = core.AxisName

class _Loop(NamedTuple):
name: ResourceAxisName
length: int


def show_axes(axes):
return ", ".join(sorted([f"`{a}`" for a in axes]))


class ResourceEnv(NamedTuple):
physical_mesh: Mesh
loops: Tuple[_Loop, ...]

def with_mesh(self, mesh: Mesh):
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
if overlap:
raise ValueError(f"Cannot update the mesh of the current resource "
f"environment. The new mesh shadows already defined axes "
f"{show_axes(overlap)}")
return self._replace(physical_mesh=mesh)

def with_extra_loop(self, loop: _Loop):
if loop.name in self.resource_axes:
raise ValueError(f"Cannot extend the resource environment with loop named "
f"`{loop.name}`. An axis of this name is already defined!")
return self._replace(loops=self.loops + (loop,))

@property
def physical_resource_axes(self) -> Set[ResourceAxisName]:
return set(self.physical_mesh.axis_names)

@property
def loop_resource_axes(self) -> Set[ResourceAxisName]:
return set(loop.name for loop in self.loops)

@property
def resource_axes(self) -> Set[ResourceAxisName]:
return self.physical_resource_axes | self.loop_resource_axes

@property
def shape(self):
shape = self.physical_mesh.shape
shape.update(self.loops)
return shape

@property
def local_shape(self):
shape = self.physical_mesh.local_mesh.shape
shape.update(self.loops)
return shape

def __repr__(self):
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"

EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())

class _ThreadResourcesLocalState(threading.local):

def __init__(self):
self.env = EMPTY_ENV

thread_resources = _ThreadResourcesLocalState()


def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
if aval is core.abstract_unit:
return aval
Expand Down
83 changes: 83 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -113,6 +113,25 @@ def f(x, y):
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False)

def testBasic1DWithMeshContextManager(self):
@partial(pjit,
in_axis_resources=(P('x'), P('x')),
out_axis_resources=None)
def f(x, y):
return x + y

shape = (8, 8)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
with jtu.create_global_mesh((2,), ('x')) as mesh:
actual = f(x, x + 1)
expected = x + (x + 1)
self.assertEqual(mesh, jtu.create_global_mesh((2,), ('x')))
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False)

@jtu.with_mesh([('x', 2), ('y', 2)])
def testBasic2D(self):
@partial(pjit,
Expand Down Expand Up @@ -141,6 +160,35 @@ def f(x, y):
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
check_dtypes=False)

def testBasic2DWithMeshContextManager(self):
@partial(pjit,
in_axis_resources=(P(None, 'x', 'y'), P('y')),
out_axis_resources=P('x'))
def f(x, y):
return x @ y

x_shape = (8, 6, 4)
y_shape = (4, 2)
x = jnp.arange(np.prod(x_shape)).reshape(x_shape)
y = jnp.arange(np.prod(y_shape)).reshape(y_shape)
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
with mesh:
actual = f(x, y)
expected = x @ y
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 4)

split0, split1 = np.split(expected, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), split1,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
check_dtypes=False)

@jtu.with_mesh([('x', 2), ('y', 2)])
def testTwoMeshAxisSharding(self):
@partial(pjit,
Expand Down Expand Up @@ -671,6 +719,41 @@ def f(x):
'in_axis_resources cannot be `pjit.FROM_GDA`.')):
f(input_data)

def test_pjit_gda_single_output_with_mesh_context_manager(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]

gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

with jax._src.config.parallel_functions_output_gda(True):
with global_mesh:
@partial(pjit, in_axis_resources=FROM_GDA, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
expected_matrix_mul = input_data @ input_data.T

out = f(gda_obj)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

out2 = f(out)
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)

with self.assertRaisesRegex(
ValueError, ('For a non-GDA input, the corresponding resource in '
'in_axis_resources cannot be `pjit.FROM_GDA`.')):
f(input_data)

@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_multi_input_multi_output(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Expand Down

0 comments on commit a83695a

Please sign in to comment.