Skip to content

Commit

Permalink
Remove axis_resources from with_sharding_constraint since it has been…
Browse files Browse the repository at this point in the history
… 3 months since the deprecation as per the API deprecation policy.

PiperOrigin-RevId: 535687618
  • Loading branch information
yashk2810 authored and jax authors committed May 26, 2023
1 parent 25a9a97 commit fe3fed3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 52 deletions.
20 changes: 11 additions & 9 deletions CHANGELOG.md
Expand Up @@ -11,16 +11,18 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
- `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
- `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
- `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
* `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
as input and remove the optional `in_shardings` argument to `pjit`.
- `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
- `jax.interpreters.xla.Device`: use `jax.Device`.
- `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead,
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
* `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
use `shardings` instead.


## jaxlib 0.4.11
Expand Down
33 changes: 3 additions & 30 deletions jax/_src/pjit.py
Expand Up @@ -1786,32 +1786,7 @@ def _pjit_pp_rule(eqn, context, settings):

# -------------------- with_sharding_constraint --------------------

def _resolve_wsc_args(axis_resources, shardings):
if not is_unspecified(axis_resources) and not is_unspecified(shardings):
raise ValueError(
'Setting both axis_resources and shardings is not '
'allowed. axis_resources is deprecated. Please use shardings.')
if is_unspecified(axis_resources) and is_unspecified(shardings):
raise ValueError(
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
'Please specify the shardings argument with a concrete sharding. Note '
'that axis_resources is deprecated, so use the shardings argument.')

if not is_unspecified(axis_resources):
warnings.warn(
'axis_resources is deprecated. Please use shardings argument instead.',
DeprecationWarning)
final_shardings = axis_resources
else:
final_shardings = shardings
return final_shardings


# TODO(yashkatariya): Remove the axis_resources argument and make the signature
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
def with_sharding_constraint(x, shardings=UNSPECIFIED,
axis_resources=UNSPECIFIED):
def with_sharding_constraint(x, shardings):
"""Mechanism to constrain the sharding of an Array inside a jitted computation
This is a strict constraint for the GSPMD partitioner and not a hint. For examples
Expand All @@ -1821,17 +1796,15 @@ def with_sharding_constraint(x, shardings=UNSPECIFIED,
x: PyTree of jax.Arrays which will have their shardings constrainted
shardings: PyTree of sharding specifications. Valid values are the same as for
the ``in_shardings`` argument of :func:`jax.experimental.pjit`.
axis_resources: (deprecated) use shardings instead.
Returns:
x_with_shardings: PyTree of jax.Arrays with specified sharding constraints.
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
"""
final_shardings = _resolve_wsc_args(axis_resources, shardings)
x_flat, tree = tree_flatten(x)
user_shardings, _, _ = prepare_axis_resources(
final_shardings, "shardings", allow_unconstrained_dims=True)
del final_shardings
shardings, "shardings", allow_unconstrained_dims=True)
del shardings

user_shardings_flat = tuple(
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
Expand Down
13 changes: 0 additions & 13 deletions tests/pjit_test.py
Expand Up @@ -2953,19 +2953,6 @@ def test_set_both_axis_resources_and_shardings(self):
"Setting both out_shardings and out_axis_resources is not allowed"):
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))

def test_set_none_wsc_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Not specifying shardings to `with_sharding_constraint` is not allowed."):
pjit(jax.lax.with_sharding_constraint(jnp.arange(8)))

def test_set_both_wsc_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Setting both axis_resources and shardings is not allowed"):
pjit(jax.lax.with_sharding_constraint(
jnp.arange(8), axis_resources=P('x'), shardings=P('x')))

def test_with_sharding_constraint_spmd_axis_name(self):
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
shape = (8, 4, 2, 2)
Expand Down

0 comments on commit fe3fed3

Please sign in to comment.