Skip to content

Commit

Permalink
[shape_poly] More cleanup for the internal APIs for shape polymorphism.
Browse files Browse the repository at this point in the history
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
  • Loading branch information
gnecula committed Jul 13, 2023
1 parent 58d6c4c commit 71ac0bb
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 333 deletions.
196 changes: 35 additions & 161 deletions jax/_src/core.py
Expand Up @@ -1241,10 +1241,10 @@ def definitely_equal(x, y):
return same_referent(x, y)
elif x is y:
return True
else:
handler, ds = _dim_handler_and_canonical(x, y)
return handler.symbolic_equal(*ds)

try:
return x == y
except InconclusiveDimensionOperation:
return False

# -------------------- abstract values --------------------

Expand Down Expand Up @@ -1837,122 +1837,13 @@ class InconclusiveDimensionOperation(Exception):
"""Raised when we cannot conclusively compute with symbolic dimensions."""
pass

class DimensionHandler:
"""Operations on dimension sizes.
Dimension sizes are normally integer constants, but can also be symbolic,
e.g., masking.Poly or jax2tf.shape_poly.DimVar.
The base class works for integers only. Subclasses are invoked when at least
one of the operands has a type registered in _SPECIAL_DIMENSION_HANDLERS. In
that case, all operands are guaranteed to be either the special dimension
type, or Python integer scalars.
Subclasses should raise InconclusiveDimensionOperation if the result cannot
be computed in some contexts.
"""
def is_constant(self, d: DimSize) -> bool:
"""The dimension is a constant."""
return True

def symbolic_equal(self, d1: DimSize, d2: DimSize) -> bool:
"""True iff the dimension sizes are equal in all contexts; False otherwise.
Unlike `d1 == d2` this never raises InconclusiveDimensionOperation.
"""
return d1 == d2

def greater_equal(self, d1: DimSize, d2: DimSize) -> bool:
"""Computes `d1 >= d2`.
Raise InconclusiveDimensionOperation if the result is different in
different contexts.
"""
return d1 >= d2

def sum(self, *ds: DimSize) -> DimSize:
"""Sum of dimensions.
Raises InconclusiveDimensionOperation if the result cannot be represented
by the same DimSize in all contexts.
"""
return sum(ds)

def diff(self, d1: DimSize, d2: DimSize) -> DimSize:
"""Difference of dimensions.
Raises InconclusiveDimensionOperation if the result cannot be represented
by the same DimSize in all contexts.
"""
return d1 - d2

def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
"""Computes integer "i" such that i * size(s2) == size(s1).
Raise InconclusiveDimensionOperation if there is no such integer for all
contexts,
"""
sz1 = math.prod(s1)
sz2 = math.prod(s2)
if sz1 == 0 and sz2 == 0:
return 1
if sz1 % sz2:
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
return sz1 // sz2

def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""(d - window_size) // window_stride + 1.
If d == 0 or window_size > d, returns 0.
"""
if d == 0 or window_size > d: return 0
return (d - window_size) // window_stride + 1

def dilate(self, d: DimSize, dilation: int) -> DimSize:
"""Implements `d if dilation == 1 else (0 if d == 0 else 1 + dilation * (d - 1)))`"""
if definitely_equal(dilation, 1):
return d
return 0 if d == 0 else 1 + dilation * (d - 1)

def as_value(self, d: DimSize):
"""Turns a dimension size into a JAX value that we can compute with."""
return d

_dimension_handler_int = DimensionHandler()
_SPECIAL_DIMENSION_HANDLERS: dict[type, DimensionHandler] = {}
DArrayDimHandler = type('DArrayDimHandler', (DimensionHandler,), {})()
def is_symbolic_dim(v: Any) -> bool:
"""Checks if a value is a symbolic dimension used for shape polymorphism.
def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
if isinstance(dim, Tracer) and not config.jax_dynamic_shapes:
return None
if isinstance(dim, DArray) and not dim.shape and type(dim.dtype) is bint:
return DArrayDimHandler
return _SPECIAL_DIMENSION_HANDLERS.get(type(dim))

def _dim_handler_and_canonical(*dlist: DimSize) -> tuple[DimensionHandler, tuple[DimSize, ...]]:
"""Finds the handler for the given dimensions; also returns the canonical dimensions.
A dimension is canonical if it is a Python integer scalar, or has a type
registered in _SPECIAL_DIMENSION_HANDLERS.
This should be used very rarely, because symbolic dimensions overload all
operators, and should just work.
"""
special_handlers = set()
canonical = []
for d in dlist:
handler = _get_special_dim_handler(d)
if handler:
special_handlers.add(handler)
canonical.append(d)
else:
try:
canonical.append(operator.index(d))
except TypeError:
raise _invalid_shape_error(dlist)

if len(special_handlers) > 1:
msg = (f"Dimension size operation involves multiple special dimension types {dlist}")
raise ValueError(msg)
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)

def is_dynamic_dim(v: Any) -> bool:
"""Checks if a value is a dynamic DimSize."""
handler = _get_special_dim_handler(v)
return (handler is not None)
return hasattr(v, "dimension_as_value")

