Skip to content

Commit

Permalink
[jax2tf] More cleanup for shape polymorphism testing
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Jul 30, 2021
1 parent b6e25fa commit 64d2e5c
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 149 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ def need_dim_var_msg():
dim_poly = _parse_dim(dim_spec)
if not is_poly_dim(dim_poly):
if dim_poly != dim_size:
msg = (f"PolyShape {repr(spec)} in axis {i} must contain a constant or '_' "
f"for known dimension in argument shape {arg_shape}")
msg = (f"PolyShape {repr(spec)} in axis {i} must match the "
f"known dimension size {dim_size} for argument shape {arg_shape}")
raise ValueError(msg)
return dim_size
return dim_poly
Expand Down
18 changes: 18 additions & 0 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,24 @@ def fun_tf_outer_2(x):
_ = tf.function(fun_tf_outer_2)(x)
_ = tf.function(fun_tf_outer_2, jit_compile=True)(x)

def test_repro_193754660(self):
# Try to reproduce b/193754660. I can't.
# We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))).
# The get_compiler_ir will indeed fail for f_tf. Then we try to use
# shape inference for f_tf.
# I thought to use a f_tf that uses an op without shape inference, e.g.,
# tfxla.gather. If we wash it through a saved_model I expect that shape
# inference would not work on it. Instead, shape inference works!!!
x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32)
def f_jax(x):
return x[1]
f_tf = jax2tf.convert(f_jax)
f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
f_jax2 = jax2tf.call_tf(f_tf_rt)
f_tf2 = jax2tf.convert(f_jax2)
res = tf.function(f_tf2, autograph=False)(x)
self.assertAllClose(res.numpy(), f_jax(x))

def test_module_documentation(self):
def cos_tf(x):
return tf.math.cos(x)
Expand Down

0 comments on commit 64d2e5c

Please sign in to comment.