diff --git a/jax/_src/util.py b/jax/_src/util.py index 35ca26bf6466..1eb1ed4d1ad0 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -449,9 +449,10 @@ def my_grad(f): and hence is thread-safe. Currently supported functions: - checkpoint, grad, hessian, jacfwd, jacrev, jit, jvp, lax.associative_scan, - lax.cond, lax.fori_loop, lax.scan, lax.switch, lax.while_loop, - linear_transpose, linearize, named_call, pmap, value_and_grad, vjp, vmap + checkpoint, eval_shape, grad, hessian, invertible, jacfwd, jacrev, jit, jvp, + lax.associative_scan, lax.cond, lax.fori_loop, lax.scan, lax.switch, + lax.while_loop, linear_transpose, linearize, named_call, pmap, value_and_grad, + vjp, vmap """ tokens = {k: _OVERRIDES[k].set(v) for k, v in implementations.items()} try: diff --git a/jax/api.py b/jax/api.py index 1113c4278013..c300f23bc7c5 100644 --- a/jax/api.py +++ b/jax/api.py @@ -2273,6 +2273,8 @@ def __eq__(self, other): def __hash__(self): return hash((self.shape, self.dtype)) + +@overrideable('eval_shape') def eval_shape(fun: Callable, *args, **kwargs): """Compute the shape/dtype of ``fun`` without any FLOPs. @@ -2623,6 +2625,7 @@ def vjpfun(ct): return ans, vjpfun defvjp_all(fun, custom_vjp) +@overrideable('invertible') def invertible(fun: Callable) -> Callable: """Asserts that the decorated function is invertible.