Skip to content

Commit

Permalink
Some miscellaneous changes to make tests pass when jax.Array is enabl…
Browse files Browse the repository at this point in the history
…ed by default.

1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
  • Loading branch information
yashk2810 authored and jax authors committed Sep 15, 2022
1 parent 311f85e commit 28741b8
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 79 deletions.
3 changes: 2 additions & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -1331,7 +1331,8 @@ def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None
# (so it works in staged-out code as well as 'eager' code). Related to
# equi-sharding.
if (config.jax_array and hasattr(x, 'sharding') and
not dispatch.is_single_device_sharding(x.sharding)):
not dispatch.is_single_device_sharding(x.sharding) and
not isinstance(x.sharding, sharding.PmapSharding)):
return array.make_array_from_callback(
fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type]
return val
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/test_util.py
Expand Up @@ -899,11 +899,11 @@ class BufferDonationTestCase(JaxTestCase):
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)

def _assertDeleted(self, x, deleted):
if hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
elif hasattr(x, "_arrays"):
if hasattr(x, "_arrays"):
for buffer in x._arrays:
self.assertEqual(buffer.is_deleted(), deleted)
elif hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
Expand Down
19 changes: 19 additions & 0 deletions jax/experimental/array.py
Expand Up @@ -371,6 +371,23 @@ def devices(self) -> List[Device]:
self._check_if_deleted()
return list(self.sharding.device_set)

# TODO(https://github.com/google/jax/issues/12380): Remove this when DA is
# deleted.
@property
def device_buffer(self) -> DeviceArray:
self._check_if_deleted()
if len(self._arrays) == 1:
return self._arrays[0]
raise ValueError('Length of buffers is greater than 1. Please use '
'`.device_buffers` instead.')

# TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is
# deleted.
@property
def device_buffers(self) -> Sequence[DeviceArray]:
self._check_if_deleted()
return self._arrays

@pxla.maybe_cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
Expand Down Expand Up @@ -433,6 +450,7 @@ def _value(self) -> np.ndarray:
if self._npy_value is None:
if self.is_fully_replicated():
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
self._npy_value.flags.writeable = False
return cast(np.ndarray, self._npy_value)

