Skip to content

Commit

Permalink
Initial GSDA-pjit integration.
Browse files Browse the repository at this point in the history
* GSDA input is handled properly
* pjit gsda mesh mismatch is checked
* gsda dtype and input dtype is checked
* If GSDA is an input, then other input types are not allowed for now.

PiperOrigin-RevId: 409299559
  • Loading branch information
yashk2810 authored and jax authors committed Nov 12, 2021
1 parent 41ecf71 commit 8381a4a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 30 deletions.
45 changes: 35 additions & 10 deletions jax/experimental/gsda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import dataclasses
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict

from . import maps
from .. import core
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from ..interpreters import pxla
from ..interpreters import pxla, xla
from .._src.util import prod, safe_zip
from .._src.api import device_put
from ..interpreters.sharded_jit import PartitionSpec
from .pjit import get_array_mapping, _prepare_axis_resources

Shape = Tuple[int, ...]
MeshAxes = Sequence[Union[str, Tuple[str], None]]
Expand All @@ -47,6 +48,9 @@ def __eq__(self, other):

def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
# Import here to avoid cyclic import error when importing gsda in pjit.py.
from .pjit import get_array_mapping, _prepare_axis_resources

if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
Expand Down Expand Up @@ -94,10 +98,9 @@ class Shard:

class GlobalShardedDeviceArray:

