In [1]:
import jax

In [15]:
from functools import partial

def unjitted_loop_body(prev_i, i):
  return prev_i + 1 + i

In [19]:
def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body, i=9))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x, i=9))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    func = partial(unjitted_loop_body, i=9)
    i = jax.jit(func)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit -n 10 g_inner_jitted_partial(10, 20).block_until_ready()

# print("jit called in a loop with lambdas:")
# %timeit -n 10 g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit -n 10 g_inner_jitted_normal(10, 20).block_until_ready()

jit called in a loop with partials:
79.1 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit called in a loop with lambdas:
79.3 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit called in a loop with caching:
79.1 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
from typing import Optional
import jax.numpy as jnp
import haiku as hk

class MyToyModel(hk.Module):
    def __init__(self, name: str | None = None, is_training: bool = False):
        super().__init__(name=name)
        self.is_training = is_training
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        if self.is_training:
            x = hk.dropout(hk.next_rng_key(), 0.2, x)
            return hk.Linear(100)(x)
        else:
            return hk.Linear(1000)(x)