diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 777892535e9b..9384d214554f 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -448,6 +448,7 @@ More operations are partially supported for dimension polynomials: in which case there may be a constant remainder. The need for division in JAX core arises in a couple of specific situations, e.g., `jax.numpy.reshape(-1)` and operations involving striding. + See [#division-of-shape-polynomials-is-partially-supported](below) for a discussion. * equality and disequality are partially supported. They result in a boolean value only when the same result would be obtained for any valuation of the dimension variables. In other situations, an exception `core.InconclusiveDimensionOperation` is raised. @@ -549,18 +550,34 @@ that `v == 4`, the shape checking rules fail with the above error. Since the converted function works only for square matrices, the correct `polymorphic_shapes` is `["(v, v)"]`. -You would also encounter shape errors if the code attempts to use the -dimension variables in unsupported arithmetic operations, such as in the code -below that fails to compute the inferred dimension for a `reshape` operations: + +Certain codes that use shapes in the actual computation may not yet work +if those shapes are polymorphic. In the code below, the expression `x.shape[0]` +will have the value of the dimension variable `v`. This case is not yet implemented: + +``` +jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0], + polymorphic_shapes=["(v, _)"])(np.ones((4, 4))) +``` + +### Division of shape polynomials is partially supported + +Unlike addition and multiplication, which are fully supported on +shape polynomials, division is supported when either (a) there +is no remainder, or (b) the divisor is a constant +in which case there may be a constant remainder. +For example, the code below results in a division error when trying to +compute the inferred dimension for a `reshape` operation: ``` jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(b, ...)"])(np.ones((4, 5, 7))) ``` -In this case you will see the error `Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)`. -This is because the shape of `x` is `(b, 5, 7)`, with a total size represented as the -dimension polynomial `35 b`, which is not divisible by `2`. +In this case you will see the error `Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)`, +with a further `Details: Cannot divide '35*b' by '-2'`. +The polynomial `35*b` represents the total size of the input tensor. + Note that the following will succeed: ``` @@ -573,13 +590,18 @@ jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])), polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6))) ``` -Finally, certain codes that use shapes in the actual computation may not yet work -if those shapes are polymorphic. In the code below, the expression `x.shape[0]` -will have the value of the dimension variable `v`. This case is not yet implemented: +You may also encounter division errors when working with strides, such as +when computing the padding in a strided convolution. + +In some cases you may know that one of the dimension variables +is a multiple of the divisor, +e.g., `b` in the above example of dividing `35*b` by `-2` may +be known to be a multiple of `2`. You can specify that by replacing +`b` with `2*b` in the polymorphic shape specification: ``` -jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0], - polymorphic_shapes=["(v, _)"])(np.ones((4, 4))) +jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), + polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7))) ``` ## Known issues diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index d0d3ed840b16..85395cd607b4 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -297,6 +297,13 @@ def __gt__(self, other: DimSize): def __lt__(self, other: DimSize): return not self.__ge__(other) + def _division_error_msg(self, dividend, divisor, details: str = "") -> str: + msg = f"Cannot divide '{dividend}' by '{divisor}'." + if details: + msg += f"\nDetails: {details}." + msg += "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported." + return msg + def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]: """ Floor division with remainder (divmod) generalized to polynomials. @@ -309,18 +316,19 @@ def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]: divisor = _ensure_poly(divisor) dmon, dcount = divisor.leading_term dividend, quotient = self, 0 - err_msg = f"Dimension polynomial '{self}' is not a multiple of '{divisor}'" # invariant: self = dividend + divisor * quotient # the leading term of dividend decreases through the loop. while is_poly_dim(dividend) and not dividend.is_constant: mon, count = dividend.leading_term try: qmon = mon.divide(dmon) - except InconclusiveDimensionOperation: - raise InconclusiveDimensionOperation(err_msg) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + self._division_error_msg(self, divisor, str(e))) qcount, rcount = divmod(count, dcount) if rcount != 0: - raise InconclusiveDimensionOperation(err_msg) + raise InconclusiveDimensionOperation( + self._division_error_msg(self, divisor)) q = _DimPolynomial.from_coeffs({qmon: qcount}) quotient += q @@ -333,7 +341,8 @@ def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]: remainder = r else: if dividend != 0: - raise InconclusiveDimensionOperation(err_msg) + raise InconclusiveDimensionOperation( + self._division_error_msg(self, divisor)) remainder = 0 if config.jax_enable_checks: @@ -351,13 +360,14 @@ def __truediv__(self, divisor: DimSize): q, r = self.divmod(divisor) if r != 0: raise InconclusiveDimensionOperation( - f"Dimension polynomial '{self}' is not a multiple of '{divisor}'") + self._division_error_msg(self, divisor, + f"Remainder is not zero: {r}")) return q def __rtruediv__(self, dividend: DimSize): # Used for "/", when dividend is not a _DimPolynomial raise InconclusiveDimensionOperation( - f"Division of '{dividend}' by dimension polynomial '{self}' is not supported") + self._division_error_msg(dividend, self, "Dividend must be a polynomial")) def __mod__(self, divisor: DimSize) -> int: return self.divmod(divisor)[1] @@ -433,10 +443,10 @@ def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}" try: q, r = _ensure_poly(sz1).divmod(sz2) - except InconclusiveDimensionOperation: - raise InconclusiveDimensionOperation(err_msg) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation(err_msg + f"\nDetails: {e}") if r != 0: - raise InconclusiveDimensionOperation(err_msg) + raise InconclusiveDimensionOperation(err_msg + f"\nRemainder is not zero: {r}") return q # type: ignore[return-value] def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: @@ -448,7 +458,7 @@ def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> Di except InconclusiveDimensionOperation as e: raise InconclusiveDimensionOperation( f"Cannot compute stride for dimension '{d}', " - f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}.") + f"window_size '{window_size}', stride '{window_stride}'.\nDetails: {e}.") return d def as_value(self, d: DimSize): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index a88ca78167fe..67b3c1761860 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -270,7 +270,7 @@ def test_poly_int_results(self): def test_poly_divmod(self, *, dividend, quotient, divisor, remainder): if quotient is None: with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Dimension polynomial .* is not a multiple of .*"): + "Cannot divide .* by .*"): divmod(dividend, divisor) else: self.assertEqual((quotient, remainder), divmod(dividend, divisor)) @@ -294,7 +294,7 @@ def test_poly_divmod(self, *, dividend, quotient, divisor, remainder): def test_poly_truediv(self, *, dividend, divisor, quotient): if quotient is None: with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Dimension polynomial .* is not a multiple of .*"): + "Cannot divide .* by .*"): dividend / divisor else: self.assertEqual(quotient, dividend / divisor) @@ -302,7 +302,7 @@ def test_poly_truediv(self, *, dividend, divisor, quotient): def test_poly_truediv_error(self): a, = shape_poly._parse_spec("a,", (2,)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - "Division of '3' by dimension polynomial .* is not supported"): + "Cannot divide .* by .*"): 3 / a def test_dilate_shape(self): @@ -327,7 +327,7 @@ def test_stride_shape(self): with self.assertRaisesRegex( core.InconclusiveDimensionOperation, re.escape( - "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")): + "Cannot compute stride for dimension 'a', window_size '1', stride '2'.\nDetails: Cannot divide 'a + -1' by '2'")): core.stride_shape((a, 20), (1, 3), (2, 2)) @@ -929,7 +929,8 @@ def test_readme_shape_error(self): polymorphic_shapes=["(v, 4)"])(np.ones((4, 4))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, - re.escape("Cannot divide evenly the sizes of shapes (b, 5, 7) and (2, -1)")): + re.compile("Cannot divide evenly the sizes of shapes \\(b, 5, 7\\) and \\(2, -1\\).*Details: Cannot divide '35\\*b' by '-2'", + re.DOTALL)): jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(b, _, _)"])(np.ones((4, 5, 7))) @@ -938,9 +939,12 @@ def test_readme_shape_error(self): jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])), polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6))) + jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), + polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7))) + with self.assertRaisesRegex( core.InconclusiveDimensionOperation, - re.compile("Division of .* by dimension polynomial .* is not supported", + re.compile("Cannot divide .* by 'v'.*Dividend must be a polynomial.", re.DOTALL)): jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0], polymorphic_shapes=["(v, _)"])(np.ones((4, 4))) @@ -1275,20 +1279,40 @@ def _make_harness(group_name: str, name: str, [RandArg((3, 4, 5), _f32)], poly_axes=[(0, 1)]), - # Issue #11402 InconclusiveDimensionOperation: Dimension polynomial '-1*t' is not a multiple of '2' - # TODO(still fails) - # _make_harness("conv_general_dilated", "1d_1", - # lambda lhs, rhs: lax.conv_general_dilated( - # lhs, rhs, - # window_strides=(2,), - # padding="SAME", - # rhs_dilation=None, - # dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - # rhs_spec=(2, 1, 0), - # out_spec=(0, 2, 1))), - # [RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - # poly_axes=[1, None], - # enable_and_disable_xla=True), + # Issue #11402 + # We play a trick here. Since the stride is 2, when we compute the padding + # for "SAME" we need to divide by 2. We cannot do this in general, so we + # write the test with the assumption that the dimension is a multiple of 2. + # We pass the lhs as (1, b, 2, 16) and then we + # reshape it as (1, 2*b, 16), so that we know that the lhs's dimension 1 + # is a multiple of 2. + _make_harness("conv_general_dilated", "1d_1", + lambda lhs, rhs: lax.conv_general_dilated( + jnp.reshape(lhs, (1, -1, 16)), rhs, + window_strides=(2,), + padding="SAME", + rhs_dilation=None, + dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), + rhs_spec=(2, 1, 0), + out_spec=(0, 2, 1))), + [RandArg((1, 6, 2, 16), _f32), RandArg((4, 16, 16), _f32)], + poly_axes=[1, None], + enable_and_disable_xla=True), + # The same example from above, but without the reshape trick. + _make_harness("conv_general_dilated", "1d_1err", + lambda lhs, rhs: lax.conv_general_dilated( + lhs, rhs, + window_strides=(2,), + padding="SAME", + rhs_dilation=None, + dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), + rhs_spec=(2, 1, 0), + out_spec=(0, 2, 1))), + [RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)], + poly_axes=[1, None], + enable_and_disable_xla=True, + expect_error=(core.InconclusiveDimensionOperation, + "Cannot divide .* by '2'")), # Issue #11402 _make_harness("conv_general_dilated", "1d_2", lambda lhs, rhs: lax.conv_transpose(lhs, rhs, @@ -1842,7 +1866,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): # to parameterized below. @primitive_harness.parameterized( _flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES), - #one_containing="conv_general_dilated_1d_2_noxla_poly_axes=[0, None]" + #one_containing="conv_general_dilated_1d_1err_poly_axes=[1, None]" ) def test_prim(self, harness: Harness): _test_one_harness(self, harness)