Skip to content

Commit

Permalink
Improve shape validation when jax_dynamic_shapes=True
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 12, 2023
1 parent 2fa90e1 commit a1ee8c1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
14 changes: 11 additions & 3 deletions jax/_src/core.py
Expand Up @@ -2102,6 +2102,9 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
or isinstance(dim.dtype, bint))):
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
return dim
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
Expand Down Expand Up @@ -2138,11 +2141,16 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
return canonicalize_shape((d,), context)[0]

def _invalid_shape_error(shape: Shape, context: str=""):
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if config.dynamic_shapes.value:
msg = ("Shapes must be 1D sequences of integer scalars, "
f"got {shape}")
else:
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
Expand Down
8 changes: 8 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -621,6 +621,14 @@ def test_flattening_basic(self):
jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x)
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3)

def test_shape_validation(self):
# Regression test for https://github.com/google/jax/issues/18937
msg = r"Shapes must be 1D sequences of integer scalars, got .+"
with self.assertRaisesRegex(TypeError, msg):
jax.make_jaxpr(jnp.ones)(5.0)
with self.assertRaisesRegex(TypeError, msg):
jax.make_jaxpr(jnp.ones)(jnp.ones((2, 2)))

@unittest.skip("Test does not work with jax.Array")
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
class DynamicShapeAutodiffTest(jtu.JaxTestCase):
Expand Down

0 comments on commit a1ee8c1

Please sign in to comment.