Skip to content

Commit

Permalink
Use jax.* APIs rather than api.* names in tests.
Browse files Browse the repository at this point in the history
Tests should use our own public APIs where they exist.
  • Loading branch information
hawkinsp committed Sep 13, 2021
1 parent 3218b06 commit 9f083d1
Show file tree
Hide file tree
Showing 16 changed files with 516 additions and 524 deletions.
231 changes: 115 additions & 116 deletions tests/host_callback_test.py

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions tests/host_callback_to_tf_test.py
Expand Up @@ -24,7 +24,7 @@
from absl.testing import absltest
from absl.testing import parameterized

from jax._src import api
import jax
from jax.config import config
from jax import numpy as jnp
from jax import test_util as jtu
Expand Down Expand Up @@ -64,7 +64,7 @@ def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape):
functions and must be called outside the JAX computation.
"""

@api.custom_vjp
@jax.custom_vjp
def make_call(arg):
"""We wrap it all in `make_call` so that we can attach custom VJP."""
return call_tf_no_ad(tf_fun, arg, result_shape=result_shape)
Expand Down Expand Up @@ -103,7 +103,7 @@ def call_tf_full_ad(tf_fun: Callable, arg, *, result_shape):
Supports higher-order AD and pytree arguments.
"""

@api.custom_vjp
@jax.custom_vjp
def make_call(arg):
"""We wrap it all in `make_call` so that we can attach custom VJP."""
return call_tf_no_ad(tf_fun, arg, result_shape=result_shape)
Expand Down Expand Up @@ -181,7 +181,7 @@ def f_outside(x):

res = f_outside(3.)
self.assertAllClose(f_jax(3.), res)
self.assertAllClose(f_jax(3.), api.jit(f_outside)(3.))
self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.))

@parameterized.named_parameters(
dict(
Expand All @@ -201,8 +201,8 @@ def f_outside(x):
x = 4.
self.assertAllClose(f_jax(x), f_outside(x))

grad_f = api.grad(f_outside)(x)
self.assertAllClose(api.grad(f_jax)(x), grad_f)
grad_f = jax.grad(f_outside)(x)
self.assertAllClose(jax.grad(f_jax)(x), grad_f)

def test_grad_pytree(self):
call_tf = call_tf_full_ad
Expand All @@ -220,8 +220,8 @@ def f_outside(xy):

xy = (5., 6.)
self.assertAllClose(f_jax(xy), f_outside(xy))
res_jax = api.grad(f_jax)(xy)
self.assertAllClose(res_jax, api.grad(f_outside)(xy))
res_jax = jax.grad(f_jax)(xy)
self.assertAllClose(res_jax, jax.grad(f_outside)(xy))

@parameterized.named_parameters(
dict(
Expand All @@ -241,8 +241,8 @@ def f_outside(x):
grad_jax = f_jax
grad_outside = f_outside
for i in range(degree):
grad_jax = api.grad(grad_jax)
grad_outside = api.grad(grad_outside)
grad_jax = jax.grad(grad_jax)
grad_outside = jax.grad(grad_outside)

res_jax = grad_jax(5.)
self.assertAllClose(res_jax, grad_outside(5.))
Expand Down
23 changes: 11 additions & 12 deletions tests/lax_autodiff_test.py
Expand Up @@ -24,7 +24,6 @@
import numpy as np

import jax
from jax._src import api
from jax import dtypes
from jax import lax
from jax import test_util as jtu
Expand Down Expand Up @@ -405,9 +404,9 @@ def testDotGrad(self, lhs_shape, rhs_shape, dtype):
check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
# check that precision config is preserved
result, pullback = api.vjp(dot, lhs, rhs)
result, pullback = jax.vjp(dot, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult))
s = str(jax.make_jaxpr(pullback)(gresult))
assert "Precision.HIGHEST" in s

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -436,9 +435,9 @@ def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
precision=lax.Precision.HIGHEST)
check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
# check that precision config is preserved
result, pullback = api.vjp(dot_general, lhs, rhs)
result, pullback = jax.vjp(dot_general, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult))
s = str(jax.make_jaxpr(pullback)(gresult))
assert "Precision.HIGHEST" in s

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -1016,15 +1015,15 @@ def f2(x, y):
return lax.sin(x) * lax.cos(y)

x = 3.14
ans = api.grad(f)(x)
expected = api.grad(f2)(x, x)
ans = jax.grad(f)(x)
expected = jax.grad(f2)(x, x)
self.assertAllClose(ans, expected)

ans = api.grad(api.grad(f))(x)
expected = api.grad(api.grad(f2))(x, x)
ans = jax.grad(jax.grad(f))(x)
expected = jax.grad(jax.grad(f2))(x, x)
self.assertAllClose(ans, expected)

ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
ans = jax.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
expected = np.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)

Expand Down Expand Up @@ -1058,14 +1057,14 @@ def inv(x):

def test_linear_transpose_real(self):
f = lambda x: x.real
transpose = api.linear_transpose(f, 1.j)
transpose = jax.linear_transpose(f, 1.j)
actual, = transpose(1.)
expected = 1.
self.assertEqual(actual, expected)

def test_linear_transpose_imag(self):
f = lambda x: x.imag
transpose = api.linear_transpose(f, 1.j)
transpose = jax.linear_transpose(f, 1.j)
actual, = transpose(1.)
expected = -1.j
self.assertEqual(actual, expected)
Expand Down

0 comments on commit 9f083d1

Please sign in to comment.