Skip to content


Separate jax.test_util implementations into public and private sources.
Browse files Browse the repository at this point in the history
Eventually the private functionality will no longer be exported via the jax.test_util submodule.

PiperOrigin-RevId: 439415485
  • Loading branch information
Jake VanderPlas authored and jax authors committed Apr 4, 2022
1 parent 71a5eb2 commit 1246b6f
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 265 deletions.
270 changes: 270 additions & 0 deletions jax/_src/
@@ -0,0 +1,270 @@
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from jax import config
from jax.tree_util import tree_map, tree_reduce
from jax._src import api
from jax._src import dtypes as _dtypes
from jax._src.config import flags
from jax._src.lib import xla_bridge
import numpy as np

# The only functions intended to be exported are these; they should be used via
# jax.test_util. All other functionality appearing here is for internal use only,
# and may be changed or removed at any time and without any deprecation cycle.
__all__ = ['check_grads', 'check_jvp', 'check_vjp']


EPS = 1e-4

def _dtype(x):
if hasattr(x, 'dtype'):
return x.dtype
elif type(x) in _dtypes.python_scalar_dtypes:
return np.dtype(_dtypes.python_scalar_dtypes[type(x)])
return np.asarray(x).dtype

_default_tolerance = {
_dtypes.float0: 0,
np.dtype(np.bool_): 0,
np.dtype(np.int8): 0,
np.dtype(np.int16): 0,
np.dtype(np.int32): 0,
np.dtype(np.int64): 0,
np.dtype(np.uint8): 0,
np.dtype(np.uint16): 0,
np.dtype(np.uint32): 0,
np.dtype(np.uint64): 0,
np.dtype(_dtypes.bfloat16): 1e-2,
np.dtype(np.float16): 1e-3,
np.dtype(np.float32): 1e-6,
np.dtype(np.float64): 1e-15,
np.dtype(np.complex64): 1e-6,
np.dtype(np.complex128): 1e-15,

def default_tolerance():
if device_under_test() != "tpu":
return _default_tolerance
tol = _default_tolerance.copy()
tol[np.dtype(np.float32)] = 1e-3
tol[np.dtype(np.complex64)] = 1e-3
return tol

default_gradient_tolerance = {
np.dtype(_dtypes.bfloat16): 1e-1,
np.dtype(np.float16): 1e-2,
np.dtype(np.float32): 2e-3,
np.dtype(np.float64): 1e-5,
np.dtype(np.complex64): 1e-3,
np.dtype(np.complex128): 1e-5,

def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
if a.dtype == b.dtype == _dtypes.float0:
np.testing.assert_array_equal(a, b, err_msg=err_msg)
a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a
b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b
kw = {}
if atol: kw["atol"] = atol
if rtol: kw["rtol"] = rtol
with np.errstate(invalid='ignore'):
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
# value errors. It should not do that.
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)

def tolerance(dtype, tol=None):
tol = {} if tol is None else tol
if not isinstance(tol, dict):
return tol
tol = {np.dtype(key): value for key, value in tol.items()}
dtype = _dtypes.canonicalize_dtype(np.dtype(dtype))
return tol.get(dtype, default_tolerance()[dtype])

def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
a, b = np.asarray(a), np.asarray(b)
assert a.shape == b.shape
atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol))
rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol))
_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size,

def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol,
tree_map(assert_close, xs, ys)

def _check_dtypes_match(xs, ys):
def _assert_dtypes_match(x, y):
if config.x64_enabled:
assert _dtype(x) == _dtype(y)
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
tree_map(_assert_dtypes_match, xs, ys)

def inner_prod(xs, ys):
def contract(x, y):
return np.real(, y.reshape(-1)))
return tree_reduce(np.add, tree_map(contract, xs, ys))

def _safe_subtract(x, y, *, dtype):
"""Subtraction that with `inf - inf == 0` semantics."""
with np.errstate(invalid='ignore'):
return np.where(np.equal(x, y), np.array(0, dtype),
np.subtract(x, y, dtype=dtype))

add = partial(tree_map, lambda x, y: np.add(x, y, dtype=_dtype(x)))
sub = partial(tree_map, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
safe_sub = partial(tree_map,
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))

