New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make solver.update jittable and ensure output states are consistent. #106
Conversation
return AndersonState(iter_num=0, | ||
error=jnp.inf, | ||
return AndersonState(iter_num=jnp.asarray(0), | ||
error=jnp.asarray(jnp.inf), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems worthwhile to document—in a comment or even in this PR's description—the reason for sending these scalars through jnp.asarray
. My understanding is that we would like to return a state struct that is the same whether or not the update took place on device (via jit
). Is that correct?
In some situations, we might avoid writing expressions like jnp.asarray(1.)
whenever possible. If we're not under a jit
, the expression allocates a scalar on the default device, which could be an accelerator. Considering this, what is this degree of consistency buying us, and is it worth this cost?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal was to avoid this:
jitted_update = jax.jit(solver.update)
state = solver.init_state(params)
# here state contains floats
params, state = jitted_update(params, state)
# here state contains arrays due to jit
params, state = jitted_update(params, state) # recompilation occurs
I agree it could be nice to add a comment but not sure if we should repeat the same comment in all solvers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that if we use onp.asarray
instead of jnp.asarray
, where onp
is plain numpy, we will on the one hand keep these scalars in host memory, and on the other hand be consistent for jit
(in the sense of shape/dtype), such that it won't recompile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fllinares and I would like to understand this better. Since this PR is already an improvement, let's merge and explore onp.asarray
separately (potentially we could make measurements).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my side, everything LGTM! Thanks a lot Mathieu!
No description provided.