def is_constant_dim(d: DimSize) -> bool:
# Whether the dimension is a static integer constant.
Expand All @@ -1963,7 +1854,7 @@ def is_constant_dim(d: DimSize) -> bool:
return False

def is_dim(v: Any) -> bool:
return is_dynamic_dim(v) or is_constant_dim(v)
return is_symbolic_dim(v) or is_constant_dim(v)

def is_constant_shape(s: Shape) -> bool:
# Whether the shape is a static constant.
Expand All @@ -1981,39 +1872,17 @@ def definitely_equal_shape(s1: Shape, s2: Shape) -> bool:
return (len(s1) == len(s2) and
all(unsafe_map(definitely_equal, s1, s2)))

def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.symbolic_equal(*ds) or handler.greater_equal(*ds)

def greater_equal_shape(s1: Shape, s2: Shape) -> bool:
return all(map(greater_equal_dim, s1, s2))

def sum_dim(*ds: DimSize) -> DimSize:
handler, ds = _dim_handler_and_canonical(*ds)
return handler.sum(*ds)

def sum_shapes(*ss: Shape) -> Shape:
return tuple(map(sum_dim, *ss))

def diff_dim(d1: DimSize, d2: DimSize) -> DimSize:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.diff(*ds)

def diff_shape(s1: Shape, s2: Shape) -> Shape:
return tuple(map(diff_dim, s1, s2))

def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
"""Returns an integer "i" s.t., i * size(s2) == size(s1).
Raises if there is no such integer."""
s1 = s1 or (1,)
s2 = s2 or (1,)
handler, ds = _dim_handler_and_canonical(*s1, *s2)
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])

def same_shape_sizes(s1: Shape, s2: Shape) -> bool:
maybe_result = cancel_divide_tracers(s1, s2)
if maybe_result is not None: return maybe_result == 1
return 1 == divide_shape_sizes(s1, s2)
Raises InconclusiveDimensionOperation if there is no such integer."""
sz1 = math.prod(s1)
sz2 = math.prod(s2)
if definitely_equal(sz1, sz2): # Takes care of sz1 and sz2 being 0
return 1
q, r = divmod(sz1, sz2)
if isinstance(r, Tracer) or r != 0:
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
return q

def cancel_divide_tracers(num, denom):
partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l)
Expand Down Expand Up @@ -2041,23 +1910,28 @@ def is_empty_shape(s: Shape) -> bool:
return any(definitely_equal(d, 0) for d in s)

def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
handler, ds = _dim_handler_and_canonical(d, dilation)
return handler.dilate(*ds)
"""1 + dilation * (d - 1).
def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape:
return tuple(map(dilate_dim, s, dilations))
if d == 0, returns 0.
"""
if definitely_equal(dilation, 1): # fast path
return d
return 0 if d == 0 else 1 + dilation * (d - 1)

def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
handler, ds = _dim_handler_and_canonical(d, window_size, window_stride)
return handler.stride(*ds)
"""(d - window_size) // window_stride + 1
def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
"""(s - window_size) // window_stride + 1"""
return tuple(map(stride_dim, s, window_size, window_stride))
If d < window_size, returns 0.
We assume window_size >= 1 and window_stride >= 1.
"""
if is_constant_dim(d) and is_constant_dim(window_size):
# TODO(necula): Enable this check for non-constant dimensions
if d < window_size:
return 0
return (d - window_size) // window_stride + 1

def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX value that we can compute with.
"""Turns a dimension size into a JAX array.
This is the identity function for constant dimensions.
Has the same abstract value as Python constants.
Expand Down
39 changes: 0 additions & 39 deletions jax/_src/interpreters/partial_eval.py
Expand Up @@ -15,7 +15,6 @@

from collections import namedtuple
from contextlib import contextmanager, AbstractContextManager
import functools
from functools import partial
import inspect
import itertools as it
Expand Down Expand Up @@ -2452,44 +2451,6 @@ def _substitute_vars_in_type(
else:
return a


class DimensionHandlerTracer(core.DimensionHandler):
"""See core.DimensionHandler.
Most methods are inherited.
"""
def is_constant(self, d: core.DimSize) -> bool:
assert isinstance(d, Tracer)
return False

def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool:
return d1 is d2

def greater_equal(self, d1: core.DimSize, d2: core.DimSize):
raise core.InconclusiveDimensionOperation("TODO")

def divide_shape_sizes(self, s1: core.Shape, s2: core.Shape) -> core.DimSize:
"""Computes integer "i" such that i * size(s2) == size(s1).
Raise InconclusiveDimensionOperation if there is no such integer for all
contexts.
"""
s1_size = functools.reduce(op.mul, s1, 1)
s2_size = functools.reduce(op.mul, s2, 1)
q, r = divmod(s1_size, s2_size)
# TODO(necula): must check that r == 0!
return q

def stride(self, d: core.DimSize, window_size: core.DimSize, window_stride: core.DimSize) -> core.DimSize:
"""Implements `(d - window_size) // window_stride + 1`"""
raise core.InconclusiveDimensionOperation("TODO")

def as_value(self, d: core.DimSize):
"""Turns a dimension size into a Jax value that we can compute with."""
raise core.InconclusiveDimensionOperation("TODO")

core._SPECIAL_DIMENSION_HANDLERS[DynamicJaxprTracer] = DimensionHandlerTracer()

Const = Any
Val = Any

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Expand Up @@ -984,7 +984,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
type(linear) is tuple and all(type(x) is bool for x in linear))
tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)