def scalar_mul(xs, a):
def mul(x):
dtype = _dtype(x)
return np.multiply(x, np.array(a, dtype=dtype), dtype=dtype)
return tree_map(mul, xs)

def rand_like(rng, x):
shape = np.shape(x)
dtype = _dtype(x)
randn = lambda: np.asarray(rng.randn(*shape), dtype=dtype)
if _dtypes.issubdtype(dtype, np.complexfloating):
return randn() + dtype.type(1.0j) * randn()
return randn()

def numerical_jvp(f, primals, tangents, eps=EPS):
delta = scalar_mul(tangents, eps)
f_pos = f(*add(primals, delta))
f_neg = f(*sub(primals, delta))
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps)

def _merge_tolerance(tol, default):
if tol is None:
return default
if not isinstance(tol, dict):
return tol
out = default.copy()
for k, v in tol.items():
out[np.dtype(k)] = v
return out

def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
rng = np.random.RandomState(0)
tangent = tree_map(partial(rand_like, rng), args)
v_out, t_out = f_jvp(args, tangent)
_check_dtypes_match(v_out, t_out)
v_out_expected = f(*args)
_check_dtypes_match(v_out, v_out_expected)
t_out_expected = numerical_jvp(f, args, tangent, eps=eps)
# In principle we should expect exact equality of v_out and v_out_expected,
# but due to nondeterminism especially on GPU (e.g., due to convolution
# autotuning) we only require "close".
check_close(v_out, v_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} primal' if err_msg else 'primal')
check_close(t_out, t_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} tangent' if err_msg else 'tangent')

def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
_rand_like = partial(rand_like, np.random.RandomState(0))
v_out, vjpfun = f_vjp(*args)
v_out_expected = f(*args)
check_close(v_out, v_out_expected, atol=atol, rtol=rtol,
err_msg=f'{err_msg} primal' if err_msg else 'primal')
tangent = tree_map(_rand_like, args)
tangent_out = numerical_jvp(f, args, tangent, eps=eps)
cotangent = tree_map(_rand_like, v_out)
cotangent_out = conj(vjpfun(conj(cotangent)))
ip = inner_prod(tangent, cotangent_out)
ip_expected = inner_prod(tangent_out, cotangent)
check_close(ip, ip_expected, atol=atol, rtol=rtol,
err_msg=(f'{err_msg} cotangent projection'
if err_msg else 'cotangent projection'))

def check_grads(f, args, order,
modes=("fwd", "rev"), atol=None, rtol=None, eps=None):
"""Check gradients from automatic differentiation against finite differences.
Gradients are only checked in a single randomly chosen direction, which
ensures that the finite difference calculation does not become prohibitively
expensive even for large input/output spaces.
f: function to check at ``f(*args)``.
args: tuple of argument values.
order: forward and backwards gradients up to this order are checked.
modes: lists of gradient modes to check ('fwd' and/or 'rev').
atol: absolute tolerance for gradient equality.
rtol: relative tolerance for gradient equality.
eps: step size used for finite differences.
AssertionError: if gradients do not match.
args = tuple(args)
eps = eps or EPS

_check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps)
_check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps)

def _check_grads(f, args, order, err_msg=''):
if "fwd" in modes:
fwd_msg = f'JVP of {err_msg}' if err_msg else 'JVP'
_check_jvp(f, partial(api.jvp, f), args, err_msg=fwd_msg)
if order > 1:
_check_grads(partial(api.jvp, f), (args, args), order - 1, fwd_msg)

if "rev" in modes:
rev_msg = f'VJP of {err_msg}' if err_msg else 'VJP'
_check_vjp(f, partial(api.vjp, f), args, err_msg=rev_msg)
if order > 1:
def f_vjp(*args):
out_primal_py, vjp_py = api.vjp(f, *args)
return vjp_py(out_primal_py)
_check_grads(f_vjp, args, order - 1, rev_msg)

_check_grads(f, args, order)

def device_under_test():
return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform

0 comments on commit 1246b6f

Please sign in to comment.