Skip to content

Commit

Permalink
[shape-poly] Improve the error reporting for division
Browse files Browse the repository at this point in the history
Added a section to README to explain the division errors
and to show a workaround. Changed the division errors
to include more detail as to what the error is,
and to include a link to the new section in the README
  • Loading branch information
gnecula committed Jul 15, 2022
1 parent a35f9ac commit e6f93bc
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 43 deletions.
44 changes: 33 additions & 11 deletions jax/experimental/jax2tf/README.md
Expand Up @@ -448,6 +448,7 @@ More operations are partially supported for dimension polynomials:
in which case there may be a constant remainder. The need for division in JAX core
arises in a couple of specific situations, e.g.,
`jax.numpy.reshape(-1)` and operations involving striding.
See [#division-of-shape-polynomials-is-partially-supported](below) for a discussion.
* equality and disequality are partially supported. They result in a boolean value only when
the same result would be obtained for any valuation of the dimension variables. In
other situations, an exception `core.InconclusiveDimensionOperation` is raised.
Expand Down Expand Up @@ -549,18 +550,34 @@ that `v == 4`, the shape checking rules fail with the above error.
Since the converted function works only for square matrices, the correct
`polymorphic_shapes` is `["(v, v)"]`.

You would also encounter shape errors if the code attempts to use the
dimension variables in unsupported arithmetic operations, such as in the code
below that fails to compute the inferred dimension for a `reshape` operations:

Certain codes that use shapes in the actual computation may not yet work
if those shapes are polymorphic. In the code below, the expression `x.shape[0]`
will have the value of the dimension variable `v`. This case is not yet implemented:

```
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
```

### Division of shape polynomials is partially supported

Unlike addition and multiplication, which are fully supported on
shape polynomials, division is supported when either (a) there
is no remainder, or (b) the divisor is a constant
in which case there may be a constant remainder.
For example, the code below results in a division error when trying to
compute the inferred dimension for a `reshape` operation:

```
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(b, ...)"])(np.ones((4, 5, 7)))
```

In this case you will see the error `Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)`.
This is because the shape of `x` is `(b, 5, 7)`, with a total size represented as the
dimension polynomial `35 b`, which is not divisible by `2`.
In this case you will see the error `Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)`,
with a further `Details: Cannot divide '35*b' by '-2'`.
The polynomial `35*b` represents the total size of the input tensor.

Note that the following will succeed:

```
Expand All @@ -573,13 +590,18 @@ jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
```

Finally, certain codes that use shapes in the actual computation may not yet work
if those shapes are polymorphic. In the code below, the expression `x.shape[0]`
will have the value of the dimension variable `v`. This case is not yet implemented:
You may also encounter division errors when working with strides, such as
when computing the padding in a strided convolution.

In some cases you may know that one of the dimension variables
is a multiple of the divisor,
e.g., `b` in the above example of dividing `35*b` by `-2` may
be known to be a multiple of `2`. You can specify that by replacing
`b` with `2*b` in the polymorphic shape specification:

```
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
```

## Known issues
Expand Down
32 changes: 21 additions & 11 deletions jax/experimental/jax2tf/shape_poly.py
Expand Up @@ -297,6 +297,13 @@ def __gt__(self, other: DimSize):
def __lt__(self, other: DimSize):
return not self.__ge__(other)

def _division_error_msg(self, dividend, divisor, details: str = "") -> str:
msg = f"Cannot divide '{dividend}' by '{divisor}'."
if details:
msg += f"\nDetails: {details}."
msg += "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported."
return msg

def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]:
"""
Floor division with remainder (divmod) generalized to polynomials.
Expand All @@ -309,18 +316,19 @@ def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]:
divisor = _ensure_poly(divisor)
dmon, dcount = divisor.leading_term
dividend, quotient = self, 0
err_msg = f"Dimension polynomial '{self}' is not a multiple of '{divisor}'"
# invariant: self = dividend + divisor * quotient
# the leading term of dividend decreases through the loop.
while is_poly_dim(dividend) and not dividend.is_constant:
mon, count = dividend.leading_term
try:
qmon = mon.divide(dmon)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation(err_msg)
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
self._division_error_msg(self, divisor, str(e)))
qcount, rcount = divmod(count, dcount)
if rcount != 0:
raise InconclusiveDimensionOperation(err_msg)
raise InconclusiveDimensionOperation(
self._division_error_msg(self, divisor))

q = _DimPolynomial.from_coeffs({qmon: qcount})
quotient += q
Expand All @@ -333,7 +341,8 @@ def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]:
remainder = r
else:
if dividend != 0:
raise InconclusiveDimensionOperation(err_msg)
raise InconclusiveDimensionOperation(
self._division_error_msg(self, divisor))
remainder = 0

if config.jax_enable_checks:
Expand All @@ -351,13 +360,14 @@ def __truediv__(self, divisor: DimSize):
q, r = self.divmod(divisor)
if r != 0:
raise InconclusiveDimensionOperation(
f"Dimension polynomial '{self}' is not a multiple of '{divisor}'")
self._division_error_msg(self, divisor,
f"Remainder is not zero: {r}"))
return q

def __rtruediv__(self, dividend: DimSize):
# Used for "/", when dividend is not a _DimPolynomial
raise InconclusiveDimensionOperation(
f"Division of '{dividend}' by dimension polynomial '{self}' is not supported")
self._division_error_msg(dividend, self, "Dividend must be a polynomial"))

def __mod__(self, divisor: DimSize) -> int:
return self.divmod(divisor)[1]
Expand Down Expand Up @@ -433,10 +443,10 @@ def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"
try:
q, r = _ensure_poly(sz1).divmod(sz2)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation(err_msg)
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(err_msg + f"\nDetails: {e}")
if r != 0:
raise InconclusiveDimensionOperation(err_msg)
raise InconclusiveDimensionOperation(err_msg + f"\nRemainder is not zero: {r}")
return q # type: ignore[return-value]

def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
Expand All @@ -448,7 +458,7 @@ def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> Di
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"Cannot compute stride for dimension '{d}', "
f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}.")
f"window_size '{window_size}', stride '{window_stride}'.\nDetails: {e}.")
return d

def as_value(self, d: DimSize):
Expand Down
66 changes: 45 additions & 21 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -270,7 +270,7 @@ def test_poly_int_results(self):
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
if quotient is None:
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial .* is not a multiple of .*"):
"Cannot divide .* by .*"):
divmod(dividend, divisor)
else:
self.assertEqual((quotient, remainder), divmod(dividend, divisor))
Expand All @@ -294,15 +294,15 @@ def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
def test_poly_truediv(self, *, dividend, divisor, quotient):
if quotient is None:
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial .* is not a multiple of .*"):
"Cannot divide .* by .*"):
dividend / divisor
else:
self.assertEqual(quotient, dividend / divisor)

def test_poly_truediv_error(self):
a, = shape_poly._parse_spec("a,", (2,))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Division of '3' by dimension polynomial .* is not supported"):
"Cannot divide .* by .*"):
3 / a

def test_dilate_shape(self):
Expand All @@ -327,7 +327,7 @@ def test_stride_shape(self):
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
re.escape(
"Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")):
"Cannot compute stride for dimension 'a', window_size '1', stride '2'.\nDetails: Cannot divide 'a + -1' by '2'")):
core.stride_shape((a, 20), (1, 3), (2, 2))


Expand Down Expand Up @@ -929,7 +929,8 @@ def test_readme_shape_error(self):
polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))

with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
re.escape("Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)")):
re.compile("Cannot divide evenly the sizes of shapes \\(b, 5, 7\\) and \\(2, -1\\).*Details: Cannot divide '35\\*b' by '-2'",
re.DOTALL)):
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(b, _, _)"])(np.ones((4, 5, 7)))

Expand All @@ -938,9 +939,12 @@ def test_readme_shape_error(self):
jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))

jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))

with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
re.compile("Division of .* by dimension polynomial .* is not supported",
re.compile("Cannot divide .* by 'v'.*Dividend must be a polynomial.",
re.DOTALL)):
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
Expand Down Expand Up @@ -1275,20 +1279,40 @@ def _make_harness(group_name: str, name: str,
[RandArg((3, 4, 5), _f32)],
poly_axes=[(0, 1)]),

# Issue #11402 InconclusiveDimensionOperation: Dimension polynomial '-1*t' is not a multiple of '2'
# TODO(still fails)
# _make_harness("conv_general_dilated", "1d_1",
# lambda lhs, rhs: lax.conv_general_dilated(
# lhs, rhs,
# window_strides=(2,),
# padding="SAME",
# rhs_dilation=None,
# dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
# rhs_spec=(2, 1, 0),
# out_spec=(0, 2, 1))),
# [RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
# poly_axes=[1, None],
# enable_and_disable_xla=True),
# Issue #11402
# We play a trick here. Since the stride is 2, when we compute the padding
# for "SAME" we need to divide by 2. We cannot do this in general, so we
# write the test with the assumption that the dimension is a multiple of 2.
# We pass the lhs as (1, b, 2, 16) and then we
# reshape it as (1, 2*b, 16), so that we know that the lhs's dimension 1
# is a multiple of 2.
_make_harness("conv_general_dilated", "1d_1",
lambda lhs, rhs: lax.conv_general_dilated(
jnp.reshape(lhs, (1, -1, 16)), rhs,
window_strides=(2,),
padding="SAME",
rhs_dilation=None,
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
rhs_spec=(2, 1, 0),
out_spec=(0, 2, 1))),
[RandArg((1, 6, 2, 16), _f32), RandArg((4, 16, 16), _f32)],
poly_axes=[1, None],
enable_and_disable_xla=True),
# The same example from above, but without the reshape trick.
_make_harness("conv_general_dilated", "1d_1err",
lambda lhs, rhs: lax.conv_general_dilated(
lhs, rhs,
window_strides=(2,),
padding="SAME",
rhs_dilation=None,
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
rhs_spec=(2, 1, 0),
out_spec=(0, 2, 1))),
[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
poly_axes=[1, None],
enable_and_disable_xla=True,
expect_error=(core.InconclusiveDimensionOperation,
"Cannot divide .* by '2'")),
# Issue #11402
_make_harness("conv_general_dilated", "1d_2",
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
Expand Down Expand Up @@ -1842,7 +1866,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# to parameterized below.
@primitive_harness.parameterized(
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
#one_containing="conv_general_dilated_1d_2_noxla_poly_axes=[0, None]"
#one_containing="conv_general_dilated_1d_1err_poly_axes=[1, None]"
)
def test_prim(self, harness: Harness):
_test_one_harness(self, harness)
Expand Down

0 comments on commit e6f93bc

Please sign in to comment.