Skip to content

Commit

Permalink
Improve concretization errors for jnp indexing routines
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjng committed Jun 17, 2021
1 parent 6d805ca commit cba5b13
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
43 changes: 21 additions & 22 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,22 +3331,29 @@ def meshgrid(*args, **kwargs):
return output


def _make_1d_grid_from_slice(s: slice, op_name: str):
start =core.concrete_or_error(None, s.start,
f"slice start of jnp.{op_name}") or 0
stop = core.concrete_or_error(None, s.stop,
f"slice stop of jnp.{op_name}")
step = core.concrete_or_error(None, s.step,
f"slice step of jnp.{op_name}") or 1
if np.iscomplex(step):
newobj = linspace(start, stop, int(_abs(step)))
else:
newobj = arange(start, stop, step)

return newobj


class _IndexGrid:
def __getitem__(self, key):
single_slice = isinstance(key, slice)
if single_slice:
key = (key,)
output = []
for k in key:
start = core.concrete_or_error(None, k.start,
"slice start of jnp.mgrid") or 0
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.mgrid")
step = core.concrete_or_error(None, k.step,
"slice step of jnp.mgrid") or 1
if np.iscomplex(step):
output.append(linspace(start, stop, int(_abs(step))))
else:
output.append(arange(start, stop, step))
output.append(_make_1d_grid_from_slice(k, op_name=self.op_name))
if single_slice:
return output[0]
output = meshgrid(*output, indexing='ij', sparse=self.sparse)
Expand Down Expand Up @@ -3382,6 +3389,7 @@ class _Mgrid(_IndexGrid):
[0, 1, 2]]], dtype=int32)
"""
sparse = False
op_name = "mgrid"

mgrid = _Mgrid()

Expand Down Expand Up @@ -3414,23 +3422,12 @@ class _Ogrid(_IndexGrid):
DeviceArray([[0, 1, 2]], dtype=int32)]
"""
sparse = True
op_name = "ogrid"


ogrid = _Ogrid()


def _make_1d_grid_from_slice(s: slice):
start = s.start or 0
stop = s.stop
step = s.step or 1
if np.iscomplex(step):
newobj = linspace(start, stop, int(_abs(step)))
else:
newobj = arange(start, stop, step)

return newobj


class _AxisConcat:
"""Concatenates slices, scalars and array-like objects along a given axis."""
def __getitem__(self, key):
Expand Down Expand Up @@ -3467,7 +3464,7 @@ def __getitem__(self, key):
output = []
for item in key:
if isinstance(item, slice):
newobj = _make_1d_grid_from_slice(item)
newobj = _make_1d_grid_from_slice(item, op_name=self.op_name)
elif isinstance(item, str):
raise ValueError("string directive must be placed at the beginning")
else:
Expand Down Expand Up @@ -3566,6 +3563,7 @@ class RClass(_AxisConcat):
axis = 0
ndmin = 1
trans1d = -1
op_name = "r_"


r_ = RClass()
Expand Down Expand Up @@ -3613,6 +3611,7 @@ class CClass(_AxisConcat):
axis = -1
ndmin = 2
trans1d = 0
op_name = "c_"


c_ = CClass()
Expand Down
16 changes: 16 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4811,6 +4811,10 @@ def testMgrid(self):
jnp.mgrid[1.3:4.2:0.3],
atol=atol,
rtol=rtol)
# abstract tracer value for jnp.mgrid slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.mgrid"):
jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2)

def testOgrid(self):
def assertListOfArraysEqual(xs, ys):
Expand Down Expand Up @@ -4846,6 +4850,10 @@ def assertListOfArraysEqual(xs, ys):
jnp.ogrid[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
# abstract tracer value for ogrid slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.ogrid"):
jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2)

def testR_(self):
a = np.arange(6).reshape((2,3))
Expand All @@ -4868,6 +4876,10 @@ def testR_(self):
# bad directive
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.r_["asdfgh",[1,2,3]]
# abstract tracer value for r_ slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.r_"):
jax.jit(lambda a, b: jnp.r_[a:b])(0, 2)

# Complex number steps
atol = 1e-6
Expand Down Expand Up @@ -4908,6 +4920,10 @@ def testC_(self):
# bad directive
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.c_["asdfgh",[1,2,3]]
# abstract tracer value for c_ slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.c_"):
jax.jit(lambda a, b: jnp.c_[a:b])(0, 2)

# Complex number steps
atol = 1e-6
Expand Down

0 comments on commit cba5b13

Please sign in to comment.