Skip to content

Commit

Permalink
fix typo in #9923
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 18, 2022
1 parent e9f59ae commit d60d5d7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def _maybe_perturbed(x: Any) -> bool:
# happen later, but some types always have trivial tangents.
vspace = x.aval.at_least_vspace()
return not (vspace is core.abstract_unit or vspace is core.abstract_token or
vspace is dtypes.float0)
getattr(vspace, 'dtype', None) is dtypes.float0)
elif not isinstance(x, ad.JVPTracer):
# If x is not a JVPTracer, recursively check its contents.
return any(_maybe_perturbed(attr) for name, attr in x._contents())
Expand Down
11 changes: 11 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5435,6 +5435,17 @@ def g(y, _):

jax.jvp(f, (1.0,), (1.0,)) # assertions inside f

def test_maybe_perturbed_int_regression(self):
# see https://github.com/google/jax/discussions/9951
from jax._src.custom_derivatives import closure_convert

@jax.jit
def f():
x = jnp.array(1)
_, aux_args = closure_convert(lambda: x)
self.assertEmpty(aux_args)
f()


class CustomVJPTest(jtu.JaxTestCase):

Expand Down

0 comments on commit d60d5d7

Please sign in to comment.