Skip to content

Commit

Permalink
Merge pull request #9576 from nicholasjng:broadcast-validation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 432531230
  • Loading branch information
jax authors committed Mar 4, 2022
2 parents 21c9d73 + 56546d3 commit 2a3f936
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
17 changes: 16 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,23 @@

T = TypeVar("T")

def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
if not all(idx >= 0 for idx in checked):
msg = f"Only non-negative indices are allowed when broadcasting" \
f" static shapes, but got shape {shape!r}."
raise TypeError(msg)

assert shapes
if config.jax_dynamic_shapes:
# pass dynamic shapes through unchecked
return
else:
_ = tuple(map(_check_static_shape, shapes))

def _try_broadcast_shapes(
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
assert shapes
if len(shapes) == 1: return shapes[0]
rank, *others = {len(shape) for shape in shapes}
if others: return None # must have consistent rank
Expand Down Expand Up @@ -113,6 +127,7 @@ def _broadcast_shapes_cached(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
return _broadcast_shapes_uncached(*shapes)

def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst

Expand Down
9 changes: 9 additions & 0 deletions tests/lax_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,15 @@ def testBroadcastShapesReturnsPythonInts(self):
out_shape = lax.broadcast_shapes(shape1, shape2)
self.assertTrue(all(type(s) is int for s in out_shape))

def testBroadcastShapesFaultyInputs(self):
err_shape1, err_shape2 = (-1,), "hello"
# negative inputs should fail while informing about illegal negative indices...
with self.assertRaisesRegex(TypeError, "Only non-negative indices are allowed.*"):
lax.broadcast_shapes(err_shape1)
# ... while non-integers should error earlier, in the canonicalize_shape machinery.
with self.assertRaisesRegex(TypeError, "Shapes must be 1D sequences.*"):
lax.broadcast_shapes(err_shape2) # pytype: disable=wrong-arg-types

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}_bdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, bdims),
Expand Down

0 comments on commit 2a3f936

Please sign in to comment.