Skip to content

Commit

Permalink
Make checkify tests pass with Array and add methods on Array that are…
Browse files Browse the repository at this point in the history
… present on DA.

PiperOrigin-RevId: 468058909
  • Loading branch information
yashk2810 authored and jax authors committed Aug 16, 2022
1 parent 9040d5c commit 4fc3518
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 38 deletions.
23 changes: 15 additions & 8 deletions jax/_src/checkify.py
Expand Up @@ -30,11 +30,11 @@
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.experimental.sharding import OpShardingSharding
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax._src import source_info_util, traceback_util
from jax._src.lax import control_flow as cf
from jax._src.config import config
from jax import lax
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
Expand Down Expand Up @@ -688,18 +688,25 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_positional_semantics, out_positional_semantics):
checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors)
new_vals_in = [error.err, error.code, error.payload, *vals_in]

sharding = OpShardingSharding.get_replicated(
list(pxla.thread_resources.env.physical_mesh.devices.flat))
pos_sem = maps._positional_semantics.val
new_in_shardings = (*[sharding]*3, *in_shardings)
new_out_shardings = (*[sharding]*3, *out_shardings)
list(resource_env.physical_mesh.devices.flat))
new_in_shardings = (*[sharding] * 3, *in_shardings)
new_out_shardings = (*[sharding] * 3, *out_shardings)

if config.jax_array:
pos_sem = maps._PositionalSemantics.GLOBAL
else:
pos_sem = maps._positional_semantics.val

if not isinstance(in_positional_semantics, Iterable):
in_positional_semantics = (in_positional_semantics,)
if not isinstance(out_positional_semantics, Iterable):
out_positional_semantics = (out_positional_semantics,)
new_positional_sems_in = (*[pos_sem]*3, *in_positional_semantics)
new_positional_sems_out = (*[pos_sem]*3, *out_positional_semantics)
new_donated_invars = (*[False]*3, *donated_invars)
new_positional_sems_in = (*[pos_sem] * 3, *in_positional_semantics)
new_positional_sems_out = (*[pos_sem] * 3, *out_positional_semantics)
new_donated_invars = (*[False] * 3, *donated_invars)

