Skip to content

Commit

Permalink
[jax2tf] Change the InconclusiveDimensionOperation error to include l…
Browse files Browse the repository at this point in the history
…ink to documentation
  • Loading branch information
gnecula committed Jun 13, 2021
1 parent 5e3be94 commit 07cc581
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 30 deletions.
35 changes: 21 additions & 14 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ A few examples of shape specifications and uses:

### Computing with dimension variables

JAX keeps track of the shape of all intermediate results. When those shapes contain
dimension variables JAX computes intermediate shapes as multi-variate polynomials
JAX keeps track of the shape of all intermediate results. When those shapes depend
on dimension variables JAX computes them as multi-variate polynomials
involving dimension variables, which are assumed to range over strictly positive
integers.
The dimension polynomials have the following behavior for arithmetic operations:
Expand All @@ -325,6 +325,22 @@ The dimension polynomials have the following behavior for arithmetic operations:
integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`,
`a >= b`, `a - b >= 0` are inconclusive and result in an exception.

For example, the following code raises the exception
`core.InconclusiveDimensionOperation` with the message
`Dimension polynomial comparison 'a + 1' == 'b' is inconclusive`.

```
jax2tf.convert(lambda x: 0 if x.shape[0] + 1 == x.shape[1] else 1,
polymorphic_shapes=["(a, b)"])(np.ones((3, 4))
```

Note that it would be unsound for JAX to compute `x.shape[0] + 1 == x.shape[1]`
as `False` and produce a converted function that returns `1` just because the dimension polynomials
are not identical: there are some concrete input shapes for which the function
should return `0`.

### Dimension variables appearing in the numeric computation

There are some situations when dimension variables arise in the staged computation itself.
You can see in the following example how elements from the input shapes
`(1024, 28, 28)` and `(28, 28)` appear in the computation and specifically
Expand Down Expand Up @@ -369,6 +385,9 @@ using `tf.shape` on the input parameters.

### Errors in presence of shape polymorphism

In addition to the `InconclusiveDimensionOperation` error discussed above,
one may encounter other kinds of errors.

When tracing with shape polymorphism we can encounter shape errors:

```
Expand Down Expand Up @@ -420,18 +439,6 @@ jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
```

If the user code happens to perform computations directly on dimension polynomials,
it can expect it to work as described above for addition, subtraction, and multiplication,
and partially for comparisons.

```
jax2tf.convert(lambda x: 0 if x.shape[0] + 1 == x.shape[1] else 1,
polymorphic_shapes=["(a, b)"])(np.ones((3, 4))
```

will raise the exception `core.InconclusiveDimensionOperation` with the message
`Dimension polynomial comparison 'a + 1' == 'b' is inconclusive`.

Finally, certain codes that use shapes in the actual computation may not yet work
if those shapes are polymorphic. In the code below, the expression `x.shape[0]`
will have the value of the shape variable `v`. This case is not yet implemented:
Expand Down
51 changes: 35 additions & 16 deletions jax/experimental/jax2tf/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@
DimSize = core.DimSize
Shape = core.Shape


class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
"""Raised when we cannot conclusively compute with symbolic dimensions."""

_help_msg = """
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
Please see https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
"""

def __init__(self, message: str):
error_msg = f"{message}\n{InconclusiveDimensionOperation._help_msg}"
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore


class _DimMon(dict):
"""Represents a multivariate monomial, such as n^3 * m.
Expand Down Expand Up @@ -87,15 +106,15 @@ def mul(self, other: '_DimMon') -> '_DimMon':

def divide(self, divisor: '_DimMon') -> '_DimMon':
"""
Divides by another monomial. Raises a core.InconclusiveDimensionOperation
Divides by another monomial. Raises a InconclusiveDimensionOperation
if the result is not a monomial.
For example, (n^3 * m) // n == n^2*m, but n // m fails.
"""
d = collections.Counter(self)
for key, exponent in divisor.items():
diff = self.get(key, 0) - exponent
if diff < 0:
raise core.InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.")
raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.")
elif diff == 0: del d[key]
elif diff > 0: d[key] = diff
return _DimMon(d)
Expand All @@ -107,7 +126,7 @@ class _DimPolynomial(dict):
The shape variables are assumed to range over integers >= 1.
We overload integer operations, but we do that soundly, raising
:class:`core.InconclusiveDimensionOperation` when the result is not
:class:`InconclusiveDimensionOperation` when the result is not
representable as a polynomial.
The representation of a polynomial is as a dictionary mapping _DimMonomial to
Expand Down Expand Up @@ -209,7 +228,7 @@ def eq(self, other: DimSize) -> bool:
return False
if ub is not None and ub < 0:
return False
raise core.InconclusiveDimensionOperation(f"Dimension polynomial comparison '{self}' == '{other}' is inconclusive")
raise InconclusiveDimensionOperation(f"Dimension polynomial comparison '{self}' == '{other}' is inconclusive")

# We must overload __eq__ and __ne__, or else we get unsound defaults.
__eq__ = eq
Expand All @@ -222,7 +241,7 @@ def ge(self, other: DimSize) -> bool:
return True
if ub is not None and ub < 0:
return False
raise core.InconclusiveDimensionOperation(f"Dimension polynomial comparison '{self}' >= '{other}' is inconclusive")
raise InconclusiveDimensionOperation(f"Dimension polynomial comparison '{self}' >= '{other}' is inconclusive")
__ge__ = ge

def __le__(self, other: DimSize):
Expand Down Expand Up @@ -253,11 +272,11 @@ def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]:
mon, count = dividend.leading_term
try:
qmon = mon.divide(dmon)
except core.InconclusiveDimensionOperation:
raise core.InconclusiveDimensionOperation(err_msg)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation(err_msg)
qcount, rcount = divmod(count, dcount)
if rcount != 0:
raise core.InconclusiveDimensionOperation(err_msg)
raise InconclusiveDimensionOperation(err_msg)

q = _DimPolynomial.from_coeffs({qmon: qcount})
quotient += q
Expand All @@ -270,7 +289,7 @@ def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]:
remainder = r
else:
if dividend != 0:
raise core.InconclusiveDimensionOperation(err_msg)
raise InconclusiveDimensionOperation(err_msg)
remainder = 0

if config.jax_enable_checks:
Expand All @@ -295,7 +314,7 @@ def __int__(self):
if self.is_constant:
return op.index(next(iter(self.values())))
else:
raise core.InconclusiveDimensionOperation(f"Dimension polynomial '{self}' is not constant")
raise InconclusiveDimensionOperation(f"Dimension polynomial '{self}' is not constant")

def bounds(self) -> Tuple[Optional[int], Optional[int]]:
"""Returns the lower and upper bounds, if defined."""
Expand Down Expand Up @@ -353,7 +372,7 @@ def is_constant(self, d: DimSize) -> bool:
def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool:
try:
return _ensure_poly(d1) == d2
except core.InconclusiveDimensionOperation:
except InconclusiveDimensionOperation:
return False

def greater_equal(self, d1: DimSize, d2: DimSize):
Expand All @@ -367,19 +386,19 @@ def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"
try:
q, r = _ensure_poly(sz1).divmod(sz2)
except core.InconclusiveDimensionOperation:
raise core.InconclusiveDimensionOperation(err_msg)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation(err_msg)
if r != 0:
raise core.InconclusiveDimensionOperation(err_msg)
raise InconclusiveDimensionOperation(err_msg)
return q # type: ignore[return-value]

def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""Implements `(d - window_size) // window_stride + 1`"""
try:
q, r = _ensure_poly(d - window_size).divmod(window_stride)
return q + 1
except core.InconclusiveDimensionOperation as e:
raise core.InconclusiveDimensionOperation(
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"Cannot compute stride for dimension '{d}', "
f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}.")
return d
Expand Down

0 comments on commit 07cc581

Please sign in to comment.