Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jax2tf] Refactoring of shape_poly_test.
This all started because I noticed that the old self.CheckShapePolymorphism was not running the converted function and would only do the conversion in TF graph mode. Then I realized that there were multiple ways of specifying and running the tests: _make_harness, vmap harnesses, self.CheckShapePolymorphism. This PR unifies all test harnesses under a new PolyHarness class, with new documentation. There is a helper function check_shape_poly that simply wraps PolyHarness. Since the new tests exercise the jax2tf more deeply, especially in TF eager model, I have found 3 bugs. One is fixed here, in the jax2tf._assert_matching_abstract_shape. Two others are deferred (and a couple or tests are skipped here).
- Loading branch information