Skip to content

Commit

Permalink
Better error message when broadcasting ragged to static shape.
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
axch and mattjj committed Jul 7, 2023
1 parent 3a0c135 commit 6f09fe8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 13 deletions.
14 changes: 14 additions & 0 deletions jax/_src/core.py
Expand Up @@ -2124,6 +2124,20 @@ def _invalid_shape_error(shape: Shape, context: str=""):

return TypeError(msg)

class SomeTracer(object):
__slots__ = ()
def __repr__(self): return "[dynamic]"

def replace_tracer_for_error_message(obj):
# TODO(mattjj): Many ideas for improving this. Crawl the stack and see if
# there are user variables whose value is == to this object? Or search
# parameters of functions being transformed, at least? Or at least assign
# short unique ids to them?
if isinstance(obj, Tracer):
return SomeTracer()
else:
return obj

def evaluate_shape(shape: Shape, dim_vars: Sequence[str],
*dim_values: Array) -> Sequence[Array]:
"""Evaluates a shape possibly containing non-constants.
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/interpreters/batching.py
Expand Up @@ -404,6 +404,8 @@ def get_frame(self, vals, dims) -> core.AxisEnvFrame:
return frame

def process_primitive(self, primitive, tracers, params):
if config.jax_dynamic_shapes:
primitive.abstract_eval(*(t.aval for t in tracers), **params)
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
is_axis_primitive = primitive in axis_primitive_batchers
used_names = core.used_axis_names(primitive, params)
Expand Down
30 changes: 17 additions & 13 deletions jax/_src/lax/lax.py
Expand Up @@ -2808,7 +2808,9 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
"equal to their corresponding dimensions in the target broadcast "
"shape; got operand of shape {}, target broadcast shape {}, "
"broadcast_dimensions {} ")
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
raise TypeError(msg.format(
tuple([core.replace_tracer_for_error_message(d) for d in operand.shape]),
shape, broadcast_dimensions))
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
Expand Down Expand Up @@ -2980,17 +2982,16 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings):
return core._pp_eqn(new_eqn, context, settings)

def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
assert not any(d is None for d in shape) # not implemented
del dyn_shape
if not any(isinstance(d, core.DArray) and
type(core.get_aval(d).dtype) is core.bint for d in shape):
if (not dyn_shape and
not any(isinstance(d, core.DArray) and
type(core.get_aval(d).dtype) is core.bint for d in shape)):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape)
# If any BInts in shape, produce a DShapedArray (even if x is a ShapedArray)
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
# (even if x is a ShapedArray)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(shape, x.dtype, x.weak_type)
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)

broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
Expand Down Expand Up @@ -4559,8 +4560,10 @@ def rng_bit_generator(key, shape, dtype=np.uint32,
key, shape=shape, dtype=dtype, algorithm=algorithm))


def _iota_abstract_eval(*, dtype, shape, dimension):
_check_shapelike("iota", "shape", shape)
def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension):
if not dyn_shape:
# TODO(mattjj) Generalize shape_like checking to permit dynamic shapes
_check_shapelike("iota", "shape", shape)
if not any(dtypes.issubdtype(dtype, t) for t in _num):
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
typename = dtype_to_string(dtype)
Expand All @@ -4569,11 +4572,12 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
if not 0 <= dimension < len(shape):
raise ValueError("iota dimension must be between 0 and len(shape), got "
f"{dimension=} for {shape=}")
if not any(isinstance(d, core.DArray) and
type(core.get_aval(d).dtype) is core.bint for d in shape):
if (not dyn_shape and
not any(isinstance(d, core.DArray) and
type(core.get_aval(d).dtype) is core.bint for d in shape)):
return ShapedArray(shape, dtype)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(shape, dtype, False)
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)

iota_p = Primitive('iota')
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
Expand Down
12 changes: 12 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1570,6 +1570,18 @@ def func(size):
data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2)
self.assertAllClose(p.data, data)

def test_broadcast_in_dim_ragged_to_static_error(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
one_d = jnp.arange(size, dtype='int32')
# Broadcast should error even if the target shape is the same as the
# underlying data shape, because the semantic size doesn't match.
two_d = jax.lax.broadcast_in_dim(one_d, (4, 5), (1,))
return two_d
msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)"
with self.assertRaisesRegex(TypeError, msg):
jax.vmap(func, out_axes=batching.pile_axis)(ins)

def test_broadcast_in_dim_to_doubly_ragged(self):
ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
ins2 = lax.convert_element_type(jnp.array([2, 5, 1]), core.bint(6))
Expand Down

0 comments on commit 6f09fe8

Please sign in to comment.