Skip to content

Commit

Permalink
Support overriding implementations of JAX functions within a scope
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Feb 7, 2021
1 parent 8bf3f03 commit 5ae5c9b
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 8 deletions.
9 changes: 8 additions & 1 deletion jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip,
split_list, cache, extend_name_stack)
split_list, cache, extend_name_stack, overrideable)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_multimap,
tree_leaves)
Expand Down Expand Up @@ -137,6 +137,7 @@ def scanned_fun(loop_carry, _):
return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
return scanned_fun

@overrideable('lax.fori_loop')
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
Expand Down Expand Up @@ -203,6 +204,7 @@ def fori_loop(lower, upper, body_fun, init_val):
return result


@overrideable('lax.while_loop')
def while_loop(cond_fun: Callable[[T], bool],
body_fun: Callable[[T], T],
init_val: T) -> T:
Expand Down Expand Up @@ -549,6 +551,7 @@ def _while_transpose_error(*_, **kwargs):

### cond and switch

@overrideable('lax.switch')
def switch(index, branches: Sequence[Callable], operand):
"""Apply exactly one of ``branches`` given by ``index``.
Expand Down Expand Up @@ -696,6 +699,7 @@ def cond(pred, true_fun, false_fun, operand):
branches=(false_jaxpr, true_jaxpr), linear=linear)
return tree_unflatten(out_tree, out)

@overrideable('lax.cond')
@functools.wraps(_cond)
def cond(*args, **kwargs):
# detect an attempt to call the former, deprecated cond
Expand Down Expand Up @@ -1132,6 +1136,7 @@ def cond_bind(*args, branches, linear):
X = TypeVar('X')
Y = TypeVar('Y')

@overrideable('lax.scan')
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
Expand Down Expand Up @@ -2377,6 +2382,8 @@ def _interleave(a, b, axis):
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))


@overrideable('lax.associative_scan')
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
"""Performs a scan with an associative binary operation, in parallel.
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def _lu_pivots_body_fn(i, permutation_and_swaps):
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps


@partial(api.jit, static_argnums=(1,))
@partial(api.jit, static_argnums=(1,)) # type: ignore
def lu_pivots_to_permutation(swaps, m):
"""Converts the pivots (row swaps) returned by LU to a permutation.
Expand Down Expand Up @@ -991,7 +991,7 @@ def _lu_solve_core(lu, permutation, b, trans):
return lax.reshape(x, b.shape)


@partial(api.jit, static_argnums=(3,))
@partial(api.jit, static_argnums=(3,)) # type: ignore
def _lu_solve(lu, permutation, b, trans):
if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]:
raise ValueError("last two dimensions of LU decomposition must be equal, "
Expand Down
92 changes: 90 additions & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import contextlib
import functools
import itertools as it
import operator
import threading
import types
from typing import Any, Callable
from typing import Any, Callable, Mapping, TypeVar

import numpy as np

Expand Down Expand Up @@ -371,3 +372,90 @@ def __repr__(self):

def as_hashable_function(closure):
return lambda f: HashableFunction(f, closure)


class ContextVar:
"""Like contextvars.ContextVar, but implemented with threading.local()."""
# TODO(shoyer): remove in favor of contextvars when JAX requires Python 3.7+

def __init__(self, name, *, default):
self.name = name
self.default = default
self._state = threading.local()
self._state.value = default
self._lock = threading.Lock()

def get(self):
return getattr(self._state, 'value', self.default)

def set(self, value):
with self._lock:
old_value = self.get()
self._state.value = value
return old_value

def reset(self, token):
self._state.value = token


_OVERRIDES = {}


F = TypeVar('F', bound=Callable[..., Any])


def overrideable(name: str) -> Callable[[F], F]:
"""Make an internal JAX function overrideable."""
def decorator(fun):
_OVERRIDES[name] = ContextVar(name, default=fun)
@functools.wraps(fun)
def wrapper(*args, **kwargs):
impl = _OVERRIDES[name].get()
# A typical override will call back into the original JAX function after
# doing some processing. We remove the override here, because otherwise
# this would end up in an infinite loop.
with override_context({name: fun}):
return impl(*args, **kwargs)
return wrapper
return decorator


@contextlib.contextmanager
def override_context(implementations: Mapping[str, Callable]):
"""Experimental override for JAX functions within a context.
This context manager allows for overriding the implementation of higher order
JAX functions within a limited scope. It is intended for libraries such as
Haiku and Flax that implement their own versions of these functions that
support mutation, and may be removed in the future if/when JAX has a unified
interface for handling mutable state.
Usage example::
import jax
def my_grad(f):
print("inside my_grad")
return jax.grad(f)
with jax.experimental.override_context({'grad': my_grad}):
# All calls to jax.grad() are replaced by my_grad().
# However, calls to jax.grad() from *inside* the implementation of
# my_grad() will use the original.
y = jax.grad(lambda x: x ** 2)(1.0) # prints "inside my_grad"
assert y == 2
This context manager only overrides implementations within the current thread,
and hence is thread-safe.
Currently supported functions:
checkpoint, grad, hessian, jacfwd, jacrev, jit, jvp, lax.associative_scan,
lax.cond, lax.fori_loop, lax.scan, lax.switch, lax.while_loop,
linear_transpose, linearize, named_call, pmap, value_and_grad, vjp, vmap
"""
tokens = {k: _OVERRIDES[k].set(v) for k, v in implementations.items()}
try:
yield
finally:
for k, v in tokens.items():
_OVERRIDES[k].reset(v)
23 changes: 20 additions & 3 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
treedef_is_leaf, treedef_children, Partial)
from ._src.util import (unzip2, curry, partial, safe_map, safe_zip, prod,
split_list, extend_name_stack, wrap_name, cache, wraps,
HashableFunction)
HashableFunction, overrideable)
from .lib import jax_jit
from .lib import version
from .lib import xla_bridge as xb
Expand Down Expand Up @@ -86,7 +86,7 @@
# in JIT internals, as Tracer values are passed through the function.
# Should this raise any type errors for the tracing code in future, we can disable
# type checking in parts of the tracing code, or remove these annotations.
F = TypeVar("F", bound=Callable)
F = TypeVar("F", bound=Callable[..., Any])
T = TypeVar("T")
U = TypeVar("U")

Expand Down Expand Up @@ -121,6 +121,7 @@ def __init__(self):
_thread_local_state = _ThreadLocalState()


@overrideable('jit')
def jit(
fun: F,
static_argnums: Union[int, Iterable[int]] = (),
Expand Down Expand Up @@ -690,6 +691,8 @@ def computation_maker(*args, **kwargs):

return computation_maker


@overrideable('grad')
def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable:
Expand Down Expand Up @@ -752,6 +755,8 @@ def grad_f_aux(*args, **kwargs):

return grad_f_aux if has_aux else grad_f


@overrideable('value_and_grad')
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable[..., Tuple[Any, Any]]:
Expand Down Expand Up @@ -862,6 +867,7 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")


@overrideable('jacfwd')
def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
holomorphic: bool = False) -> Callable:
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Expand Down Expand Up @@ -928,6 +934,7 @@ def _check_output_dtype_jacfwd(holomorphic, x):
f"but got {aval.dtype.name}.")


@overrideable('jacrev')
def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
holomorphic: bool = False, allow_int: bool = False) -> Callable:
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Expand Down Expand Up @@ -980,6 +987,7 @@ def jacfun(*args, **kwargs):
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")


@overrideable('hessian')
def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
holomorphic: bool = False) -> Callable:
"""Hessian of ``fun`` as a dense array.
Expand Down Expand Up @@ -1069,6 +1077,7 @@ def _dtype(x):
return dtypes.canonicalize_dtype(dtypes.result_type(x))


