From c5f9f8636daf5ef6d4246d02fdbd2a397924cbad Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 7 Feb 2021 21:47:06 -0800 Subject: [PATCH] Make eval_shape and invertible overridable --- jax/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/api.py b/jax/api.py index 1113c4278013..c300f23bc7c5 100644 --- a/jax/api.py +++ b/jax/api.py @@ -2273,6 +2273,8 @@ def __eq__(self, other): def __hash__(self): return hash((self.shape, self.dtype)) + +@overrideable('eval_shape') def eval_shape(fun: Callable, *args, **kwargs): """Compute the shape/dtype of ``fun`` without any FLOPs. @@ -2623,6 +2625,7 @@ def vjpfun(ct): return ans, vjpfun defvjp_all(fun, custom_vjp) +@overrideable('invertible') def invertible(fun: Callable) -> Callable: """Asserts that the decorated function is invertible.