Skip to content

Commit

Permalink
Allow None to be passed to in_shardings and out_shardings. The defaul…
Browse files Browse the repository at this point in the history
…t is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.

The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
  • Loading branch information
yashk2810 authored and jax authors committed Jun 15, 2023
1 parent 904b46a commit 6007698
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 36 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,25 @@ Remember to align the itemized text with the first line of an item within a list
-->

## jax 0.4.13

* Changes
* `jax.jit` now allows `None` to be passed to `in_shardings` and
`out_shardings`. The semantics are as follows:
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* `jax.experimental.pjit.pjit` also allows `None` to be passed to
`in_shardings` and `out_shardings`. The semantics are as follows:
* If the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.

* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named `cudnn89` instead of `cudnn88`.
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/api.py
Expand Up @@ -187,6 +187,12 @@ def jit(
- :py:class:`XLACompatibleSharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
The size of every dimension has to be a multiple of the total number of
resources assigned to it. This is similar to pjit's in_shardings.
Expand Down
57 changes: 33 additions & 24 deletions jax/_src/pjit.py
Expand Up @@ -266,27 +266,29 @@ def cache_miss(*args, **kwargs):

def _resolve_axis_resources_and_shardings_arg(
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
if not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources):
if (in_shardings is not None and in_axis_resources is not None and
not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)):
raise ValueError(
'Setting both in_shardings and in_axis_resources is not '
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
if not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources):
if (out_shardings is not None and out_axis_resources is not None and
not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)):
raise ValueError(
'Setting both out_shardings and out_axis_resources is not '
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
if (not is_unspecified(in_axis_resources) or
not is_unspecified(out_axis_resources)):
if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or
(out_axis_resources is not None and not is_unspecified(out_axis_resources))):
warnings.warn(
'in_axis_resources and out_axis_resources are deprecated. Please use '
'in_shardings and out_shardings as their replacement.',
DeprecationWarning)

if not is_unspecified(in_axis_resources):
if in_axis_resources is not None and not is_unspecified(in_axis_resources):
final_in_shardings = in_axis_resources
else:
final_in_shardings = in_shardings

if not is_unspecified(out_axis_resources):
if out_axis_resources is not None and not is_unspecified(out_axis_resources):
final_out_shardings = out_axis_resources
else:
final_out_shardings = out_shardings
Expand All @@ -311,10 +313,10 @@ def pre_infer_params(fun, in_shardings, out_shardings,
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if not is_unspecified(in_shardings):
if in_shardings is not None and not is_unspecified(in_shardings):
raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.')
if not is_unspecified(out_shardings):
if out_shardings is not None and not is_unspecified(out_shardings):
raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.')

Expand Down Expand Up @@ -413,7 +415,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
donate_argnums, device, backend, keep_unused, inline,
resource_env, abstracted_axes) = pjit_info_args

if kwargs and not is_unspecified(user_in_shardings):
if (kwargs and user_in_shardings is not None and
not is_unspecified(user_in_shardings)):
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")

Expand Down Expand Up @@ -467,14 +470,17 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x, 'in_shardings',
jit_name),
user_in_shardings)
user_in_shardings, is_leaf=lambda x: x is None)
out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x, 'out_shardings',
jit_name),
user_out_shardings)
user_out_shardings, is_leaf=lambda x: x is None)

del user_in_shardings, user_out_shardings

assert in_shardings is not None or all(i is not None for i in in_shardings)
assert out_shardings is not None or all(o is not None for o in out_shardings)

if config.jax_dynamic_shapes:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)
Expand Down Expand Up @@ -661,11 +667,18 @@ def pjit(
- :py:class:`XLACompatibleSharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None` is a special case whose semantics are:
- if the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- If the mesh context manager is provided, None will imply that the
value will be replicated on all devices of the mesh.
- For backwards compatibility, in_shardings still supports ingesting
:py:class:`PartitionSpec` and :py:obj:`None`. These 2 options can
*only* be used with the mesh context manager.
- :py:obj:`None`, in which case the value will be replicated on all devices
:py:class:`PartitionSpec`. This option can *only* be used with the
mesh context manager.
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
axis or a tuple of mesh axes, and specifies the set of resources assigned
Expand Down Expand Up @@ -774,14 +787,9 @@ def hashable_pytree(pytree):
closure=(treedef, vals))


@lru_cache(maxsize=4096)
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
if is_unspecified_or_auto(x):
return x
return pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)


def _create_sharding_for_array(mesh, x, name, api_name):
if x is None and (mesh is None or mesh.empty):
return UNSPECIFIED
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
return x
if mesh is None:
Expand All @@ -804,8 +812,9 @@ def _create_sharding_for_array(mesh, x, name, api_name):
f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and'
' then the mesh context manager is not required.')
# A nice user error is raised in prepare_axis_resources.
assert isinstance(x, ParsedPartitionSpec), x
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
assert x is None or isinstance(x, ParsedPartitionSpec), x
return (pxla.create_mesh_pspec_sharding(mesh, x)
if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x))


def _create_sharding_with_device_backend(device, backend):
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/sharding_impls.py
Expand Up @@ -217,7 +217,8 @@ def _preprocess(self):
# representation of Parsed Pspec
if self._parsed_pspec is None:
self._parsed_pspec, _, _ = prepare_axis_resources(
self.spec, "NamedSharding spec", allow_unconstrained_dims=True)
PartitionSpec() if self.spec is None else self.spec,
"NamedSharding spec", allow_unconstrained_dims=True)

_check_mesh_resource_axis(self.mesh, self._parsed_pspec)

Expand Down Expand Up @@ -956,7 +957,7 @@ def prepare_axis_resources(axis_resources,

new_entries = []
for entry in entries:
if is_unspecified_or_auto(entry):
if is_unspecified_or_auto(entry) or entry is None:
new_entries.append(entry)
elif isinstance(entry, sharding.Sharding):
if isinstance(entry, PmapSharding):
Expand Down
3 changes: 1 addition & 2 deletions tests/array_test.py
Expand Up @@ -958,8 +958,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
self.assertEqual(mps.is_fully_replicated, ops_ifr)

ps = _op_sharding_to_pos_sharding(mps_op_sharding,
mps._device_assignment)
ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment)
self.assertEqual(ps.is_fully_replicated,
op_shardings.is_op_sharding_replicated(
ps._to_xla_hlo_sharding(len(shape))))
Expand Down
52 changes: 44 additions & 8 deletions tests/pjit_test.py
Expand Up @@ -2573,7 +2573,8 @@ def test_pjit_kwargs_axis_resources_error(self):
with self.assertRaisesRegex(
ValueError,
"pjit does not support kwargs when in_shardings is specified."):
pjit(lambda x: x, in_shardings=None)(x=jnp.arange(8.))
pjit(lambda x: x,
in_shardings=SingleDeviceSharding(jax.devices()[0]))(x=jnp.arange(8.))

def test_pjit_keep_unused_true(self):
@partial(pjit, keep_unused=True)
Expand Down Expand Up @@ -2693,17 +2694,18 @@ def test_autodiff_with_device_arg(self):
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)

def test_pjit_device_backend_axis_resources_error(self):
s = SingleDeviceSharding(jax.devices()[0])
with self.assertRaisesRegex(
ValueError,
'If backend or device is specified on jit, then '
'in_shardings should not be specified.'):
pjit(lambda x: x, in_shardings=None, backend='cpu')
pjit(lambda x: x, in_shardings=s, backend='cpu')

with self.assertRaisesRegex(
ValueError,
'If backend or device is specified on jit, then '
'out_shardings should not be specified.'):
pjit(lambda x: x, out_shardings=None, device=jax.devices()[0])
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])

def test_pjit_device_backend_both_error(self):
with self.assertRaisesRegex(
Expand Down Expand Up @@ -3468,6 +3470,43 @@ def test_shape_dtype_struct_as_const_error(self):
r"Argument.*is not a valid JAX type"):
jax.jit(lambda x: (x, const))(jnp.arange(8))

def test_jit_out_shardings_none(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
out = jax.jit(lambda x: x * 2, out_shardings=None)(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, s)

def test_jit_in_shardings_none(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)

out = jax.jit(lambda x: x * 2, in_shardings=None)(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, s)

out2 = jax.jit(lambda x: x * 2, in_shardings=None)(np_inp)
self.assertArraysEqual(out2, np_inp * 2)
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))

def test_jit_both_shardings_none(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)

out = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, s)

out2 = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(np_inp)
self.assertArraysEqual(out2, np_inp * 2)
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))


class TempSharding(Sharding):

Expand Down Expand Up @@ -3683,11 +3722,8 @@ def testCatchesInnerXMapErrors(self):
f(x, x)

def testEmptyMesh(self):
error = (
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
r' `None` to in_shardings.*')
with self.assertRaisesRegex(RuntimeError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0]))

def test_pspec_to_wsc_without_mesh(self):
error = (
Expand Down

0 comments on commit 6007698

Please sign in to comment.