@overrideable('vmap')
def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Expand Down Expand Up @@ -1274,6 +1283,8 @@ def _get_axis_size(name: str, i:int, shape: Tuple[int, ...], axis: int):
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None


@overrideable('pmap')
def pmap(
fun: F,
axis_name: Optional[AxisName] = None,
Expand Down Expand Up @@ -1632,6 +1643,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable):
map(tuple, out_shapes), out_tree_thunk())
return fun

@overrideable('jvp')
def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]:
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
Expand Down Expand Up @@ -1697,6 +1709,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents):
return (tree_unflatten(out_tree(), out_primals),
tree_unflatten(out_tree(), out_tangents))

@overrideable('linearize')
def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
Expand Down Expand Up @@ -1827,7 +1840,8 @@ def vjp(fun: Callable[..., Any],
Tuple[Any, Callable, Any]]:
...

def vjp( # type: ignore
@overrideable('vjp') # type: ignore
def vjp(
fun: Callable, *primals, has_aux: bool = False,
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
Expand Down Expand Up @@ -1897,6 +1911,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)


@overrideable('linear_transpose')
def linear_transpose(fun: Callable, *primals) -> Callable:
"""Transpose a function that is promised to be linear.
Expand Down Expand Up @@ -2328,6 +2343,7 @@ def abstractify(x):
return tree_unflatten(out_tree(), out)


@overrideable('checkpoint')
def checkpoint(fun: Callable, concrete: bool = False) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
Expand Down Expand Up @@ -2422,6 +2438,7 @@ def fun_remat(*args, **kwargs):
remat = checkpoint


@overrideable('named_call')
def named_call(
fun: Callable[..., Any],
*,
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from ..interpreters.sharded_jit import (sharded_jit, PartitionSpec,
with_sharding_constraint)
from .x64_context import enable_x64, disable_x64
from .._src.util import override_context
26 changes: 26 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5018,5 +5018,31 @@ def test_partial_eval(self):
self.assertEqual(out, 5)


class _CountedCalls:
def __init__(self, fun):
self.fun = fun
self.calls = 0
def __call__(self, *args, **kwargs):
self.calls += 1
return self.fun(*args, **kwargs)


class OverrideTest(jtu.JaxTestCase):

def test(self):

counted_jit = _CountedCalls(jax.jit)

jax.jit(lambda x: x)(1)
self.assertEqual(counted_jit.calls, 0)

with jax.experimental.override_context({'jit': counted_jit}):
jax.jit(lambda x: x)(1)
self.assertEqual(counted_jit.calls, 1)

jax.jit(lambda x: x)(1)
self.assertEqual(counted_jit.calls, 1)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 5ae5c9b

Please sign in to comment.