Skip to content

Commit

Permalink
Avoid retracing when a host_callback.call is called multiple times wi…
Browse files Browse the repository at this point in the history
…th the same function.

If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.

This issue was found in passing while debugging #9970.
  • Loading branch information
hawkinsp committed Apr 1, 2022
1 parent e766b96 commit 208e83c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
45 changes: 33 additions & 12 deletions jax/experimental/host_callback.py
Expand Up @@ -690,6 +690,37 @@ def call(callback_func: Callable, arg, *,
call_with_device=call_with_device, identity=False)


# We need the wrapper function to have hash and equality defined since it is
# used as a primitive keyword argument, and we want a compilation cache hit if
# the user uses the same function twice.
class _CallbackWrapper:
def __init__(self, callback_func, identity, call_with_device):
self.callback_func = callback_func
self.identity = identity
self.call_with_device = call_with_device

def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device))

def __eq__(self, other):
return (self.callback_func == other.callback_func and
self.identity == other.identity and
self.call_with_device == other.call_with_device)

def __call__(self, arg, device, transforms):
if self.identity:
# For id_tap, we pass the transforms, for backwards compatibility
if self.call_with_device:
return self.callback_func(arg, transforms, device=device)
else:
return self.callback_func(arg, transforms)
else:
if self.call_with_device:
return self.callback_func(arg, device=device)
else:
return self.callback_func(arg)


# Helper function to implement both `call` and `id_tap`. The two cases are
# differentiated by the `identity` flag.
def _call(callback_func: Callable, arg, *,
Expand All @@ -706,18 +737,8 @@ def _call(callback_func: Callable, arg, *,
# See definition of outside_call_p for what parameters it takes
params: Dict[str, Any] = {}
# TODO: wrap function
if identity:
# For id_tap, we pass the transforms, for backwards compatibility
if call_with_device:
callback = lambda arg, device, transforms: callback_func(arg, transforms, device=device)
else:
callback = lambda arg, device, transforms: callback_func(arg, transforms)
else:
if call_with_device:
callback = lambda arg, device, transforms: callback_func(arg, device=device)
else:
callback = lambda arg, device, transforms: callback_func(arg)
params["callback"] = callback
params["callback"] = _CallbackWrapper(callback_func, identity,
call_with_device)
params["identity"] = identity
params["arg_treedef"] = arg_treedef

Expand Down
16 changes: 15 additions & 1 deletion tests/host_callback_test.py
Expand Up @@ -132,7 +132,7 @@ def repl_floats(match_group):
x = np.around(float(matched), decimals=2)
return f"{x:.2f}"

what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
what = re.sub(r"\-?\d+\.[\-\def]*", repl_floats, what)
what = re.sub(r"output_stream=[^\]\n,]*,?", "", what)
what = re.sub(r"threshold=[^\]\n,]*,?", "", what)
what = re.sub(r"bwd=[^\]\n]*", "", what)
Expand Down Expand Up @@ -2153,6 +2153,20 @@ def fun(x):
arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg))

def test_primitive_compilation(self):

def f_outside(x):
return 2 * x

def fun(x):
return hcb.call(f_outside, x, result_shape=x)

arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
with jtu.count_primitive_compiles() as count:
for _ in range(3):
self.assertAllClose(2 * arg, fun(arg))
self.assertEqual(count[0], 1)

@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{np.dtype(dtype).name}", dtype=dtype)
Expand Down

0 comments on commit 208e83c

Please sign in to comment.