Skip to content

Commit

Permalink
jax.mask and jax.shapecheck are being deprecated
Browse files Browse the repository at this point in the history
Issue: #11557
PiperOrigin-RevId: 462315754
  • Loading branch information
gnecula authored and jax authors committed Jul 21, 2022
1 parent ba7ded4 commit 07fcf79
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 896 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -28,6 +28,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
following a similar deprecation in {func}`scipy.linalg.solve`.
* Deprecations:
* {func}`jax.mask` {func}`jax.shapecheck` are being deprecated.
See {jax-issue}`#11557`.

## jaxlib 0.3.15 (Unreleased)

Expand Down
4 changes: 4 additions & 0 deletions jax/_src/api.py
Expand Up @@ -2249,6 +2249,8 @@ def lower(*args, **kwargs) -> stages.Lowered:


def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
warn("`jax.mask` is deprecated and will be removed soon. ",
DeprecationWarning)
_check_callable(fun)
unique_ids = masking.UniqueIds()

Expand Down Expand Up @@ -2292,6 +2294,8 @@ def padded_spec(shape_spec):

@curry
def shapecheck(in_shapes, out_shape, fun: Callable):
warn("`jax.shapecheck` is deprecated and will be removed soon. ",
DeprecationWarning)
_check_callable(fun)
in_shapes, in_tree = tree_flatten(in_shapes)
in_shapes = map(masking.parse_spec, in_shapes)
Expand Down
2 changes: 2 additions & 0 deletions jax/interpreters/masking.py
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Masking is **DEPRECATED** and is being removed."""

from contextlib import contextmanager
from collections import Counter, namedtuple
from functools import partial, reduce
Expand Down
5 changes: 0 additions & 5 deletions tests/BUILD
Expand Up @@ -464,11 +464,6 @@ jax_test(
},
)

jax_test(
name = "masking_test",
srcs = ["masking_test.py"],
)

jax_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
Expand Down
61 changes: 0 additions & 61 deletions tests/host_callback_test.py
Expand Up @@ -1859,67 +1859,6 @@ def g(x):
what: ct_b
1.""", testing_stream.output)

def test_tap_mask(self):

@partial(jax.mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
three_x = hcb.id_print((x, 2 * x), result=3 * x, what="x",
output_stream=testing_stream)
return jnp.sum(three_x)

x = np.arange(5.)

self.assertAllClose(9., padded_sum([x], dict(n=3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()

# With VMAP
xv = np.arange(10.).reshape((2, 5)) # logical_shape = 5
self.assertAllClose(
np.array([9., 78.]),
# batch_size = 2, n=3 and 4 for the two elements
jax.vmap(padded_sum)([xv],
dict(n=np.array([3., 4.]))))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), ('batch', {'batch_dims': (0, 0, 0, 0)})] what: x
( ( [[0. 1. 2. 3. 4.]
[5. 6. 7. 8. 9.]]
[[ 0. 2. 4. 6. 8.]
[10. 12. 14. 16. 18.]] )
( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output)
testing_stream.reset()

# With JVP
self.assertAllClose((9., 0.9),
jax.jvp(lambda arg: padded_sum([arg], dict(n=3)),
(x,), (x * 0.1,)))
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
testing_stream.output)
else:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()

# Now with JIT
self.assertAllClose(9., jax.jit(padded_sum)([x], dict(n=3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)

def test_tap_callback_delay(self):
hcb.callback_extra = lambda dev: time.sleep(1)

Expand Down

0 comments on commit 07fcf79

Please sign in to comment.