forked from HIPS/autograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_util.py
104 lines (89 loc) · 3.58 KB
/
test_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import itertools as it
from .vspace import vspace
from .core import make_vjp, make_jvp
from .util import subvals
TOL = 1e-6
RTOL = 1e-6
def scalar_close(a, b):
return abs(a - b) < TOL or abs(a - b) / abs(a + b) < RTOL
EPS = 1e-6
def make_numerical_jvp(f, x):
y = f(x)
x_vs, y_vs = vspace(x), vspace(y)
def jvp(v):
# (f(x + v*eps/2) - f(x - v*eps/2)) / eps
f_x_plus = f(x_vs.add(x, x_vs.scalar_mul(v, EPS/2)))
f_x_minus = f(x_vs.add(x, x_vs.scalar_mul(v, -EPS/2)))
neg_f_x_minus = y_vs.scalar_mul(f_x_minus, -1.0)
return y_vs.scalar_mul(y_vs.add(f_x_plus, neg_f_x_minus), 1.0 / EPS)
return jvp
def check_vjp_unary(f, x):
vjp, y = make_vjp(f, x)
jvp = make_numerical_jvp(f, x)
x_vs, y_vs = vspace(x), vspace(y)
x_v, y_v = x_vs.randn(), y_vs.randn()
vjp_y = x_vs.covector(vjp(y_vs.covector(y_v)))
vjv_numeric = x_vs.inner_prod(x_v, vjp_y)
vjv_exact = y_vs.inner_prod(y_v, jvp(x_v))
assert scalar_close(vjv_numeric, vjv_exact), \
"Derivative check failed with arg {}:\nanalytic: {}\nnumeric: {}".format(
x, vjv_numeric, vjv_exact)
def check_vjp(f, argnums=None, order=2):
def _check_vjp(*args, **kwargs):
if not order: return
_argnums = argnums if argnums else range(len(args))
x = tuple(args[argnum] for argnum in _argnums)
f_unary = lambda x: f(*subvals(args, zip(_argnums, x)), **kwargs)
check_vjp_unary(f_unary, x)
v = vspace(f_unary(x)).randn()
f_unary_vjp = lambda x, v: make_vjp(f_unary, x)[0](v)
check_vjp(f_unary_vjp, order=order-1)(x, v)
return _check_vjp
def check_jvp_unary(f, x):
y = f(x)
jvp = make_jvp(f, x)
jvp_numeric = make_numerical_jvp(f, x)
x_vs = vspace(x)
x_v = x_vs.randn()
check_equivalent(jvp(x_v), jvp_numeric(x_v))
def check_jvp(f, argnums=None, order=2):
def _check_jvp(*args, **kwargs):
if not order: return
_argnums = argnums if argnums else range(len(args))
x = tuple(args[argnum] for argnum in _argnums)
f_unary = lambda x: f(*subvals(args, zip(_argnums, x)), **kwargs)
check_jvp_unary(f_unary, x)
v = vspace(f_unary(x)).randn()
f_unary_vjp = lambda x, v: make_vjp(f_unary, x)[0](v)
check_jvp(f_unary_vjp, order=order-1)(x, v)
u = vspace(x).randn()
f_unary_jvp = lambda x, v: make_jvp(f_unary, x)(v)
check_vjp(f_unary_jvp, order=order-1)(x, u)
return _check_jvp
# backwards compatibility
def check_grads(f, *args, **kwargs):
fwd = kwargs.pop('fwd', True)
check_vjp(f, order=1)(*args)
if fwd:
check_jvp(f, order=1)(*args)
def nd(f, *args):
return [make_numerical_jvp(lambda args: f(*args), args)(v)
for v in vspace(args).standard_basis()]
def check_equivalent(x, y):
x_vs, y_vs = vspace(x), vspace(y)
assert x_vs == y_vs, "VSpace mismatch:\nx: {}\ny: {}".format(x_vs, y_vs)
v = x_vs.randn()
assert scalar_close(x_vs.inner_prod(x, v), x_vs.inner_prod(y, v)), \
"Value mismatch:\nx: {}\ny: {}".format(x, y)
def combo_check(fun, argnums, *args, **kwargs):
# Tests all combinations of args given.
fwd = kwargs.pop('fwd', True)
args = list(args)
kwarg_key_vals = [[(key, val) for val in kwargs[key]] for key in kwargs]
num_args = len(args)
for args_and_kwargs in it.product(*(args + kwarg_key_vals)):
cur_args = args_and_kwargs[:num_args]
cur_kwargs = dict(args_and_kwargs[num_args:])
check_vjp(fun, argnums)(*cur_args, **cur_kwargs)
if fwd:
check_jvp(fun, argnums)(*cur_args, **cur_kwargs)