Skip to content

Commit

Permalink
[jax2tf] Refactoring of shape_poly_test.
Browse files Browse the repository at this point in the history
This all started because I noticed that the old
self.CheckShapePolymorphism was not running the converted
function and would only do the conversion in TF graph mode.
Then I realized that there were multiple ways of specifying
and running the tests: _make_harness, vmap harnesses,
self.CheckShapePolymorphism.

This PR unifies all test harnesses under a new PolyHarness class,
with new documentation. There is a helper function check_shape_poly
that simply wraps PolyHarness.

Since the new tests exercise the jax2tf more deeply, especially
in TF eager model, I have found 3 bugs. One is fixed here, in the
jax2tf._assert_matching_abstract_shape.
Two others are deferred (and a couple or tests are skipped here).
  • Loading branch information
gnecula committed Dec 19, 2022
1 parent b1415bb commit 3c17027
Show file tree
Hide file tree
Showing 3 changed files with 1,052 additions and 1,044 deletions.
9 changes: 7 additions & 2 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -974,9 +974,14 @@ def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize
"""Asserts that shape matches x.shape in the known dimensions and has
dimension polynomials elsewhere."""
# Ensures that the shape does not contain None; it should contain polynomials
def check_one(xd: Optional[int], sd: Any):
if core.is_constant_dim(sd):
return xd == sd
else:
assert isinstance(sd, shape_poly._DimPolynomial)
return True
assert (len(x.shape) == len(shape) and
all((xd is None and isinstance(sd, shape_poly._DimPolynomial) or
core.is_constant_dim(sd) and xd == sd)
all(check_one(xd, sd)
for xd, sd in zip(x.shape, shape))), \
f"Shape {shape} does not match x.shape {x.shape}"

Expand Down

0 comments on commit 3c17027

Please sign in to comment.