Skip to content

Commit

Permalink
Merge pull request #16601 from gnecula:clean_api
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 545137395
  • Loading branch information
jax authors committed Jul 3, 2023
2 parents 404e306 + 9261eda commit 658e8ff
Show file tree
Hide file tree
Showing 21 changed files with 105 additions and 110 deletions.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Expand Up @@ -1300,7 +1300,7 @@ def _get_axis_size(name: str, shape: tuple[core.AxisSize, ...], axis: int
(sz, ct), *other_counts = counts = size_counts.most_common()
def _all_sizes_index(sz):
for i, isz in enumerate(all_sizes):
if core.symbolic_equal_dim(isz, sz): return i
if core.definitely_equal(isz, sz): return i
assert False, (sz, all_sizes)

ex, *examples = [key_paths[_all_sizes_index(sz)] for sz, _ in counts]
Expand Down
47 changes: 25 additions & 22 deletions jax/_src/core.py
Expand Up @@ -1239,8 +1239,11 @@ def dedup_referents(itr: Iterable[Any]) -> list[Any]:
def definitely_equal(x, y):
if isinstance(x, Tracer) or isinstance(y, Tracer):
return same_referent(x, y)
elif x is y:
return True
else:
return symbolic_equal_dim(x, y)
handler, ds = _dim_handler_and_canonical(x, y)
return handler.symbolic_equal(*ds)


# -------------------- abstract values --------------------
Expand Down Expand Up @@ -1388,6 +1391,12 @@ def concrete_or_error(force: Any, val: Any, context=""):
else:
return force(val)

def concrete_dim_or_error(val: Any, context=""):
"""Like concrete_or_error(operator.index)."""
if is_dim(val):
return val
else:
return concrete_or_error(operator.index, val, context=context)

### Opaque dtypes
#
Expand Down Expand Up @@ -1555,7 +1564,7 @@ def at_least_vspace(self):
self.weak_type, self.named_shape)

def join(self, other):
if symbolic_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
named_shape = join_named_shapes(self.named_shape, other.named_shape)
return self.update(weak_type=weak_type, named_shape=named_shape)
Expand Down Expand Up @@ -1709,7 +1718,7 @@ def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))