tc(length, 'length', 'non-negative int', core.greater_equal_dim(length, 0))
tc(length, 'length', 'non-negative int', length >= 0)

if len(linear) != len(avals):
raise core.JaxprTypeError(
Expand Down
11 changes: 4 additions & 7 deletions jax/_src/lax/convolution.py
Expand Up @@ -797,9 +797,7 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
if np.any(lhs_padded < 0):
raise ValueError("Negative padding is larger than the size of the corresponding dimension: "
f"got padding={pads} for lhs_shape[2:]={lhs_shape[2:]}")
out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides)
out_space = [d if core.greater_equal_dim(d, 0) else 0
for d in out_space]
out_space = tuple(map(core.stride_dim, lhs_padded, rhs_shape[2:], strides))
if batch_group_count > 1:
assert lhs_shape[0] % batch_group_count == 0
out_shape_0 = lhs_shape[0] // batch_group_count
Expand Down Expand Up @@ -930,8 +928,7 @@ def _conv_general_vjp_rhs_padding(
rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
pads_lo, _ = util.unzip2(padding)
pads_from_lhs = core.diff_shape(out_dilated_shape, lhs_dilated_shape)
pads_from_rhs = core.diff_shape(core.diff_shape(rhs_dilated_shape, pads_lo),
(1,) * len(pads_lo))
pads_hi = core.sum_shapes(pads_from_lhs, pads_from_rhs)
pads_from_lhs = map(operator.sub, out_dilated_shape, lhs_dilated_shape)
pads_from_rhs = tuple(rd - pd - 1 for rd, pd in zip(rhs_dilated_shape, pads_lo))
pads_hi = tuple(map(operator.add, pads_from_lhs, pads_from_rhs))
return list(zip(pads_lo, pads_hi))
14 changes: 7 additions & 7 deletions jax/_src/lax/lax.py
Expand Up @@ -3169,9 +3169,9 @@ def _pad_shape_rule(operand, padding_value, *, padding_config):
if not all(i >= 0 for _, _, i in padding_config):
raise ValueError("interior padding in padding_config must be nonnegative, "
f"got padding_config {padding_config}")
result = tuple(core.sum_dim(l, h, core.dilate_dim(d, i + 1))
result = tuple(l + h + core.dilate_dim(d, i + 1)
for (l, h, i), d in zip(padding_config, op_shape))
if not all(core.greater_equal_dim(d, 0) for d in result):
if not all(d >= 0 for d in result):
msg = (f"Dimension size after padding is not at least 0, "
f"got result shape {result}, for padding_config {padding_config}"
f" and operand shape {op_shape}")
Expand Down Expand Up @@ -3298,12 +3298,12 @@ def shape_as_value(shape: core.Shape):
return concatenate(dims, dimension=0)

def _reshape_shape_rule(operand, *, new_sizes, dimensions):
if not all(core.greater_equal_dim(d, 0) for d in new_sizes):
if not all(d >= 0 for d in new_sizes):
msg = 'reshape new_sizes must all be positive, got {}.'
raise TypeError(msg.format(new_sizes))
# TODO(necula): re-enable this check
if (not config.jax_dynamic_shapes and
not core.same_shape_sizes(np.shape(operand), new_sizes)):
not math.prod(np.shape(operand)) == math.prod(new_sizes)):
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
raise TypeError(msg.format(new_sizes, np.shape(operand)))
if dimensions is not None:
Expand Down Expand Up @@ -3822,7 +3822,7 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype):
axis, = axes
if not (0 <= axis < len(operand.shape)):
raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}")
if not core.greater_equal_dim(operand.shape[axis], 1):
if operand.shape[axis] < 1:
raise ValueError("argmin and argmax require non-empty reduced dimension. "
f"operand.shape={operand.shape} {axis=}")
return tuple(np.delete(operand.shape, axis))
Expand Down Expand Up @@ -4661,7 +4661,7 @@ def _dilate_shape(shape, dilation):
msg = "All dilations must be positive, got {}."
raise TypeError(msg.format(dilation))
dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
return core.dilate_shape(shape, dilation)
return tuple(map(core.dilate_dim, shape, dilation))

def _ceil_divide(x1, x2):
return -np.floor_divide(np.negative(x1), x2)
Expand Down Expand Up @@ -4803,7 +4803,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err
lower_bound, bound_error = (
(1, "strictly positive") if non_zero_shape else (0, "nonnegative"))
if not all(core.greater_equal_dim(d, lower_bound) for d in obj_arr):
if not all(d >= lower_bound for d in obj_arr):
msg = "{} {} must have every element be {}, got {}."
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))

Expand Down

0 comments on commit 71ac0bb

Please sign in to comment.