Skip to content

Commit

Permalink
Add Array support to xmap. Just using the GDA path.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 454604138
  • Loading branch information
yashk2810 authored and jax authors committed Jun 13, 2022
1 parent fe0c921 commit 1089c79
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 19 deletions.
48 changes: 29 additions & 19 deletions jax/experimental/maps.py
Expand Up @@ -39,6 +39,7 @@
from jax._src import traceback_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax.experimental.array import Array
from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
Expand Down Expand Up @@ -543,10 +544,16 @@ def infer_params(*args):
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
closure=(out_axes_entries, out_axes_treedef))

in_positional_semantics = tuple(
_PositionalSemantics.GLOBAL
if isinstance(a, GlobalDeviceArray) else _positional_semantics.val
for a in args_flat)
if config.jax_array:
if any(not isinstance(a, Array) for a in args_flat):
raise ValueError('All arguments to pjit when `config.jax_array` is '
'enabled should be `Array`s.')
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
else:
in_positional_semantics = tuple(
_PositionalSemantics.GLOBAL
if isinstance(a, GlobalDeviceArray) else _positional_semantics.val
for a in args_flat)
out_positional_semantics = _positional_semantics.val

axis_resource_count = _get_axis_resource_count(
Expand Down Expand Up @@ -577,7 +584,7 @@ def infer_params(*args):
f"which asserts that it should be of rank {spec.expected_rank}, "
f"but the argument has rank {arg.ndim} (and shape {arg.shape})")

_check_gda_xmap_partitioning(frozen_axis_resources, resource_env,
_check_gda_or_array_xmap_partitioning(frozen_axis_resources, resource_env,
frozen_global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat)

Expand Down Expand Up @@ -1785,27 +1792,30 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
f"defined by this xmap call: {', '.join(undeclared_axes_str)}")


def _check_gda_xmap_partitioning(axis_resources, resource_env,
global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat):
def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat):
mesh_in_axes = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes,
in_positional_semantics).to_mesh_axes(in_axes_flat)
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
if isinstance(arg, GlobalDeviceArray):
if arg.mesh != resource_env.physical_mesh:
raise ValueError("xmap's mesh and GDA's mesh should be equal. Got Xmap "
f"mesh: {resource_env.physical_mesh},\n"
f"GDA mesh: {arg.mesh}")

gda_array_mapping = _get_array_mapping(arg.mesh_axes)
if gda_array_mapping != xmap_array_mapping:
if isinstance(arg, (GlobalDeviceArray, Array)):
arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array'
mesh = arg.mesh if arr_flavor == 'GDA' else arg.sharding.mesh
if mesh != resource_env.physical_mesh:
raise ValueError(f"xmap's mesh and {arr_flavor}'s mesh should be equal. "
f"Got xmap mesh: {resource_env.physical_mesh},\n"
f"{arr_flavor} mesh: {mesh}")

array_mapping = _get_array_mapping(
arg.mesh_axes if arr_flavor == 'GDA' else arg.sharding.spec)
if array_mapping != xmap_array_mapping:
raise ValueError(
"Got an input GDA to xmap with different partitioning than "
f"Got an input {arr_flavor} to xmap with different partitioning than "
"specified in xmap. The partitioning must match. "
f"Got GDA spec: {pxla.array_mapping_to_axis_resources(gda_array_mapping)} and "
f"Got {arr_flavor} spec: {pxla.array_mapping_to_axis_resources(array_mapping)} and "
f"xmap spec: {pxla.array_mapping_to_axis_resources(xmap_array_mapping)} "
f"for GDA: {arg}")
f"for {arr_flavor}: {arg}")


# TODO: We should relax this at least for "constructor primitives"
Expand Down
109 changes: 109 additions & 0 deletions tests/xmap_test.py
Expand Up @@ -35,6 +35,8 @@
from jax.core import NamedShape
from jax.experimental import maps
from jax.experimental import global_device_array
from jax.experimental import array
from jax.experimental.sharding import MeshPspecSharding
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.experimental.pjit import PartitionSpec as P
from jax.experimental.maps import xmap, serial_loop, SerialLoop
Expand Down Expand Up @@ -71,6 +73,17 @@ def tearDownModule():
xla_bridge.get_backend.cache_clear()


def create_array(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(
prod(global_shape), dtype=np.float32).reshape(global_shape)

sharding = MeshPspecSharding(global_mesh, mesh_axes)

return array.make_array_from_callback(
global_shape, sharding, lambda idx: global_data[idx]), global_data


# -------------------- Itertools helpers --------------------

def partitions(s, k):
Expand Down Expand Up @@ -1010,6 +1023,102 @@ def cb(index):
f(gda_obj)


class XMapArrayTest(XMapTestCase):

def test_basic(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_array, input_data = create_array(global_input_shape, global_mesh,
mesh_axes)

with jax._src.config.jax_array(True):
with global_mesh:
f = maps.xmap(
lambda x: x,
in_axes=({0: "a", 1: "b"}),
out_axes=({0: "a", 1: "b"}),
axis_resources={"a": "x", "b": "y"})

out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
self.assertDictEqual(out.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

def test_xmap_array_mixed_inputs(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x')
input_array, input_data = create_array(global_input_shape, global_mesh,
mesh_axes)

with jax._src.config.jax_array(True):
with global_mesh:
f = maps.xmap(
lambda x, y: (x @ x.T, y @ y.T),
in_axes=({0: "a"}, ["c", ...]),
out_axes=({0: "a"}, ["c", ...]),
axis_resources={"a": "x", "c": "x"})
with self.assertRaisesRegex(
ValueError, ('All arguments to pjit when `config.jax_array` is '
'enabled should be `Array`s.')):
f(input_array, input_data)

def test_xmap_array_double_input(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
a1, input_data = create_array(global_input_shape, global_mesh, P('x'))
a2, _ = create_array(global_input_shape, global_mesh, P('y'))

with jax._src.config.jax_array(True):
with global_mesh:
f = maps.xmap(
lambda x, y: (x @ x.T, y @ y.T),
in_axes=({0: "a"}, ["c", ...]),
out_axes=({0: "a"}, ["c", ...]),
axis_resources={"a": "x", "c": "y"})

expected_matrix_mul = np.diagonal(input_data @ input_data.T)
out1, out2 = f(a1, a2)

self.assertIsInstance(out1, array.Array)
self.assertEqual(out1.shape, (8,))
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
self.assertDictEqual(out1.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])

self.assertIsInstance(out2, array.Array)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
self.assertDictEqual(out2.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])

def test_xmap_array_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_array, _ = create_array(global_input_shape, global_mesh,
mesh_axes)

with jax._src.config.jax_array(True):
with global_mesh:
f = maps.xmap(
lambda x: x @ x.T,
in_axes=({0: "a"}),
out_axes=({0: "a"}),
axis_resources={"a": "x"})
with self.assertRaisesRegex(
ValueError,
('Got an input Array to xmap with different partitioning than '
'specified in xmap. The partitioning must match.')):
f(input_array)


class NewPrimitiveTest(XMapTestCase):

def testGatherPositional(self):
Expand Down

0 comments on commit 1089c79

Please sign in to comment.