From 71ac0bb4463f8f3f155449b25cadbc8c7508d3d4 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 11 Jul 2023 14:03:52 +0100 Subject: [PATCH] [shape_poly] More cleanup for the internal APIs for shape polymorphism. Previously we had a number of APIs in core.py that operated on dimensions and shapes and delegated to instances of DimensionHandler. We remove most of those APIs because by now they ended up doing very little, e.g., `core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was the same as `tuple(map(operator.add))`. We also remove the whole `DimensionHandler` machinery because by now the only other use of non-constant dimensions using this mechanism are the symbolic dimensions used for shape polymorphism, and those support now full operator overloading. (When we introduced `DimensionHandler` we had the masking transformation around that needed it also.) --- jax/_src/core.py | 196 ++++-------------- jax/_src/interpreters/partial_eval.py | 39 ---- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/convolution.py | 11 +- jax/_src/lax/lax.py | 14 +- jax/_src/lax/slicing.py | 30 +-- jax/_src/lax/windowed_reductions.py | 5 +- jax/_src/numpy/lax_numpy.py | 6 +- jax/_src/numpy/reductions.py | 2 +- jax/experimental/jax2tf/shape_poly.py | 74 ++----- .../jax2tf/tests/shape_poly_test.py | 78 ++++--- jax/interpreters/partial_eval.py | 1 - tests/dynamic_api_test.py | 5 - 13 files changed, 130 insertions(+), 333 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index f113f7cc7e9c..5f16d223b5cf 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1241,10 +1241,10 @@ def definitely_equal(x, y): return same_referent(x, y) elif x is y: return True - else: - handler, ds = _dim_handler_and_canonical(x, y) - return handler.symbolic_equal(*ds) - + try: + return x == y + except InconclusiveDimensionOperation: + return False # -------------------- abstract values -------------------- @@ -1837,122 +1837,13 @@ class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" pass -class DimensionHandler: - """Operations on dimension sizes. - - Dimension sizes are normally integer constants, but can also be symbolic, - e.g., masking.Poly or jax2tf.shape_poly.DimVar. - - The base class works for integers only. Subclasses are invoked when at least - one of the operands has a type registered in _SPECIAL_DIMENSION_HANDLERS. In - that case, all operands are guaranteed to be either the special dimension - type, or Python integer scalars. - - Subclasses should raise InconclusiveDimensionOperation if the result cannot - be computed in some contexts. - """ - def is_constant(self, d: DimSize) -> bool: - """The dimension is a constant.""" - return True - - def symbolic_equal(self, d1: DimSize, d2: DimSize) -> bool: - """True iff the dimension sizes are equal in all contexts; False otherwise. - Unlike `d1 == d2` this never raises InconclusiveDimensionOperation. - """ - return d1 == d2 - - def greater_equal(self, d1: DimSize, d2: DimSize) -> bool: - """Computes `d1 >= d2`. - Raise InconclusiveDimensionOperation if the result is different in - different contexts. - """ - return d1 >= d2 - - def sum(self, *ds: DimSize) -> DimSize: - """Sum of dimensions. - Raises InconclusiveDimensionOperation if the result cannot be represented - by the same DimSize in all contexts. - """ - return sum(ds) - - def diff(self, d1: DimSize, d2: DimSize) -> DimSize: - """Difference of dimensions. - Raises InconclusiveDimensionOperation if the result cannot be represented - by the same DimSize in all contexts. - """ - return d1 - d2 - - def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: - """Computes integer "i" such that i * size(s2) == size(s1). - - Raise InconclusiveDimensionOperation if there is no such integer for all - contexts, - """ - sz1 = math.prod(s1) - sz2 = math.prod(s2) - if sz1 == 0 and sz2 == 0: - return 1 - if sz1 % sz2: - raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") - return sz1 // sz2 - - def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: - """(d - window_size) // window_stride + 1. - - If d == 0 or window_size > d, returns 0. - """ - if d == 0 or window_size > d: return 0 - return (d - window_size) // window_stride + 1 - - def dilate(self, d: DimSize, dilation: int) -> DimSize: - """Implements `d if dilation == 1 else (0 if d == 0 else 1 + dilation * (d - 1)))`""" - if definitely_equal(dilation, 1): - return d - return 0 if d == 0 else 1 + dilation * (d - 1) - - def as_value(self, d: DimSize): - """Turns a dimension size into a JAX value that we can compute with.""" - return d - -_dimension_handler_int = DimensionHandler() -_SPECIAL_DIMENSION_HANDLERS: dict[type, DimensionHandler] = {} -DArrayDimHandler = type('DArrayDimHandler', (DimensionHandler,), {})() +def is_symbolic_dim(v: Any) -> bool: + """Checks if a value is a symbolic dimension used for shape polymorphism. -def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]: - if isinstance(dim, Tracer) and not config.jax_dynamic_shapes: - return None - if isinstance(dim, DArray) and not dim.shape and type(dim.dtype) is bint: - return DArrayDimHandler - return _SPECIAL_DIMENSION_HANDLERS.get(type(dim)) - -def _dim_handler_and_canonical(*dlist: DimSize) -> tuple[DimensionHandler, tuple[DimSize, ...]]: - """Finds the handler for the given dimensions; also returns the canonical dimensions. - - A dimension is canonical if it is a Python integer scalar, or has a type - registered in _SPECIAL_DIMENSION_HANDLERS. + This should be used very rarely, because symbolic dimensions overload all + operators, and should just work. """ - special_handlers = set() - canonical = [] - for d in dlist: - handler = _get_special_dim_handler(d) - if handler: - special_handlers.add(handler) - canonical.append(d) - else: - try: - canonical.append(operator.index(d)) - except TypeError: - raise _invalid_shape_error(dlist) - - if len(special_handlers) > 1: - msg = (f"Dimension size operation involves multiple special dimension types {dlist}") - raise ValueError(msg) - return next(iter(special_handlers), _dimension_handler_int), tuple(canonical) - -def is_dynamic_dim(v: Any) -> bool: - """Checks if a value is a dynamic DimSize.""" - handler = _get_special_dim_handler(v) - return (handler is not None) + return hasattr(v, "dimension_as_value") def is_constant_dim(d: DimSize) -> bool: # Whether the dimension is a static integer constant. @@ -1963,7 +1854,7 @@ def is_constant_dim(d: DimSize) -> bool: return False def is_dim(v: Any) -> bool: - return is_dynamic_dim(v) or is_constant_dim(v) + return is_symbolic_dim(v) or is_constant_dim(v) def is_constant_shape(s: Shape) -> bool: # Whether the shape is a static constant. @@ -1981,39 +1872,17 @@ def definitely_equal_shape(s1: Shape, s2: Shape) -> bool: return (len(s1) == len(s2) and all(unsafe_map(definitely_equal, s1, s2))) -def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool: - handler, ds = _dim_handler_and_canonical(d1, d2) - return handler.symbolic_equal(*ds) or handler.greater_equal(*ds) - -def greater_equal_shape(s1: Shape, s2: Shape) -> bool: - return all(map(greater_equal_dim, s1, s2)) - -def sum_dim(*ds: DimSize) -> DimSize: - handler, ds = _dim_handler_and_canonical(*ds) - return handler.sum(*ds) - -def sum_shapes(*ss: Shape) -> Shape: - return tuple(map(sum_dim, *ss)) - -def diff_dim(d1: DimSize, d2: DimSize) -> DimSize: - handler, ds = _dim_handler_and_canonical(d1, d2) - return handler.diff(*ds) - -def diff_shape(s1: Shape, s2: Shape) -> Shape: - return tuple(map(diff_dim, s1, s2)) - def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize: """Returns an integer "i" s.t., i * size(s2) == size(s1). - Raises if there is no such integer.""" - s1 = s1 or (1,) - s2 = s2 or (1,) - handler, ds = _dim_handler_and_canonical(*s1, *s2) - return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):]) - -def same_shape_sizes(s1: Shape, s2: Shape) -> bool: - maybe_result = cancel_divide_tracers(s1, s2) - if maybe_result is not None: return maybe_result == 1 - return 1 == divide_shape_sizes(s1, s2) + Raises InconclusiveDimensionOperation if there is no such integer.""" + sz1 = math.prod(s1) + sz2 = math.prod(s2) + if definitely_equal(sz1, sz2): # Takes care of sz1 and sz2 being 0 + return 1 + q, r = divmod(sz1, sz2) + if isinstance(r, Tracer) or r != 0: + raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") + return q def cancel_divide_tracers(num, denom): partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l) @@ -2041,23 +1910,28 @@ def is_empty_shape(s: Shape) -> bool: return any(definitely_equal(d, 0) for d in s) def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize: - """Implements `0 if d == 0 else 1 + dilation * (d - 1))`""" - handler, ds = _dim_handler_and_canonical(d, dilation) - return handler.dilate(*ds) + """1 + dilation * (d - 1). -def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape: - return tuple(map(dilate_dim, s, dilations)) + if d == 0, returns 0. + """ + if definitely_equal(dilation, 1): # fast path + return d + return 0 if d == 0 else 1 + dilation * (d - 1) def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: - handler, ds = _dim_handler_and_canonical(d, window_size, window_stride) - return handler.stride(*ds) + """(d - window_size) // window_stride + 1 -def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape: - """(s - window_size) // window_stride + 1""" - return tuple(map(stride_dim, s, window_size, window_stride)) + If d < window_size, returns 0. + We assume window_size >= 1 and window_stride >= 1. + """ + if is_constant_dim(d) and is_constant_dim(window_size): + # TODO(necula): Enable this check for non-constant dimensions + if d < window_size: + return 0 + return (d - window_size) // window_stride + 1 def dimension_as_value(d: DimSize): - """Turns a dimension size into a JAX value that we can compute with. + """Turns a dimension size into a JAX array. This is the identity function for constant dimensions. Has the same abstract value as Python constants. diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index da64912254fe..84ef1db87439 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,6 @@ from collections import namedtuple from contextlib import contextmanager, AbstractContextManager -import functools from functools import partial import inspect import itertools as it @@ -2452,44 +2451,6 @@ def _substitute_vars_in_type( else: return a - -class DimensionHandlerTracer(core.DimensionHandler): - """See core.DimensionHandler. - - Most methods are inherited. - """ - def is_constant(self, d: core.DimSize) -> bool: - assert isinstance(d, Tracer) - return False - - def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool: - return d1 is d2 - - def greater_equal(self, d1: core.DimSize, d2: core.DimSize): - raise core.InconclusiveDimensionOperation("TODO") - - def divide_shape_sizes(self, s1: core.Shape, s2: core.Shape) -> core.DimSize: - """Computes integer "i" such that i * size(s2) == size(s1). - - Raise InconclusiveDimensionOperation if there is no such integer for all - contexts. - """ - s1_size = functools.reduce(op.mul, s1, 1) - s2_size = functools.reduce(op.mul, s2, 1) - q, r = divmod(s1_size, s2_size) - # TODO(necula): must check that r == 0! - return q - - def stride(self, d: core.DimSize, window_size: core.DimSize, window_stride: core.DimSize) -> core.DimSize: - """Implements `(d - window_size) // window_stride + 1`""" - raise core.InconclusiveDimensionOperation("TODO") - - def as_value(self, d: core.DimSize): - """Turns a dimension size into a Jax value that we can compute with.""" - raise core.InconclusiveDimensionOperation("TODO") - -core._SPECIAL_DIMENSION_HANDLERS[DynamicJaxprTracer] = DimensionHandlerTracer() - Const = Any Val = Any diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ef81a4a86d30..d23d2fe20749 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -984,7 +984,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, type(linear) is tuple and all(type(x) is bool for x in linear)) tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) - tc(length, 'length', 'non-negative int', core.greater_equal_dim(length, 0)) + tc(length, 'length', 'non-negative int', length >= 0) if len(linear) != len(avals): raise core.JaxprTypeError( diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index e13cbd4fc954..6811b1e8587b 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -797,9 +797,7 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): if np.any(lhs_padded < 0): raise ValueError("Negative padding is larger than the size of the corresponding dimension: " f"got padding={pads} for lhs_shape[2:]={lhs_shape[2:]}") - out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides) - out_space = [d if core.greater_equal_dim(d, 0) else 0 - for d in out_space] + out_space = tuple(map(core.stride_dim, lhs_padded, rhs_shape[2:], strides)) if batch_group_count > 1: assert lhs_shape[0] % batch_group_count == 0 out_shape_0 = lhs_shape[0] // batch_group_count @@ -930,8 +928,7 @@ def _conv_general_vjp_rhs_padding( rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation) out_dilated_shape = lax._dilate_shape(out_shape, window_strides) pads_lo, _ = util.unzip2(padding) - pads_from_lhs = core.diff_shape(out_dilated_shape, lhs_dilated_shape) - pads_from_rhs = core.diff_shape(core.diff_shape(rhs_dilated_shape, pads_lo), - (1,) * len(pads_lo)) - pads_hi = core.sum_shapes(pads_from_lhs, pads_from_rhs) + pads_from_lhs = map(operator.sub, out_dilated_shape, lhs_dilated_shape) + pads_from_rhs = tuple(rd - pd - 1 for rd, pd in zip(rhs_dilated_shape, pads_lo)) + pads_hi = tuple(map(operator.add, pads_from_lhs, pads_from_rhs)) return list(zip(pads_lo, pads_hi)) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c2bc7f299dc6..9fa35fc349b5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3169,9 +3169,9 @@ def _pad_shape_rule(operand, padding_value, *, padding_config): if not all(i >= 0 for _, _, i in padding_config): raise ValueError("interior padding in padding_config must be nonnegative, " f"got padding_config {padding_config}") - result = tuple(core.sum_dim(l, h, core.dilate_dim(d, i + 1)) + result = tuple(l + h + core.dilate_dim(d, i + 1) for (l, h, i), d in zip(padding_config, op_shape)) - if not all(core.greater_equal_dim(d, 0) for d in result): + if not all(d >= 0 for d in result): msg = (f"Dimension size after padding is not at least 0, " f"got result shape {result}, for padding_config {padding_config}" f" and operand shape {op_shape}") @@ -3298,12 +3298,12 @@ def shape_as_value(shape: core.Shape): return concatenate(dims, dimension=0) def _reshape_shape_rule(operand, *, new_sizes, dimensions): - if not all(core.greater_equal_dim(d, 0) for d in new_sizes): + if not all(d >= 0 for d in new_sizes): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) # TODO(necula): re-enable this check if (not config.jax_dynamic_shapes and - not core.same_shape_sizes(np.shape(operand), new_sizes)): + not math.prod(np.shape(operand)) == math.prod(new_sizes)): msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' raise TypeError(msg.format(new_sizes, np.shape(operand))) if dimensions is not None: @@ -3822,7 +3822,7 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): axis, = axes if not (0 <= axis < len(operand.shape)): raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}") - if not core.greater_equal_dim(operand.shape[axis], 1): + if operand.shape[axis] < 1: raise ValueError("argmin and argmax require non-empty reduced dimension. " f"operand.shape={operand.shape} {axis=}") return tuple(np.delete(operand.shape, axis)) @@ -4661,7 +4661,7 @@ def _dilate_shape(shape, dilation): msg = "All dilations must be positive, got {}." raise TypeError(msg.format(dilation)) dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation) - return core.dilate_shape(shape, dilation) + return tuple(map(core.dilate_dim, shape, dilation)) def _ceil_divide(x1, x2): return -np.floor_divide(np.negative(x1), x2) @@ -4803,7 +4803,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False): raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err lower_bound, bound_error = ( (1, "strictly positive") if non_zero_shape else (0, "nonnegative")) - if not all(core.greater_equal_dim(d, lower_bound) for d in obj_arr): + if not all(d >= lower_bound for d in obj_arr): msg = "{} {} must have every element be {}, got {}." raise TypeError(msg.format(fun_name, arg_name, bound_error, obj)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6e019853c7f1..01ae535a9aad 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +import operator from functools import partial import math from typing import Callable, NamedTuple, Optional, Sequence, Union @@ -1081,34 +1082,33 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): msg = ("slice limit_indices must have the same length as start_indices, " "got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) - if not core.greater_equal_shape(operand.shape, limit_indices): + if not all(map(operator.ge, operand.shape, limit_indices)): msg = ("slice limit_indices must be less than or equal to operand shape, " "got limit_indices {} for operand shape {}.") raise TypeError(msg.format(limit_indices, operand.shape)) - if not all(core.greater_equal_dim(si, 0) for si in start_indices): + if not all(si >= 0 for si in start_indices): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) if not jax.config.jax_dynamic_shapes: - if not core.greater_equal_shape(limit_indices, start_indices): + if not all(map(operator.ge, limit_indices, start_indices)): msg = ("slice limit_indices must be greater than or equal to start_indices," " got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) + diff = tuple(map(operator.sub, limit_indices, start_indices)) if strides is None or tuple(strides) == (1,) * len(operand.shape): - shape = [limit if type(start) is int and start == 0 else limit - start - for start, limit in zip(start_indices, limit_indices)] - return tuple(shape) + return diff lax._check_shapelike("slice", "strides", strides) if len(strides) != operand.ndim: msg = ("slice strides must have length equal to the number of dimensions " "of the operand, got strides {} for operand shape {}.") raise TypeError(msg.format(strides, operand.shape)) - if not core.greater_equal_shape(strides, (0,) * len(strides)): + if not all(s >= 0 for s in strides): msg = "slice strides must be positive, got {}" raise TypeError(msg.format(strides)) - diff = core.diff_shape(limit_indices, start_indices) - return core.stride_shape(diff, (1,) * len(diff), strides) + return tuple(core.stride_dim(d, window_size=1, window_stride=s) + for d, s in zip(diff, strides)) def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) @@ -1172,11 +1172,11 @@ def _dynamic_slice_shape_rule( msg = ("dynamic_slice slice_sizes must have the same length as " "start_indices, got start_indices length {} and slice_sizes {}.") raise TypeError(msg.format(len(start_indices), slice_sizes)) - if not dyn and not core.greater_equal_shape(operand.shape, slice_sizes): + if not dyn and not all(map(operator.ge, operand.shape, slice_sizes)): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) - if not dyn and not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes): + if not dyn and not all(ssz >= 0 for ssz in slice_sizes): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) @@ -1321,7 +1321,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): msg = ("dynamic_update_slice start_indices must have length equal to the " "rank of operand, got indices {} for operand shape {}.") raise TypeError(msg.format(start_indices, operand.shape)) - if not core.greater_equal_shape(operand.shape, update.shape): + if not all(map(operator.ge, operand.shape, update.shape)): msg = ("dynamic_update_slice update shape must be smaller than operand " "shape, got update shape {} for operand shape {}.") raise TypeError(msg.format(update.shape, operand.shape)) @@ -1522,8 +1522,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_size = slice_sizes[i] corresponding_input_size = operand.shape[i] - if not (core.greater_equal_dim(slice_size, 0) and - core.greater_equal_dim(corresponding_input_size, slice_size)): + if not (slice_size >= 0 and + corresponding_input_size >= slice_size): raise TypeError(f"Slice size at index {i} in gather op is out of range, " f"must be within [0, {corresponding_input_size} + 1), " f"got {slice_size}.") @@ -1917,7 +1917,7 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] - if not core.greater_equal_dim(max_update_slice_sizes[i], updates.shape[update_window_dim]): + if max_update_slice_sizes[i] < updates.shape[update_window_dim]: raise TypeError(f"Bounds of the window dimensions of updates must not " f"exceed the bounds of the corresponding dimensions of " f"operand. For dimension {update_window_dim}, updates " diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 96b5d2ea34e0..f0bf3246c8d1 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -438,9 +438,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, operand_shape = lax._dilate_shape(operand_shape, base_dilation) if window_dilation is not None: window_dimensions = lax._dilate_shape(window_dimensions, window_dilation) - pads_lo, pads_hi = util.unzip2(padding) - operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi) - return core.stride_shape(operand_padded, window_dimensions, window_strides) + operand_padded = tuple(d + pl + ph for d, (pl, ph) in zip(operand_shape, padding)) + return tuple(map(core.stride_dim, operand_padded, window_dimensions, window_strides)) reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b5ad9b16331..2887a4749261 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2311,7 +2311,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None, for name, val in [(start_name, start), ("stop", stop), ("step", step)]: if val is not None and np.ndim(val) != 0: raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}") - if any(core.is_dynamic_dim(v) for v in (start, stop, step)): + if any(core.is_symbolic_dim(v) for v in (start, stop, step)): # Some dynamic shapes if stop is None and step is None: stop = start @@ -2628,7 +2628,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *, axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()") assert isinstance(axis, int) # to appease mypy - if core.is_dynamic_dim(repeats): + if core.is_symbolic_dim(repeats): if total_repeat_length is not None: raise ValueError("jnp.repeat with a non-constant `repeats` is supported only " f"when `total_repeat_length` is None. ({repeats=} {total_repeat_length=})") @@ -4389,7 +4389,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], if start is None or core.definitely_equal(start, 0): start = None if stop is None or (not isinstance(stop, core.Tracer) and - core.greater_equal_dim(stop, x_shape[x_axis])): + stop >= x_shape[x_axis]): stop = None elif core.definitely_equal(step, -1): step = -1 diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 39a3d4b94448..7a31ac22d477 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -94,7 +94,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: if initial is None and not has_identity: shape = np.shape(a) - if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims): + if not _all(shape[d] >= 1 for d in pos_dims): raise ValueError(f"zero-size array to reduction operation {name} which has no identity") result_dtype = dtype or dtypes.dtype(a) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index 9b0b17608886..38dbb6a0f78e 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -440,15 +440,18 @@ def eq(self, other: DimSize) -> bool: # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported return False + def inconclusive_comparison(self, operation: str, op: Any) -> Exception: + return InconclusiveDimensionOperation( + f"Symbolic dimension comparison '{self}' {operation} '{op}' is inconclusive.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic0dimensions-is-partially-supported.") + def ge(self, other: DimSize) -> bool: lb, ub = _ensure_poly(self - other, "ge").bounds() if lb >= 0: return True if ub < 0: return False - raise InconclusiveDimensionOperation( - f"Symbolic dimension comparison '{self}' >= '{other}' is inconclusive.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic0dimensions-is-partially-supported.") + raise self.inconclusive_comparison(">=", other) def __hash__(self): return hash(tuple(sorted(self.monomials()))) @@ -571,13 +574,22 @@ def __ne__(self, other: DimSize) -> bool: __ge__ = ge def __le__(self, other: DimSize): - return _ensure_poly(other, "le").__ge__(self) + try: + return _ensure_poly(other, "le").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<=", other) from e def __gt__(self, other: DimSize): - return not _ensure_poly(other, "gt").__ge__(self) + try: + return not _ensure_poly(other, "gt").__ge__(self) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison(">", other) from e def __lt__(self, other: DimSize): - return not self.__ge__(other) + try: + return not self.__ge__(other) + except InconclusiveDimensionOperation as e: + raise self.inconclusive_comparison("<", other) from e def divmod(self, divisor: "_DimExpr") -> tuple[DimSize, int]: """ @@ -735,56 +747,6 @@ def _convertible_to_poly(p: DimSize) -> bool: def is_poly_dim(p: DimSize) -> bool: return isinstance(p, _DimExpr) - -class DimensionHandlerPoly(core.DimensionHandler): - """See core.DimensionHandler. - - Most methods are inherited. - """ - def is_constant(self, d: DimSize) -> bool: - assert isinstance(d, _DimExpr) - return False - - def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool: - try: - return _ensure_poly(d1, "equal") == d2 - except InconclusiveDimensionOperation: - return False - - def greater_equal(self, d1: DimSize, d2: DimSize): - return _ensure_poly(d1, "ge") >= d2 - - def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: - sz1 = math.prod(s1) - sz2 = math.prod(s2) - if core.definitely_equal(sz1, sz2): # Takes care also of sz1 == sz2 == 0 - return 1 - err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}" - try: - q, r = _ensure_poly(sz1, "divide_shape").divmod(_ensure_poly(sz2, "divide_shape")) - except InconclusiveDimensionOperation as e: - raise InconclusiveDimensionOperation(err_msg + f"\nDetails: {e}") - if not core.definitely_equal(r, 0): - raise InconclusiveDimensionOperation(err_msg + f"\nRemainder is not zero: {r}") - return q # type: ignore[return-value] - - def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: - """Implements `(d - window_size) // window_stride + 1`""" - try: - # TODO(necula): check for d == 0 or window_size > d and return 0. - q, r = _ensure_poly(d - window_size, "stride").divmod(_ensure_poly(window_stride, "stride")) - return q + 1 - except InconclusiveDimensionOperation as e: - raise InconclusiveDimensionOperation( - f"Cannot compute stride for dimension '{d}', " - f"window_size '{window_size}', stride '{window_stride}'.\nDetails: {e}.") - return d - - def as_value(self, d: DimSize): - """Turns a dimension size into a Jax value that we can compute with.""" - return _dim_as_value(d) - -core._SPECIAL_DIMENSION_HANDLERS[_DimExpr] = DimensionHandlerPoly() dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index e5e0fb3e0ad3..2c0c68d82c66 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -40,6 +40,7 @@ from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client import numpy as np @@ -178,9 +179,7 @@ def test_dim_vars_symbolic_equal(self): self.assertFalse(core.definitely_equal_one_of_dim(3, [])) self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # A DeviceArray - with self.assertRaisesRegex(TypeError, - re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")): - self.assertTrue(core.definitely_equal(1, "a")) + self.assertFalse(core.definitely_equal(1, "a")) def test_poly_bounds(self): a, b = shape_poly._parse_spec("a, b", (2, 3)) @@ -312,12 +311,24 @@ def test_poly_compare(self): def test_poly_compare_overload(self): a, b = shape_poly._parse_spec("a, b", (2, 3)) + self.assertTrue(a >= a) + self.assertTrue(a >= 0) + self.assertTrue(a >= 1) + + with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): + a >= 2 + poly = 4 * a + b + 3 self.assertTrue(poly >= 0) self.assertTrue(poly >= 8) self.assertTrue(poly > 7) self.assertTrue(poly >= poly) self.assertTrue(poly >= poly - 1) + # LHS is an integer + self.assertTrue(8 <= poly) + self.assertTrue(7 < poly) + self.assertTrue(-8 >= -poly) + self.assertTrue(-7 > -poly) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): poly >= 9 @@ -325,22 +336,6 @@ def test_poly_compare_overload(self): with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): (4 * a - b) >= 0 - def test_core_greater_equal(self): - a, b = shape_poly._parse_spec("a, b", (2, 3)) - self.assertTrue(core.greater_equal_dim(a, a)) - self.assertTrue(core.greater_equal_dim(a, 0)) - self.assertTrue(core.greater_equal_dim(a, 1)) - - self.assertTrue(core.greater_equal_shape((a, 2), (1, 1))) - - with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Symbolic dimension comparison .* is inconclusive"): - core.greater_equal_dim(a, 2) - - with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Symbolic dimension comparison .* is inconclusive"): - core.greater_equal_dim(a, b) - def test_poly_int_results(self): # Whenever the result is an integer, it should be represented as an # Python integer, not a symbolic dimension. @@ -377,28 +372,32 @@ def test_poly_divmod(self, *, dividend, quotient, divisor, remainder): else: self.assertEqual((quotient, remainder), divmod(dividend, divisor)) - def test_dilate_shape(self): + def test_dilate_dim(self): """0 if d == 0 else 1 + dilation * (d - 1))""" a, = shape_poly._parse_spec("a,", (2,)) - self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3))) - self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3))) - self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3))) - self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3))) + self.assertEqual(4, core.dilate_dim(2, 3)) + self.assertEqual(7, core.dilate_dim(3, 3)) + self.assertEqual(0, core.dilate_dim(0, 3)) + self.assertEqual(a, core.dilate_dim(a, 1)) + self.assertEqual(2 * a - 1, core.dilate_dim(a, 2)) - def test_stride_shape(self): - """(s - window_size) // window_stride + 1""" - a, stride = shape_poly._parse_spec("a, s", (2, 3)) + def test_stride_dim(self): + """(d - window_size) // window_stride + 1 - self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2))) - self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2))) + If d == 0 or window_size > d, returns 0. + """ + a, stride = shape_poly._parse_spec("a, s", (2, 3)) - self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2))) - self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2))) + self.assertEqual(8, core.stride_dim(10, window_size=3, window_stride=1)) + self.assertEqual(9, core.stride_dim(20, window_size=3, window_stride=2)) + self.assertEqual(9, core.stride_dim(20, window_size=4, window_stride=2)) + self.assertEqual(a, core.stride_dim(a, window_size=1, window_stride=1)) - (stride0, stride1) = core.stride_shape((a, 20), (1, 3), (2, 2)) - self.assertEqual("floordiv(a + -1, 2) + 1", str(stride0)) - self.assertEqual(9, stride1) + self.assertEqual(a - 1, core.stride_dim(a, window_size=2, window_stride=1)) + self.assertEqual(a + 1, core.stride_dim(a * stride + 2, window_size=2, + window_stride=stride)) + self.assertEqual((a - 1) // 2 + 1, core.stride_dim(a, 1, 2)) class PolyHarness(Harness): @@ -2902,6 +2901,17 @@ def test_harness(self, harness: PolyHarness): # https://github.com/openxla/stablehlo/issues/1255: need DynamicTopK raise unittest.SkipTest("native lowering with shape polymorphism not implemented for top_k") + # Some tests need the latest jaxlib + need_new_jaxlib = [] + if jaxlib_version < (0, 4, 13): + need_new_jaxlib.append("fft") + elif jaxlib_version < (0, 4, 14): + need_new_jaxlib.extend(("lu", "vmap_lu", "custom_linear_solve", + "vmap_custom_linear_solve", + "vmap_approx_top_k", "schur")) + if harness.group_name in need_new_jaxlib: + raise unittest.SkipTest("native lowering with shape polymorphism needs newer jaxlib") + if (jtu.device_under_test() in ["cpu", "gpu"] and harness.fullname in [ "cumsum_reduce_axis=poly", "cumprod_reduce_axis=poly", diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 6cb3d9ae4c86..0428a66dcbb4 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -21,7 +21,6 @@ ConstVar as ConstVar, DCERule as DCERule, DebugInfo as DebugInfo, - DimensionHandlerTracer as DimensionHandlerTracer, DynamicJaxprTrace as DynamicJaxprTrace, DynamicJaxprTracer as DynamicJaxprTracer, ForwardingRule as ForwardingRule, diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 71fbcaa91117..91e654b0d9c3 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -621,11 +621,6 @@ def test_flattening_basic(self): jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x) self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) - # do need divide, also shouldn't typecheck - _ = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], x.shape[0], -1), - abstracted_axes={0: 'n'})(x) # don't crash - - @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeAutodiffTest(jtu.JaxTestCase):