def join(self, other):
if (symbolic_equal_shape(self.shape, other.shape) and
if (definitely_equal_shape(self.shape, other.shape) and
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
Expand Down Expand Up @@ -1897,7 +1906,7 @@ def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> Di

def dilate(self, d: DimSize, dilation: int) -> DimSize:
"""Implements `d if dilation == 1 else (0 if d == 0 else 1 + dilation * (d - 1)))`"""
if symbolic_equal_dim(dilation, 1):
if definitely_equal(dilation, 1):
return d
return 0 if d == 0 else 1 + dilation * (d - 1)

Expand Down Expand Up @@ -1940,8 +1949,8 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> tuple[DimensionHandler, tuple
raise ValueError(msg)
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)

def is_special_dim_size(v: Any) -> bool:
"""Checks if a value is a special DimSize."""
def is_dynamic_dim(v: Any) -> bool:
"""Checks if a value is a dynamic DimSize."""
handler = _get_special_dim_handler(v)
return (handler is not None)

Expand All @@ -1954,27 +1963,21 @@ def is_constant_dim(d: DimSize) -> bool:
return False

def is_dim(v: Any) -> bool:
return is_special_dim_size(v) or is_constant_dim(v)
return is_dynamic_dim(v) or is_constant_dim(v)

def is_constant_shape(s: Shape) -> bool:
# Whether the shape is a static constant.
return all(is_constant_dim(d) for d in s)

def symbolic_equal_dim(d1: DimSize, d2: DimSize) -> bool:
if d1 is d2 or same_referent(d1, d2): return True
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.symbolic_equal(*ds)

def symbolic_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
if any(d1 is d or same_referent(d1, d) for d in dlist): return True # identical always implies equal
handler, ds = _dim_handler_and_canonical(d1, *dlist)
return any([handler.symbolic_equal(ds[0], d) for d in ds[1:]])

def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
return (len(s1) == len(s2) and
all(unsafe_map(symbolic_equal_dim, s1, s2)))
def definitely_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
return any(definitely_equal(d1, d) for d in dlist)

def definitely_equal_shape(s1: Shape, s2: Shape) -> bool:
"""Check that two shapes are guaranteed to be element-wise equal.
In presence of dynamic shapes may return False even when the shapes may
be equal at runtime.
"""
return (len(s1) == len(s2) and
all(unsafe_map(definitely_equal, s1, s2)))

Expand Down Expand Up @@ -2035,7 +2038,7 @@ def _cancel_divide(num, denom):
return math.prod(num)

def is_empty_shape(s: Shape) -> bool:
return any(symbolic_equal_dim(d, 0) for d in s)
return any(definitely_equal(d, 0) for d in s)

def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
Expand Down Expand Up @@ -2075,7 +2078,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
elif (config.jax_dynamic_shapes and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_special_dim_size(dim):
elif is_dim(dim):
return dim
else:
raise type_error
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/image/scale.py
Expand Up @@ -253,7 +253,7 @@ def _resize_nearest(x, output_shape: core.Shape):
input_shape = x.shape
assert len(input_shape) == len(output_shape)
spatial_dims = tuple(i for i in range(len(input_shape))
if not core.symbolic_equal_dim(input_shape[i], output_shape[i]))
if not core.definitely_equal(input_shape[i], output_shape[i]))
for d in spatial_dims:
m = input_shape[d]
n = output_shape[d]
Expand Down Expand Up @@ -286,8 +286,8 @@ def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
# since all of the current resize methods (kernels) are interpolating, so the
# output = input under an identity warp.
spatial_dims = tuple(i for i in range(len(shape))
if not core.symbolic_equal_dim(image.shape[i], shape[i]))
scale = [1.0 if core.symbolic_equal_dim(shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d])
if not core.definitely_equal(image.shape[i], shape[i]))
scale = [1.0 if core.definitely_equal(shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d])
for d in spatial_dims]
return _scale_and_translate(image, shape, spatial_dims,
scale, [0.] * len(spatial_dims), kernel,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Expand Up @@ -477,7 +477,7 @@ def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
assert core.symbolic_equal_shape(primal_aval.shape, tangent_aval.shape)
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/batching.py
Expand Up @@ -889,7 +889,7 @@ def broadcast_batcher(prim, args, dims, **params):
assert len(args) > 1
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
if d is not not_mapped)
if all(core.symbolic_equal_shape(shape, x.shape) and d == dim
if all(core.definitely_equal_shape(shape, x.shape) and d == dim
for x, d in zip(args, dims) if np.ndim(x)):
# if there's only agreeing batch dims and scalars, just call the primitive
out = prim.bind(*args, **params)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Expand Up @@ -1344,7 +1344,7 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext,
out = []
for op, op_aval in zip(ops, ops_avals):
op_aval_shape = op_aval.shape # type: ignore
if core.symbolic_equal_shape(op_aval_shape, out_shape): # type: ignore
if core.definitely_equal_shape(op_aval_shape, out_shape): # type: ignore
out.append(op)
else:
assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -1906,7 +1906,7 @@ def combine(a_flat, b_flat):
# Check that all inputs have a consistent leading dimension `num_elems`.
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)

if core.is_special_dim_size(elems_flat[0].shape[axis]):
if not core.is_constant_dim(elems_flat[0].shape[axis]):
raise NotImplementedError("associative scan over axis "
f"of non-constant size: {elems_flat[0].shape[axis]}. You may be "
"able to avoid this on TPU. See b/274176030.")
Expand Down Expand Up @@ -2016,7 +2016,6 @@ def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
return [cumsum(t, axis=axis, reverse=not reverse)]



def cumred_reduce_window_impl(window_reduce: Callable, x, *, axis: int,
reverse: bool):
n = x.shape[axis]
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/convolution.py
Expand Up @@ -367,7 +367,7 @@ def _conv_general_dilated_shape_rule(
msg = ("conv_general_dilated feature_group_count must divide lhs feature "
"dimension size, but {} does not divide {}.")
raise ValueError(msg.format(feature_group_count, lhs_feature_count))
if not core.symbolic_equal_dim(quot, rhs.shape[dimension_numbers.rhs_spec[1]]):
if not core.definitely_equal(quot, rhs.shape[dimension_numbers.rhs_spec[1]]):
msg = ("conv_general_dilated lhs feature dimension size divided by "
"feature_group_count must equal the rhs input feature dimension "
"size, but {} // {} != {}.")
Expand Down
34 changes: 17 additions & 17 deletions jax/_src/lax/lax.py
Expand Up @@ -690,7 +690,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
Returns:
An array containing the product.
"""
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.symbolic_equal_dim(lhs.shape[-1], rhs.shape[0]):
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision,
preferred_element_type=preferred_element_type)
Expand Down Expand Up @@ -843,7 +843,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
"""
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
same_shape = core.definitely_equal_shape(np.shape(operand), new_sizes)
if dimensions is None:
same_dims = True
dims = None
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
- :func:`jax.lax.approx_max_k`
- :func:`jax.lax.approx_min_k`
"""
if not core.is_special_dim_size(k):
if core.is_constant_dim(k):
k = int(k)
if k < 0:
raise ValueError(f"k argument to top_k must be nonnegative, got {k}")
Expand Down Expand Up @@ -1574,10 +1574,10 @@ def broadcasting_shape_rule(name, *avals):
result_shape.append(ds[0])
else:
# if all dims are equal (or 1), the result is the non-1 size
non_1s = [d for d in ds if not core.symbolic_equal_dim(d, 1)]
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
if not non_1s:
result_shape.append(1)
elif all(core.symbolic_equal_dim(non_1s[0], d) for d in non_1s[1:]):
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
result_shape.append(non_1s[0])
else:
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
Expand Down Expand Up @@ -1632,25 +1632,25 @@ def _unbroadcast(aval, x):
if not isinstance(aval, (core.DShapedArray, ShapedArray)):
raise TypeError("transpose with implicit broadcasting of unshaped values")
x_shape = np.shape(x)
if core.symbolic_equal_shape(aval.shape, x_shape):
if core.definitely_equal_shape(aval.shape, x_shape):
return x
assert not aval.shape or len(x_shape) == len(aval.shape)
if not aval.shape:
return _reduce_sum(x, list(range(len(x_shape))))
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.symbolic_equal_dim(a, b)]
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
if config.jax_enable_checks: assert all(aval.shape[i] == 1 for i in dims)
return reshape(_reduce_sum(x, dims), aval.shape)

def _maybe_broadcast(target_shape, x):
x_shape = np.shape(x)
if core.symbolic_equal_shape(x_shape, target_shape):
if core.definitely_equal_shape(x_shape, target_shape):
return x
elif not x_shape:
return broadcast_in_dim(x, target_shape, ())
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape))
if core.symbolic_equal_dim(a, b)]
if core.definitely_equal(a, b)]
squeeze_shape = [x_shape[i] for i in dims]
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)

Expand Down Expand Up @@ -2492,13 +2492,13 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
raise TypeError(msg.format(rhs_batch, rhs_contracting))
lhs_batch_shape = tuple(lhs.shape[i] for i in lhs_batch)
rhs_batch_shape = tuple(rhs.shape[i] for i in rhs_batch)
if not core.symbolic_equal_shape(lhs_batch_shape, rhs_batch_shape):
if not core.definitely_equal_shape(lhs_batch_shape, rhs_batch_shape):
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
"to have the same shape, got {} and {}.")
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
lhs_contracting_shape = tuple(lhs.shape[i] for i in lhs_contracting)
rhs_contracting_shape = tuple(rhs.shape[i] for i in rhs_contracting)
if not core.symbolic_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
if not core.definitely_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
msg = ("dot_general requires contracting dimensions to have the same "
"shape, got {} and {}.")
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
Expand Down Expand Up @@ -2800,8 +2800,8 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
'dimensions, got {} for operand ndim {} and shape {}.')
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
if not all(core.symbolic_equal_one_of_dim(operand.shape[i],
[1, shape[broadcast_dimensions[i]]])
if not all(core.definitely_equal_one_of_dim(operand.shape[i],
[1, shape[broadcast_dimensions[i]]])
for i in range(operand_ndim)):
msg = (
"broadcast_in_dim operand dimension sizes must either be 1, or be "
Expand Down Expand Up @@ -2836,7 +2836,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
if type(ct) is ad_util.Zero:
return [ad_util.Zero(operand.aval)]
unit_dims = [i for i, s in enumerate(operand.aval.shape)
if core.symbolic_equal_dim(s, 1)]
if core.definitely_equal(s, 1)]
bdims = tuple(np.delete(broadcast_dimensions, unit_dims))
axes = tuple(np.delete(range(len(shape)), bdims))
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
Expand Down Expand Up @@ -2897,7 +2897,7 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,

def _broadcast_in_dim_fwd_rule(eqn):
v, *dyn = eqn.invars
if not dyn and core.symbolic_equal_shape(eqn.params['shape'], v.aval.shape):
if not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape):
return [v], None
else:
return [None], eqn
Expand Down Expand Up @@ -3251,7 +3251,7 @@ def _compute_squeeze_shape(shape, dimensions):
raise ValueError(f"dimensions are not unique: {dimensions}")
if not all(0 <= d < len(shape) for d in dims_set):
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
if any(not core.symbolic_equal_dim(shape[d], 1) for d in dimensions):
if any(not core.definitely_equal(shape[d], 1) for d in dimensions):
raise ValueError(
"cannot select an axis to squeeze out which has size not equal to "
f"one, got {shape=} and {dimensions=}")
Expand Down Expand Up @@ -4176,7 +4176,7 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
def _top_k_lower(ctx, operand, k):
if core.is_special_dim_size(k):
if not core.is_constant_dim(k):
# TODO: https://github.com/openxla/stablehlo/issues/1396
raise ValueError("native serialization with shape polymorphism not implemented for top_k")
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/slicing.py
Expand Up @@ -1315,7 +1315,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
# the number of slices is zero. Likely the best fix would be to change the
# definition of gather() so it can be batched without the construction of
# an explicit iota of size-1 slices.
if core.symbolic_equal_dim(operand.shape[0], 0):
if core.definitely_equal(operand.shape[0], 0):
output_shape = _gather_shape_rule(
core.ShapedArray(operand.shape[1:], operand.dtype),
core.ShapedArray(indices.shape[1:],
Expand Down Expand Up @@ -1557,7 +1557,7 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr,
for i in update_scatter_dims:
if scatter_dims_seen == index_vector_dim:
scatter_dims_seen += 1
if not core.symbolic_equal_dim(updates.shape[i], expanded_indices_shape[scatter_dims_seen]):
if not core.definitely_equal(updates.shape[i], expanded_indices_shape[scatter_dims_seen]):
raise TypeError(f"Bounds of the scatter dimensions of updates must be "
f"the same as the bounds of the corresponding dimensions "
f"of scatter indices. For scatter dimension {i}, updates "
Expand Down
14 changes: 6 additions & 8 deletions jax/_src/nn/functions.py
Expand Up @@ -412,10 +412,9 @@ def normalize(x: Array,
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Array, num_classes: int, *,
dtype: Any, axis: Union[int, AxisName]) -> Array:
if not core.is_special_dim_size(num_classes):
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
try:
Expand Down Expand Up @@ -459,10 +458,9 @@ def one_hot(x: Array, num_classes: int, *,
axis: the axis or axes along which the function should be
computed.
"""
if not core.is_special_dim_size(num_classes):
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
return _one_hot(x, num_classes, dtype=dtype, axis=axis)


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/array_methods.py
Expand Up @@ -130,7 +130,7 @@ def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
if sz is not None:
return (*newshape[:i], sz, *newshape[i+1:])
return tuple(-core.divide_shape_sizes(np.shape(a), newshape)
if core.symbolic_equal_dim(d, -1) else d
if core.definitely_equal(d, -1) else d
for d in newshape)


Expand Down

0 comments on commit 658e8ff

Please sign in to comment.