diff --git a/jax/_src/core.py b/jax/_src/core.py index f113f7cc7e9c..5f16d223b5cf 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1241,10 +1241,10 @@ def definitely_equal(x, y): return same_referent(x, y) elif x is y: return True - else: - handler, ds = _dim_handler_and_canonical(x, y) - return handler.symbolic_equal(*ds) - + try: + return x == y + except InconclusiveDimensionOperation: + return False # -------------------- abstract values -------------------- @@ -1837,122 +1837,13 @@ class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" pass -class DimensionHandler: - """Operations on dimension sizes. - - Dimension sizes are normally integer constants, but can also be symbolic, - e.g., masking.Poly or jax2tf.shape_poly.DimVar. - - The base class works for integers only. Subclasses are invoked when at least - one of the operands has a type registered in _SPECIAL_DIMENSION_HANDLERS. In - that case, all operands are guaranteed to be either the special dimension - type, or Python integer scalars. - - Subclasses should raise InconclusiveDimensionOperation if the result cannot - be computed in some contexts. - """ - def is_constant(self, d: DimSize) -> bool: - """The dimension is a constant.""" - return True - - def symbolic_equal(self, d1: DimSize, d2: DimSize) -> bool: - """True iff the dimension sizes are equal in all contexts; False otherwise. - Unlike `d1 == d2` this never raises InconclusiveDimensionOperation. - """ - return d1 == d2 - - def greater_equal(self, d1: DimSize, d2: DimSize) -> bool: - """Computes `d1 >= d2`. - Raise InconclusiveDimensionOperation if the result is different in - different contexts. - """ - return d1 >= d2 - - def sum(self, *ds: DimSize) -> DimSize: - """Sum of dimensions. - Raises InconclusiveDimensionOperation if the result cannot be represented - by the same DimSize in all contexts. - """ - return sum(ds) - - def diff(self, d1: DimSize, d2: DimSize) -> DimSize: - """Difference of dimensions. - Raises InconclusiveDimensionOperation if the result cannot be represented - by the same DimSize in all contexts. - """ - return d1 - d2 - - def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: - """Computes integer "i" such that i * size(s2) == size(s1). - - Raise InconclusiveDimensionOperation if there is no such integer for all - contexts, - """ - sz1 = math.prod(s1) - sz2 = math.prod(s2) - if sz1 == 0 and sz2 == 0: - return 1 - if sz1 % sz2: - raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") - return sz1 // sz2 - - def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: - """(d - window_size) // window_stride + 1. - - If d == 0 or window_size > d, returns 0. - """ - if d == 0 or window_size > d: return 0 - return (d - window_size) // window_stride + 1 - - def dilate(self, d: DimSize, dilation: int) -> DimSize: - """Implements `d if dilation == 1 else (0 if d == 0 else 1 + dilation * (d - 1)))`""" - if definitely_equal(dilation, 1): - return d - return 0 if d == 0 else 1 + dilation * (d - 1) - - def as_value(self, d: DimSize): - """Turns a dimension size into a JAX value that we can compute with.""" - return d - -_dimension_handler_int = DimensionHandler() -_SPECIAL_DIMENSION_HANDLERS: dict[type, DimensionHandler] = {} -DArrayDimHandler = type('DArrayDimHandler', (DimensionHandler,), {})() +def is_symbolic_dim(v: Any) -> bool: + """Checks if a value is a symbolic dimension used for shape polymorphism. -def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]: - if isinstance(dim, Tracer) and not config.jax_dynamic_shapes: - return None - if isinstance(dim, DArray) and not dim.shape and type(dim.dtype) is bint: - return DArrayDimHandler - return _SPECIAL_DIMENSION_HANDLERS.get(type(dim)) - -def _dim_handler_and_canonical(*dlist: DimSize) -> tuple[DimensionHandler, tuple[DimSize, ...]]: - """Finds the handler for the given dimensions; also returns the canonical dimensions. - - A dimension is canonical if it is a Python integer scalar, or has a type - registered in _SPECIAL_DIMENSION_HANDLERS. + This should be used very rarely, because symbolic dimensions overload all + operators, and should just work. """ - special_handlers = set() - canonical = [] - for d in dlist: - handler = _get_special_dim_handler(d) - if handler: - special_handlers.add(handler) - canonical.append(d) - else: - try: - canonical.append(operator.index(d)) - except TypeError: - raise _invalid_shape_error(dlist) - - if len(special_handlers) > 1: - msg = (f"Dimension size operation involves multiple special dimension types {dlist}") - raise ValueError(msg) - return next(iter(special_handlers), _dimension_handler_int), tuple(canonical) - -def is_dynamic_dim(v: Any) -> bool: - """Checks if a value is a dynamic DimSize.""" - handler = _get_special_dim_handler(v) - return (handler is not None) + return hasattr(v, "dimension_as_value") def is_constant_dim(d: DimSize) -> bool: # Whether the dimension is a static integer constant. @@ -1963,7 +1854,7 @@ def is_constant_dim(d: DimSize) -> bool: return False def is_dim(v: Any) -> bool: - return is_dynamic_dim(v) or is_constant_dim(v) + return is_symbolic_dim(v) or is_constant_dim(v) def is_constant_shape(s: Shape) -> bool: # Whether the shape is a static constant. @@ -1981,39 +1872,17 @@ def definitely_equal_shape(s1: Shape, s2: Shape) -> bool: return (len(s1) == len(s2) and all(unsafe_map(definitely_equal, s1, s2))) -def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool: - handler, ds = _dim_handler_and_canonical(d1, d2) - return handler.symbolic_equal(*ds) or handler.greater_equal(*ds) - -def greater_equal_shape(s1: Shape, s2: Shape) -> bool: - return all(map(greater_equal_dim, s1, s2)) - -def sum_dim(*ds: DimSize) -> DimSize: - handler, ds = _dim_handler_and_canonical(*ds) - return handler.sum(*ds) - -def sum_shapes(*ss: Shape) -> Shape: - return tuple(map(sum_dim, *ss)) - -def diff_dim(d1: DimSize, d2: DimSize) -> DimSize: - handler, ds = _dim_handler_and_canonical(d1, d2) - return handler.diff(*ds) - -def diff_shape(s1: Shape, s2: Shape) -> Shape: - return tuple(map(diff_dim, s1, s2)) - def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize: """Returns an integer "i" s.t., i * size(s2) == size(s1). - Raises if there is no such integer.""" - s1 = s1 or (1,) - s2 = s2 or (1,) - handler, ds = _dim_handler_and_canonical(*s1, *s2) - return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):]) - -def same_shape_sizes(s1: Shape, s2: Shape) -> bool: - maybe_result = cancel_divide_tracers(s1, s2) - if maybe_result is not None: return maybe_result == 1 - return 1 == divide_shape_sizes(s1, s2) + Raises InconclusiveDimensionOperation if there is no such integer.""" + sz1 = math.prod(s1) + sz2 = math.prod(s2) + if definitely_equal(sz1, sz2): # Takes care of sz1 and sz2 being 0 + return 1 + q, r = divmod(sz1, sz2) + if isinstance(r, Tracer) or r != 0: + raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") + return q def cancel_divide_tracers(num, denom): partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l) @@ -2041,23 +1910,28 @@ def is_empty_shape(s: Shape) -> bool: return any(definitely_equal(d, 0) for d in s) def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize: - """Implements `0 if d == 0 else 1 + dilation * (d - 1))`""" - handler, ds = _dim_handler_and_canonical(d, dilation) - return handler.dilate(*ds) + """1 + dilation * (d - 1). -def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape: - return tuple(map(dilate_dim, s, dilations)) + if d == 0, returns 0. + """ + if definitely_equal(dilation, 1): # fast path + return d + return 0 if d == 0 else 1 + dilation * (d - 1) def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: - handler, ds = _dim_handler_and_canonical(d, window_size, window_stride) - return handler.stride(*ds) + """(d - window_size) // window_stride + 1 -def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape: - """(s - window_size) // window_stride + 1""" - return tuple(map(stride_dim, s, window_size, window_stride)) + If d < window_size, returns 0. + We assume window_size >= 1 and window_stride >= 1. + """ + if is_constant_dim(d) and is_constant_dim(window_size): + # TODO(necula): Enable this check for non-constant dimensions + if d < window_size: + return 0 + return (d - window_size) // window_stride + 1 def dimension_as_value(d: DimSize): - """Turns a dimension size into a JAX value that we can compute with. + """Turns a dimension size into a JAX array. This is the identity function for constant dimensions. Has the same abstract value as Python constants. diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index da64912254fe..84ef1db87439 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,6 @@ from collections import namedtuple from contextlib import contextmanager, AbstractContextManager -import functools from functools import partial import inspect import itertools as it @@ -2452,44 +2451,6 @@ def _substitute_vars_in_type( else: return a - -class DimensionHandlerTracer(core.DimensionHandler): - """See core.DimensionHandler. - - Most methods are inherited. - """ - def is_constant(self, d: core.DimSize) -> bool: - assert isinstance(d, Tracer) - return False - - def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool: - return d1 is d2 - - def greater_equal(self, d1: core.DimSize, d2: core.DimSize): - raise core.InconclusiveDimensionOperation("TODO") - - def divide_shape_sizes(self, s1: core.Shape, s2: core.Shape) -> core.DimSize: - """Computes integer "i" such that i * size(s2) == size(s1). - - Raise InconclusiveDimensionOperation if there is no such integer for all - contexts. - """ - s1_size = functools.reduce(op.mul, s1, 1) - s2_size = functools.reduce(op.mul, s2, 1) - q, r = divmod(s1_size, s2_size) - # TODO(necula): must check that r == 0! - return q - - def stride(self, d: core.DimSize, window_size: core.DimSize, window_stride: core.DimSize) -> core.DimSize: - """Implements `(d - window_size) // window_stride + 1`""" - raise core.InconclusiveDimensionOperation("TODO") - - def as_value(self, d: core.DimSize): - """Turns a dimension size into a Jax value that we can compute with.""" - raise core.InconclusiveDimensionOperation("TODO") - -core._SPECIAL_DIMENSION_HANDLERS[DynamicJaxprTracer] = DimensionHandlerTracer() - Const = Any Val = Any diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ef81a4a86d30..d23d2fe20749 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -984,7 +984,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, type(linear) is tuple and all(type(x) is bool for x in linear)) tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) - tc(length, 'length', 'non-negative int', core.greater_equal_dim(length, 0)) + tc(length, 'length', 'non-negative int', length >= 0) if len(linear) != len(avals): raise core.JaxprTypeError( diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index e13cbd4fc954..6811b1e8587b 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -797,9 +797,7 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): if np.any(lhs_padded < 0): raise ValueError("Negative padding is larger than the size of the corresponding dimension: " f"got padding={pads} for lhs_shape[2:]={lhs_shape[2:]}") - out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides) - out_space = [d if core.greater_equal_dim(d, 0) else 0 - for d in out_space] + out_space = tuple(map(core.stride_dim, lhs_padded, rhs_shape[2:], strides)) if batch_group_count > 1: assert lhs_shape[0] % batch_group_count == 0 out_shape_0 = lhs_shape[0] // batch_group_count @@ -930,8 +928,7 @@ def _conv_general_vjp_rhs_padding( rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation) out_dilated_shape = lax._dilate_shape(out_shape, window_strides) pads_lo, _ = util.unzip2(padding) - pads_from_lhs = core.diff_shape(out_dilated_shape, lhs_dilated_shape) - pads_from_rhs = core.diff_shape(core.diff_shape(rhs_dilated_shape, pads_lo), - (1,) * len(pads_lo)) - pads_hi = core.sum_shapes(pads_from_lhs, pads_from_rhs) + pads_from_lhs = map(operator.sub, out_dilated_shape, lhs_dilated_shape) + pads_from_rhs = tuple(rd - pd - 1 for rd, pd in zip(rhs_dilated_shape, pads_lo)) + pads_hi = tuple(map(operator.add, pads_from_lhs, pads_from_rhs)) return list(zip(pads_lo, pads_hi)) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c2bc7f299dc6..9fa35fc349b5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3169,9 +3169,9 @@ def _pad_shape_rule(operand, padding_value, *, padding_config): if not all(i >= 0 for _, _, i in padding_config): raise ValueError("interior padding in padding_config must be nonnegative, " f"got padding_config {padding_config}") - result = tuple(core.sum_dim(l, h, core.dilate_dim(d, i + 1)) + result = tuple(l + h + core.dilate_dim(d, i + 1) for (l, h, i), d in zip(padding_config, op_shape)) - if not all(core.greater_equal_dim(d, 0) for d in result): + if not all(d >= 0 for d in result): msg = (f"Dimension size after padding is not at least 0, " f"got result shape {result}, for padding_config {padding_config}" f" and operand shape {op_shape}") @@ -3298,12 +3298,12 @@ def shape_as_value(shape: core.Shape): return concatenate(dims, dimension=0) def _reshape_shape_rule(operand, *, new_sizes, dimensions): - if not all(core.greater_equal_dim(d, 0) for d in new_sizes): + if not all(d >= 0 for d in new_sizes): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) # TODO(necula): re-enable this check if (not config.jax_dynamic_shapes and - not core.same_shape_sizes(np.shape(operand), new_sizes)): + not math.prod(np.shape(operand)) == math.prod(new_sizes)): msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' raise TypeError(msg.format(new_sizes, np.shape(operand))) if dimensions is not None: @@ -3822,7 +3822,7 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): axis, = axes if not (0 <= axis < len(operand.shape)): raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}") - if not core.greater_equal_dim(operand.shape[axis], 1): + if operand.shape[axis] < 1: raise ValueError("argmin and argmax require non-empty reduced dimension. " f"operand.shape={operand.shape} {axis=}") return tuple(np.delete(operand.shape, axis)) @@ -4661,7 +4661,7 @@ def _dilate_shape(shape, dilation): msg = "All dilations must be positive, got {}." raise TypeError(msg.format(dilation)) dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation) - return core.dilate_shape(shape, dilation) + return tuple(map(core.dilate_dim, shape, dilation)) def _ceil_divide(x1, x2): return -np.floor_divide(np.negative(x1), x2) @@ -4803,7 +4803,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False): raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err lower_bound, bound_error = ( (1, "strictly positive") if non_zero_shape else (0, "nonnegative")) - if not all(core.greater_equal_dim(d, lower_bound) for d in obj_arr): + if not all(d >= lower_bound for d in obj_arr): msg = "{} {} must have every element be {}, got {}." raise TypeError(msg.format(fun_name, arg_name, bound_error, obj)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6e019853c7f1..01ae535a9aad 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +import operator from functools import partial import math from typing import Callable, NamedTuple, Optional, Sequence, Union @@ -1081,34 +1082,33 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): msg = ("slice limit_indices must have the same length as start_indices, " "got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) - if not core.greater_equal_shape(operand.shape, limit_indices): + if not all(map(operator.ge, operand.shape, limit_indices)): msg = ("slice limit_indices must be less than or equal to operand shape, " "got limit_indices {} for operand shape {}.") raise TypeError(msg.format(limit_indices, operand.shape)) - if not all(core.greater_equal_dim(si, 0) for si in start_indices): + if not all(si >= 0 for si in start_indices): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) if not jax.config.jax_dynamic_shapes: - if not core.greater_equal_shape(limit_indices, start_indices): + if not all(map(operator.ge, limit_indices, start_indices)): msg = ("slice limit_indices must be greater than or equal to start_indices," " got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) + diff = tuple(map(operator.sub, limit_indices, start_indices)) if strides is None or tuple(strides) == (1,) * len(operand.shape): - shape = [limit if type(start) is int and start == 0 else limit - start - for start, limit in zip(start_indices, limit_indices)] - return tuple(shape) + return diff lax._check_shapelike("slice", "strides", strides) if len(strides) != operand.ndim: msg = ("slice strides must have length equal to the number of dimensions " "of the operand, got strides {} for operand shape {}.") raise TypeError(msg.format(strides, operand.shape)) - if not core.greater_equal_shape(strides, (0,) * len(strides)): + if not all(s >= 0 for s in strides): msg = "slice strides must be positive, got {}" raise TypeError(msg.format(strides)) - diff = core.diff_shape(limit_indices, start_indices) - return core.stride_shape(diff, (1,) * len(diff), strides) + return tuple(core.stride_dim(d, window_size=1, window_stride=s) + for d, s in zip(diff, strides)) def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) @@ -1172,11 +1172,11 @@ def _dynamic_slice_shape_rule( msg = ("dynamic_slice slice_sizes must have the same length as " "start_indices, got start_indices length {} and slice_sizes {}.") raise TypeError(msg.format(len(start_indices), slice_sizes)) - if not dyn and not core.greater_equal_shape(operand.shape, slice_sizes): + if not dyn and not all(map(operator.ge, operand.shape, slice_sizes)): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) - if not dyn and not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes): + if not dyn and not all(ssz >= 0 for ssz in slice_sizes): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) @@ -1321,7 +1321,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): msg = ("dynamic_update_slice start_indices must have length equal to the " "rank of operand, got indices {} for operand shape {}.") raise TypeError(msg.format(start_indices, operand.shape)) - if not core.greater_equal_shape(operand.shape, update.shape): + if not all(map(operator.ge, operand.shape, update.shape)): msg = ("dynamic_update_slice update shape must be smaller than operand " "shape, got update shape {} for operand shape {}.") raise TypeError(msg.format(update.shape, operand.shape)) @@ -1522,8 +1522,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_size = slice_sizes[i] corresponding_input_size = operand.shape[i] - if not (core.greater_equal_dim(slice_size, 0) and - core.greater_equal_dim(corresponding_input_size, slice_size)): + if not (slice_size >= 0 and + corresponding_input_size >= slice_size): raise TypeError(f"Slice size at index {i} in gather op is out of range, " f"must be within [0, {corresponding_input_size} + 1), " f"got {slice_size}.") @@ -1917,7 +1917,7 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] - if not core.greater_equal_dim(max_update_slice_sizes[i], updates.shape[update_window_dim]): + if max_update_slice_sizes[i] < updates.shape[update_window_dim]: raise TypeError(f"Bounds of the window dimensions of updates must not " f"exceed the bounds of the corresponding dimensions of " f"operand. For dimension {update_window_dim}, updates " diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 96b5d2ea34e0..f0bf3246c8d1 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -438,9 +438,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, operand_shape = lax._dilate_shape(operand_shape, base_dilation) if window_dilation is not None: window_dimensions = lax._dilate_shape(window_dimensions, window_dilation) - pads_lo, pads_hi = util.unzip2(padding) - operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi) - return core.stride_shape(operand_padded, window_dimensions, window_strides) + operand_padded = tuple(d + pl + ph for d, (pl, ph) in zip(operand_shape, padding)) + return tuple(map(core.stride_dim, operand_padded, window_dimensions, window_strides)) reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b5ad9b16331..2887a4749261 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2311,7 +2311,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None, for name, val in [(start_name, start), ("stop", stop), ("step", step)]: if val is not None and np.ndim(val) != 0: raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}") - if any(core.is_dynamic_dim(v) for v in (start, stop, step)): + if any(core.is_symbolic_dim(v) for v in (start, stop, step)): # Some dynamic shapes if stop is None and step is None: stop = start @@ -2628,7 +2628,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *, axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()") assert isinstance(axis, int) # to appease mypy - if core.is_dynamic_dim(repeats): + if core.is_symbolic_dim(repeats): if total_repeat_length is not None: raise ValueError("jnp.repeat with a non-constant `repeats` is supported only " f"when `total_repeat_length` is None. ({repeats=} {total_repeat_length=})") @@ -4389,7 +4389,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], if start is None or core.definitely_equal(start, 0): start = None if stop is None or (not isinstance(stop, core.Tracer) and - core.greater_equal_dim(stop, x_shape[x_axis])): + stop >= x_shape[x_axis]): stop = None elif core.definitely_equal(step, -1): step = -1 diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 39a3d4b94448..7a31ac22d477 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -94,7 +94,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: if initial is None and not has_identity: shape = np.shape(a) - if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims): + if not _all(shape[d] >= 1 for d in pos_dims): raise ValueError(f"zero-size array to reduction operation {name} which has no identity") result_dtype = dtype or dtypes.dtype(a) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index 9b0b17608886..38dbb6a0f78e 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -440,15 +440,18 @@ def eq(self, other: DimSize) -> bool: # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported return False + def inconclusive_comparison(self, operation: str, op: Any) -> Exception: + return InconclusiveDimensionOperation( + f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic0dimensions-is-partially-supported.") + def ge(self, other: DimSize) -> bool: lb, ub = _ensure_poly(self - other, "ge").bounds() if lb >= 0: return True if ub < 0: return False - raise InconclusiveDimensionOperation( - f"Symbolic dimension comparison '{self}' >= '{other}' is inconclusive.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic0dimensions-is-partially-supported.") + raise self.inconclusive_comparison(">=", other) def __hash__(self): return hash(tuple(sorted(self.monomials()))) @@ -571,13 +574,22 @@ def __ne__(self, other: DimSize) -> bool: __ge__ = ge def __le__(self, other: DimSize): - return _ensure_poly(other, "le").__ge__(self) + try: + return _ensure_poly(other, "le").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<=", other) from e def __gt__(self, other: DimSize): - return not _ensure_poly(other, "gt").__ge__(self) + try: + return not _ensure_poly(other, "gt").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison(">", other) from e def __lt__(self, other: DimSize): - return not self.__ge__(other) + try: + return not self.__ge__(other) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<", other) from e def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: """ @@ -735,56 +747,6 @@ def _convertible_to_poly(p: DimSize) -> bool: def is_poly_dim(p: DimSize) -> bool: return isinstance(p, _DimExpr) - -class DimensionHandlerPoly(core.DimensionHandler): - """See core.DimensionHandler. - - Most methods are inherited. - """ - def is_constant(self, d: DimSize) -> bool: - assert isinstance(d, _DimExpr) - return False - - def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool: - try: - return _ensure_poly(d1, "equal") == d2 - except InconclusiveDimensionOperation: - return False - - def greater_equal(self, d1: DimSize, d2: DimSize): - return _ensure_poly(d1, "ge") >= d2 - - def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: - sz1 = math.prod(s1) - sz2 = math.prod(s2) - if core.definitely_equal(sz1, sz2): # Takes care also of sz1 == sz2 == 0 - return 1 - err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}" - try: - q, r = _ensure_poly(sz1, "divide_shape").divmod(_ensure_poly(sz2, "divide_shape")) - except InconclusiveDimensionOperation as e: - raise InconclusiveDimensionOperation(err_msg + f"\nDetails: {e}") - if not core.definitely_equal(r, 0): - raise InconclusiveDimensionOperation(err_msg + f"\nRemainder is not zero: {r}") - return q # type: ignore[return-value] - - def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: - """Implements `(d - window_size) // window_stride + 1`""" - try: - # TODO(necula): check for d == 0 or window_size > d and return 0. - q, r = _ensure_poly(d - window_size, "stride").divmod(_ensure_poly(window_stride, "stride")) - return q + 1 - except InconclusiveDimensionOperation as e: - raise InconclusiveDimensionOperation( - f"Cannot compute stride for dimension '{d}', " - f"window_size '{window_size}', stride '{window_stride}'.\nDetails: {e}.") - return d - - def as_value(self, d: DimSize): - """Turns a dimension size into a Jax value that we can compute with.""" - return _dim_as_value(d) - -core._SPECIAL_DIMENSION_HANDLERS[_DimExpr] = DimensionHandlerPoly() dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index e5e0fb3e0ad3..2c0c68d82c66 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -40,6 +40,7 @@ from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client import numpy as np @@ -178,9 +179,7 @@ def test_dim_vars_symbolic_equal(self): self.assertFalse(core.definitely_equal_one_of_dim(3, [])) self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # A DeviceArray - with self.assertRaisesRegex(TypeError, - re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")): - self.assertTrue(core.definitely_equal(1, "a")) + self.assertFalse(core.definitely_equal(1, "a")) def test_poly_bounds(self): a, b = shape_poly._parse_spec("a, b", (2, 3)) @@ -312,12 +311,24 @@ def test_poly_compare(self): def test_poly_compare_overload(self): a, b = shape_poly._parse_spec("a, b", (2, 3)) + self.assertTrue(a >= a) + self.assertTrue(a >= 0) + self.assertTrue(a >= 1) + + with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): + a >= 2 + poly = 4 * a + b + 3 self.assertTrue(poly >= 0) self.assertTrue(poly >= 8) self.assertTrue(poly > 7) self.assertTrue(poly >= poly) self.assertTrue(poly >= poly - 1) + # LHS is an integer + self.assertTrue(8 <= poly) + self.assertTrue(7 < poly) + self.assertTrue(-8 >= -poly) + self.assertTrue(-7 > -poly) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): poly >= 9 @@ -325,22 +336,6 @@ def test_poly_compare_overload(self): with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): (4 * a - b) >= 0 - def test_core_greater_equal(self): - a, b = shape_poly._parse_spec("a, b", (2, 3)) - self.assertTrue(core.greater_equal_dim(a, a)) - self.assertTrue(core.greater_equal_dim(a, 0)) - self.assertTrue(core.greater_equal_dim(a, 1)) - - self.assertTrue(core.greater_equal_shape((a, 2), (1, 1))) - - with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Symbolic dimension comparison .* is inconclusive"): - core.greater_equal_dim(a, 2) - - with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Symbolic dimension comparison .* is inconclusive"): - core.greater_equal_dim(a, b) - def test_poly_int_results(self): # Whenever the result is an integer, it should be represented as an # Python integer, not a symbolic dimension. @@ -377,28 +372,32 @@ def test_poly_divmod(self, *, dividend, quotient, divisor, remainder): else: self.assertEqual((quotient, remainder), divmod(dividend, divisor)) - def test_dilate_shape(self): + def test_dilate_dim(self): """0 if d == 0 else 1 + dilation * (d - 1))""" a, = shape_poly._parse_spec("a,", (2,)) - self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3))) - self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3))) - self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3))) - self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3))) + self.assertEqual(4, core.dilate_dim(2, 3)) + self.assertEqual(7, core.dilate_dim(3, 3)) + self.assertEqual(0, core.dilate_dim(0, 3)) + self.assertEqual(a, core.dilate_dim(a, 1)) + self.assertEqual(2 * a - 1, core.dilate_dim(a, 2)) - def test_stride_shape(self): - """(s - window_size) // window_stride + 1""" - a, stride = shape_poly._parse_spec("a, s", (2, 3)) + def test_stride_dim(self): + """(d - window_size) // window_stride + 1 - self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2))) - self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2))) + If d == 0 or window_size > d, returns 0. + """ + a, stride = shape_poly._parse_spec("a, s", (2, 3)) - self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2))) - self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2))) + self.assertEqual(8, core.stride_dim(10, window_size=3, window_stride=1)) + self.assertEqual(9, core.stride_dim(20, window_size=3, window_stride=2)) + self.assertEqual(9, core.stride_dim(20, window_size=4, window_stride=2)) + self.assertEqual(a, core.stride_dim(a, window_size=1, window_stride=1)) - (stride0, stride1) = core.stride_shape((a, 20), (1, 3), (2, 2)) - self.assertEqual("floordiv(a + -1, 2) + 1", str(stride0)) - self.assertEqual(9, stride1) + self.assertEqual(a - 1, core.stride_dim(a, window_size=2, window_stride=1)) + self.assertEqual(a + 1, core.stride_dim(a * stride + 2, window_size=2, + window_stride=stride)) + self.assertEqual((a - 1) // 2 + 1, core.stride_dim(a, 1, 2)) class PolyHarness(Harness): @@ -2902,6 +2901,17 @@ def test_harness(self, harness: PolyHarness): # https://github.com/openxla/stablehlo/issues/1255: need DynamicTopK raise unittest.SkipTest("native lowering with shape polymorphism not implemented for top_k") + # Some tests need the latest jaxlib + need_new_jaxlib = [] + if jaxlib_version < (0, 4, 13): + need_new_jaxlib.append("fft") + elif jaxlib_version < (0, 4, 14): + need_new_jaxlib.extend(("lu", "vmap_lu", "custom_linear_solve", + "vmap_custom_linear_solve", + "vmap_approx_top_k", "schur")) + if harness.group_name in need_new_jaxlib: + raise unittest.SkipTest("native lowering with shape polymorphism needs newer jaxlib") + if (jtu.device_under_test() in ["cpu", "gpu"] and harness.fullname in [ "cumsum_reduce_axis=poly", "cumprod_reduce_axis=poly", diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 6cb3d9ae4c86..0428a66dcbb4 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -21,7 +21,6 @@ ConstVar as ConstVar, DCERule as DCERule, DebugInfo as DebugInfo, - DimensionHandlerTracer as DimensionHandlerTracer, DynamicJaxprTrace as DynamicJaxprTrace, DynamicJaxprTracer as DynamicJaxprTracer, ForwardingRule as ForwardingRule, diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 71fbcaa91117..91e654b0d9c3 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -621,11 +621,6 @@ def test_flattening_basic(self): jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x) self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) - # do need divide, also shouldn't typecheck - _ = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], x.shape[0], -1), - abstracted_axes={0: 'n'})(x) # don't crash - - @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeAutodiffTest(jtu.JaxTestCase):