You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using an adaptive integrator and a dtype=jnp.float32 for parameters, samples, etc, the adaptive integrator raises errors. The errors arise due to the following:
the two dt's in the jax.lax.cond in the accepted case can be different (next_dt vs rk_state.dt)
the error norm inherits the dtype from the variational state, and so the replaced last_norm and last_scaled_error can be different in the accepted case
the last happens because we initialize the last_norm e.g. with 0. in the adaptive case, which can be float32 after.
I'm not sure what the best solution is (I wasn't even expecting errors of this kind tbh), but some possibilities are:
make sure all replaced dt's have the same dtype as rk_state.dt
initialize last_norm and last_scaled_error fields with a jnp.array with a predefined fixed dtype (not clear how to determine this in general, but float64 would make sense)
initialize and convert everything to float64 in the RKState.
The text was updated successfully, but these errors were encountered:
Thanks, it all makes sense.
Getting things to work with non standard dtypes is always a mess...
In particular
make sure all replaced dt's have the same dtype as rk_state.dt
I agree
initialize last_norm and last_scaled_error fields with a jnp.array with a predefined fixed dtype (not clear how to determine this in general, but float64 would make sense)
I would initialise it with the dtype of the output of the error_norm.
You can get it from abstract interpretation, aka jax.eval_shape
Those two should be enough, I think? I don't think that there are other fields that might change... We already forcibly enforce that the dtype of the parameters does not change (for the same reason).
When using an adaptive integrator and a dtype=jnp.float32 for parameters, samples, etc, the adaptive integrator raises errors. The errors arise due to the following:
I'm not sure what the best solution is (I wasn't even expecting errors of this kind tbh), but some possibilities are:
The text was updated successfully, but these errors were encountered: