From 4fc3518e5f9ae004697bf57f96cd7b1c56bf68d3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Aug 2022 16:51:26 -0700 Subject: [PATCH] Make checkify tests pass with Array and add methods on Array that are present on DA. PiperOrigin-RevId: 468058909 --- jax/_src/checkify.py | 23 ++++++++----- jax/_src/numpy/lax_numpy.py | 3 ++ jax/experimental/array.py | 69 +++++++++++++++++++++++++++++++++++++ tests/api_test.py | 40 +++++++++++++-------- tests/checkify_test.py | 35 +++++++++++++++++-- tests/pmap_test.py | 12 ------- 6 files changed, 144 insertions(+), 38 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e011f35edcde..9c9d5e16dafe 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -30,11 +30,11 @@ from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe -from jax.interpreters import pxla from jax.experimental.sharding import OpShardingSharding from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node from jax._src import source_info_util, traceback_util from jax._src.lax import control_flow as cf +from jax._src.config import config from jax import lax from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, safe_zip) @@ -688,18 +688,25 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_positional_semantics, out_positional_semantics): checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors) new_vals_in = [error.err, error.code, error.payload, *vals_in] + sharding = OpShardingSharding.get_replicated( - list(pxla.thread_resources.env.physical_mesh.devices.flat)) - pos_sem = maps._positional_semantics.val - new_in_shardings = (*[sharding]*3, *in_shardings) - new_out_shardings = (*[sharding]*3, *out_shardings) + list(resource_env.physical_mesh.devices.flat)) + new_in_shardings = (*[sharding] * 3, *in_shardings) + new_out_shardings = (*[sharding] * 3, *out_shardings) + + if config.jax_array: + pos_sem = maps._PositionalSemantics.GLOBAL + else: + pos_sem = maps._positional_semantics.val + if not isinstance(in_positional_semantics, Iterable): in_positional_semantics = (in_positional_semantics,) if not isinstance(out_positional_semantics, Iterable): out_positional_semantics = (out_positional_semantics,) - new_positional_sems_in = (*[pos_sem]*3, *in_positional_semantics) - new_positional_sems_out = (*[pos_sem]*3, *out_positional_semantics) - new_donated_invars = (*[False]*3, *donated_invars) + new_positional_sems_in = (*[pos_sem] * 3, *in_positional_semantics) + new_positional_sems_out = (*[pos_sem] * 3, *out_positional_semantics) + new_donated_invars = (*[False] * 3, *donated_invars) + err, code, payload, *vals_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 00a8c6b8211a..76c4e5e4d765 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4774,6 +4774,8 @@ def _multi_slice(arr, def _unstack(x): return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] setattr(device_array.DeviceArray, "_unstack", _unstack) +setattr(Array, '_unstack', _unstack) + def _chunk_iter(x, size): if size > x.shape[0]: yield x @@ -4784,6 +4786,7 @@ def _chunk_iter(x, size): if tail: yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter) +setattr(Array, '_chunk_iter', _chunk_iter) # Syntactic sugar for scatter operations. class _IndexUpdateHelper: diff --git a/jax/experimental/array.py b/jax/experimental/array.py index 54081cd29035..d0473e04cd93 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -14,6 +14,7 @@ from __future__ import annotations +import operator import numpy as np from typing import Sequence, Tuple, Callable, Union, Optional, cast, List @@ -21,6 +22,7 @@ from jax._src import ad_util from jax._src import api_util from jax._src import dispatch +from jax._src import dtypes from jax._src.lax import lax as lax_internal from jax._src.config import config from jax._src.util import prod, safe_zip @@ -139,6 +141,73 @@ def size(self): def sharding(self): return self._sharding + def __str__(self): + return str(self._value) + + def __len__(self): + try: + return self.shape[0] + except IndexError as err: + raise TypeError("len() of unsized object") from err # same as numpy error + + def __bool__(self): + return bool(self._value) + + def __nonzero__(self): + return bool(self._value) + + def __float__(self): + return self._value.__float__() + + def __int__(self): + return self._value.__int__() + + def __complex__(self): + return self._value.__complex__() + + def __hex__(self): + assert self.ndim == 0, 'hex only works on scalar values' + return hex(self._value) # type: ignore + + def __oct__(self): + assert self.ndim == 0, 'oct only works on scalar values' + return oct(self._value) # type: ignore + + def __index__(self): + return operator.index(self._value) + + def to_bytes(self, order="C"): + return self._value.tobytes(order) + + def tolist(self): + return self._value.tolist() + + def __format__(self, format_spec): + # Simulates behavior of https://github.com/numpy/numpy/pull/9883 + if self.ndim == 0: + return format(self._value[()], format_spec) + else: + return format(self._value, format_spec) + + def __iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error + else: + # chunk_iter is added to Array in lax_numpy.py similar to DA. + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore + + def item(self): + if dtypes.issubdtype(self.dtype, np.complexfloating): + return complex(self) + elif dtypes.issubdtype(self.dtype, np.floating): + return float(self) + elif dtypes.issubdtype(self.dtype, np.integer): + return int(self) + elif dtypes.issubdtype(self.dtype, np.bool_): + return bool(self) + else: + raise TypeError(self.dtype) + def __repr__(self): prefix = '{}('.format(self.__class__.__name__.lstrip('_')) # TODO(yashkatariya): Add weak_type to the repr and handle weak_type diff --git a/tests/api_test.py b/tests/api_test.py index c3901b7b6766..b1a0737a5a93 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1175,22 +1175,32 @@ def use_cpp_jit(self) -> bool: class APITest(jtu.JaxTestCase): - def test_grad_item(self): - def f(x): - if x.astype(bool).item(): - return x ** 2 - else: - return x - out = jax.grad(f)(2.0) - self.assertEqual(out, 4) + @parameterized.named_parameters( + ('array', True), + ('no_array', False) + ) + def test_grad_item(self, array_enabled): + with jax._src.config.jax_array(array_enabled): + def f(x): + if x.astype(bool).item(): + return x ** 2 + else: + return x + out = jax.grad(f)(2.0) + self.assertEqual(out, 4) - def test_jit_item(self): - def f(x): - return x.item() - x = jnp.array(1.0) - self.assertEqual(f(x), x) - with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"): - jax.jit(f)(x) + @parameterized.named_parameters( + ('array', True), + ('no_array', False) + ) + def test_jit_item(self, array_enabled): + with jax._src.config.jax_array(array_enabled): + def f(x): + return x.item() + x = jnp.array(1.0) + self.assertEqual(f(x), x) + with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"): + jax.jit(f)(x) def test_grad_bad_input(self): def f(x): diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 7a97e859de0a..97143d363764 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -26,6 +26,8 @@ from jax.experimental import checkify from jax.experimental import pjit from jax.experimental import maps +from jax.experimental.sharding import MeshPspecSharding +from jax.experimental import array from jax._src.checkify import CheckEffect import jax.numpy as jnp @@ -421,13 +423,20 @@ def g(x, y): # binary func return x / y - ps = pjit.PartitionSpec("dev") + mesh = maps.Mesh(np.array(jax.devices()), ["dev"]) + if config.jax_array: + ps = MeshPspecSharding(mesh, pjit.PartitionSpec("dev")) + inp = np.arange(8) + x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) + else: + ps = pjit.PartitionSpec("dev") + x = jnp.arange(8) + f = pjit.pjit(f, in_axis_resources=ps, out_axis_resources=ps) f = checkify.checkify(f, errors=checkify.float_checks) g = pjit.pjit(g, in_axis_resources=ps, out_axis_resources=ps) g = checkify.checkify(g, errors=checkify.float_checks) - with maps.Mesh(np.array(jax.devices()), ["dev"]): - x = jnp.arange(8) + with mesh: u_err, _ = f(x) b_err, _ = g(x, x) @@ -852,5 +861,25 @@ def g(x): checkify.checkify(g)(0.) # does not crash + +class CheckifyWithArray: + + def setUp(self): + super().setUp() + self.array_enabled = config.jax_array + config.update('jax_array', True) + + def tearDown(self): + config.update('jax_array', self.array_enabled) + super().tearDown() + + +class ArrayCheckifyTransformTests(CheckifyWithArray, CheckifyTransformTests): + pass + +class ArrayAssertPrimitiveTests(CheckifyWithArray, AssertPrimitiveTests): + pass + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 593c2131c1a4..aff581c4fb1c 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -365,8 +365,6 @@ def testGatherTiled(self): self.assertAllClose(ans, expected, check_dtypes=False) def testReduceScatter(self): - if config.jax_array: - raise unittest.SkipTest('psum_scatter gives wrong answer with Array.') f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i') device_count = jax.device_count() @@ -378,8 +376,6 @@ def testReduceScatter(self): self.assertAllClose(actual, expected[i]) def testReduceScatterTiled(self): - if config.jax_array: - raise unittest.SkipTest('psum_scatter gives wrong answer with Array.') f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i') device_count = jax.device_count() @@ -393,8 +389,6 @@ def testReduceScatterTiled(self): expected[i * scatter_len:(i + 1) * scatter_len]) def testReduceScatterReplicaGroupsTiled(self): - if config.jax_array: - raise unittest.SkipTest('psum_scatter gives wrong answer with Array.') replicas = jax.device_count() if replicas % 2 != 0: raise SkipTest @@ -1063,8 +1057,6 @@ def testPpermuteWithZipObject(self): self.assertAllClose(result, expected) def testRule30(self): - if config.jax_array: - raise unittest.SkipTest('times out when Array is enabled.') # This is a test of collective_permute implementing a simple halo exchange # to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30 # Halo exchange should be useful in spatially-sharded convolutions and in @@ -1829,8 +1821,6 @@ def f(key): self.pmap(remat(f), axis_name='i')(keys) def testPmapMapVmapCombinations(self): - if config.jax_array: - raise unittest.SkipTest('times out when Array is enabled.') # https://github.com/google/jax/issues/2822 def vv(x, y): """Vector-vector multiply""" @@ -1873,8 +1863,6 @@ def test(x): self.pmap(test)(a) def testPsumOnBooleanDtype(self): - if config.jax_array: - raise unittest.SkipTest('times out when Array is enabled.') # https://github.com/google/jax/issues/3123 n = jax.device_count() if n > 1: