Skip to content

Commit

Permalink
Refactoring of jax_to_tf tests: (google#3262) (google#3308)
Browse files Browse the repository at this point in the history
* Moved control-flow tests into their own file
* Added a helper module tf_test_util, with a helper function ConvertAndCompare
* Used self.assertAllClose instead of numpy.testing.assert_all_close because
  the former iterates over lists and tuples (and is standard in other JAX tests)
* Used @parameterized.named_parameters for parameterized tests, for nicer test
 names.
  • Loading branch information
gnecula committed Jun 4, 2020
1 parent afa9276 commit 71f1c5c
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 231 deletions.
13 changes: 13 additions & 0 deletions jax/experimental/jax_to_tf/tests/__init__.py
@@ -0,0 +1,13 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
152 changes: 152 additions & 0 deletions jax/experimental/jax_to_tf/tests/control_flow_ops_test.py
@@ -0,0 +1,152 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the jax_to_tf conversion for control-flow primitives."""

from absl.testing import absltest
from absl.testing import parameterized
from typing import Any, Callable, Sequence, Tuple

import jax
import jax.lax as lax
import jax.numpy as jnp
from jax import test_util as jtu
import numpy as np

from jax.experimental import jax_to_tf
from jax.experimental.jax_to_tf.tests import tf_test_util

from jax.config import config
config.parse_flags_with_absl()


class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_cond(self, with_function=False):
def f_jax(pred, x):
return lax.cond(pred, lambda t: t + 1., lambda f: f, x)

with jax_to_tf.enable_jit():
self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_cond_multiple_results(self, with_function=False):
def f_jax(pred, x):
return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x)

with jax_to_tf.enable_jit():
self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_while_single_carry(self, with_function=False):
"""A while with a single carry"""
def func(x):
# Equivalent to:
# for(i=x; i < 4; i++);
return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)

with jax_to_tf.enable_jit():
self.ConvertAndCompare(func, 0, with_function=with_function)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_while(self, with_function=False):
# Some constants to capture in the conditional branches
cond_const = np.ones(3, dtype=np.float32)
body_const1 = np.full_like(cond_const, 1.)
body_const2 = np.full_like(cond_const, 2.)

def func(x):
# Equivalent to:
# c = [1, 1, 1]
# for(i=0; i < 3; i++)
# c += [1, 1, 1] + [2, 2, 2]
#
# The function is set-up so that it captures constants in the
# body of the functionals. This covers some cases in the representation
# of the lax.while primitive.
def cond(idx_carry):
i, c = idx_carry
return i < jnp.sum(lax.tie_in(i, cond_const)) # Capture cond_const

def body(idx_carry):
i, c = idx_carry
return (i + 1, c + body_const1 + body_const2)

return lax.while_loop(cond, body, (0, x))

with jax_to_tf.enable_jit():
self.ConvertAndCompare(func, cond_const, with_function=with_function)


@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_while_batched(self, with_function=True):
"""A while with a single carry"""
def product(x, y):
# Equivalent to "x * y" implemented as:
# res = 0.
# for(i=0; i < y; i++)
# res += x
return lax.while_loop(lambda idx_carry: idx_carry[0] < y,
lambda idx_carry: (idx_carry[0] + 1,
idx_carry[1] + x),
(0, 0.))

# We use vmap to compute result[i, j] = i * j
xs = np.arange(4, dtype=np.int32)
ys = np.arange(5, dtype=np.int32)

def product_xs_y(xs, y):
return jax.vmap(product, in_axes=(0, None))(xs, y)
def product_xs_ys(xs, ys):
return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)

with jax_to_tf.enable_jit():
self.ConvertAndCompare(product_xs_ys, xs, ys, with_function=with_function)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_scan(self, with_function=False):
def f_jax(xs, ys):
body_const = np.ones((2, ), dtype=np.float32) # Test constant capture
def body(res0, inputs):
x, y = inputs
return res0 + x * y, body_const
return lax.scan(body, 0., (xs, ys))

arg = np.arange(10, dtype=np.float32)
with jax_to_tf.enable_jit():
self.ConvertAndCompare(f_jax, arg, arg, with_function=with_function)


if __name__ == "__main__":
absltest.main()
17 changes: 6 additions & 11 deletions jax/experimental/jax_to_tf/tests/savedmodel_test.py
Expand Up @@ -13,41 +13,36 @@
# limitations under the License.

import os
import unittest

import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
import jax.numpy as jnp

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow as tf # type: ignore[import]

from jax.experimental import jax_to_tf
from jax.experimental.jax_to_tf.tests import tf_test_util

from jax.config import config
config.parse_flags_with_absl()


class SavedModelTest(jtu.JaxTestCase):
class SavedModelTest(tf_test_util.JaxToTfTestCase):

def testSavedModel(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
model = tf.Module()
model.f = tf.function(jax_to_tf.convert(f_jax),
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7)
np.testing.assert_allclose(model.f(x), f_jax(x))
self.assertAllClose(model.f(x), f_jax(x))
# Roundtrip through saved model on disk.
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
tf.saved_model.save(model, model_dir)
restored_model = tf.saved_model.load(model_dir)
np.testing.assert_allclose(restored_model.f(x), f_jax(x))
self.assertAllClose(restored_model.f(x), f_jax(x))


if __name__ == "__main__":
absltest.main()
9 changes: 4 additions & 5 deletions jax/experimental/jax_to_tf/tests/stax_test.py
Expand Up @@ -20,7 +20,7 @@
import os
import sys

from jax.experimental import jax_to_tf
from jax.experimental.jax_to_tf.tests import tf_test_util

# Import ../../../../examples/resnet50.py
def from_examples_import_resnet50():
Expand All @@ -30,7 +30,7 @@ def from_examples_import_resnet50():
assert os.path.isfile(os.path.join(examples_dir, "resnet50.py"))
try:
sys.path.append(examples_dir)
import resnet50 # type: ignore[import-error]
import resnet50 # type: ignore
return resnet50
finally:
sys.path.pop()
Expand All @@ -44,7 +44,7 @@ def from_examples_import_resnet50():
config.parse_flags_with_absl()


class StaxTest(jtu.JaxTestCase):
class StaxTest(tf_test_util.JaxToTfTestCase):

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_res_net(self):
Expand All @@ -54,8 +54,7 @@ def test_res_net(self):
_, params = init_fn(key, shape)
infer = functools.partial(apply_fn, params)
images = np.array(jax.random.normal(key, shape))
np.testing.assert_allclose(infer(images), jax_to_tf.convert(infer)(images),
rtol=0.5)
self.ConvertAndCompare(infer, images, rtol=0.5)


if __name__ == "__main__":
Expand Down

0 comments on commit 71f1c5c

Please sign in to comment.