Skip to content

Commit

Permalink
[shape_poly] Refactor shape_poly_test in preparation for moving out o…
Browse files Browse the repository at this point in the history
…f jax2tf.

The shape polymotphism is now independent of jax2tf and the code is actually
out of jax2tf. Here we refactor shape_poly_test to prepare for moving most of
out of jax2tf.

The main change is that we replace `jax2tf.convert(f_jax)(*args)` with
a call to `check_shape_poly` which now still uses `jax2tf` but in the
future will use JAX native mechanisms.
  • Loading branch information
gnecula committed Nov 13, 2023
1 parent c8f3e23 commit f9474b2
Showing 1 changed file with 94 additions and 91 deletions.
185 changes: 94 additions & 91 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -560,7 +560,7 @@ def both_enable_and_disable_xla(self) -> tuple["PolyHarness", "PolyHarness"]:
self.name = f"{self.name}_enable_xla_True"
return (self, other)

def run_test(self, tst: tf_test_util.JaxToTfTestCase):
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> Optional[jax.Array]:
def log_message(extra: str):
return f"[{tst._testMethodName}]: {extra}"

Expand Down Expand Up @@ -609,7 +609,7 @@ def log_message(extra: str):
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)

if expect_error_type is not None:
return
return None

if self.expected_output_signature:
# Strangely, output_shapes can be a single shape for a function with a
Expand Down Expand Up @@ -649,6 +649,11 @@ def log_message(extra: str):
f"to {custom_assert_lims[0]}"))
custom_assert_lims[0].custom_assert(tst, res_jax, res_tf, args=args, # type: ignore
tol=tol, err_msg=None)
return res_tf
else:
return None
else:
return None


def check_shape_poly(tst, f_jax: Callable, *,
Expand All @@ -657,7 +662,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
polymorphic_shapes: Sequence[Optional[str]] = (),
input_signature: Optional[Sequence[tf.TensorSpec]] = None,
expected_output_signature: Optional[tf.TensorSpec] = None,
expect_error=(None, None)):
expect_error=(None, None)) -> Optional[jax.Array]:
# Makes and tests a harness. See PolyHarness documentation.
h = PolyHarness("", "", f_jax,
arg_descriptors=arg_descriptors,
Expand All @@ -666,7 +671,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
input_signature=input_signature,
expected_output_signature=expected_output_signature,
expect_error=expect_error)
h.run_test(tst)
return h.run_test(tst)


class ShapePolyTest(tf_test_util.JaxToTfTestCase):
Expand Down Expand Up @@ -730,43 +735,6 @@ def f_jax(x, y):
polymorphic_shapes=["h, h", "h, h"],
expected_output_signature=tf.TensorSpec([None, None]))

@jtu.parameterized_filterable(
# make_args invoked with op.shape[0]: start, stop, step, dtype
# b == 6
kwargs=[
# Positive step
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), # Cannot tell if size >= 0
# Negative step
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
dict(testcase_name="5b+1_0_-2", make_args=lambda b: (5 * b + 1, 0, -2, None)),
dict(testcase_name="5b+2_0_-2", make_args=lambda b: (5 * b + 2, 0, -2, None)),
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), # Cannot tell if size >= 0
# Symbolic step
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
# Float return type
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
])
def test_arange(self, make_args):
def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((6,), dtype=np.int32)
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
f_jax(x))

@jtu.parameterized_filterable(
# make_args invoked with op.shape[0]: start, stop, step, dtype
kwargs=[
Expand All @@ -792,14 +760,9 @@ def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((3,), dtype=np.int32)
with self.assertRaisesRegex(expect_error, expect_msg):
jax2tf.convert(f_jax, polymorphic_shapes="b")(x)
check_shape_poly(self, f_jax, arg_descriptors=[x],
polymorphic_shapes=["b"])

def test_argmax(self):
def f_jax(x): # x: f32[b, 4, 5]
return lax.argmax(x, axis=1, index_dtype=np.int32)
x = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
f_jax(x))

@jtu.parameterized_filterable(
kwargs=[
Expand Down Expand Up @@ -996,11 +959,13 @@ def shaped_array(shape_spec: str, actual_shape: core.Shape):
expected_shapeenv=dict(a=2, b=3, c=4))

def test_arg_avals_errors(self):
"""Test error reporting for shape polymorpish."""
"""Test error reporting for shape polymorphism."""
def conv_and_run(*, arg_shape: core.Shape,
polymorphic_shape: str):
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg)
check_shape_poly(self, lambda x: x,
arg_descriptors=[arg],
polymorphic_shapes=[polymorphic_shape])

with self.assertRaisesRegex(ValueError,
re.escape("polymorphic shape spec should be")):
Expand Down Expand Up @@ -1094,7 +1059,9 @@ def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
_ = jax2tf.convert(f_jax, polymorphic_shapes=[poly_spec])(x)
_ = check_shape_poly(self, f_jax,
arg_descriptors=[x],
polymorphic_shapes=[poly_spec])

def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""
Expand Down Expand Up @@ -1372,7 +1339,8 @@ def f(x, y):
res_jax = f(x, y)
self.assertAllClose(
res_jax,
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
check_shape_poly(self, f, arg_descriptors=[x, y],
polymorphic_shapes=["(b, h)", "h"]))

def test_while(self):
def f(x):
Expand All @@ -1382,7 +1350,8 @@ def f(x):
(x, 0))

