From 3c17027072a681dca092f6bb7b72ffaadc12e2f8 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 17 Dec 2022 05:56:48 +0200 Subject: [PATCH] [jax2tf] Refactoring of shape_poly_test. This all started because I noticed that the old self.CheckShapePolymorphism was not running the converted function and would only do the conversion in TF graph mode. Then I realized that there were multiple ways of specifying and running the tests: _make_harness, vmap harnesses, self.CheckShapePolymorphism. This PR unifies all test harnesses under a new PolyHarness class, with new documentation. There is a helper function check_shape_poly that simply wraps PolyHarness. Since the new tests exercise the jax2tf more deeply, especially in TF eager model, I have found 3 bugs. One is fixed here, in the jax2tf._assert_matching_abstract_shape. Two others are deferred (and a couple or tests are skipped here). --- jax/experimental/jax2tf/jax2tf.py | 9 +- .../jax2tf/tests/shape_poly_test.py | 2050 +++++++++-------- jax/experimental/jax2tf/tests/tf_test_util.py | 37 - 3 files changed, 1052 insertions(+), 1044 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 144cf7d2f40d..dac1ec88fd57 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -974,9 +974,14 @@ def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize """Asserts that shape matches x.shape in the known dimensions and has dimension polynomials elsewhere.""" # Ensures that the shape does not contain None; it should contain polynomials + def check_one(xd: Optional[int], sd: Any): + if core.is_constant_dim(sd): + return xd == sd + else: + assert isinstance(sd, shape_poly._DimPolynomial) + return True assert (len(x.shape) == len(shape) and - all((xd is None and isinstance(sd, shape_poly._DimPolynomial) or - core.is_constant_dim(sd) and xd == sd) + all(check_one(xd, sd) for xd, sd in zip(x.shape, shape))), \ f"Shape {shape} does not match x.shape {x.shape}" diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index a3bc6f0ef0d3..ee0456461450 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -15,7 +15,7 @@ import unittest from absl.testing import absltest, parameterized -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import collections import functools @@ -31,6 +31,7 @@ import jax.numpy as jnp from jax import random from jax._src import test_util as jtu +from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow import numpy as np @@ -49,7 +50,7 @@ from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation PS = jax2tf.PolyShape - +_f32 = np.float32 class DimPolynomialTest(tf_test_util.JaxToTfTestCase): @@ -331,38 +332,230 @@ def test_stride_shape(self): core.stride_shape((a, 20), (1, 3), (2, 2)) -class ShapePolyTest(tf_test_util.JaxToTfTestCase): +class PolyHarness(Harness): + """Tests a function with shape polymorphism. + + Converts `fun` with shape polymorphism, creates a `tf.ConcreteFunction` + given `input_signature` and checks the inferred output shapes to match + `expected_output_shapes`, then checks that the JAX and the TF functions + produce the same results. + """ + def __init__(self, + group_name: str, name: str, + fun: Callable, + *, + arg_descriptors: Sequence[primitive_harness.ArgDescriptor] = (), + polymorphic_shapes: Optional[Sequence[Any]] = None, + input_signature: Optional[Sequence[tf.TensorSpec]] = None, + poly_axes: Optional[Sequence[Optional[Union[int, Sequence[int]]]]] = None, + expected_output_signature: Optional[tf.TensorSpec] = None, + enable_xla: bool = True, + expect_error: Tuple[Optional[Any], Optional[str]] = (None, None), + skip_jax_run: bool = False, + check_result: bool = True, + tol: Optional[float] = None): + """Args: + + group_name, name: The name for the harness. See `Harness.__init__`. + fun: the function to be converted, possbily after partial application to + static arguments from `arg_descriptors`. See `Harness.__init__`. + arg_descriptors: The argument descriptors. See `Harness.__init__`. May + be missing, in which case `skip_jax_run` should be `True` and + `poly_axes` cannot be used. + polymorphic_shapes: For `jax2tf.convert`. If missing, generated from + `poly_axes`. + input_signature: For `tf.function.get_concrete_function`. If missing, + generated from `poly_axes`. + poly_axes: If present, used to generate `polymorphic_shapes` and + `input_signature`. Must correspond to the non-static arguments, and for + each one it must specify which axes are polymorphic: None, or an int + (for the index of the polymorphic axis), or a tuple of ints + (for multiple polymorphic axes). For each argument, we use its + `poly_axes` entry to generate the polymorphic_shapes specification, + creating dimension variables `b0`, `b1, ..., for each of its polymorphic + axes. This means that separate arguments will share the same dimension + variable names, in the order in which the axes are listed in + `poly_axes`. We also generate the input_signature from `poly_axes`. + expected_output_signature: the expected inferred output shape. + enable_xla: For `jax2tf.convert`. + expect_error: a pair of an Exception type and a regular expression to + match the expected exception string. + skip_jax_run: If True, then neither the JAX nor the TF functions are + executed. + check_result: specifies if we want to check that the result of the shape + polymorphic conversion produces the same result and the JAX function. + tol: the tolerance to use for checking results. + """ + super().__init__(group_name, name, fun, arg_descriptors, + dtype=np.float32) + self.poly_axes = poly_axes + self.polymorphic_shapes = polymorphic_shapes + self.input_signature = input_signature + self.expected_output_signature = expected_output_signature + self.skip_jax_run = skip_jax_run + self.expect_error = expect_error + self.enable_xla = enable_xla + self.tol = tol + self.check_result = check_result + + # Replicate the harness for both enable and disable xla + def both_enable_and_disable_xla(self) -> Tuple["PolyHarness", "PolyHarness"]: + assert self.enable_xla + other = PolyHarness(self.group_name, + f"{self.name}_enable_xla=False", + self.fun, + arg_descriptors=self.arg_descriptors, + poly_axes=self.poly_axes, + polymorphic_shapes=self.polymorphic_shapes, + input_signature=self.input_signature, + expected_output_signature=self.expected_output_signature, + expect_error=self.expect_error, + tol=self.tol, + enable_xla=False) + self.name = f"{self.name}_enable_xla=True" + return (self, other) + + def run_test(self, tst: tf_test_util.JaxToTfTestCase): + # Make polymorphic_shapes and input_signature from poly_axes. + if self.poly_axes is None: + polymorphic_shapes = self.polymorphic_shapes + input_signature = self.input_signature + assert input_signature is not None + if not self.skip_jax_run: + args = self.dyn_args_maker(tst.rng()) + + else: + assert isinstance(self.poly_axes, Sequence) + # Make poly_axes: Sequence[Sequence[int]], one top-level element for each argument + poly_axes = tuple(map(lambda pa: pa if isinstance(pa, Sequence) or pa is None else (pa,), + self.poly_axes)) + args = self.dyn_args_maker(tst.rng()) + + assert self.polymorphic_shapes is None + assert self.input_signature is None + assert args is not None and len(args) == len(poly_axes) + # Make the polymorphic_shapes and input_signature + polymorphic_shapes = [] + input_signature = [] + for arg, poly_axis in zip(args, poly_axes): + if poly_axis is None: + polymorphic_shapes.append(None) + input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype)) + else: + def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]: + idx = -1 + dims = [] + tensorspec_dims: List[Optional[int]] = [] + for i, d in enumerate(arg.shape): + if i in poly_axis: + idx += 1 + dims.append(f"b{idx}") + tensorspec_dims.append(None) + else: + dims.append(str(d)) + tensorspec_dims.append(d) + return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype) + + arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis) + polymorphic_shapes.append(arg_polymorphic_shapes) + input_signature.append(arg_tensorspec) + + expect_error_type, expect_error_regex = self.expect_error + if self.skip_jax_run and self.arg_descriptors == (): + f_jax = self.fun + else: + f_jax = self.dyn_fun + if expect_error_type is not None: + with tst.assertRaisesRegex(expect_error_type, expect_error_regex): + f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes, + enable_xla=self.enable_xla) + f_tf_func = tf.function( + f_tf, autograph=False, input_signature=input_signature) + # Create tf.ConcreteFunction and check inferred output signature + f_tf_func.get_concrete_function(*input_signature) + + return + else: + f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes, + enable_xla=self.enable_xla) + f_tf_func = tf.function( + f_tf, autograph=False, input_signature=input_signature) + # Create tf.ConcreteFunction and check inferred output signature + concrete_f_tf = f_tf_func.get_concrete_function(*input_signature) + + if self.expected_output_signature: + # Strangely, output_shapes can be a single shape for a function with a + # single result, or a list/tuple of shapes. + expected_output_signature = self.expected_output_signature + concrete_output_tf_shape = concrete_f_tf.output_shapes + if not isinstance(concrete_output_tf_shape, (tuple, list)): # Single result + assert not isinstance(self.expected_output_signature, (tuple, list)) + expected_output_signature = [self.expected_output_signature] + concrete_output_tf_shape = [concrete_output_tf_shape] + for expected, found in util.safe_zip(expected_output_signature, + concrete_output_tf_shape): + tst.assertEqual(tuple(expected.shape), tuple(found)) + + # Run the JAX and the TF functions and compare the results + if not self.skip_jax_run: + res_jax = f_jax(*args) + res_tf = f_tf(*args) + if self.check_result: + tst.assertAllClose(res_jax, res_tf, atol=self.tol, rtol=self.tol) + + +def check_shape_poly(tst, f_jax: Callable, *, + arg_descriptors: Sequence[primitive_harness.ArgDescriptor] = (), + skip_jax_run: bool = False, + poly_axes = None, + polymorphic_shapes: Optional[Sequence[Any]] = None, + input_signature: Optional[Sequence[tf.TensorSpec]] = None, + expected_output_signature: Optional[tf.TensorSpec] = None): + # Makes and tests a harness. See PolyHarness documentation. + h = PolyHarness("", "", f_jax, + arg_descriptors=arg_descriptors, + skip_jax_run=skip_jax_run, poly_axes=poly_axes, + polymorphic_shapes=polymorphic_shapes, + input_signature=input_signature, + expected_output_signature=expected_output_signature) + h.run_test(tst) +class ShapePolyTest(tf_test_util.JaxToTfTestCase): + def test_simple_unary(self): """Test shape polymorphism for a simple case, unary function.""" def f_jax(x): return x + jnp.sin(x) - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([2, 3])], - polymorphic_shapes=None, - expected_output_signature=tf.TensorSpec([2, 3])) - - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([2, None])], - polymorphic_shapes=["_, h"], - expected_output_signature=tf.TensorSpec([2, None])) - - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None, None])], - polymorphic_shapes=["h, h"], - expected_output_signature=tf.TensorSpec([None, None])) - - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None, None])], - polymorphic_shapes="h, h", - expected_output_signature=tf.TensorSpec([None, None])) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((2, 3), _f32)], + input_signature=[tf.TensorSpec([2, 3])], + polymorphic_shapes=None, + expected_output_signature=tf.TensorSpec([2, 3])) + + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((2, 3), _f32)], + input_signature=[tf.TensorSpec([2, None])], + polymorphic_shapes=["_, h"], + expected_output_signature=tf.TensorSpec([2, None])) + + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3, 3), _f32)], + input_signature=[tf.TensorSpec([None, None])], + polymorphic_shapes=["h, h"], + expected_output_signature=tf.TensorSpec([None, None])) + + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3, 3), _f32)], + input_signature=[tf.TensorSpec([None, None])], + polymorphic_shapes="h, h", + expected_output_signature=tf.TensorSpec([None, None])) def test_simple_binary(self): """Test shape polymorphism for a simple case, binary function.""" @@ -370,27 +563,30 @@ def test_simple_binary(self): def f_jax(x, y): return x + jnp.sin(y) - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([2, 3]), tf.TensorSpec([2, 3])], - polymorphic_shapes=None, - expected_output_signature=tf.TensorSpec([2, 3])) - - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([2, None]), tf.TensorSpec([2, 3])], - polymorphic_shapes="2, h", - expected_output_signature=( - # for native lowering we cannot refine the inferred shape of the - # output if the input is more specific than polymorphic_shapes. - tf.TensorSpec([2, 3]) if not config.jax2tf_default_experimental_native_lowering - else tf.TensorSpec([2, None]))) - - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None, None]), tf.TensorSpec([None, None])], - polymorphic_shapes=PS("h", "h"), - expected_output_signature=tf.TensorSpec([None, None])) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)], + input_signature=[tf.TensorSpec([2, 3]), tf.TensorSpec([2, 3])], + polymorphic_shapes=None, + expected_output_signature=tf.TensorSpec([2, 3])) + + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)], + input_signature=[tf.TensorSpec([2, None]), tf.TensorSpec([2, 3])], + polymorphic_shapes="_, h", + expected_output_signature=( + # for native lowering we cannot refine the inferred shape of the + # output if the input is more specific than polymorphic_shapes. + tf.TensorSpec([2, 3]) if not config.jax2tf_default_experimental_native_lowering + else tf.TensorSpec([2, None]))) + + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3, 3), _f32), RandArg((3, 3), _f32)], + input_signature=[tf.TensorSpec([None, None]), tf.TensorSpec([None, None])], + polymorphic_shapes=PS("h", "h"), + expected_output_signature=tf.TensorSpec([None, None])) def test_arange(self): def f_jax(x): @@ -412,25 +608,28 @@ def test_static_shape_result(self): def f_jax(x): return jnp.sum(x + jnp.sin(x), axis=0) - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([2, 3])], - polymorphic_shapes=None, - expected_output_signature=tf.TensorSpec([3])) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((2, 3), _f32)], + input_signature=[tf.TensorSpec([2, 3])], + polymorphic_shapes=None, + expected_output_signature=tf.TensorSpec([3])) - self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None, 3])], - polymorphic_shapes="b, _", - expected_output_signature=tf.TensorSpec([3])) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((2, 3), _f32)], + input_signature=[tf.TensorSpec([None, 3])], + polymorphic_shapes="b, _", + expected_output_signature=tf.TensorSpec([3])) def test_forgot_polymorphic_shapes_error(self): msg_re = "polymorphic shape None in axis .* must contain a dimension variable for unknown dimension in argument shape .*. Perhaps you forgot to add the polymorphic_shapes" with self.assertRaisesRegex(ValueError, msg_re): - self.CheckShapePolymorphism( - jnp.sin, - input_signature=[tf.TensorSpec([1, None])], - polymorphic_shapes=None) + check_shape_poly(self, + jnp.sin, + arg_descriptors=[RandArg((1, 3,), _f32)], + input_signature=[tf.TensorSpec([1, None])], + polymorphic_shapes=None) def test_kwargs(self): """Test shape polymorphism for a function with kwargs.""" @@ -440,8 +639,8 @@ def test_kwargs(self): def f_jax(x, *, y): return x + jnp.sin(y) - f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."]) - f_tf(x, y=y) + f_tf: Callable[..., Any] = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."]) + self.assertAllClose(f_jax(x, y=y), f_tf(x, y=y)) def test_arg_avals(self): """Test conversion of actual arguments to abstract values.""" @@ -693,86 +892,91 @@ def add_all_jax(x_pair_of_list, y_dict): return functools.reduce(operator.add, x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) - self.CheckShapePolymorphism( - add_all_jax, - input_signature=[([tf.TensorSpec([None]), - tf.TensorSpec([None])], [tf.TensorSpec([None])]), - dict(a=tf.TensorSpec([None]), - b=tf.TensorSpec([None]))], - polymorphic_shapes=[(["v", "v"], [("v")]), - dict(a="v", b="v")], - expected_output_signature=tf.TensorSpec([None])) + check_shape_poly(self, + add_all_jax, + skip_jax_run=True, + input_signature=[([tf.TensorSpec([None]), tf.TensorSpec([None])], + [tf.TensorSpec([None])]), + dict(a=tf.TensorSpec([None]), + b=tf.TensorSpec([None]))], + polymorphic_shapes=[(["v", "v"], ["v"]), + dict(a="v", b="v")], + expected_output_signature=tf.TensorSpec([None])) # Now partial polymorphic_shapes; the parts of the polymorphic_shapes that # are not specified must have full input_signatures. - self.CheckShapePolymorphism( - add_all_jax, - input_signature=[([tf.TensorSpec([4]), - tf.TensorSpec([4])], [tf.TensorSpec([4])]), - dict(a=tf.TensorSpec([4]), b=tf.TensorSpec([4]))], - polymorphic_shapes=[(["(4,)", "(_,)"], [("4,")]), - dict(a="(_,)", b="(4,)")], - expected_output_signature=tf.TensorSpec([4])) + check_shape_poly(self, + add_all_jax, + skip_jax_run=True, + input_signature=[([tf.TensorSpec([4]), tf.TensorSpec([4])], [tf.TensorSpec([4])]), + dict(a=tf.TensorSpec([4]), b=tf.TensorSpec([4]))], + polymorphic_shapes=[(["(4,)", "(_,)"], [("4,")]), + dict(a="(_,)", b="(4,)")], + expected_output_signature=tf.TensorSpec([4])) def test_with_nested_jit(self): - x = np.ones((3, 4), dtype=np.float32) - # We implement the following computation - _ = x + (np.sin(x) + np.broadcast_to(np.arange(x.shape[1]), x.shape)) def f_jax(x): # x: f32[w, h] + # x + (np.sin(x) + np.broadcast_to(np.arange(x.shape[1]), x.shape)) return jnp.sin(x) + jnp.arange(x.shape[1], dtype=x.dtype) - self.CheckShapePolymorphism( - lambda x: x + jax.jit(f_jax)(x), - input_signature=[tf.TensorSpec([None, None])], - polymorphic_shapes=["w, h"]) + check_shape_poly(self, + lambda x: x + jax.jit(f_jax)(x), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[(0, 1)]) def test_non_trivial_polynomials(self): if config.jax_dynamic_shapes: raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials") # We can handle non-trivial polynomials in the input shape, # as long as all variables also occur in trivial polynoamials - self.CheckShapePolymorphism( - lambda x, y: x + y.reshape((-1,)), - input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])], - polymorphic_shapes=["b * b", "b, b"]) + check_shape_poly(self, + lambda x, y: x + y.reshape((-1,)), + arg_descriptors=[RandArg((9,), _f32), RandArg((3, 3), _f32)], + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])], + polymorphic_shapes=["b * b", "b, b"]) def test_unused_args(self): # Tests with functions that do not use their inputs. # First arg unused, not polymorphic - self.CheckShapePolymorphism( - lambda x_unused, y: y * 2.0, - input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])], - polymorphic_shapes=[None, "b"]) + check_shape_poly(self, + lambda x_unused, y: y * 2.0, + arg_descriptors=[RandArg((2, 3), _f32), RandArg((3,), _f32)], + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b"]) # Some args unused, not polymorphic - self.CheckShapePolymorphism( - lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]), - input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]), + check_shape_poly(self, + lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]), + arg_descriptors=[RandArg((3,), _f32), RandArg((4,), _f32), + RandArg((5,), _f32), RandArg((6,), _f32)], + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]), tf.TensorSpec([]), tf.TensorSpec([None])], - polymorphic_shapes=[None, "b1", None, "b2"]) + polymorphic_shapes=[None, "b1", None, "b2"]) # A polymorphic arg is not used, but the dimension var appears # in a used arg also - self.CheckShapePolymorphism( - lambda x_unused, y: y * 2.0, - input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], - polymorphic_shapes=["b", "b"]) + check_shape_poly(self, + lambda x_unused, y: y * 2.0, + arg_descriptors=[RandArg((3,), _f32), RandArg((3,), _f32)], + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b", "b"]) # A polymorphic arg is not used, and the dimension var does not appear # elsewhere. - self.CheckShapePolymorphism( + check_shape_poly(self, lambda x_unused, y: y * 2.0, + arg_descriptors=[RandArg((4,), _f32), RandArg((3,), _f32)], input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], polymorphic_shapes=["b1", "b2"]) # A polymorphic arg is not used, and the dimension var does appear # elsewhere but not as a trivial monomial. - self.CheckShapePolymorphism( + check_shape_poly(self, lambda x_unused, y: y * 2.0, + arg_descriptors=[RandArg((3,), _f32), RandArg((9,), _f32)], input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], polymorphic_shapes=["b1", "b1 * b1"]) - def test_with_custom_vjp(self): """Shape-polymorphic custom VJP.""" @@ -802,12 +1006,7 @@ def f_bwd(residual, ct_b): res_jax = f(x) res_jax_grad = jax.grad(lambda x: jnp.sum(f(x)))(x) - f_tf = self.CheckShapePolymorphism( - f, - input_signature=[tf.TensorSpec([None, None, None, None])], - polymorphic_shapes=["(batch1, batch2, d1, d2)"], - expected_output_signature=tf.TensorSpec([None, None, None, None])) - + f_tf = jax2tf.convert(f, polymorphic_shapes=["(batch1, batch2, d1, d2)"]) self.assertAllClose(res_jax, f_tf(x)) xv = tf.Variable(x, dtype=np.float32) @@ -845,12 +1044,13 @@ def f(x): # res: dict(res=[b, 3, 4]) return dict(res=x["x"] * 2.) - f_tf = self.CheckShapePolymorphism( - f, - input_signature=[dict(x=tf.TensorSpec([None, 3, 4]))], - polymorphic_shapes=[dict(x=("b, 3, 4"))], - expected_output_signature=None) + check_shape_poly(self, + f, + skip_jax_run=True, + input_signature=[dict(x=tf.TensorSpec([None, 3, 4]))], + polymorphic_shapes=[dict(x=("b, 3, 4"))]) + f_tf = jax2tf.convert(f, polymorphic_shapes=[dict(x=("b, 3, 4"))]) x = dict(x=np.ones((2, 3, 4), dtype=np.float32)) xv = tf.Variable(x["x"], dtype=np.float32) @@ -951,6 +1151,7 @@ def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2] self.assertAllClose(g_tf[1], np.zeros_like(zb)) def test_prng(self): + raise unittest.SkipTest("TODO(necula): investigate") # The PRNG implementation uses opaque types, test shape polymorphism try: prev_custom_prng = config.jax_enable_custom_prng @@ -967,9 +1168,10 @@ def f_jax(x): # x: f32[b1, b2] _ = lax.dynamic_update_slice(upd1, gather_keys, start_indices=(0, 0)) return x - self.CheckShapePolymorphism(f_jax, - input_signature=[tf.TensorSpec([None, None], dtype=tf.float32)], - polymorphic_shapes=["b1, b2"]) + check_shape_poly(self, f_jax, + arg_descriptors=[RandArg((3, 4), _f32)], + input_signature=[tf.TensorSpec([None, None], dtype=tf.float32)], + polymorphic_shapes=["b1, b2"]) finally: config.update("jax_enable_custom_prng", prev_custom_prng) @@ -1039,17 +1241,13 @@ def test_readme_examples(self): jax2tf.convert(lambda x: jnp.prod(jnp.array(x.shape)), polymorphic_shapes=["(b, 4)"])(np.ones((3, 4))) + four_ones = np.ones((4,)) with self.assertRaisesRegex( TypeError, re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")): - self.CheckShapePolymorphism( - lambda x, y: x + y, - input_signature=[tf.TensorSpec([None]), - tf.TensorSpec([4])], - polymorphic_shapes=["(v,)", "(4,)"], - expected_output_signature=tf.TensorSpec([None])) + jax2tf.convert(lambda x, y: x + y, + polymorphic_shapes=["(v,)", "(4,)"])(four_ones, four_ones) - four_ones = np.ones((4,)) # We get the error even if we use correct actual arguments with self.assertRaisesRegex( TypeError, @@ -1127,14 +1325,6 @@ def f_jax(x): "Cannot solve for values of dimension variables"): jax2tf.convert(lambda x: x, polymorphic_shapes=["a + b"])(x) - -class DimAsValueTest(tf_test_util.JaxToTfTestCase): - """Dimension polynomials used as values in the JAX computation.""" - def setUp(self): - super().setUp() - if config.jax2tf_default_experimental_native_lowering and not config.jax_dynamic_shapes: - self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") - def test_dynamic_shapes(self): # Test dim_as_value with dynamic shapes. def f(x): @@ -1173,13 +1363,11 @@ def test_mean0(self): def f_jax(x): return jnp.sum(x, axis=0) / x.shape[0] - x = np.arange(12.).reshape((3, 4)) - f_tf = self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None, 4], dtype=x.dtype)], - polymorphic_shapes=[("b, _")], - expected_output_signature=tf.TensorSpec([4])) - self.assertAllClose(np.array([4., 5., 6., 7.]), f_tf(x)) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0], + expected_output_signature=tf.TensorSpec([4])) def test_dimension_used_as_value(self): def f_jax(x): @@ -1189,209 +1377,189 @@ def f_jax(x): x3 = x2 + jnp.sin(poly) # In jnp operations # A list or tuple of poly in jnp operations return x3.astype(np.float32) - x = np.arange(3, dtype=np.int32) - f_tf = self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None], dtype=x.dtype)], - polymorphic_shapes=["b"], - expected_output_signature=tf.TensorSpec([])) - self.assertAllClose(np.float32(3 + 3 + np.sin(3)), f_tf(x)) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3,), np.int32)], + poly_axes=[0], + expected_output_signature=tf.TensorSpec([])) def test_dimension_used_as_result(self): + if config.jax_enable_x64: + raise unittest.SkipTest("TODO(necula): dim_as_value in x64 mode") def f_jax(x): return 2 * x.shape[0] - x = np.arange(3, dtype=np.int32) - f_tf = self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None], dtype=x.dtype)], - polymorphic_shapes=["b"], - expected_output_signature=tf.TensorSpec([])) - self.assertAllClose(np.array(2 * 3, dtype=x.dtype), f_tf(x)) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3,), np.int32)], + poly_axes=[0], + expected_output_signature=tf.TensorSpec([])) def test_shape_as_array(self): def f_jax(x): # The entire x.shape is passed to jnp.array return jnp.sum(jnp.array(x.shape)).astype(np.int32) - x = np.arange(12, dtype=np.int32).reshape((3, 4)) - f_tf = self.CheckShapePolymorphism( - f_jax, - input_signature=[tf.TensorSpec([None, 4], dtype=x.dtype)], - polymorphic_shapes=["b, _"], - expected_output_signature=tf.TensorSpec([])) - self.assertAllClose(np.int32(3 + 4), f_tf(x)) - -### -### We define primitive harnesses for which we will test shape-polymorphic -### conversion. -def _make_harness(group_name: str, name: str, - func: Callable, - args: primitive_harness.ArgDescriptor, - *, - poly_axes: Sequence[Optional[Union[int, Sequence[int]]]], - check_result=True, - skip_jax_run=True, - tol=None, - enable_and_disable_xla=False, - expect_error=(None, None), - **params) -> Union[Harness, Sequence[Harness]]: - """The `poly_axes` must correspond to the non-static arguments, and for each - one it must specify which axes are: None, or an int (for the index of the - polymorphic axis), or a tuple of ints (for multiple polymorphic axes). - - For each argument, we use its `poly_axes` entry to generate the polymorphic_shapes - specification, creating dimension variables `b0`, `b1, ..., for each of its - polymorphic axes. This means that separate arguments will share the same - dimension variable names, in the order in which the axes are listed in - poly_axes. - - The name of the harness within the group will include `poly_axes`. - You can add an additional `name`. - - `check_result` specifies if we want to check that the result of the shape - polymorphic conversion produces the same result and the JAX function. - - `expect_error` is a pair of an Exception type and a regular expression to - match the expected exception string. - - enable_and_disable_xla=True means that we generate two harnesses, - one with enable_xla=False and one with enable_xal=True. Otherwise we create - only one harness with enable_xla=True. - """ - if enable_and_disable_xla: - return [ - _make_harness(group_name, name + ("" if enable_xla else "_noxla"), # type: ignore - func, args, poly_axes=poly_axes, - check_result=check_result, tol=tol, enable_xla=enable_xla, - enable_and_disable_xla=False, skip_jax_run=skip_jax_run, - expect_error=expect_error) - for enable_xla in [True, False] - ] - poly_axes_name = f"poly_axes={repr(poly_axes)}" - assert isinstance(poly_axes, Sequence) - # Make poly_axes: Sequence[Sequence[int]] - poly_axes = tuple(map(lambda pa: pa if isinstance(pa, Sequence) or pa is None else (pa,), - poly_axes)) - if name: - name = f"{name}_{poly_axes_name}" - else: - name = poly_axes_name - return Harness(group_name, - name, - func, args, - dtype=np.float32, - poly_axes=poly_axes, check_result=check_result, - skip_jax_run=skip_jax_run, expect_error=expect_error, - tol=tol, - **params) + check_shape_poly(self, + f_jax, + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]) + def test_vmap_while(self): + def cond_func(x): # x: f32[3] + return jnp.sum(x) >= 0. + def body_func(x): # x: f32[3] + return x - 1. + def f_jax(x): + return lax.while_loop(cond_func, body_func, x) + check_shape_poly(self, + jax.vmap(f_jax), + arg_descriptors=[RandArg((3,), _f32)], + input_signature=[tf.TensorSpec((None, 3), dtype=tf.float32)], + polymorphic_shapes=["b, ..."], + expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32) + ) + + def test_reshape_compiled(self): + # We compile the result of conversion for two shapes, hence we need to + # involve the TF compiler twice, but we trace only once with shape polymorphism + traced = False + + def f_jax(x): + nonlocal traced + traced = True + y = jnp.sin(x) + return y.reshape([x.shape[0], -1]) + + x = self.rng().rand(4, 2, 3) + res_jax = f_jax(x) + + traced = False + # If we get_concrete_function we trace once + f_tf = tf.function( + jax2tf.convert(f_jax, polymorphic_shapes=[PS("b", ...)]), + autograph=False, + jit_compile=True).get_concrete_function( + tf.TensorSpec([None, 2, 3], x.dtype)) + self.assertTrue(traced) + traced = False + self.assertAllClose(res_jax, f_tf(x)) + self.assertFalse(traced) # We are not tracing again + + x = self.rng().rand(6, 2, 3) + res_jax = f_jax(x) + traced = False + + self.assertAllClose(res_jax, f_tf(x)) + self.assertFalse(traced) # We are not tracing again -_f32 = np.float32 # List containing either harnesses, or lists of harnesses _POLY_SHAPE_TEST_HARNESSES = [ - _make_harness("add", "", - jnp.add, - [RandArg((3, 4), _f32), RandArg((2, 3, 4), _f32)], - poly_axes=[0, 1]), - _make_harness("add_transpose", "", - jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + x)), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("arange", "start", - lambda op: jnp.arange(2 * op.shape[0], dtype=_f32), - [RandArg((3,), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), - _make_harness("arange", "start_no_dtype", - lambda op: jnp.arange(op.shape[0]), - [RandArg((3,), _f32)], - poly_axes=[0]), - _make_harness("arange", "error1", - lambda op: jnp.arange(op.shape[0], 10), - [RandArg((3,), _f32)], - poly_axes=[0], - expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")), - _make_harness("arange", "error2", - lambda op: jnp.arange(1, op.shape[0]), - [RandArg((3,), _f32)], - poly_axes=[0], - expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")), - _make_harness("arange", "error3", - lambda op: jnp.arange(1, 5, op.shape[0]), - [RandArg((3,), _f32)], - poly_axes=[0], - expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")), + PolyHarness("simple_binary", "", + lambda x, y: x + jnp.sin(y), + arg_descriptors=[RandArg((3, 4), _f32), RandArg((3, 4), _f32)], + input_signature=[tf.TensorSpec([2, None]), tf.TensorSpec([2, 3])], + polymorphic_shapes="_, h", + expected_output_signature=tf.TensorSpec([2, 3])), + PolyHarness("add", "", + jnp.add, + arg_descriptors=[RandArg((3, 4), _f32), RandArg((2, 3, 4), _f32)], + poly_axes=[0, 1]), + PolyHarness("add_transpose", "", + jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + x)), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("arange", "start", + lambda op: jnp.arange(2 * op.shape[0], dtype=_f32), + arg_descriptors=[RandArg((3,), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("arange", "start_no_dtype", + lambda op: jnp.arange(op.shape[0]), + arg_descriptors=[RandArg((3,), _f32)], + poly_axes=[0]), + PolyHarness("arange", "error1", + lambda op: jnp.arange(op.shape[0], 10), + arg_descriptors=[RandArg((3,), _f32)], + poly_axes=[0], + expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")), + PolyHarness("arange", "error2", + lambda op: jnp.arange(1, op.shape[0]), + arg_descriptors=[RandArg((3,), _f32)], + poly_axes=[0], + expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")), + PolyHarness("arange", "error3", + lambda op: jnp.arange(1, 5, op.shape[0]), + arg_descriptors=[RandArg((3,), _f32)], + poly_axes=[0], + expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")), # Reduce the poly dimension - _make_harness("argmax", "0", - lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), - [RandArg((3, 4, 5), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), + PolyHarness("argmax", "0", + lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), + arg_descriptors=[RandArg((3, 4, 5), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), # Reduce the non-poly dimension - _make_harness("argmax", "1", - lambda op: lax.argmax(op, axis=1, index_dtype=np.int32), - [RandArg((3, 4, 5), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), + PolyHarness("argmax", "1", + lambda op: lax.argmax(op, axis=1, index_dtype=np.int32), + arg_descriptors=[RandArg((3, 4, 5), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), [ - _make_harness("average", - f"{axis=}_weights=None", - lambda x, axis: jnp.average(x, axis=axis, returned=False, weights=None), - [RandArg((7, 8, 4), _f32), StaticArg(axis)], - poly_axes=[0]) + PolyHarness("average", + f"{axis=}_weights=None", + lambda x, axis: jnp.average(x, axis=axis, returned=False, weights=None), + arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis)], + poly_axes=[0]) for axis in [None, 0, 1] ], [ - _make_harness("average", - f"{axis=}_weights=Some", - lambda x, weights, axis: jnp.average(x, axis=axis, returned=False, weights=weights), - [RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), _f32), StaticArg(axis)], - poly_axes=[0, 0]) + PolyHarness("average", + f"{axis=}_weights=Some", + lambda x, weights, axis: jnp.average(x, axis=axis, returned=False, weights=weights), + arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), _f32), StaticArg(axis)], + poly_axes=[0, 0]) for axis in [None, 0, 1] ], - _make_harness("broadcast_to", "", - lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("broadcast_in_dim", "0", - lambda x: lax.broadcast_in_dim(x, [x.shape[0], 4, 5, 6], - broadcast_dimensions=(0, 2, 3)), - [RandArg((3, 1, 6), _f32)], - poly_axes=[0]), - _make_harness("broadcast_in_dim", "poly", - lambda x: lax.broadcast_in_dim(x, [x.shape[0], x.shape[0] + x.shape[0], 4], - broadcast_dimensions=(0, 1, 2)), - [RandArg((3, 1, 4), _f32)], - poly_axes=[0]), - _make_harness("broadcast_in_dim", "poly2", - lambda x: lax.broadcast_in_dim(x, [x.shape[0], 5, 6, x.shape[2], 4], - broadcast_dimensions=(0, 2, 3)), - [RandArg((3, 1, 4), _f32)], - poly_axes=[(0, 2)]), - _make_harness("broadcast_in_dim", "transpose", - jax.grad(lambda x: jnp.sum(lax.broadcast_in_dim(x, [2, x.shape[0], 5, x.shape[2], 4], - broadcast_dimensions=(1, 2, 3)))), - [RandArg((3, 1, 4), _f32)], - poly_axes=[(0, 2)]), - _make_harness("clamp", "", - lax.clamp, - [RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32), - RandArg((3, 4, 5), _f32)], - poly_axes=[0, 0, 0]), - _make_harness("collapse", "", - lambda x: lax.collapse(x, 1, 4), - [RandArg((3, 4, 5, 6, 7), _f32)], - poly_axes=[(0, 1, 3)]), - _make_harness("concatenate", "", - lambda x: jnp.concatenate([x, x], axis=0), - [RandArg((3, 4, 5), _f32)], - poly_axes=[(0, 1)]), - _make_harness("concatenate", "grad", - jax.grad(lambda x: jnp.sum(jnp.concatenate([x, x], axis=0))), - [RandArg((3, 4, 5), _f32)], - poly_axes=[(0, 1)]), + PolyHarness("broadcast_to", "", + lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("broadcast_in_dim", "0", + lambda x: lax.broadcast_in_dim(x, [x.shape[0], 4, 5, 6], + broadcast_dimensions=(0, 2, 3)), + arg_descriptors=[RandArg((3, 1, 6), _f32)], + poly_axes=[0]), + PolyHarness("broadcast_in_dim", "poly", + lambda x: lax.broadcast_in_dim(x, [x.shape[0], x.shape[0] + x.shape[0], 4], + broadcast_dimensions=(0, 1, 2)), + arg_descriptors=[RandArg((3, 1, 4), _f32)], + poly_axes=[0]), + PolyHarness("broadcast_in_dim", "poly2", + lambda x: lax.broadcast_in_dim(x, [x.shape[0], 5, 6, x.shape[2], 4], + broadcast_dimensions=(0, 2, 3)), + arg_descriptors=[RandArg((3, 1, 4), _f32)], + poly_axes=[(0, 2)]), + PolyHarness("broadcast_in_dim", "transpose", + jax.grad(lambda x: jnp.sum(lax.broadcast_in_dim(x, [2, x.shape[0], 5, x.shape[2], 4], + broadcast_dimensions=(1, 2, 3)))), + arg_descriptors=[RandArg((3, 1, 4), _f32)], + poly_axes=[(0, 2)]), + PolyHarness("clamp", "", + lax.clamp, + arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32), + RandArg((3, 4, 5), _f32)], + poly_axes=[0, 0, 0]), + PolyHarness("collapse", "", + lambda x: lax.collapse(x, 1, 4), + arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)], + poly_axes=[(0, 1, 3)]), + PolyHarness("concatenate", "", + lambda x: jnp.concatenate([x, x], axis=0), + arg_descriptors=[RandArg((3, 4, 5), _f32)], + poly_axes=[(0, 1)]), + PolyHarness("concatenate", "grad", + jax.grad(lambda x: jnp.sum(jnp.concatenate([x, x], axis=0))), + arg_descriptors=[RandArg((3, 4, 5), _f32)], + poly_axes=[(0, 1)]), # Issue #11402 # We play a trick here. Since the stride is 2, when we compute the padding @@ -1400,612 +1568,544 @@ def _make_harness(group_name: str, name: str, # We pass the lhs as (1, b, 2, 16) and then we # reshape it as (1, 2*b, 16), so that we know that the lhs's dimension 1 # is a multiple of 2. - _make_harness("conv_general_dilated", "1d_1", - lambda lhs, rhs: lax.conv_general_dilated( - jnp.reshape(lhs, (1, -1, 16)), rhs, - window_strides=(2,), - padding="SAME", - rhs_dilation=None, - dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - rhs_spec=(2, 1, 0), - out_spec=(0, 2, 1))), - [RandArg((1, 6, 2, 16), _f32), RandArg((4, 16, 16), _f32)], - poly_axes=[1, None], - enable_and_disable_xla=True), + PolyHarness("conv_general_dilated", "1d_1", + lambda lhs, rhs: lax.conv_general_dilated( + jnp.reshape(lhs, (1, -1, 16)), rhs, + window_strides=(2,), + padding="SAME", + rhs_dilation=None, + dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), + rhs_spec=(2, 1, 0), + out_spec=(0, 2, 1))), + arg_descriptors=[RandArg((1, 6, 2, 16), _f32), RandArg((4, 16, 16), _f32)], + poly_axes=[1, None]).both_enable_and_disable_xla(), # The same example from above, but without the reshape trick. - _make_harness("conv_general_dilated", "1d_1err", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(2,), - padding="SAME", - rhs_dilation=None, - dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - rhs_spec=(2, 1, 0), - out_spec=(0, 2, 1))), - [RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - poly_axes=[1, None], - enable_and_disable_xla=True, - expect_error=(core.InconclusiveDimensionOperation, - "Cannot divide .* by '2'")), + PolyHarness("conv_general_dilated", "1d_1err", + lambda lhs, rhs: lax.conv_general_dilated( + lhs, rhs, + window_strides=(2,), + padding="SAME", + rhs_dilation=None, + dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), + rhs_spec=(2, 1, 0), + out_spec=(0, 2, 1))), + arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)], + poly_axes=[1, None], + expect_error=(core.InconclusiveDimensionOperation, + "Cannot divide .* by '2'") + ).both_enable_and_disable_xla(), # Issue #11402 - _make_harness("conv_general_dilated", "1d_2", - lambda lhs, rhs: lax.conv_transpose(lhs, rhs, - strides=(2,), - padding="SAME", - rhs_dilation=None, - transpose_kernel=False), - [RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - poly_axes=[0, None], - enable_and_disable_xla=True), + PolyHarness("conv_general_dilated", "1d_2", + lambda lhs, rhs: lax.conv_transpose(lhs, rhs, + strides=(2,), + padding="SAME", + rhs_dilation=None, + transpose_kernel=False), + arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)], + poly_axes=[0, None], + tol=1e-5).both_enable_and_disable_xla(), # Issue #11402 - _make_harness("conv_general_dilated", "1d_3", - lambda lhs, rhs: lax.conv_transpose(lhs, rhs, - strides=(2,), - padding="SAME", - rhs_dilation=None, - transpose_kernel=False), - [RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - poly_axes=[1, None], - enable_and_disable_xla=True), - _make_harness("conv_general_dilated", "", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(2, 3), - padding=((0, 0), (0, 0)), - lhs_dilation=(1, 1), - rhs_dilation=(1, 2), - dimension_numbers=("NCHW", "OIHW", "NCHW"), - feature_group_count=1, - batch_group_count=1, - precision=None), - [RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)], - poly_axes=[0, None], - enable_and_disable_xla=True), - _make_harness("cummax", "", - lambda x: lax_control_flow.cummax(x, axis=1, reverse=False), - [RandArg((3, 4, 5), _f32)], - poly_axes=[0]), - _make_harness("delta", "0", - lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("dot_general", "", - lambda lhs, rhs: lax.dot_general(lhs, rhs, - dimension_numbers=(((2,), (1,)), ((0,), (0,)))), - [RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)], - poly_axes=[0, 0]), - _make_harness("dynamic_slice", "idx=tuple_int", - # x:shape: (b, 4) - lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)), - [RandArg((3, 4), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), - _make_harness("dynamic_slice", "idx=tuple_arg", - # x:shape: (b, 4) - lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)), - [RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)], - poly_axes=[0, None], - enable_and_disable_xla=True), - _make_harness("dynamic_slice", "idx=array", - # x:shape: (b, 4) - lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)), - [RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)], - poly_axes=[0, None], - enable_and_disable_xla=True), - _make_harness("dynamic_slice_in_dim", "idx=0", - # x:shape: (b, 4) - lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0), - [RandArg((3, 4), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), - _make_harness("dynamic_update_slice", "idx=tuple_int", - # x:shape: (b, 4) - lambda x: lax.dynamic_update_slice(x, x, (0, 0)), - [RandArg((3, 4), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), - _make_harness("dynamic_update_slice", "idx=tuple_arg", - # x:shape: (b, 4) - lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))), - [RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)], - poly_axes=[0, None], - enable_and_disable_xla=True), - _make_harness("dynamic_update_slice", "idx=array", - # x:shape: (b, 4) - lambda x, idx: lax.dynamic_update_slice(x, x, idx), - [RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)], - poly_axes=[0, None], - enable_and_disable_xla=True), - _make_harness("einsum", "0", - lambda x: jnp.einsum("...i->...", x), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("einsum", "0_alt", - lambda x: jnp.einsum(x, (..., 1), [...]), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("einsum", "1", - lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y), - [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], - poly_axes=[0, 0]), - _make_harness("einsum", "1_alt", - lambda x, y: jnp.einsum(x, [..., 0, 1], y, (..., 1, 2), [..., 0, 2]), - [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], - poly_axes=[0, 0]), - _make_harness("einsum", "2", - lambda x, y: jnp.einsum("...ij,jk->...ik", x, y), - [RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], - poly_axes=[0, None]), - _make_harness("einsum", "2_alt", - lambda x, y: jnp.einsum(x, [..., 0, 1], y, [1, 2], [..., 0, 2]), - [RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], - poly_axes=[0, None]), - _make_harness("einsum", "3", - # Reduced dimension is polymorphic - lambda x, y: jnp.einsum("ij,jk->ik", x, y), - [RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - poly_axes=[1, 0]), - _make_harness("einsum", "3_alt", - # Reduced dimension is polymorphic - lambda x, y: jnp.einsum(x, [0, 1], y, [1, 2], [0, 2]), - [RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - poly_axes=[1, 0]), - _make_harness("einsum", "4", - # Reduced dimension is polymorphic, and is 2*b - lambda x, y: jnp.einsum("ij,jk->ik", - jnp.concatenate([x, x], axis=1), - jnp.concatenate([y, y], axis=0)), - [RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - poly_axes=[1, 0]), - _make_harness("einsum", "4_alt", - # Reduced dimension is polymorphic, and is 2*b - lambda x, y: jnp.einsum(jnp.concatenate([x, x], axis=1), [0, 1], - jnp.concatenate([y, y], axis=0), [1, 2], - [0, 2]), - [RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - poly_axes=[1, 0]), - _make_harness("einsum", "multiple_contractions", - lambda x, y, z: jnp.einsum("ab,bc,cd->ad", x, y, z), - [RandArg((3, 2), _f32), RandArg((2, 3), _f32), RandArg((3, 4), _f32),], - poly_axes=[0, None, None]), - _make_harness("einsum", "incompatible_contractions_error", - lambda x, y: jnp.einsum("ab,cb->ac", x, y), - [RandArg((2, 3), _f32), RandArg((2, 3), _f32)], - poly_axes=[1, (0, 1)], - expect_error=(core.InconclusiveDimensionOperation, - "Dimension polynomial comparison 'b1' == 'b0' is inconclusive")), - _make_harness("eye", "N=poly_M=None", - lambda x: jnp.eye(x.shape[0]), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("eye", "N=poly_M=poly", - lambda x: jnp.eye(x.shape[0], M=x.shape[0] + 2), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("full", "", - lambda x: lax.full((x.shape[0], 2), 3.), - [RandArg((3, 4), _f32)], - poly_axes=[0]), + PolyHarness("conv_general_dilated", "1d_3", + lambda lhs, rhs: lax.conv_transpose(lhs, rhs, + strides=(2,), + padding="SAME", + rhs_dilation=None, + transpose_kernel=False), + arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)], + poly_axes=[1, None], + tol=1e-5).both_enable_and_disable_xla(), + PolyHarness("conv_general_dilated", "", + lambda lhs, rhs: lax.conv_general_dilated( + lhs, rhs, + window_strides=(2, 3), + padding=((0, 0), (0, 0)), + lhs_dilation=(1, 1), + rhs_dilation=(1, 2), + dimension_numbers=("NCHW", "OIHW", "NCHW"), + feature_group_count=1, + batch_group_count=1, + precision=None), + arg_descriptors=[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)], + poly_axes=[0, None]).both_enable_and_disable_xla(), + PolyHarness("cummax", "", + lambda x: lax_control_flow.cummax(x, axis=1, reverse=False), + arg_descriptors=[RandArg((3, 4, 5), _f32)], + poly_axes=[0]), + PolyHarness("delta", "0", + lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("dot_general", "", + lambda lhs, rhs: lax.dot_general(lhs, rhs, + dimension_numbers=(((2,), (1,)), ((0,), (0,)))), + arg_descriptors=[RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)], + poly_axes=[0, 0]), + PolyHarness("dynamic_slice", "idx=tuple_int", + # x:shape: (b, 4) + lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("dynamic_slice", "idx=tuple_arg", + # x:shape: (b, 4) + lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)), + arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)], + poly_axes=[0, None]).both_enable_and_disable_xla(), + PolyHarness("dynamic_slice", "idx=array", + # x:shape: (b, 4) + lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)), + arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)], + poly_axes=[0, None]).both_enable_and_disable_xla(), + PolyHarness("dynamic_slice_in_dim", "idx=0", + # x:shape: (b, 4) + lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("dynamic_update_slice", "idx=tuple_int", + # x:shape: (b, 4) + lambda x: lax.dynamic_update_slice(x, x, (0, 0)), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("dynamic_update_slice", "idx=tuple_arg", + # x:shape: (b, 4) + lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))), + arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)], + poly_axes=[0, None]).both_enable_and_disable_xla(), + PolyHarness("dynamic_update_slice", "idx=array", + # x:shape: (b, 4) + lambda x, idx: lax.dynamic_update_slice(x, x, idx), + arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)], + poly_axes=[0, None]).both_enable_and_disable_xla(), + PolyHarness("einsum", "0", + lambda x: jnp.einsum("...i->...", x), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("einsum", "0_alt", + lambda x: jnp.einsum(x, (..., 1), [...]), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("einsum", "1", + lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y), + arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], + poly_axes=[0, 0]), + PolyHarness("einsum", "1_alt", + lambda x, y: jnp.einsum(x, [..., 0, 1], y, (..., 1, 2), [..., 0, 2]), + arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], + poly_axes=[0, 0]), + PolyHarness("einsum", "2", + lambda x, y: jnp.einsum("...ij,jk->...ik", x, y), + arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], + poly_axes=[0, None]), + PolyHarness("einsum", "2_alt", + lambda x, y: jnp.einsum(x, [..., 0, 1], y, [1, 2], [..., 0, 2]), + arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], + poly_axes=[0, None]), + PolyHarness("einsum", "3", + # Reduced dimension is polymorphic + lambda x, y: jnp.einsum("ij,jk->ik", x, y), + arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], + poly_axes=[1, 0]), + PolyHarness("einsum", "3_alt", + # Reduced dimension is polymorphic + lambda x, y: jnp.einsum(x, [0, 1], y, [1, 2], [0, 2]), + arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], + poly_axes=[1, 0]), + PolyHarness("einsum", "4", + # Reduced dimension is polymorphic, and is 2*b + lambda x, y: jnp.einsum("ij,jk->ik", + jnp.concatenate([x, x], axis=1), + jnp.concatenate([y, y], axis=0)), + arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], + poly_axes=[1, 0]), + PolyHarness("einsum", "4_alt", + # Reduced dimension is polymorphic, and is 2*b + lambda x, y: jnp.einsum(jnp.concatenate([x, x], axis=1), [0, 1], + jnp.concatenate([y, y], axis=0), [1, 2], + [0, 2]), + arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], + poly_axes=[1, 0]), + PolyHarness("einsum", "multiple_contractions", + lambda x, y, z: jnp.einsum("ab,bc,cd->ad", x, y, z), + arg_descriptors=[RandArg((3, 2), _f32), RandArg((2, 3), _f32), RandArg((3, 4), _f32),], + poly_axes=[0, None, None]), + PolyHarness("einsum", "incompatible_contractions_error", + lambda x, y: jnp.einsum("ab,cb->ac", x, y), + arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)], + poly_axes=[1, (0, 1)], + expect_error=(core.InconclusiveDimensionOperation, + "Dimension polynomial comparison 'b1' == 'b0' is inconclusive")), + PolyHarness("eye", "N=poly_M=None", + lambda x: jnp.eye(x.shape[0]), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("eye", "N=poly_M=poly", + lambda x: jnp.eye(x.shape[0], M=x.shape[0] + 2), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("full", "", + lambda x: lax.full((x.shape[0], 2), 3.), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), # operand is non-poly, index is poly - _make_harness("getitem", "op=static_idx=poly", - lambda a, i: a[i], - [RandArg((3, 4), _f32), np.array([2, 2], np.int32)], - poly_axes=[None, 0], enable_and_disable_xla=True), + PolyHarness("getitem", "op=static_idx=poly", + lambda a, i: a[i], + arg_descriptors=[RandArg((3, 4), _f32), np.array([2, 2], np.int32)], + poly_axes=[None, 0]).both_enable_and_disable_xla(), # operand is poly, index is integer - _make_harness("getitem", "op=poly_idx=const", - lambda a: a[1], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True), + PolyHarness("getitem", "op=poly_idx=const", + lambda a: a[1], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), # operand is poly, index is dim poly - _make_harness("getitem", "op=poly_idx=dim", - lambda a: a[jnp.array(a.shape[0] - 2)], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True), + PolyHarness("getitem", "op=poly_idx=dim", + lambda a: a[jnp.array(a.shape[0] - 2)], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), # Both the operand and the index are poly - _make_harness("getitem", "op=poly_idx=poly", - lambda a, i: a[i], - [RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)], - poly_axes=[0, 0], enable_and_disable_xla=True), + PolyHarness("getitem", "op=poly_idx=poly", + lambda a, i: a[i], + arg_descriptors=[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)], + poly_axes=[0, 0]).both_enable_and_disable_xla(), # op is poly and index is an entire slice - _make_harness("getitem", "op=poly_idx=slice-all", - lambda a: a[:], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True), + PolyHarness("getitem", "op=poly_idx=slice-all", + lambda a: a[:], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), # op is poly and index is a partial slice - _make_harness("getitem", "op=poly_idx=slice-ct-1", - lambda a: a[:2], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True, - expect_error=(IndexError, "Cannot use NumPy slice indexing on an array dimension")), - _make_harness("getitem", "op=poly_idx=slice-ct-2", - lambda a: a[:, :2], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True), - _make_harness("getitem", "op=poly_idx=slice-None-1", - lambda a: a[:a.shape[0]], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True), - _make_harness("getitem", "op=poly_idx=slice-poly", - lambda a: a[:a.shape[0] - 1], - [RandArg((3, 4), _f32)], - poly_axes=[0], enable_and_disable_xla=True, - expect_error=(IndexError, "Array slice indices must have static")), - _make_harness("image_resize", "linear_0", - lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), - method="linear"), - [RandArg((3, 16, 32, 3), _f32)], - poly_axes=[(1, 2)]), - _make_harness("image_resize", "linear_to_fixed_dim", - lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]), - method="linear"), - [RandArg((3, 16, 32, 3), _f32)], - poly_axes=[(1, 2)]), - _make_harness("image_resize", "nearest_0", - lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), - method="nearest"), - [RandArg((3, 5, 7, 3), _f32)], - poly_axes=[(1, 2)]), - _make_harness("index_in_dim", "0", - lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("index_in_dim", "idx=neg", - lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("index_in_dim", "idx=last", - lambda x: lax.index_in_dim(x, x.shape[0] - 1, axis=0, keepdims=False), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("iota", "", - lambda x: x + lax.iota(_f32, x.shape[0]), - [RandArg((3,), _f32)], - poly_axes=[0]), - _make_harness("matmul", "0", - jnp.matmul, - [RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)], - poly_axes=[0, 0], - tol=1e-5), - _make_harness("matmul", "1", - jnp.matmul, - [RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)], - poly_axes=[0, None], - tol=1e-5), + PolyHarness("getitem", "op=poly_idx=slice-ct-1", + lambda a: a[:2], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0], + expect_error=(IndexError, "Cannot use NumPy slice indexing on an array dimension") + ).both_enable_and_disable_xla(), + PolyHarness("getitem", "op=poly_idx=slice-ct-2", + lambda a: a[:, :2], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("getitem", "op=poly_idx=slice-None-1", + lambda a: a[:a.shape[0]], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("getitem", "op=poly_idx=slice-poly", + lambda a: a[:a.shape[0] - 1], + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0], + expect_error=(IndexError, "Array slice indices must have static")).both_enable_and_disable_xla(), + PolyHarness("image_resize", "linear_0", + lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), + method="linear"), + arg_descriptors=[RandArg((3, 16, 32, 3), _f32)], + poly_axes=[(1, 2)]), + PolyHarness("image_resize", "linear_to_fixed_dim", + lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]), + method="linear"), + arg_descriptors=[RandArg((3, 16, 32, 3), _f32)], + poly_axes=[(1, 2)]), + PolyHarness("image_resize", "nearest_0", + lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), + method="nearest"), + arg_descriptors=[RandArg((3, 5, 7, 3), _f32)], + poly_axes=[(1, 2)]), + PolyHarness("index_in_dim", "0", + lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("index_in_dim", "idx=neg", + lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("index_in_dim", "idx=last", + lambda x: lax.index_in_dim(x, x.shape[0] - 1, axis=0, keepdims=False), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("iota", "", + lambda x: x + lax.iota(_f32, x.shape[0]), + arg_descriptors=[RandArg((3,), _f32)], + poly_axes=[0]), + PolyHarness("matmul", "0", + jnp.matmul, + arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)], + poly_axes=[0, 0], + tol=1e-5), + PolyHarness("matmul", "1", + jnp.matmul, + arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)], + poly_axes=[0, None], + tol=1e-5), [ - _make_harness("mean", - f"{axis=}_{keepdims=}_where=None", - lambda x, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=None), - [RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)], - poly_axes=[0]) + PolyHarness("mean", + f"{axis=}_{keepdims=}_where=None", + lambda x, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=None), + arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)], + poly_axes=[0]) for keepdims in [False, True] for axis in [None, (0,), (0, 1), (1,)] ], [ - _make_harness("mean", - f"{axis=}_{keepdims=}_where=Some", - lambda x, where, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=where), - [RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)], - poly_axes=[0, 0]) + PolyHarness("mean", + f"{axis=}_{keepdims=}_where=Some", + lambda x, where, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=where), + arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), + StaticArg(axis), StaticArg(keepdims)], + poly_axes=[0, 0]) for keepdims in [False, True] for axis in [None, (0,), (0, 1), (1,)] ], - _make_harness("ones", "", - lambda x: jnp.ones(x.shape, dtype=_f32), - [RandArg((3, 2, 4), _f32)], - poly_axes=[0]), - _make_harness("pad", "", - lax.pad, - [RandArg((3, 2, 5), _f32), np.float32(5.), - StaticArg(((0, 0, 0), (0, 0, 0), (1, 1, 1)))], - poly_axes=[0, None]), - _make_harness("random_gamma", "", - lambda key, a: jax.random.gamma(key, a), - [RandArg((3, 2), np.uint32), RandArg((3, 3), _f32)], - poly_axes=[0, 0]), + PolyHarness("ones", "", + lambda x: jnp.ones(x.shape, dtype=_f32), + arg_descriptors=[RandArg((3, 2, 4), _f32)], + poly_axes=[0]), + PolyHarness("pad", "", + lax.pad, + arg_descriptors=[RandArg((3, 2, 5), _f32), np.float32(5.), + StaticArg(((0, 0, 0), (0, 0, 0), (1, 1, 1)))], + poly_axes=[0, None]), + PolyHarness("random_gamma", "", + lambda key, a: jax.random.gamma(key, a), + arg_descriptors=[RandArg((3, 2), np.uint32), RandArg((3, 3), _f32)], + poly_axes=[0, 0]), # The known dimensions product must be even. - _make_harness("random_categorical", "axis=0", - lambda key, a: jax.random.categorical(key, a, axis=0), - [RandArg((2,), np.uint32), RandArg((3, 8), _f32)], - poly_axes=[None, 0]), - _make_harness("random_categorical", "axis=1", - lambda key, a: jax.random.categorical(key, a, axis=1), - [RandArg((2,), np.uint32), RandArg((3, 8), _f32)], - poly_axes=[None, 0]), + PolyHarness("random_categorical", "axis=0", + lambda key, a: jax.random.categorical(key, a, axis=0), + arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 8), _f32)], + poly_axes=[None, 0]), + PolyHarness("random_categorical", "axis=1", + lambda key, a: jax.random.categorical(key, a, axis=1), + arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 8), _f32)], + poly_axes=[None, 0]), # Works when the known dimensions are known to be even or odd. - _make_harness("random_uniform", "even_1", - lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), - [RandArg((2,), np.uint32), RandArg((3, 4), _f32)], - poly_axes=[None, 0]), - _make_harness("random_uniform", "even_2", - lambda key, a: jax.random.uniform(key, (2 * a.shape[0], a.shape[1]), - dtype=_f32), - [RandArg((2,), np.uint32), RandArg((3, 5), _f32)], - poly_axes=[None, 0]), - _make_harness("random_uniform", "error_not_even", - lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), - [RandArg((2,), np.uint32), RandArg((3, 5), _f32)], - poly_axes=[None, 0], - expect_error=(core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even")), - _make_harness("reduce_window", "min", - # x.shape = (b, 8) - lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min, - (2, 2), (1, 1), "VALID"), - [RandArg((3, 8), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), - _make_harness("reduce_window", "add", - # x.shape = (b, 8) - lambda x: lax.reduce_window(x, 0, lax.add, (2, 2), (1, 1), - "VALID"), - [RandArg((3, 8), _f32)], - poly_axes=[0], - enable_and_disable_xla=True), + PolyHarness("random_uniform", "even_1", + lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), + arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 4), _f32)], + poly_axes=[None, 0]), + PolyHarness("random_uniform", "even_2", + lambda key, a: jax.random.uniform(key, (2 * a.shape[0], a.shape[1]), + dtype=_f32), + arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5), _f32)], + poly_axes=[None, 0]), + PolyHarness("random_uniform", "error_not_even", + lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), + arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5), _f32)], + poly_axes=[None, 0], + expect_error=(core.InconclusiveDimensionOperation, + "the product of the known dimensions must be even")), + PolyHarness("reduce_window", "min", + # x.shape = (b, 8) + lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min, + (2, 2), (1, 1), "VALID"), + arg_descriptors=[RandArg((3, 8), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), + PolyHarness("reduce_window", "add_0", + # x.shape = (b, 8) + lambda x: lax.reduce_window(x, 0, lax.add, (2, 2), (1, 1), + "VALID"), + arg_descriptors=[RandArg((3, 8), _f32)], + poly_axes=[0]).both_enable_and_disable_xla(), # https://github.com/google/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) - _make_harness("reduce_window", "add", - # x.shape = (1, 16*b, 1) - lambda x: lax.reduce_window( - jnp.reshape(x, (1, -1, 1)), - 0., lax.add, (1, 4, 1), (1, 2, 1), "SAME"), - [RandArg((1, 128, 16), _f32)], - poly_axes=[1], - enable_and_disable_xla=True), + PolyHarness("reduce_window", "add_1", + # x.shape = (1, 16*b, 1) + lambda x: lax.reduce_window( + jnp.reshape(x, (1, -1, 1)), + 0., lax.add, (1, 4, 1), (1, 2, 1), "SAME"), + arg_descriptors=[RandArg((1, 128, 16), _f32)], + poly_axes=[1]).both_enable_and_disable_xla(), # TODO(necula): not yet supported, but also unlikely to come up. - # _make_harness("random_uniform", "odd", + # PolyHarness("random_uniform", "odd", # lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]), # dtype=_f32), # [RandArg((2,), np.uint32), RandArg((3, 5), _f32)], # poly_axes=[None, 0]), [ - _make_harness("reduce", reduce_op.__name__, - lambda x: reduce_op(x, axis=-1, keepdims=True), # type: ignore - [RandArg((3, 5), _f32)], - poly_axes=[0]) + PolyHarness("reduce", reduce_op.__name__, + lambda x: reduce_op(x, axis=-1, keepdims=True), # type: ignore + arg_descriptors=[RandArg((3, 5), _f32)], + poly_axes=[0]) for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum] ], # Repeat f32[b, 2] * 3 - _make_harness("repeat", "repeats=int_axis=0", - lambda x: jnp.repeat(x, repeats=3, axis=0), - [RandArg((3, 2), _f32)], - poly_axes=[0]), + PolyHarness("repeat", "repeats=int_axis=0", + lambda x: jnp.repeat(x, repeats=3, axis=0), + arg_descriptors=[RandArg((3, 2), _f32)], + poly_axes=[0]), # Repeat f32[b, 2] * b - _make_harness("repeat", "repeats=poly_axis=0", - lambda x: jnp.repeat(x, repeats=x.shape[0], axis=0), - [RandArg((3, 2), _f32)], - poly_axes=[0]), + PolyHarness("repeat", "repeats=poly_axis=0", + lambda x: jnp.repeat(x, repeats=x.shape[0], axis=0), + arg_descriptors=[RandArg((3, 2), _f32)], + poly_axes=[0]), # Repeat f32[b, 2] * b - _make_harness("repeat", "repeats=poly_axis=None", - lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None), - [RandArg((3, 2), _f32)], - poly_axes=[0]), + PolyHarness("repeat", "repeats=poly_axis=None", + lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None), + arg_descriptors=[RandArg((3, 2), _f32)], + poly_axes=[0]), # Repeat f32 * b - _make_harness("repeat", "repeats=poly_axis=None_scalar", - lambda x, y: jnp.repeat(x, repeats=y.shape[0], axis=None), - [RandArg((), _f32), RandArg((3, 2), _f32)], - poly_axes=[None, 0]), - _make_harness("repeat", "repeats=poly_axis=None_total_repeat_length1", - lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None, total_repeat_length=8), - [RandArg((3, 2), _f32)], - poly_axes=[0], - expect_error=(ValueError, "jnp.repeat with a DimPolynomial `repeats` is supported only .*")), - _make_harness("reshape", "0", - lambda x: x.reshape([x.shape[0], -1]), - [RandArg((3, 2, 3), _f32)], - poly_axes=[0]), - _make_harness("reshape", "1", - lambda x: x.reshape([x.shape[0], -1]), - [RandArg((3, 2, 3), _f32)], - poly_axes=[(0, 1)]), - _make_harness("reshape", "2", - lambda x: x.reshape([x.shape[0], -1, x.shape[3], x.shape[2]]), - [RandArg((3, 4, 5, 6, 7), _f32)], - poly_axes=[(0, 2, 3)]), - _make_harness("reshape", "3", - lambda x: jnp.reshape(x, [2, -1]), - [RandArg((3, 4, 5, 6, 7), _f32)], - poly_axes=[(0, 2)]), - _make_harness("reshape", "_issue_9975", - # The newshape is a scalar - lambda x: jnp.reshape(x, x.shape[0] * x.shape[1]), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("reshape", "error", - lambda x: x.reshape([x.shape[0], -1, 3]), - [RandArg((3, 2, 4), _f32)], - poly_axes=[0], - skip_jax_run=True, - expect_error=(core.InconclusiveDimensionOperation, - re.escape( - "Cannot divide evenly the sizes of shapes (b0, 2, 4) and (b0, -1, 3)"))), - _make_harness("roll", "axis=0", - lambda x: jnp.roll(x, 2, axis=0), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("roll", "axis=None", - lambda x: jnp.roll(x, 2), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("scatter_add", "", - partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True), - [RandArg((7, 4), _f32), - np.array([[1], [2]], np.int32), # indices: [2, 1] - RandArg((7, 2), _f32), # updates: [7, 2] - StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))], - poly_axes=[0, None, 0]), - _make_harness("scatter_add", "clip0", - partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP), - [RandArg((7, 4), _f32), # [b, 4] - np.array([[1], [2]], np.int32), # indices: [2, 1] - RandArg((7, 2), _f32), # updates: [b, 2] - StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))], - poly_axes=[0, None, 0]), - _make_harness("scatter_add", "clip1", - partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP), - [RandArg((7, 4), _f32), # [b, 4] - np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32), # indices: [b, 2] - RandArg((7, 1), _f32), # updates: [b, 1] - StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))], - poly_axes=[0, 0, 0]), - _make_harness("select", "0", - # x.shape = (b, 3) - lambda x: lax.select(x > 5., x, x), - [RandArg((7, 3), _f32)], - poly_axes=[0]), - _make_harness("select", "1", - # x.shape = (b, 3); y.shape = (3,) - jax.vmap(lambda x, y: lax.select(x > 5., x, y), in_axes=[0, None]), - [RandArg((7, 3), _f32), RandArg((3,), _f32)], - poly_axes=[0, None]), - _make_harness("slice", "entire_axis", - lambda x: lax.slice(x, start_indices=(0, 1), limit_indices=(x.shape[0], 3)), - [RandArg((7, 3), _f32)], - poly_axes=[0]), - _make_harness("slice_in_dim", "entire_axis", - lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1, axis=0), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("slice_in_dim", "start=neg", - lambda x: lax.slice_in_dim(x, -1, x.shape[0], stride=1, axis=0), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("slice_in_dim", "limit=neg", - lambda x: lax.slice_in_dim(x, 0, -1, stride=1, axis=0), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("squeeze", "axis=None", - jnp.squeeze, - [RandArg((5,), _f32), StaticArg(())], - poly_axes=[0]), - _make_harness("squeeze", "axis=1", - jnp.squeeze, - [RandArg((4, 1), _f32), StaticArg((1,))], - poly_axes=[0]), - _make_harness("squeeze", "axis=1_2", - jnp.squeeze, - [RandArg((4, 1, 1), _f32), StaticArg((1, 2))], - poly_axes=[0]), - _make_harness("squeeze", "error", - jnp.squeeze, - [RandArg((3, 33), _f32), StaticArg(-1)], - poly_axes=[(0, 1)], - skip_jax_run=True, - expect_error=(ValueError, - re.escape( - "cannot select an axis to squeeze out which has size not equal to one, got shape=(b0, b1) and dimensions=(1,)")) - ), - _make_harness("take", "", - lambda a, i: jnp.take(a, i, axis=1), - [RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)], - poly_axes=[0, None], enable_and_disable_xla=True), - _make_harness("take_along_axis", "0", - lambda x, y: jnp.take_along_axis(x, y, axis=0), - [RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], - poly_axes=[0, 0]), - _make_harness("take_along_axis", "1", - lambda x, y: jnp.take_along_axis(x, y, axis=1), - [RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], - poly_axes=[0, 0]), - _make_harness("tile", "0", - lambda x: jnp.tile(x, (1, 2)), - [RandArg((4, 3), _f32)], - poly_axes=[0]), - _make_harness("tile", "1", - # The repetitions are polys - lambda x: jnp.tile(x, (1, x.shape[0])), - [RandArg((4, 2), _f32)], - poly_axes=[0]), - _make_harness("tri", "N=poly_M=None", - lambda x: jnp.tri(x.shape[0]), - [RandArg((3, 4), _f32)], - poly_axes=[0]), - _make_harness("tri", "N=poly_M=poly", - lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2), - [RandArg((3, 4), _f32)], - poly_axes=[0]), + PolyHarness("repeat", "repeats=poly_axis=None_scalar", + lambda x, y: jnp.repeat(x, repeats=y.shape[0], axis=None), + arg_descriptors=[RandArg((), _f32), RandArg((3, 2), _f32)], + poly_axes=[None, 0]), + PolyHarness("repeat", "repeats=poly_axis=None_total_repeat_length1", + lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None, total_repeat_length=8), + arg_descriptors=[RandArg((3, 2), _f32)], + poly_axes=[0], + expect_error=(ValueError, "jnp.repeat with a DimPolynomial `repeats` is supported only .*")), + PolyHarness("reshape", "0", + lambda x: x.reshape([x.shape[0], -1]), + arg_descriptors=[RandArg((3, 2, 3), _f32)], + poly_axes=[0]), + PolyHarness("reshape", "1", + lambda x: x.reshape([x.shape[0], -1]), + arg_descriptors=[RandArg((3, 2, 3), _f32)], + poly_axes=[(0, 1)]), + PolyHarness("reshape", "2", + lambda x: x.reshape([x.shape[0], -1, x.shape[3], x.shape[2]]), + arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)], + poly_axes=[(0, 2, 3)]), + PolyHarness("reshape", "3", + lambda x: jnp.reshape(x, [2, -1]), + arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)], + poly_axes=[(0, 2)]), + PolyHarness("reshape", "_issue_9975", + # The newshape is a scalar + lambda x: jnp.reshape(x, x.shape[0] * x.shape[1]), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("reshape", "error", + lambda x: x.reshape([x.shape[0], -1, 3]), + arg_descriptors=[RandArg((3, 2, 4), _f32)], + poly_axes=[0], + skip_jax_run=True, + expect_error=(core.InconclusiveDimensionOperation, + re.escape( + "Cannot divide evenly the sizes of shapes (b0, 2, 4) and (b0, -1, 3)"))), + PolyHarness("roll", "axis=0", + lambda x: jnp.roll(x, 2, axis=0), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("roll", "axis=None", + lambda x: jnp.roll(x, 2), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("scatter_add", "", + partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True), + arg_descriptors=[RandArg((7, 4), _f32), + np.array([[1], [2]], np.int32), # indices: [2, 1] + RandArg((7, 2), _f32), # updates: [7, 2] + StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))], + poly_axes=[0, None, 0]), + PolyHarness("scatter_add", "clip0", + partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP), + arg_descriptors=[RandArg((7, 4), _f32), # [b, 4] + np.array([[1], [2]], np.int32), # indices: [2, 1] + RandArg((7, 2), _f32), # updates: [b, 2] + StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))], + poly_axes=[0, None, 0]), + PolyHarness("scatter_add", "clip1", + partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP), + arg_descriptors=[RandArg((7, 4), _f32), # [b, 4] + np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32), # indices: [b, 2] + RandArg((7, 1), _f32), # updates: [b, 1] + StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))], + poly_axes=[0, 0, 0]), + PolyHarness("select", "0", + # x.shape = (b, 3) + lambda x: lax.select(x > 5., x, x), + arg_descriptors=[RandArg((7, 3), _f32)], + poly_axes=[0]), + PolyHarness("select", "1", + # x.shape = (b, 3); y.shape = (3,) + jax.vmap(lambda x, y: lax.select(x > 5., x, y), in_axes=[0, None]), + arg_descriptors=[RandArg((7, 3), _f32), RandArg((3,), _f32)], + poly_axes=[0, None]), + PolyHarness("slice", "entire_axis", + lambda x: lax.slice(x, start_indices=(0, 1), limit_indices=(x.shape[0], 3)), + arg_descriptors=[RandArg((7, 3), _f32)], + poly_axes=[0]), + PolyHarness("slice_in_dim", "entire_axis", + lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1, axis=0), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("slice_in_dim", "start=neg", + lambda x: lax.slice_in_dim(x, -1, x.shape[0], stride=1, axis=0), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("slice_in_dim", "limit=neg", + lambda x: lax.slice_in_dim(x, 0, -1, stride=1, axis=0), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("squeeze", "axis=None", + jnp.squeeze, + arg_descriptors=[RandArg((5,), _f32), StaticArg(())], + poly_axes=[0]), + PolyHarness("squeeze", "axis=1", + jnp.squeeze, + arg_descriptors=[RandArg((4, 1), _f32), StaticArg((1,))], + poly_axes=[0]), + PolyHarness("squeeze", "axis=1_2", + jnp.squeeze, + arg_descriptors=[RandArg((4, 1, 1), _f32), StaticArg((1, 2))], + poly_axes=[0]), + PolyHarness("squeeze", "error", + jnp.squeeze, + arg_descriptors=[RandArg((3, 33), _f32), StaticArg(-1)], + poly_axes=[(0, 1)], + skip_jax_run=True, + expect_error=(ValueError, + re.escape( + "cannot select an axis to squeeze out which has size not equal to one, got shape=(b0, b1) and dimensions=(1,)")) + ), + PolyHarness("take", "", + lambda a, i: jnp.take(a, i, axis=1), + arg_descriptors=[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)], + poly_axes=[0, None]).both_enable_and_disable_xla(), + PolyHarness("take_along_axis", "0", + lambda x, y: jnp.take_along_axis(x, y, axis=0), + arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], + poly_axes=[0, 0]), + PolyHarness("take_along_axis", "1", + lambda x, y: jnp.take_along_axis(x, y, axis=1), + arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], + poly_axes=[0, 0]), + PolyHarness("tile", "0", + lambda x: jnp.tile(x, (1, 2)), + arg_descriptors=[RandArg((4, 3), _f32)], + poly_axes=[0]), + PolyHarness("tile", "1", + # The repetitions are polys + lambda x: jnp.tile(x, (1, x.shape[0])), + arg_descriptors=[RandArg((4, 2), _f32)], + poly_axes=[0]), + PolyHarness("tri", "N=poly_M=None", + lambda x: jnp.tri(x.shape[0]), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), + PolyHarness("tri", "N=poly_M=poly", + lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2), + arg_descriptors=[RandArg((3, 4), _f32)], + poly_axes=[0]), [ - _make_harness("var", - f"{axis=}_{keepdims=}_where=None", - lambda x, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=None), - [RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)], - poly_axes=[0]) + PolyHarness("var", + f"{axis=}_{keepdims=}_where=None", + lambda x, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=None), + arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)], + poly_axes=[0]) for keepdims in [False, True] for axis in [None, (0,), (0, 1), (1,)] ], [ - _make_harness("var", - f"{axis=}_{keepdims=}_where=Some", - lambda x, where, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=where), - [RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)], - poly_axes=[0, 0]) + PolyHarness("var", + f"{axis=}_{keepdims=}_where=Some", + lambda x, where, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=where), + arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)], + poly_axes=[0, 0]) for keepdims in [False, True] for axis in [None, (0,), (0, 1), (1,)] ], - _make_harness("where", "", - jnp.where, - [RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)], - poly_axes=[0, None, 0]), + PolyHarness("where", "", + jnp.where, + arg_descriptors=[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)], + poly_axes=[0, None, 0]), ] - -def _test_one_harness(tst: tf_test_util.JaxToTfTestCase, harness: Harness): - args = harness.dyn_args_maker(tst.rng()) - poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]] - assert len(args) == len(poly_axes) - # Make the polymorphic_shapes and input_signature - polymorphic_shapes: List[Optional[str]] = [] - input_signature: List[tf.TensorSpec] = [] - for arg, poly_axis in zip(args, poly_axes): - if poly_axis is None: - polymorphic_shapes.append(None) - input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype)) - else: - def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]: - idx = -1 - dims = [] - tensorspec_dims: List[Optional[int]] = [] - for i, d in enumerate(arg.shape): - if i in poly_axis: - idx += 1 - dims.append(f"b{idx}") - tensorspec_dims.append(None) - else: - dims.append(str(d)) - tensorspec_dims.append(d) - return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype) - - arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis) - polymorphic_shapes.append(arg_polymorphic_shapes) - input_signature.append(arg_tensorspec) - - skip_jax_run = harness.params["skip_jax_run"] - if not skip_jax_run: - res_jax = harness.dyn_fun(*args) - - enable_xla = harness.params.get("enable_xla", True) - expect_error_type, expect_error_regex = harness.params["expect_error"] - if expect_error_type is not None: - with tst.assertRaisesRegex(expect_error_type, expect_error_regex): - f_tf = tst.CheckShapePolymorphism( - harness.dyn_fun, - input_signature=input_signature, - polymorphic_shapes=polymorphic_shapes, - expected_output_signature=None, - enable_xla=enable_xla) - else: - f_tf = tst.CheckShapePolymorphism( - harness.dyn_fun, - input_signature=input_signature, - polymorphic_shapes=polymorphic_shapes, - expected_output_signature=None, - enable_xla=enable_xla) - - if not skip_jax_run and expect_error_type is None and harness.params["check_result"]: - tol = harness.params["tol"] - tst.assertAllClose(res_jax, f_tf(*args), atol=tol, rtol=tol) - - def _get_jax2tf_limitations( device, h: primitive_harness.Harness) -> Sequence[Jax2TfLimitation]: # And the jax2tf limitations @@ -2016,117 +2116,15 @@ def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool: limitations = Jax2TfLimitation.limitations_for_harness(h) return tuple(filter(applicable_jax2tf_limitation, limitations)) - -def _flatten_harnesses(harnesses): - res = [] - for h in harnesses: - if isinstance(h, Sequence): - res.extend(h) - else: - res.append(h) - return res - -# Set of harness.group_name:platform that are implemented with custom call -custom_call_harnesses = { - "cholesky:cpu", "cholesky:gpu", "eig:cpu", - "eigh:cpu", "eigh:gpu", "fft:cpu", - "householder_product:cpu", "householder_product:gpu", - "geqrf:cpu", "geqrf:gpu", "lu:cpu", "lu:gpu", "qr:cpu", "qr:gpu", - "random_gamma:gpu", "random_categorical:gpu", - "random_randint:gpu", "random_uniform:gpu", "random_split:gpu", - "svd:cpu", "svd:gpu"} - -# Set of harness.group_name or harness.group_name:platform that are implemented with HLO fallback lowering rules -fallback_lowering_harnesses = { - "approx_top_k:cpu", "bessel_i0e", "eigh:tpu", - "erf_inv", "igamma", "igammac", "lu", - "regularized_incomplete_beta", "qr:tpu", - "random_gamma:cpu", "random_gamma:tpu", "svd:tpu"} - -def _exclude_native_lowering_harnesses(harness: Harness): - if config.jax2tf_default_experimental_native_lowering and not harness.params.get("enable_xla", True): - raise unittest.SkipTest("disabled for experimental_native_lowering and enable_xla=False") - if config.jax2tf_default_experimental_native_lowering: - if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: - raise unittest.SkipTest("native lowering with shape polymorphism not implemented for custom calls; b/261671778") - if (config.jax2tf_default_experimental_native_lowering and - (harness.group_name in fallback_lowering_harnesses or - f"{harness.group_name}:{jtu.device_under_test()}" in fallback_lowering_harnesses)): - raise unittest.SkipTest( - "native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623") - -class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): - """Tests for primitives that take shape values as parameters.""" - - # This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES. - - # For each primitive "xxx" the test will be called "test_prim_xxx_...". - # If you want to run this test for only one harness that includes "foo" - # in the name (after test_prim), add parameter `one_containing="foo"` - # to parameterized below. - @primitive_harness.parameterized( - _flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES), - #one_containing="", - ) - def test_prim(self, harness: Harness): - _exclude_native_lowering_harnesses(harness) - _test_one_harness(self, harness) - - def test_vmap_while(self): - def cond_func(x): # x: f32[3] - return jnp.sum(x) >= 0. - def body_func(x): # x: f32[3] - return x - 1. - def f_jax(x): - return lax.while_loop(cond_func, body_func, x) - - self.CheckShapePolymorphism( - jax.vmap(f_jax), - input_signature=[tf.TensorSpec((None, 3), dtype=tf.float32)], - polymorphic_shapes=["b, ..."], - expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32) - ) - - def test_reshape_compiled(self): - # We compile the result of conversion for two shapes, hence we need to - # involve the TF compiler twice, but we trace only once with shape polymorphism - traced = False - - def f_jax(x): - nonlocal traced - traced = True - y = jnp.sin(x) - return y.reshape([x.shape[0], -1]) - - x = self.rng().rand(4, 2, 3) - res_jax = f_jax(x) - - traced = False - # If we get_concrete_function we trace once - f_tf = tf.function( - jax2tf.convert(f_jax, polymorphic_shapes=[PS("b", ...)]), - autograph=False, - jit_compile=True).get_concrete_function( - tf.TensorSpec([None, 2, 3], x.dtype)) - self.assertTrue(traced) - traced = False - self.assertAllClose(res_jax, f_tf(x)) - self.assertFalse(traced) # We are not tracing again - - x = self.rng().rand(6, 2, 3) - res_jax = f_jax(x) - traced = False - - self.assertAllClose(res_jax, f_tf(x)) - self.assertFalse(traced) # We are not tracing again - ### We add to the test harnesses some that are obtained from the ### primitive harnesses by applying vmap to the function and then asserting ### that we can convert shape polymorphically the result. -def _make_vmap_primitive_harnesses(): +def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]: """For each harness group, pick a single dtype. + See PolyHarness for documentation. + Ignore harnesses that fail in graph mode in jax2tf. """ all_h = primitive_harness.all_harnesses @@ -2203,34 +2201,76 @@ def wrap_custom(rng): # We do not check the result of harnesses that require custom assertions. check_result = all(not l.custom_assert and not l.skip_comparison and l.tol is None for l in _get_jax2tf_limitations(device, h)) - vmap_harness = _make_harness(h.group_name, h.name, - jax.vmap(h.dyn_fun, in_axes=0, out_axes=0), - new_args, - poly_axes=[0] * len(new_args), - check_result=check_result, - **h.params) + if h.group_name == "cumsum": + # TODO(necula): why do we need to adjust the cumsum tolerance? + tol = 1e-5 + else: + tol = None + vmap_harness = PolyHarness("vmap_" + h.group_name, h.name, + jax.vmap(h.dyn_fun, in_axes=0, out_axes=0), + arg_descriptors=new_args, + poly_axes=[0] * len(new_args), + check_result=check_result, + tol=tol) + vmap_harness.original_harness = h res.append(vmap_harness) return res -_POLY_SHAPE_VMAP_TEST_HARNESSES = _make_vmap_primitive_harnesses() +_POLY_SHAPE_TEST_HARNESSES.append(_make_vmap_primitive_harnesses()) + +def _flatten_harnesses(harnesses): + res = [] + for h in harnesses: + if isinstance(h, Sequence): + res.extend(h) + else: + res.append(h) + return res + +# Set of harness.group_name:platform that are implemented with custom call +custom_call_harnesses = { + "cholesky:cpu", "cholesky:gpu", "eig:cpu", + "eigh:cpu", "eigh:gpu", "fft:cpu", + "householder_product:cpu", "householder_product:gpu", + "geqrf:cpu", "geqrf:gpu", "lu:cpu", "lu:gpu", "qr:cpu", "qr:gpu", + "random_gamma:gpu", "random_categorical:gpu", + "random_randint:gpu", "random_uniform:gpu", "random_split:gpu", + "svd:cpu", "svd:gpu"} +# Set of harness.group_name or harness.group_name:platform that are implemented with HLO fallback lowering rules +fallback_lowering_harnesses = { + "approx_top_k:cpu", "bessel_i0e", "eigh:tpu", + "erf_inv", "igamma", "igammac", "lu", + "regularized_incomplete_beta", "qr:tpu", + "random_gamma:cpu", "random_gamma:tpu", "svd:tpu"} -class ShapePolyVmapPrimitivesTest(tf_test_util.JaxToTfTestCase): - """Tests that we can handle batch polymorphism for vmapped primitives.""" +class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): + """Tests for primitives that take shape values as parameters.""" - # This test runs for all _POLY_SHAPE_VMAP_PRIMITIVE_HARNESSES. + # This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES. - # For each primitive "xxx" the test will be called "test_vmap_prim_xxx_...". + # For each primitive "xxx" the test will be called "test_harness_xxx_...". # If you want to run this test for only one harness that includes "foo" - # in the name (after test_vmap_prim), add parameter `one_containing="foo"` + # in the name (after test_harness), add parameter `one_containing="foo"` # to parameterized below. @primitive_harness.parameterized( - _flatten_harnesses(_POLY_SHAPE_VMAP_TEST_HARNESSES), - one_containing="" + _flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES), + #one_containing="", ) - def test_vmap_prim(self, harness: Harness): - _exclude_native_lowering_harnesses(harness) - return _test_one_harness(self, harness) + def test_harness(self, harness: PolyHarness): + # Exclude some harnesses that are known to fail + if config.jax2tf_default_experimental_native_lowering and not harness.params.get("enable_xla", True): + raise unittest.SkipTest("disabled for experimental_native_lowering and enable_xla=False") + if config.jax2tf_default_experimental_native_lowering: + if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: + raise unittest.SkipTest("native lowering with shape polymorphism not implemented for custom calls; b/261671778") + if (config.jax2tf_default_experimental_native_lowering and + (harness.group_name in fallback_lowering_harnesses or + f"{harness.group_name}:{jtu.device_under_test()}" in fallback_lowering_harnesses)): + raise unittest.SkipTest( + "native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623") + + harness.run_test(self) if __name__ == "__main__": diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 861b768d4c31..b49ceb6c5487 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -30,7 +30,6 @@ from jax.config import config from jax.experimental import jax2tf -from jax._src import util from jax._src.lib import xla_bridge import numpy as np import tensorflow as tf # type: ignore[import] @@ -367,42 +366,6 @@ def TransformConvertAndCompare(self, func: Callable, arg, return self.ConvertAndCompare(grad_func, t_arg) assert False, transform - - def CheckShapePolymorphism(self, f_jax: Callable, *, - input_signature: Sequence[tf.TensorSpec], - polymorphic_shapes: Optional[Sequence[Any]], - expected_output_signature: Optional[tf.TensorSpec] = None, - enable_xla: bool = True): - """Converts a function using polymorphic shapes. - - Args: - f_jax: a JAX function of `n` arguments - input_signature: used as the input signature for the tf.function. - polymorphic_shapes: Specifies input shapes to be treated polymorphically - during conversion. - expected_output_signature: if given, this function tests whether the - actual output signature is equal to this one. - enable_xla: Whether to enable XLA conversion for jax2tf.convert. - """ - f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes, - enable_xla=enable_xla) - f_tf_func = tf.function( - f_tf, autograph=False, input_signature=input_signature) - concrete_f_tf = f_tf_func.get_concrete_function(*input_signature) - if expected_output_signature: - # Strangely, output_shapes can be a single shape for a function with a - # single result, or a list/tuple of shapes. - concrete_output_tf_shape = concrete_f_tf.output_shapes - if not isinstance(concrete_output_tf_shape, (tuple, list)): # Single result - assert not isinstance(expected_output_signature, (tuple, list)) - expected_output_signature = [expected_output_signature] - concrete_output_tf_shape = [concrete_output_tf_shape] - - for expected, found in util.safe_zip(expected_output_signature, - concrete_output_tf_shape): - self.assertEqual(tuple(expected.shape), tuple(found)) - return f_tf - def TfToHlo(self, tf_fun: Callable, *args): # Converts a tf.function to HLO text which we can inspect for occurrence of # substrings. This works whether we use native lowering or not.