--- .venv/lib/python3.10/site-packages/kfac_jax/_src/optimizer.py 2022-07-14 12:48:51.387431181 +0100 +++ .venv/lib/python3.10/site-packages/kfac_jax/_src/optimizer_new.py 2022-07-14 12:48:39.959253562 +0100 @@ -304,7 +304,8 @@ self._value_func_has_rng = value_func_has_rng self._value_func: ValueFunc = convert_value_and_grad_to_value_func( value_and_grad_func, - has_aux=value_func_has_aux, + # CHANGE + has_aux=value_func_has_aux or value_func_has_state, ) self._l2_reg = jnp.asarray(l2_reg) self._use_adaptive_learning_rate = use_adaptive_learning_rate