def __init__(self, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray]):
self._global_shape = global_shape
self._dtype = dtype
self._global_mesh = global_mesh
self._mesh_axes = mesh_axes
assert len(device_buffers) == len(self._global_mesh.local_devices)
Expand All @@ -109,6 +112,12 @@ def __init__(self, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape {device_buffers[0].shape}")

dtype = device_buffers[0].dtype
assert all(db.dtype == dtype for db in device_buffers), (
"Input arrays to GlobalShardedDeviceArray must have matching dtypes, "
f"got: {[db.dtype for db in device_buffers]}")
self.dtype = dtype

@property
def shape(self) -> Shape:
return self._global_shape
Expand All @@ -130,6 +139,8 @@ def _create_shards(
sh = Shard(device, index, replica_id, buf)
gs.append(sh)
if local_shard:
if sh.data is None:
raise ValueError("Local shard's data field should not be None.")
ls.append(sh)
return gs, ls

Expand All @@ -142,30 +153,30 @@ def global_shards(self) -> Sequence[Shard]:
return self._global_shards

@classmethod
def from_callback(cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, data_callback: Callable[[Index],
ArrayLike]):
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
dbs = [
device_put(data_callback(indices[device]), device)
for device in global_mesh.local_devices
]
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
return cls(global_shape, global_mesh, mesh_axes, dbs)

@classmethod
def from_batched_callback(cls, global_shape: Shape, dtype,
def from_batched_callback(cls, global_shape: Shape,
global_mesh: pxla.Mesh, mesh_axes: MeshAxes,
data_callback: Callable[[Sequence[Index]],
Sequence[ArrayLike]]):
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
local_indices = [indices[d] for d in global_mesh.local_devices]
local_arrays = data_callback(local_indices)
dbs = pxla.device_put(local_arrays, global_mesh.local_devices)
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
return cls(global_shape, global_mesh, mesh_axes, dbs)

@classmethod
def from_batched_callback_with_devices(
cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes,
data_callback: Callable[[Sequence[Tuple[Index, Tuple[Device, ...]]]],
Sequence[DeviceArray]]):
Expand All @@ -180,4 +191,18 @@ def from_batched_callback_with_devices(
(index.val, tuple(device)) for index, device in index_to_device.items()
]
dbs = data_callback(cb_inp)
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
return cls(global_shape, global_mesh, mesh_axes, dbs)


core.pytype_aval_mappings[GlobalShardedDeviceArray] = lambda x: core.ShapedArray(x.shape, x.dtype)
xla.pytype_aval_mappings[GlobalShardedDeviceArray] = lambda x: core.ShapedArray(x.shape, x.dtype)
xla.canonicalize_dtype_handlers[GlobalShardedDeviceArray] = pxla.identity

def _gsda_shard_arg(x, devices, indices):
pjit_mesh = maps.thread_resources.env.physical_mesh
if x._global_mesh != pjit_mesh:
raise ValueError("Pjit's mesh and GSDA's mesh should be equal. Got Pjit "
f"mesh: {pjit_mesh},\n GSDA mesh: {x._global_mesh}")
return [s.data for s in x.local_shards]

pxla.shard_arg_handlers[GlobalShardedDeviceArray] = _gsda_shard_arg
24 changes: 15 additions & 9 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import partial

from . import maps
from .gsda import GlobalShardedDeviceArray as GSDA
from .. import core
from .. import linear_util as lu
from .._src.api import _check_callable, _check_arg, Lowered
Expand Down Expand Up @@ -169,10 +170,10 @@ def pjit(fun: Callable,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axis_resources = tuple(in_axis_resources)

in_axis_resources, _, _ = \
_prepare_axis_resources(in_axis_resources, "in_axis_resources")
out_axis_resources, _, _ = \
_prepare_axis_resources(out_axis_resources, "out_axis_resources")
in_axis_resources, _, _ = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")

static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
Expand Down Expand Up @@ -210,11 +211,16 @@ def infer_params(*args, **kwargs):
donated_invars = (False,) * len(args_flat)

local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
jaxpr, in_axis_resources_flat, out_axis_resources_flat = \
_pjit_jaxpr(flat_fun, mesh, local_in_avals,
in_tree, hashable_pytree(in_axis_resources),
HashableFunction(out_tree, closure=()), hashable_pytree(out_axis_resources),
maps._positional_semantics)
# TODO(yashkatariya): Remove `is_gsda` check when special value for in_axis_resources
# is added for GSDA.
is_gsda = all(isinstance(a, GSDA) for a in args_flat)
jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, local_in_avals, in_tree,
hashable_pytree(in_axis_resources),
HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources),
(maps._PositionalSemantics.GLOBAL
if is_gsda else maps._positional_semantics))
params = dict(
jaxpr=jaxpr,
in_axis_resources=in_axis_resources_flat,
Expand Down
20 changes: 10 additions & 10 deletions tests/gsda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import numpy as np

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip

Expand Down Expand Up @@ -79,7 +78,7 @@ def test_gsda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
def cb(index):
return global_input_data[index]
gsda = GlobalShardedDeviceArray.from_callback(global_input_shape,
jnp.float32, global_mesh,
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
Expand Down Expand Up @@ -123,7 +122,7 @@ def test_gsda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
def cb(index):
return global_input_data[index]
gsda = GlobalShardedDeviceArray.from_callback(global_input_shape,
jnp.float32, global_mesh,
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
Expand Down Expand Up @@ -156,7 +155,7 @@ def test_gsda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
def cb(index):
return global_input_data[index]
gsda = GlobalShardedDeviceArray.from_callback(global_input_shape,
jnp.float32, global_mesh,
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
Expand Down Expand Up @@ -185,7 +184,7 @@ def test_gsda_subset_devices(self, mesh_axes, expected_index,
def cb(index):
return global_input_data[index]
gsda = GlobalShardedDeviceArray.from_callback(global_input_shape,
jnp.float32, global_mesh,
global_mesh,
mesh_axes, cb)
self.assertEqual(gsda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gsda.local_shards[0].data,
Expand Down Expand Up @@ -214,7 +213,7 @@ def cb(indices):
return [global_input_data[index] for index in indices]

gsda = GlobalShardedDeviceArray.from_batched_callback(
global_input_shape, jnp.float32, global_mesh, mesh_axes, cb)
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
expected_first_shard_value)
Expand All @@ -227,7 +226,7 @@ def test_gsda_batched_callback_with_devices(self):
global_input_shape = (8, 2)
mesh_axes = ['x']
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)

def cb(cb_inp):
self.assertLen(cb_inp, 4)
Expand All @@ -240,13 +239,14 @@ def cb(cb_inp):
return dbs

gsda = GlobalShardedDeviceArray.from_batched_callback_with_devices(
global_input_shape, jnp.float32, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]])
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]])
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
expected_second_shard_value)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
62 changes: 61 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from jax import lax
# TODO(skye): do we still wanna call this PartitionSpec?
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh
from jax.experimental.maps import xmap, mesh, Mesh
from jax.experimental import gsda
import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
from jax.interpreters import pxla
Expand Down Expand Up @@ -62,6 +63,15 @@ def check_1d_2d_mesh(f, set_mesh):
))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)


def create_global_mesh(mesh_shape, axis_names):
size = prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
return global_mesh


# TODO(skye): make the buffer donation utils part of JaxTestCase
class PJitTest(jtu.BufferDonationTestCase):

Expand Down Expand Up @@ -571,6 +581,56 @@ def f(x, y):
"called with:\n.*int32.*",
lambda: exe(x_i32, x_i32))


class GSDAPjitTest(jtu.JaxTestCase):

@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_single_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]

gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T

out = f(gsda_obj)
# TODO(yashkatariya): Enable the gsda_out flag and check for GSDA as the
# output.
self.assertIsInstance(out, pxla.ShardedDeviceArray)
self.assertLen(out.device_buffers, 8)
self.assertEqual(out.device_buffers[0].shape, (2, 4))

@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gsda_mesh_mismatch(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(index):
return global_input_data[index]

gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

with self.assertRaisesRegex(
ValueError,
"Pjit's mesh and GSDA's mesh should be equal."):
@partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
def f(x):
return x
f(gsda_obj)



def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")

Expand Down

0 comments on commit 8381a4a

Please sign in to comment.