Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separate jax.test_util implementations into public and private sources.
Eventually the private functionality will no longer be exported via the jax.test_util submodule. PiperOrigin-RevId: 439415485
- Loading branch information
Jake VanderPlas
authored and
jax authors
committed
Apr 4, 2022
1 parent
71a5eb2
commit 1246b6f
Showing
3 changed files
with
332 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from functools import partial | ||
|
||
from jax import config | ||
from jax.tree_util import tree_map, tree_reduce | ||
from jax._src import api | ||
from jax._src import dtypes as _dtypes | ||
from jax._src.config import flags | ||
from jax._src.lib import xla_bridge | ||
import numpy as np | ||
|
||
|
||
# The only functions intended to be exported are these; they should be used via | ||
# jax.test_util. All other functionality appearing here is for internal use only, | ||
# and may be changed or removed at any time and without any deprecation cycle. | ||
__all__ = ['check_grads', 'check_jvp', 'check_vjp'] | ||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
EPS = 1e-4 | ||
|
||
|
||
def _dtype(x): | ||
if hasattr(x, 'dtype'): | ||
return x.dtype | ||
elif type(x) in _dtypes.python_scalar_dtypes: | ||
return np.dtype(_dtypes.python_scalar_dtypes[type(x)]) | ||
else: | ||
return np.asarray(x).dtype | ||
|
||
|
||
_default_tolerance = { | ||
_dtypes.float0: 0, | ||
np.dtype(np.bool_): 0, | ||
np.dtype(np.int8): 0, | ||
np.dtype(np.int16): 0, | ||
np.dtype(np.int32): 0, | ||
np.dtype(np.int64): 0, | ||
np.dtype(np.uint8): 0, | ||
np.dtype(np.uint16): 0, | ||
np.dtype(np.uint32): 0, | ||
np.dtype(np.uint64): 0, | ||
np.dtype(_dtypes.bfloat16): 1e-2, | ||
np.dtype(np.float16): 1e-3, | ||
np.dtype(np.float32): 1e-6, | ||
np.dtype(np.float64): 1e-15, | ||
np.dtype(np.complex64): 1e-6, | ||
np.dtype(np.complex128): 1e-15, | ||
} | ||
|
||
|
||
def default_tolerance(): | ||
if device_under_test() != "tpu": | ||
return _default_tolerance | ||
tol = _default_tolerance.copy() | ||
tol[np.dtype(np.float32)] = 1e-3 | ||
tol[np.dtype(np.complex64)] = 1e-3 | ||
return tol | ||
|
||
|
||
default_gradient_tolerance = { | ||
np.dtype(_dtypes.bfloat16): 1e-1, | ||
np.dtype(np.float16): 1e-2, | ||
np.dtype(np.float32): 2e-3, | ||
np.dtype(np.float64): 1e-5, | ||
np.dtype(np.complex64): 1e-3, | ||
np.dtype(np.complex128): 1e-5, | ||
} | ||
|
||
def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): | ||
if a.dtype == b.dtype == _dtypes.float0: | ||
np.testing.assert_array_equal(a, b, err_msg=err_msg) | ||
return | ||
a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a | ||
b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b | ||
kw = {} | ||
if atol: kw["atol"] = atol | ||
if rtol: kw["rtol"] = rtol | ||
with np.errstate(invalid='ignore'): | ||
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid | ||
# value errors. It should not do that. | ||
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) | ||
|
||
def tolerance(dtype, tol=None): | ||
tol = {} if tol is None else tol | ||
if not isinstance(tol, dict): | ||
return tol | ||
tol = {np.dtype(key): value for key, value in tol.items()} | ||
dtype = _dtypes.canonicalize_dtype(np.dtype(dtype)) | ||
return tol.get(dtype, default_tolerance()[dtype]) | ||
|
||
|
||
def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''): | ||
a, b = np.asarray(a), np.asarray(b) | ||
assert a.shape == b.shape | ||
atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) | ||
rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) | ||
_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size, | ||
err_msg=err_msg) | ||
|
||
|
||
def check_close(xs, ys, atol=None, rtol=None, err_msg=''): | ||
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol, | ||
err_msg=err_msg) | ||
tree_map(assert_close, xs, ys) | ||
|
||
|
||
def _check_dtypes_match(xs, ys): | ||
def _assert_dtypes_match(x, y): | ||
if config.x64_enabled: | ||
assert _dtype(x) == _dtype(y) | ||
else: | ||
assert (_dtypes.canonicalize_dtype(_dtype(x)) == | ||
_dtypes.canonicalize_dtype(_dtype(y))) | ||
tree_map(_assert_dtypes_match, xs, ys) | ||
|
||
|
||
def inner_prod(xs, ys): | ||
def contract(x, y): | ||
return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1))) | ||
return tree_reduce(np.add, tree_map(contract, xs, ys)) | ||
|
||
|
||
def _safe_subtract(x, y, *, dtype): | ||
"""Subtraction that with `inf - inf == 0` semantics.""" | ||
with np.errstate(invalid='ignore'): | ||
return np.where(np.equal(x, y), np.array(0, dtype), | ||
np.subtract(x, y, dtype=dtype)) | ||
|
||
add = partial(tree_map, lambda x, y: np.add(x, y, dtype=_dtype(x))) | ||
sub = partial(tree_map, lambda x, y: np.subtract(x, y, dtype=_dtype(x))) | ||
safe_sub = partial(tree_map, | ||
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x))) | ||
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x))) | ||
|
||
|
||
def scalar_mul(xs, a): | ||
def mul(x): | ||
dtype = _dtype(x) | ||
return np.multiply(x, np.array(a, dtype=dtype), dtype=dtype) | ||
return tree_map(mul, xs) | ||
|
||
|
||
def rand_like(rng, x): | ||
shape = np.shape(x) | ||
dtype = _dtype(x) | ||
randn = lambda: np.asarray(rng.randn(*shape), dtype=dtype) | ||
if _dtypes.issubdtype(dtype, np.complexfloating): | ||
return randn() + dtype.type(1.0j) * randn() | ||
else: | ||
return randn() | ||
|
||
|
||
def numerical_jvp(f, primals, tangents, eps=EPS): | ||
delta = scalar_mul(tangents, eps) | ||
f_pos = f(*add(primals, delta)) | ||
f_neg = f(*sub(primals, delta)) | ||
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps) | ||
|
||
|
||
def _merge_tolerance(tol, default): | ||
if tol is None: | ||
return default | ||
if not isinstance(tol, dict): | ||
return tol | ||
out = default.copy() | ||
for k, v in tol.items(): | ||
out[np.dtype(k)] = v | ||
return out | ||
|
||
|
||
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''): | ||
atol = _merge_tolerance(atol, default_gradient_tolerance) | ||
rtol = _merge_tolerance(rtol, default_gradient_tolerance) | ||
rng = np.random.RandomState(0) | ||
tangent = tree_map(partial(rand_like, rng), args) | ||
v_out, t_out = f_jvp(args, tangent) | ||
_check_dtypes_match(v_out, t_out) | ||
v_out_expected = f(*args) | ||
_check_dtypes_match(v_out, v_out_expected) | ||
t_out_expected = numerical_jvp(f, args, tangent, eps=eps) | ||
# In principle we should expect exact equality of v_out and v_out_expected, | ||
# but due to nondeterminism especially on GPU (e.g., due to convolution | ||
# autotuning) we only require "close". | ||
check_close(v_out, v_out_expected, atol=atol, rtol=rtol, | ||
err_msg=f'{err_msg} primal' if err_msg else 'primal') | ||
check_close(t_out, t_out_expected, atol=atol, rtol=rtol, | ||
err_msg=f'{err_msg} tangent' if err_msg else 'tangent') | ||
|
||
|
||
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''): | ||
atol = _merge_tolerance(atol, default_gradient_tolerance) | ||
rtol = _merge_tolerance(rtol, default_gradient_tolerance) | ||
_rand_like = partial(rand_like, np.random.RandomState(0)) | ||
v_out, vjpfun = f_vjp(*args) | ||
v_out_expected = f(*args) | ||
check_close(v_out, v_out_expected, atol=atol, rtol=rtol, | ||
err_msg=f'{err_msg} primal' if err_msg else 'primal') | ||
tangent = tree_map(_rand_like, args) | ||
tangent_out = numerical_jvp(f, args, tangent, eps=eps) | ||
cotangent = tree_map(_rand_like, v_out) | ||
cotangent_out = conj(vjpfun(conj(cotangent))) | ||
ip = inner_prod(tangent, cotangent_out) | ||
ip_expected = inner_prod(tangent_out, cotangent) | ||
check_close(ip, ip_expected, atol=atol, rtol=rtol, | ||
err_msg=(f'{err_msg} cotangent projection' | ||
if err_msg else 'cotangent projection')) | ||
|
||
|
||
def check_grads(f, args, order, | ||
modes=("fwd", "rev"), atol=None, rtol=None, eps=None): | ||
"""Check gradients from automatic differentiation against finite differences. | ||
Gradients are only checked in a single randomly chosen direction, which | ||
ensures that the finite difference calculation does not become prohibitively | ||
expensive even for large input/output spaces. | ||
Args: | ||
f: function to check at ``f(*args)``. | ||
args: tuple of argument values. | ||
order: forward and backwards gradients up to this order are checked. | ||
modes: lists of gradient modes to check ('fwd' and/or 'rev'). | ||
atol: absolute tolerance for gradient equality. | ||
rtol: relative tolerance for gradient equality. | ||
eps: step size used for finite differences. | ||
Raises: | ||
AssertionError: if gradients do not match. | ||
""" | ||
args = tuple(args) | ||
eps = eps or EPS | ||
|
||
_check_jvp = partial(check_jvp, atol=atol, rtol=rtol, eps=eps) | ||
_check_vjp = partial(check_vjp, atol=atol, rtol=rtol, eps=eps) | ||
|
||
def _check_grads(f, args, order, err_msg=''): | ||
if "fwd" in modes: | ||
fwd_msg = f'JVP of {err_msg}' if err_msg else 'JVP' | ||
_check_jvp(f, partial(api.jvp, f), args, err_msg=fwd_msg) | ||
if order > 1: | ||
_check_grads(partial(api.jvp, f), (args, args), order - 1, fwd_msg) | ||
|
||
if "rev" in modes: | ||
rev_msg = f'VJP of {err_msg}' if err_msg else 'VJP' | ||
_check_vjp(f, partial(api.vjp, f), args, err_msg=rev_msg) | ||
if order > 1: | ||
def f_vjp(*args): | ||
out_primal_py, vjp_py = api.vjp(f, *args) | ||
return vjp_py(out_primal_py) | ||
_check_grads(f_vjp, args, order - 1, rev_msg) | ||
|
||
_check_grads(f, args, order) | ||
|
||
|
||
def device_under_test(): | ||
return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform |
Oops, something went wrong.