Skip to content

Commit

Permalink
Add error checking that arguments of jvp are tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Nov 27, 2019
1 parent ec79adc commit c1d8d3f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
5 changes: 5 additions & 0 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,11 @@ def jvp(fun, primals, tangents):
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)

if not isinstance(primals, tuple) or not isinstance(tangents, tuple):
msg = ("primal and tangent arguments to jax.jvp must be tuples; "
"found {} and {}.")
raise TypeError(msg.format(type(primals).__name__, type(tangents).__name__))

ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
Expand Down
12 changes: 12 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,18 @@ def test_jvp_mismatched_arguments(self):
"primal and tangent arguments to jax.jvp must have equal types",
lambda: api.jvp(lambda x: -x, (onp.float16(2),), (onp.float32(4),)))


def test_jvp_non_tuple_arguments(self):
def f(x, y): return x + y
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples; found float and tuple.",
lambda: partial(api.jvp(f, 0., (1.,))))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples; found tuple and ndarray.",
lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.]))))

def test_vjp_mismatched_arguments(self):
_, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4))
self.assertRaisesRegex(
Expand Down
2 changes: 1 addition & 1 deletion tests/generated_fun_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def testJVPMatchesFD(self, fun):
tangents = [tangents[i] for i in dyn_argnums]
fun, vals = partial_argnums(fun, vals, dyn_argnums)
ans1, deriv1 = jvp_fd(fun, vals, tangents)
ans2, deriv2 = jvp(fun, vals, tangents)
ans2, deriv2 = jvp(fun, tuple(vals), tuple(tangents))
check_all_close(ans1, ans2)
check_all_close(deriv1, deriv2)

Expand Down

0 comments on commit c1d8d3f

Please sign in to comment.