Optax isn't treating its optimisers as pytrees. This is introducing spurious recompilation:
import functools as ft
import jax
import jax.numpy as jnp
import optax
optim1 = optax.adam(jnp.array(3e-3))
optim2 = optax.adam(jnp.array(3e-4))
@ft.partial(jax.jit, static_argnums=0) # post-pytree-isation, this would probably just be `jax.jit`
def evaluate(optim, params):
print("Compiling!")
state = optim.init(params)
grads = jnp.zeros_like(params)
updates, new_state = optim.update(grads, state)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
evaluate(optim1, jnp.array(0)) # Compiling!
evaluate(optim2, jnp.array(0)) # Compiling!
Can Optax optimisers be treated as pytrees with respect to their input arguments?
Optax isn't treating its optimisers as pytrees. This is introducing spurious recompilation:
Can Optax optimisers be treated as pytrees with respect to their input arguments?