Skip to content

Commit

Permalink
Make eval_shape and invertible overridable
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Feb 8, 2021
1 parent 06987b4 commit 314f302
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
7 changes: 4 additions & 3 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,10 @@ def my_grad(f):
and hence is thread-safe.
Currently supported functions:
checkpoint, grad, hessian, jacfwd, jacrev, jit, jvp, lax.associative_scan,
lax.cond, lax.fori_loop, lax.scan, lax.switch, lax.while_loop,
linear_transpose, linearize, named_call, pmap, value_and_grad, vjp, vmap
checkpoint, eval_shape, grad, hessian, invertible, jacfwd, jacrev, jit, jvp,
lax.associative_scan, lax.cond, lax.fori_loop, lax.scan, lax.switch,
lax.while_loop, linear_transpose, linearize, named_call, pmap, value_and_grad,
vjp, vmap
"""
tokens = {k: _OVERRIDES[k].set(v) for k, v in implementations.items()}
try:
Expand Down
3 changes: 3 additions & 0 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,6 +2273,8 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.shape, self.dtype))


@overrideable('eval_shape')
def eval_shape(fun: Callable, *args, **kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
Expand Down Expand Up @@ -2623,6 +2625,7 @@ def vjpfun(ct):
return ans, vjpfun
defvjp_all(fun, custom_vjp)

@overrideable('invertible')
def invertible(fun: Callable) -> Callable:
"""Asserts that the decorated function is invertible.
Expand Down

0 comments on commit 314f302

Please sign in to comment.