if not self.is_fully_addressable():
Expand All @@ -454,6 +472,7 @@ def _value(self) -> np.ndarray:
if not replica_id_exists or s.replica_id == 0:
npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr]
self._npy_value = npy_value # type: ignore
self._npy_value.flags.writeable = False
# https://docs.python.org/3/library/typing.html#typing.cast
return cast(np.ndarray, self._npy_value)

Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/maps.py
Expand Up @@ -1817,6 +1817,8 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
if isinstance(arg, (GlobalDeviceArray, Array)):
arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array'
if arr_flavor == 'Array' and not isinstance(arg.sharding, MeshPspecSharding):
continue
mesh = arg.mesh if arr_flavor == 'GDA' else arg.sharding.mesh
if mesh != resource_env.physical_mesh:
raise ValueError(f"xmap's mesh and {arr_flavor}'s mesh should be equal. "
Expand Down
94 changes: 50 additions & 44 deletions jax/experimental/pjit.py
Expand Up @@ -261,31 +261,10 @@ def pjit(fun: Callable,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axis_resources = tuple(in_axis_resources)

in_any_auto: bool
if not config.jax_array:
in_axis_resources, _, _, in_any_auto = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")
else:
if not _is_unspecified(in_axis_resources):
# `pjit.AUTO` is allowed partially in `in_axis_resources` i.e. you can
# put sharding instances and `pjit.AUTO` together.
if not all(isinstance(s, Sharding) or _is_auto(s)
for s in tree_flatten(in_axis_resources)[0]):
raise ValueError('When `config.jax_array` flag is enabled, '
'in_axis_resources should contain instances of '
'`Sharding` or `pjit.AUTO`.')

# `out_axis_resources` should be instances of `Sharding` if it's not
# unspecified. For `AUTO` sharding, it can only be used with
# MeshPspecSharding.
if not _is_unspecified(out_axis_resources):
if not all(isinstance(s, Sharding) or _is_auto(s)
for s in tree_flatten(out_axis_resources)[0]):
raise ValueError('When `config.jax_array` flag is enabled, '
'out_axis_resources should contain instances of '
'`Sharding` or `pjit.AUTO`.')
in_axis_resources, _, _, in_any_auto = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")

static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
Expand Down Expand Up @@ -323,7 +302,10 @@ def infer_params(*args, _global_avals=False, **kwargs):
donated_invars = (False,) * len(args_flat)

if config.jax_array:
in_shardings, out_shardings = in_axis_resources, out_axis_resources
in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
else:
in_shardings = tree_map(
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
Expand Down Expand Up @@ -424,6 +406,19 @@ def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
return pxla._create_mesh_pspec_sharding(mesh, x.user_spec, x)


def _create_sharding_for_array(mesh, x):
if isinstance(x, XLACompatibleSharding) or _is_auto(x) or _is_unspecified(x):
return x
if mesh.empty:
raise RuntimeError("pjit requires a non-empty mesh! Is a mesh defined at "
"the call site? Alternatively, provide a "
"XLACompatibleSharding to pjit and then the "
"mesh context manager is not required.")
# A nice user error is raised in _prepare_axis_resources.
assert isinstance(x, ParsedPartitionSpec)
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)


def flatten_axis_resources(what, tree, shardings, tupled_args):
try:
return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args))
Expand Down Expand Up @@ -585,9 +580,6 @@ def pjit_check_aval_sharding(
for aval, s in zip(flat_avals, shardings):
if _is_unspecified_or_from_gda_or_auto(s):
continue
if not isinstance(s, XLACompatibleSharding):
raise ValueError(f'One of {what_aval} got sharding {s} which is not a '
'subclass of XLACompatibleSharding.')
global_str = "" if s.is_fully_addressable else " global"
shape = aval.shape
try:
Expand Down Expand Up @@ -747,14 +739,25 @@ def _prepare_axis_resources(axis_resources,
# be 1 entry for that since _UNSPECIFIED is a private API.
_check_all_or_none_unspecified(entries, arg_name)
any_auto = pxla._check_if_any_auto(entries)
entries = [
(entry if _is_unspecified_or_from_gda_or_auto(entry)
else ParsedPartitionSpec.from_user_input(
new_entries = []
for entry in entries:
if _is_unspecified_or_from_gda_or_auto(entry):
if config.jax_array and _is_from_gda(entry):
raise ValueError('`FROM_GDA` cannot be set when config.jax_array is '
'enabled. Leave in_axis_resources empty or populate '
'it with shardings.')
new_entries.append(entry)
elif isinstance(entry, Sharding):
if not isinstance(entry, XLACompatibleSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not a '
'subclass of XLACompatibleSharding.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
for entry in entries
]
_check_unique_resources(entries, arg_name)
return tree_unflatten(treedef, entries), entries, treedef, any_auto

_check_unique_resources(new_entries, arg_name)
return tree_unflatten(treedef, new_entries), new_entries, treedef, any_auto


def _check_resources_mismatch(in_axis_resources_flat, is_gda):
Expand All @@ -765,7 +768,9 @@ def _check_resources_mismatch(in_axis_resources_flat, is_gda):
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if _is_unspecified_or_from_gda_or_auto(arg_axis_resources): continue
if (_is_unspecified_or_from_gda_or_auto(arg_axis_resources) or
isinstance(arg_axis_resources, XLACompatibleSharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = Counter(it.chain.from_iterable(constrained_dims))
if not resource_counts: continue
Expand Down Expand Up @@ -806,7 +811,8 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
if _is_unspecified(arg_s):
resolved_in_shardings.append(OpShardingSharding.get_replicated(da))
else:
resolved_in_shardings.append(to_op_sharding_sharding(arg_s, arg.ndim))
resolved_in_shardings.append(to_op_sharding_sharding(
cast(XLACompatibleSharding, arg_s), arg.ndim))
else:
if not _is_unspecified(arg_s):
if committed and not pxla.are_op_shardings_equal(
Expand Down Expand Up @@ -1290,16 +1296,16 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r

def with_sharding_constraint(x, axis_resources):
x_flat, tree = tree_flatten(x)
if not config.jax_array:
axis_resources, _, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True)
axis_resources, _, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True)
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint sharding", tree, axis_resources))
resource_env = pxla.thread_resources.env
mesh = resource_env.physical_mesh

if config.jax_array:
sharding_flat = axis_resources_flat
sharding_flat = [_create_sharding_for_array(mesh, a)
for a in axis_resources_flat]
unconstrained_dims = [
get_unconstrained_dims(s) if isinstance(s, MeshPspecSharding) else {}
for s in sharding_flat
Expand Down Expand Up @@ -1410,11 +1416,11 @@ def get_array_mapping(
if axes is not None for axis in axes)


def to_op_sharding_sharding(s, ndim):
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> OpShardingSharding:
if isinstance(s, OpShardingSharding):
return s
op_sharding_sharding = OpShardingSharding(
s._device_assignment, s._to_xla_op_sharding(ndim)) # type: ignore
s._device_assignment, s._to_xla_op_sharding(ndim))
op_sharding_sharding._original_sharding = s
return op_sharding_sharding

Expand Down
1 change: 1 addition & 0 deletions tests/BUILD
Expand Up @@ -416,6 +416,7 @@ jax_test(
jax_test(
name = "lax_test",
srcs = ["lax_test.py"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 40,
"gpu": 40,
Expand Down
6 changes: 4 additions & 2 deletions tests/lax_test.py
Expand Up @@ -2938,9 +2938,11 @@ def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out):
x_buf = x.device_buffer
y_buf = y.device_buffer
if np.dtype(dtype_in) == np.dtype(dtype_out):
self.assertIs(x_buf, y_buf)
self.assertEqual(x_buf.unsafe_buffer_pointer(),
y_buf.unsafe_buffer_pointer())
else:
self.assertIsNot(x_buf, y_buf)
self.assertNotEqual(x_buf.unsafe_buffer_pointer(),
y_buf.unsafe_buffer_pointer())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fn={}_indexdtype={}"
Expand Down
45 changes: 20 additions & 25 deletions tests/pjit_test.py
Expand Up @@ -1816,27 +1816,6 @@ def test_in_axis_resources_same_as_array_sharding(self):
in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))(input_array)
self.assertIsInstance(out, array.Array)

def test_in_axis_resources_error(self):
mesh = jtu.create_global_mesh((2,), ('x'))
with jax_array(True):
with self.assertRaisesRegex(
ValueError,
('When `config.jax_array` flag is enabled, '
'in_axis_resources should contain instances of `Sharding` '
'or `pjit.AUTO`.')):
pjit(lambda x: x,
in_axis_resources=(MeshPspecSharding(mesh, P('x')),
pjit_lib._UNSPECIFIED))

def test_out_axis_resources_error(self):
with jax_array(True):
with self.assertRaisesRegex(
ValueError,
('When `config.jax_array` flag is enabled, '
'out_axis_resources should contain instances of `Sharding` '
'or `pjit.AUTO`.')):
pjit(lambda x: x, out_axis_resources=P('x'))

def test_no_input_output(self):
with jax_array(True):
def f():
Expand Down Expand Up @@ -2118,16 +2097,32 @@ def test_not_xlacompatible_sharding_error(self):

with self.assertRaisesRegex(
ValueError,
'One of pjit arguments got sharding.*which is not a subclass of '
'XLACompatibleSharding.'):
'One of in_axis_resources leaf specifications got sharding.*which is '
'not a subclass of XLACompatibleSharding.'):
pjit(lambda x: x, in_axis_resources=ts)(arr)

with self.assertRaisesRegex(
ValueError,
'One of pjit outputs got sharding.*which is not a subclass of '
'XLACompatibleSharding.'):
'One of out_axis_resources leaf specifications got sharding.*which is '
'not a subclass of XLACompatibleSharding.'):
pjit(lambda x: x, out_axis_resources=ts)(arr)

@jax_array(True)
def test_array_enabled_non_empty_mesh_with_pspec(self):
arr = jnp.array([1, 2, 3])
with self.assertRaisesRegex(
RuntimeError,
"pjit requires a non-empty mesh!.*Alternatively, provide a "
"XLACompatibleSharding to pjit and then the mesh context manager is "
"not required."):
pjit(lambda x: x, in_axis_resources=P('x'))(arr)

with self.assertRaisesRegex(
TypeError,
"in_axis_resources leaf specifications are expected to be PartitionSpec "
"instances or None, but got x"):
pjit(lambda x: x, in_axis_resources='x')


class TempSharding(Sharding):

Expand Down
5 changes: 1 addition & 4 deletions tests/random_test.py
Expand Up @@ -1451,10 +1451,7 @@ def test_random_split_doesnt_device_put_during_tracing(self):
key = self.seed_prng(1).block_until_ready()
with jtu.count_device_put() as count:
jax.jit(random.split)(key)
if config.jax_array:
self.assertEqual(count[0], 0)
else:
self.assertEqual(count[0], 1) # 1 for the argument device_put
self.assertLessEqual(count[0], 1) # 1 for the argument device_put

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
Expand Down

0 comments on commit 28741b8

Please sign in to comment.