x = np.ones((3,), dtype=np.float32)
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
res_tf = check_shape_poly(self, f, arg_descriptors=[x],
polymorphic_shapes=["(b,)"])
self.assertAllClose(f(x), res_tf)

@jtu.parameterized_filterable(
Expand Down Expand Up @@ -1671,32 +1640,37 @@ def f(x):
return jnp.sum(x, axis=0) * x.shape[0]

x = np.arange(3.)
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(
9.,
jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(9.,
check_shape_poly(self, f,
arg_descriptors=[x],
polymorphic_shapes=["(b,)"]))
self.assertAllClose(
9.,
tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x))
check_shape_poly(self, jax.jit(f),
arg_descriptors=[x], polymorphic_shapes=["(b,)"]))

res_primal, res_tangent = jax2tf.convert(
res_primal, res_tangent = check_shape_poly(self,
lambda x, xt: jax.jvp(f, (x,), (xt,)),
polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3]))
arg_descriptors=[x, np.array([0.1, 0.2, 0.3])],
polymorphic_shapes=["b", "b"])
self.assertAllClose((9., 1.8), (res_primal, res_tangent))

self.assertAllClose(
np.array([3., 3., 3.]),
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
check_shape_poly(self, jax.grad(f),
arg_descriptors=[x],
polymorphic_shapes=["b"]))

xv = np.arange(24.).reshape((2, 3, 4))
res_vmap = jax.vmap(f, in_axes=1)(xv)
# Implement by iteration
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
self.assertAllClose(res_iter, res_vmap)

res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1),
polymorphic_shapes=["b1, b2, ..."])(xv)
self.assertAllClose(res_iter, res_vmap_tf.numpy())
res_vmap_tf = check_shape_poly(self, jax.vmap(f, in_axes=1),
arg_descriptors=[xv],
polymorphic_shapes=["b1, b2, ..."])
self.assertAllClose(res_iter, res_vmap_tf)

def test_with_hash_collision_vmap(self):
# Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have
Expand Down Expand Up @@ -1948,33 +1922,6 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
res = jax2tf.convert(f2, polymorphic_shapes=zw_polymorphic_shapes)(z, w)
self.assertAllClose(f2(* f1(x, y)), res)

def test_gather_1d(self):
operand = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], np.float32)
rand_idxs = np.random.randint(0, high=max(operand.shape), size=(3, 1), dtype=np.int32)
slice_x = np.zeros((10,), dtype=jnp.float32)
dnums = lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)
)

@jax.jit
def f_jax(operand, start_indices, x):
return lax.gather(
operand,
start_indices,
dimension_numbers=dnums,
slice_sizes=x.shape,
mode="promise_in_bounds",
)

res = f_jax(operand, rand_idxs, slice_x)
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"],
)
res_tf = f_tf(operand, rand_idxs, slice_x)
self.assertAllClose(res, res_tf)


# List containing either harnesses, or lists of harnesses
_POLY_SHAPE_TEST_HARNESSES = [
Expand All @@ -1986,6 +1933,45 @@ def f_jax(operand, start_indices, x):
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]),
[
# make_args invoked with op.shape[0] and produces the arange args:
# start, stop, step, dtype
PolyHarness("arange", kwargs["testcase_name"], # type: ignore
lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore
arg_descriptors=[RandArg((6,), np.float32)],
polymorphic_shapes=["b"])
for kwargs in [
# Positive step
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)),
# Cannot tell if size >= 0
# Negative step
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
dict(testcase_name="5b+1_0_-2",
make_args=lambda b: (5 * b + 1, 0, -2, None)),
dict(testcase_name="5b+2_0_-2",
make_args=lambda b: (5 * b + 2, 0, -2, None)),
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)),
# Cannot tell if size >= 0
# Symbolic step
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
# Float return type
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
]
],
# Reduce the poly dimension
PolyHarness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
Expand Down Expand Up @@ -2328,6 +2314,23 @@ def f_jax(operand, start_indices, x):
lambda x: lax.full((x.shape[0], 2), 3.) + x,
arg_descriptors=[RandArg((3, 1), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("gather", "1d",
lambda operand, start_indices, x: lax.gather(
operand,
start_indices,
dimension_numbers=lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(),
start_index_map=(0,)),
slice_sizes=x.shape,
mode="promise_in_bounds"),
arg_descriptors=[
RandArg((10,), np.float32),
np.random.randint(0, high=10, size=(3, 1),
dtype=np.int32),
np.zeros((10,), dtype=jnp.int32),
],
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]),
# operand is non-poly, index is poly
PolyHarness("getitem", "op=static_idx=poly",
lambda a, i: a[i],
Expand Down

0 comments on commit f9474b2

Please sign in to comment.