diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 2d2ca917b2ad..13302eb32d93 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -45,7 +45,7 @@ from jax.lib import xla_bridge as xb from jax.lib import xla_client from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip, - split_list, cache, extend_name_stack) + split_list, cache, extend_name_stack, overrideable) from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, treedef_children, treedef_tuple, tree_multimap, tree_leaves) @@ -137,6 +137,7 @@ def scanned_fun(loop_carry, _): return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None return scanned_fun +@overrideable('lax.fori_loop') def fori_loop(lower, upper, body_fun, init_val): """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. @@ -203,6 +204,7 @@ def fori_loop(lower, upper, body_fun, init_val): return result +@overrideable('lax.while_loop') def while_loop(cond_fun: Callable[[T], bool], body_fun: Callable[[T], T], init_val: T) -> T: @@ -549,6 +551,7 @@ def _while_transpose_error(*_, **kwargs): ### cond and switch +@overrideable('lax.switch') def switch(index, branches: Sequence[Callable], operand): """Apply exactly one of ``branches`` given by ``index``. @@ -696,6 +699,7 @@ def cond(pred, true_fun, false_fun, operand): branches=(false_jaxpr, true_jaxpr), linear=linear) return tree_unflatten(out_tree, out) +@overrideable('lax.cond') @functools.wraps(_cond) def cond(*args, **kwargs): # detect an attempt to call the former, deprecated cond @@ -1132,6 +1136,7 @@ def cond_bind(*args, branches, linear): X = TypeVar('X') Y = TypeVar('Y') +@overrideable('lax.scan') def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, @@ -2377,6 +2382,8 @@ def _interleave(a, b, axis): return lax.add(lax.pad(a, lax._const(a, 0), a_pad), lax.pad(b, lax._const(b, 0), b_pad)) + +@overrideable('lax.associative_scan') def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): """Performs a scan with an associative binary operation, in parallel. diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 42029903497b..0bd94a8ce7d2 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -945,7 +945,7 @@ def _lu_pivots_body_fn(i, permutation_and_swaps): return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps -@partial(api.jit, static_argnums=(1,)) +@partial(api.jit, static_argnums=(1,)) # type: ignore def lu_pivots_to_permutation(swaps, m): """Converts the pivots (row swaps) returned by LU to a permutation. @@ -991,7 +991,7 @@ def _lu_solve_core(lu, permutation, b, trans): return lax.reshape(x, b.shape) -@partial(api.jit, static_argnums=(3,)) +@partial(api.jit, static_argnums=(3,)) # type: ignore def _lu_solve(lu, permutation, b, trans): if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]: raise ValueError("last two dimensions of LU decomposition must be equal, " diff --git a/jax/_src/util.py b/jax/_src/util.py index b11192ebb3a1..35ca26bf6466 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import functools import itertools as it import operator +import threading import types -from typing import Any, Callable +from typing import Any, Callable, Mapping, TypeVar import numpy as np @@ -371,3 +372,90 @@ def __repr__(self): def as_hashable_function(closure): return lambda f: HashableFunction(f, closure) + + +class ContextVar: + """Like contextvars.ContextVar, but implemented with threading.local().""" + # TODO(shoyer): remove in favor of contextvars when JAX requires Python 3.7+ + + def __init__(self, name, *, default): + self.name = name + self.default = default + self._state = threading.local() + self._state.value = default + self._lock = threading.Lock() + + def get(self): + return getattr(self._state, 'value', self.default) + + def set(self, value): + with self._lock: + old_value = self.get() + self._state.value = value + return old_value + + def reset(self, token): + self._state.value = token + + +_OVERRIDES = {} + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def overrideable(name: str) -> Callable[[F], F]: + """Make an internal JAX function overrideable.""" + def decorator(fun): + _OVERRIDES[name] = ContextVar(name, default=fun) + @functools.wraps(fun) + def wrapper(*args, **kwargs): + impl = _OVERRIDES[name].get() + # A typical override will call back into the original JAX function after + # doing some processing. We remove the override here, because otherwise + # this would end up in an infinite loop. + with override_context({name: fun}): + return impl(*args, **kwargs) + return wrapper + return decorator + + +@contextlib.contextmanager +def override_context(implementations: Mapping[str, Callable]): + """Experimental override for JAX functions within a context. + + This context manager allows for overriding the implementation of higher order + JAX functions within a limited scope. It is intended for libraries such as + Haiku and Flax that implement their own versions of these functions that + support mutation, and may be removed in the future if/when JAX has a unified + interface for handling mutable state. + + Usage example:: + + import jax + + def my_grad(f): + print("inside my_grad") + return jax.grad(f) + + with jax.experimental.override_context({'grad': my_grad}): + # All calls to jax.grad() are replaced by my_grad(). + # However, calls to jax.grad() from *inside* the implementation of + # my_grad() will use the original. + y = jax.grad(lambda x: x ** 2)(1.0) # prints "inside my_grad" + assert y == 2 + + This context manager only overrides implementations within the current thread, + and hence is thread-safe. + + Currently supported functions: + checkpoint, grad, hessian, jacfwd, jacrev, jit, jvp, lax.associative_scan, + lax.cond, lax.fori_loop, lax.scan, lax.switch, lax.while_loop, + linear_transpose, linearize, named_call, pmap, value_and_grad, vjp, vmap + """ + tokens = {k: _OVERRIDES[k].set(v) for k, v in implementations.items()} + try: + yield + finally: + for k, v in tokens.items(): + _OVERRIDES[k].reset(v) diff --git a/jax/api.py b/jax/api.py index e76838752590..1113c4278013 100644 --- a/jax/api.py +++ b/jax/api.py @@ -53,7 +53,7 @@ treedef_is_leaf, treedef_children, Partial) from ._src.util import (unzip2, curry, partial, safe_map, safe_zip, prod, split_list, extend_name_stack, wrap_name, cache, wraps, - HashableFunction) + HashableFunction, overrideable) from .lib import jax_jit from .lib import version from .lib import xla_bridge as xb @@ -86,7 +86,7 @@ # in JIT internals, as Tracer values are passed through the function. # Should this raise any type errors for the tracing code in future, we can disable # type checking in parts of the tracing code, or remove these annotations. -F = TypeVar("F", bound=Callable) +F = TypeVar("F", bound=Callable[..., Any]) T = TypeVar("T") U = TypeVar("U") @@ -121,6 +121,7 @@ def __init__(self): _thread_local_state = _ThreadLocalState() +@overrideable('jit') def jit( fun: F, static_argnums: Union[int, Iterable[int]] = (), @@ -690,6 +691,8 @@ def computation_maker(*args, **kwargs): return computation_maker + +@overrideable('grad') def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: @@ -752,6 +755,8 @@ def grad_f_aux(*args, **kwargs): return grad_f_aux if has_aux else grad_f + +@overrideable('value_and_grad') def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable[..., Tuple[Any, Any]]: @@ -862,6 +867,7 @@ def _check_output_dtype_revderiv(name, holomorphic, x): _check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad") +@overrideable('jacfwd') def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0, holomorphic: bool = False) -> Callable: """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. @@ -928,6 +934,7 @@ def _check_output_dtype_jacfwd(holomorphic, x): f"but got {aval.dtype.name}.") +@overrideable('jacrev') def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0, holomorphic: bool = False, allow_int: bool = False) -> Callable: """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. @@ -980,6 +987,7 @@ def jacfun(*args, **kwargs): _check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev") +@overrideable('hessian') def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0, holomorphic: bool = False) -> Callable: """Hessian of ``fun`` as a dense array. @@ -1069,6 +1077,7 @@ def _dtype(x): return dtypes.canonicalize_dtype(dtypes.result_type(x)) +@overrideable('vmap') def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F: """Vectorizing map. Creates a function which maps ``fun`` over argument axes. @@ -1274,6 +1283,8 @@ def _get_axis_size(name: str, i:int, shape: Tuple[int, ...], axis: int): sizes = tree_unflatten(tree, sizes) raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None + +@overrideable('pmap') def pmap( fun: F, axis_name: Optional[AxisName] = None, @@ -1632,6 +1643,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable): map(tuple, out_shapes), out_tree_thunk()) return fun +@overrideable('jvp') def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]: """Computes a (forward-mode) Jacobian-vector product of ``fun``. @@ -1697,6 +1709,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents): return (tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_tangents)) +@overrideable('linearize') def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]: """Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval. @@ -1827,7 +1840,8 @@ def vjp(fun: Callable[..., Any], Tuple[Any, Callable, Any]]: ... -def vjp( # type: ignore +@overrideable('vjp') # type: ignore +def vjp( fun: Callable, *primals, has_aux: bool = False, ) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]: """Compute a (reverse-mode) vector-Jacobian product of ``fun``. @@ -1897,6 +1911,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) +@overrideable('linear_transpose') def linear_transpose(fun: Callable, *primals) -> Callable: """Transpose a function that is promised to be linear. @@ -2328,6 +2343,7 @@ def abstractify(x): return tree_unflatten(out_tree(), out) +@overrideable('checkpoint') def checkpoint(fun: Callable, concrete: bool = False) -> Callable: """Make ``fun`` recompute internal linearization points when differentiated. @@ -2422,6 +2438,7 @@ def fun_remat(*args, **kwargs): remat = checkpoint +@overrideable('named_call') def named_call( fun: Callable[..., Any], *, diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 97336644dc10..65850c3cf1dd 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -16,3 +16,4 @@ from ..interpreters.sharded_jit import (sharded_jit, PartitionSpec, with_sharding_constraint) from .x64_context import enable_x64, disable_x64 +from .._src.util import override_context diff --git a/tests/api_test.py b/tests/api_test.py index 5fa481430195..c82646eb4355 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5018,5 +5018,31 @@ def test_partial_eval(self): self.assertEqual(out, 5) +class _CountedCalls: + def __init__(self, fun): + self.fun = fun + self.calls = 0 + def __call__(self, *args, **kwargs): + self.calls += 1 + return self.fun(*args, **kwargs) + + +class OverrideTest(jtu.JaxTestCase): + + def test(self): + + counted_jit = _CountedCalls(jax.jit) + + jax.jit(lambda x: x)(1) + self.assertEqual(counted_jit.calls, 0) + + with jax.experimental.override_context({'jit': counted_jit}): + jax.jit(lambda x: x)(1) + self.assertEqual(counted_jit.calls, 1) + + jax.jit(lambda x: x)(1) + self.assertEqual(counted_jit.calls, 1) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())