err, code, payload, *vals_out = pjit.pjit_p.bind(
*new_vals_in,
jaxpr=checked_jaxpr,
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -4774,6 +4774,8 @@ def _multi_slice(arr,
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(device_array.DeviceArray, "_unstack", _unstack)
setattr(Array, '_unstack', _unstack)

def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
Expand All @@ -4784,6 +4786,7 @@ def _chunk_iter(x, size):
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter)
setattr(Array, '_chunk_iter', _chunk_iter)

# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
Expand Down
69 changes: 69 additions & 0 deletions jax/experimental/array.py
Expand Up @@ -14,13 +14,15 @@

from __future__ import annotations

import operator
import numpy as np
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List

from jax import core
from jax._src import ad_util
from jax._src import api_util
from jax._src import dispatch
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.config import config
from jax._src.util import prod, safe_zip
Expand Down Expand Up @@ -139,6 +141,73 @@ def size(self):
def sharding(self):
return self._sharding

def __str__(self):
return str(self._value)

def __len__(self):
try:
return self.shape[0]
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error

def __bool__(self):
return bool(self._value)

def __nonzero__(self):
return bool(self._value)

def __float__(self):
return self._value.__float__()

def __int__(self):
return self._value.__int__()

def __complex__(self):
return self._value.__complex__()

def __hex__(self):
assert self.ndim == 0, 'hex only works on scalar values'
return hex(self._value) # type: ignore

def __oct__(self):
assert self.ndim == 0, 'oct only works on scalar values'
return oct(self._value) # type: ignore

def __index__(self):
return operator.index(self._value)

def to_bytes(self, order="C"):
return self._value.tobytes(order)

def tolist(self):
return self._value.tolist()

def __format__(self, format_spec):
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
if self.ndim == 0:
return format(self._value[()], format_spec)
else:
return format(self._value, format_spec)

def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
# chunk_iter is added to Array in lax_numpy.py similar to DA.
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # type: ignore

def item(self):
if dtypes.issubdtype(self.dtype, np.complexfloating):
return complex(self)
elif dtypes.issubdtype(self.dtype, np.floating):
return float(self)
elif dtypes.issubdtype(self.dtype, np.integer):
return int(self)
elif dtypes.issubdtype(self.dtype, np.bool_):
return bool(self)
else:
raise TypeError(self.dtype)

def __repr__(self):
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
# TODO(yashkatariya): Add weak_type to the repr and handle weak_type
Expand Down
40 changes: 25 additions & 15 deletions tests/api_test.py
Expand Up @@ -1175,22 +1175,32 @@ def use_cpp_jit(self) -> bool:

class APITest(jtu.JaxTestCase):

def test_grad_item(self):
def f(x):
if x.astype(bool).item():
return x ** 2
else:
return x
out = jax.grad(f)(2.0)
self.assertEqual(out, 4)
@parameterized.named_parameters(
('array', True),
('no_array', False)
)
def test_grad_item(self, array_enabled):
with jax._src.config.jax_array(array_enabled):
def f(x):
if x.astype(bool).item():
return x ** 2
else:
return x
out = jax.grad(f)(2.0)
self.assertEqual(out, 4)

def test_jit_item(self):
def f(x):
return x.item()
x = jnp.array(1.0)
self.assertEqual(f(x), x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"):
jax.jit(f)(x)
@parameterized.named_parameters(
('array', True),
('no_array', False)
)
def test_jit_item(self, array_enabled):
with jax._src.config.jax_array(array_enabled):
def f(x):
return x.item()
x = jnp.array(1.0)
self.assertEqual(f(x), x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"):
jax.jit(f)(x)

def test_grad_bad_input(self):
def f(x):
Expand Down
35 changes: 32 additions & 3 deletions tests/checkify_test.py
Expand Up @@ -26,6 +26,8 @@
from jax.experimental import checkify
from jax.experimental import pjit
from jax.experimental import maps
from jax.experimental.sharding import MeshPspecSharding
from jax.experimental import array
from jax._src.checkify import CheckEffect
import jax.numpy as jnp

Expand Down Expand Up @@ -421,13 +423,20 @@ def g(x, y):
# binary func
return x / y

ps = pjit.PartitionSpec("dev")
mesh = maps.Mesh(np.array(jax.devices()), ["dev"])
if config.jax_array:
ps = MeshPspecSharding(mesh, pjit.PartitionSpec("dev"))
inp = np.arange(8)
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
else:
ps = pjit.PartitionSpec("dev")
x = jnp.arange(8)

f = pjit.pjit(f, in_axis_resources=ps, out_axis_resources=ps)
f = checkify.checkify(f, errors=checkify.float_checks)
g = pjit.pjit(g, in_axis_resources=ps, out_axis_resources=ps)
g = checkify.checkify(g, errors=checkify.float_checks)
with maps.Mesh(np.array(jax.devices()), ["dev"]):
x = jnp.arange(8)
with mesh:
u_err, _ = f(x)
b_err, _ = g(x, x)

Expand Down Expand Up @@ -852,5 +861,25 @@ def g(x):

checkify.checkify(g)(0.) # does not crash


class CheckifyWithArray:

def setUp(self):
super().setUp()
self.array_enabled = config.jax_array
config.update('jax_array', True)

def tearDown(self):
config.update('jax_array', self.array_enabled)
super().tearDown()


class ArrayCheckifyTransformTests(CheckifyWithArray, CheckifyTransformTests):
pass

class ArrayAssertPrimitiveTests(CheckifyWithArray, AssertPrimitiveTests):
pass


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
12 changes: 0 additions & 12 deletions tests/pmap_test.py
Expand Up @@ -365,8 +365,6 @@ def testGatherTiled(self):
self.assertAllClose(ans, expected, check_dtypes=False)

def testReduceScatter(self):
if config.jax_array:
raise unittest.SkipTest('psum_scatter gives wrong answer with Array.')
f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')

device_count = jax.device_count()
Expand All @@ -378,8 +376,6 @@ def testReduceScatter(self):
self.assertAllClose(actual, expected[i])

def testReduceScatterTiled(self):
if config.jax_array:
raise unittest.SkipTest('psum_scatter gives wrong answer with Array.')
f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')

device_count = jax.device_count()
Expand All @@ -393,8 +389,6 @@ def testReduceScatterTiled(self):
expected[i * scatter_len:(i + 1) * scatter_len])

def testReduceScatterReplicaGroupsTiled(self):
if config.jax_array:
raise unittest.SkipTest('psum_scatter gives wrong answer with Array.')
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
Expand Down Expand Up @@ -1063,8 +1057,6 @@ def testPpermuteWithZipObject(self):
self.assertAllClose(result, expected)

def testRule30(self):
if config.jax_array:
raise unittest.SkipTest('times out when Array is enabled.')
# This is a test of collective_permute implementing a simple halo exchange
# to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30
# Halo exchange should be useful in spatially-sharded convolutions and in
Expand Down Expand Up @@ -1829,8 +1821,6 @@ def f(key):
self.pmap(remat(f), axis_name='i')(keys)

def testPmapMapVmapCombinations(self):
if config.jax_array:
raise unittest.SkipTest('times out when Array is enabled.')
# https://github.com/google/jax/issues/2822
def vv(x, y):
"""Vector-vector multiply"""
Expand Down Expand Up @@ -1873,8 +1863,6 @@ def test(x):
self.pmap(test)(a)

def testPsumOnBooleanDtype(self):
if config.jax_array:
raise unittest.SkipTest('times out when Array is enabled.')
# https://github.com/google/jax/issues/3123
n = jax.device_count()
if n > 1:
Expand Down

0 comments on commit 4fc3518

Please sign in to comment.