From f9474b221c15da64f4ff5afdf08778bde6fe0444 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 12 Nov 2023 18:17:07 +0100 Subject: [PATCH] [shape_poly] Refactor shape_poly_test in preparation for moving out of jax2tf. The shape polymotphism is now independent of jax2tf and the code is actually out of jax2tf. Here we refactor shape_poly_test to prepare for moving most of out of jax2tf. The main change is that we replace `jax2tf.convert(f_jax)(*args)` with a call to `check_shape_poly` which now still uses `jax2tf` but in the future will use JAX native mechanisms. --- .../jax2tf/tests/shape_poly_test.py | 185 +++++++++--------- 1 file changed, 94 insertions(+), 91 deletions(-) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 316b03b48908..2b367c44999e 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -560,7 +560,7 @@ def both_enable_and_disable_xla(self) -> tuple["PolyHarness", "PolyHarness"]: self.name = f"{self.name}_enable_xla_True" return (self, other) - def run_test(self, tst: tf_test_util.JaxToTfTestCase): + def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> Optional[jax.Array]: def log_message(extra: str): return f"[{tst._testMethodName}]: {extra}" @@ -609,7 +609,7 @@ def log_message(extra: str): concrete_f_tf = f_tf_func.get_concrete_function(*input_signature) if expect_error_type is not None: - return + return None if self.expected_output_signature: # Strangely, output_shapes can be a single shape for a function with a @@ -649,6 +649,11 @@ def log_message(extra: str): f"to {custom_assert_lims[0]}")) custom_assert_lims[0].custom_assert(tst, res_jax, res_tf, args=args, # type: ignore tol=tol, err_msg=None) + return res_tf + else: + return None + else: + return None def check_shape_poly(tst, f_jax: Callable, *, @@ -657,7 +662,7 @@ def check_shape_poly(tst, f_jax: Callable, *, polymorphic_shapes: Sequence[Optional[str]] = (), input_signature: Optional[Sequence[tf.TensorSpec]] = None, expected_output_signature: Optional[tf.TensorSpec] = None, - expect_error=(None, None)): + expect_error=(None, None)) -> Optional[jax.Array]: # Makes and tests a harness. See PolyHarness documentation. h = PolyHarness("", "", f_jax, arg_descriptors=arg_descriptors, @@ -666,7 +671,7 @@ def check_shape_poly(tst, f_jax: Callable, *, input_signature=input_signature, expected_output_signature=expected_output_signature, expect_error=expect_error) - h.run_test(tst) + return h.run_test(tst) class ShapePolyTest(tf_test_util.JaxToTfTestCase): @@ -730,43 +735,6 @@ def f_jax(x, y): polymorphic_shapes=["h, h", "h, h"], expected_output_signature=tf.TensorSpec([None, None])) - @jtu.parameterized_filterable( - # make_args invoked with op.shape[0]: start, stop, step, dtype - # b == 6 - kwargs=[ - # Positive step - dict(testcase_name="b", make_args=lambda b: (b, None, None, None)), - dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)), - dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)), - dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)), - dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)), - dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)), - dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)), - dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)), - dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)), - dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), # Cannot tell if size >= 0 - # Negative step - dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)), - dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)), - dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)), - dict(testcase_name="5b+1_0_-2", make_args=lambda b: (5 * b + 1, 0, -2, None)), - dict(testcase_name="5b+2_0_-2", make_args=lambda b: (5 * b + 2, 0, -2, None)), - dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), # Cannot tell if size >= 0 - # Symbolic step - dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)), - dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)), - dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)), - dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)), - # Float return type - dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32)) - ]) - def test_arange(self, make_args): - def f_jax(x): # x: i32[b] - return x[0] + jnp.arange(*(make_args(x.shape[0]))) - x = np.ones((6,), dtype=np.int32) - self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x), - f_jax(x)) - @jtu.parameterized_filterable( # make_args invoked with op.shape[0]: start, stop, step, dtype kwargs=[ @@ -792,14 +760,9 @@ def f_jax(x): # x: i32[b] return x[0] + jnp.arange(*(make_args(x.shape[0]))) x = np.ones((3,), dtype=np.int32) with self.assertRaisesRegex(expect_error, expect_msg): - jax2tf.convert(f_jax, polymorphic_shapes="b")(x) + check_shape_poly(self, f_jax, arg_descriptors=[x], + polymorphic_shapes=["b"]) - def test_argmax(self): - def f_jax(x): # x: f32[b, 4, 5] - return lax.argmax(x, axis=1, index_dtype=np.int32) - x = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) - self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x), - f_jax(x)) @jtu.parameterized_filterable( kwargs=[ @@ -996,11 +959,13 @@ def shaped_array(shape_spec: str, actual_shape: core.Shape): expected_shapeenv=dict(a=2, b=3, c=4)) def test_arg_avals_errors(self): - """Test error reporting for shape polymorpish.""" + """Test error reporting for shape polymorphism.""" def conv_and_run(*, arg_shape: core.Shape, polymorphic_shape: str): arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape) - jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg) + check_shape_poly(self, lambda x: x, + arg_descriptors=[arg], + polymorphic_shapes=[polymorphic_shape]) with self.assertRaisesRegex(ValueError, re.escape("polymorphic shape spec should be")): @@ -1094,7 +1059,9 @@ def f_jax(x): # x: f32[a + 2*b, a, a + b + c] with contextlib.ExitStack() as stack: if expect_error is not None: stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) - _ = jax2tf.convert(f_jax, polymorphic_shapes=[poly_spec])(x) + _ = check_shape_poly(self, f_jax, + arg_descriptors=[x], + polymorphic_shapes=[poly_spec]) def test_pytree(self): """Arguments and polymorphic_shapes are pytrees.""" @@ -1372,7 +1339,8 @@ def f(x, y): res_jax = f(x, y) self.assertAllClose( res_jax, - jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y)) + check_shape_poly(self, f, arg_descriptors=[x, y], + polymorphic_shapes=["(b, h)", "h"])) def test_while(self): def f(x): @@ -1382,7 +1350,8 @@ def f(x): (x, 0)) x = np.ones((3,), dtype=np.float32) - res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x) + res_tf = check_shape_poly(self, f, arg_descriptors=[x], + polymorphic_shapes=["(b,)"]) self.assertAllClose(f(x), res_tf) @jtu.parameterized_filterable( @@ -1671,22 +1640,26 @@ def f(x): return jnp.sum(x, axis=0) * x.shape[0] x = np.arange(3.) - self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)) - self.assertAllClose( - 9., - jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x)) + self.assertAllClose(9., + check_shape_poly(self, f, + arg_descriptors=[x], + polymorphic_shapes=["(b,)"])) self.assertAllClose( 9., - tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x)) + check_shape_poly(self, jax.jit(f), + arg_descriptors=[x], polymorphic_shapes=["(b,)"])) - res_primal, res_tangent = jax2tf.convert( + res_primal, res_tangent = check_shape_poly(self, lambda x, xt: jax.jvp(f, (x,), (xt,)), - polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3])) + arg_descriptors=[x, np.array([0.1, 0.2, 0.3])], + polymorphic_shapes=["b", "b"]) self.assertAllClose((9., 1.8), (res_primal, res_tangent)) self.assertAllClose( np.array([3., 3., 3.]), - jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x)) + check_shape_poly(self, jax.grad(f), + arg_descriptors=[x], + polymorphic_shapes=["b"])) xv = np.arange(24.).reshape((2, 3, 4)) res_vmap = jax.vmap(f, in_axes=1)(xv) @@ -1694,9 +1667,10 @@ def f(x): res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])]) self.assertAllClose(res_iter, res_vmap) - res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1), - polymorphic_shapes=["b1, b2, ..."])(xv) - self.assertAllClose(res_iter, res_vmap_tf.numpy()) + res_vmap_tf = check_shape_poly(self, jax.vmap(f, in_axes=1), + arg_descriptors=[xv], + polymorphic_shapes=["b1, b2, ..."]) + self.assertAllClose(res_iter, res_vmap_tf) def test_with_hash_collision_vmap(self): # Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have @@ -1948,33 +1922,6 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] res = jax2tf.convert(f2, polymorphic_shapes=zw_polymorphic_shapes)(z, w) self.assertAllClose(f2(* f1(x, y)), res) - def test_gather_1d(self): - operand = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], np.float32) - rand_idxs = np.random.randint(0, high=max(operand.shape), size=(3, 1), dtype=np.int32) - slice_x = np.zeros((10,), dtype=jnp.float32) - dnums = lax.GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,) - ) - - @jax.jit - def f_jax(operand, start_indices, x): - return lax.gather( - operand, - start_indices, - dimension_numbers=dnums, - slice_sizes=x.shape, - mode="promise_in_bounds", - ) - - res = f_jax(operand, rand_idxs, slice_x) - f_tf = jax2tf.convert( - f_jax, - native_serialization=True, - polymorphic_shapes=["(t, )", "(3, 1)", "(t)"], - ) - res_tf = f_tf(operand, rand_idxs, slice_x) - self.assertAllClose(res, res_tf) - # List containing either harnesses, or lists of harnesses _POLY_SHAPE_TEST_HARNESSES = [ @@ -1986,6 +1933,45 @@ def f_jax(operand, start_indices, x): jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))), arg_descriptors=[RandArg((3, 4), _f32)], polymorphic_shapes=["b, ..."]), + [ + # make_args invoked with op.shape[0] and produces the arange args: + # start, stop, step, dtype + PolyHarness("arange", kwargs["testcase_name"], # type: ignore + lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore + arg_descriptors=[RandArg((6,), np.float32)], + polymorphic_shapes=["b"]) + for kwargs in [ + # Positive step + dict(testcase_name="b", make_args=lambda b: (b, None, None, None)), + dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)), + dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)), + dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)), + dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)), + dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)), + dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)), + dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)), + dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)), + dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), + # Cannot tell if size >= 0 + # Negative step + dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)), + dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)), + dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)), + dict(testcase_name="5b+1_0_-2", + make_args=lambda b: (5 * b + 1, 0, -2, None)), + dict(testcase_name="5b+2_0_-2", + make_args=lambda b: (5 * b + 2, 0, -2, None)), + dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), + # Cannot tell if size >= 0 + # Symbolic step + dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)), + dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)), + dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)), + dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)), + # Float return type + dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32)) + ] + ], # Reduce the poly dimension PolyHarness("argmax", "0", lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), @@ -2328,6 +2314,23 @@ def f_jax(operand, start_indices, x): lambda x: lax.full((x.shape[0], 2), 3.) + x, arg_descriptors=[RandArg((3, 1), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("gather", "1d", + lambda operand, start_indices, x: lax.gather( + operand, + start_indices, + dimension_numbers=lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(), + start_index_map=(0,)), + slice_sizes=x.shape, + mode="promise_in_bounds"), + arg_descriptors=[ + RandArg((10,), np.float32), + np.random.randint(0, high=10, size=(3, 1), + dtype=np.int32), + np.zeros((10,), dtype=jnp.int32), + ], + polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]), # operand is non-poly, index is poly PolyHarness("getitem", "op=static_idx=poly", lambda a, i: a[i],