Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make solver.update jittable and ensure output states are consistent. #106

Merged
merged 1 commit into from Dec 3, 2021

Conversation

mblondel
Copy link
Collaborator

@mblondel mblondel commented Dec 1, 2021

No description provided.

@google-cla google-cla bot added the cla: yes label Dec 1, 2021
jaxopt/_src/test_util.py Outdated Show resolved Hide resolved
return AndersonState(iter_num=0,
error=jnp.inf,
return AndersonState(iter_num=jnp.asarray(0),
error=jnp.asarray(jnp.inf),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems worthwhile to document—in a comment or even in this PR's description—the reason for sending these scalars through jnp.asarray. My understanding is that we would like to return a state struct that is the same whether or not the update took place on device (via jit). Is that correct?

In some situations, we might avoid writing expressions like jnp.asarray(1.) whenever possible. If we're not under a jit, the expression allocates a scalar on the default device, which could be an accelerator. Considering this, what is this degree of consistency buying us, and is it worth this cost?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal was to avoid this:

jitted_update = jax.jit(solver.update)
state = solver.init_state(params)
# here state contains floats
params, state = jitted_update(params, state)
# here state contains arrays due to jit
params, state = jitted_update(params, state)  # recompilation occurs

I agree it could be nice to add a comment but not sure if we should repeat the same comment in all solvers.

Copy link
Member

@froystig froystig Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that if we use onp.asarray instead of jnp.asarray, where onp is plain numpy, we will on the one hand keep these scalars in host memory, and on the other hand be consistent for jit (in the sense of shape/dtype), such that it won't recompile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fllinares and I would like to understand this better. Since this PR is already an improvement, let's merge and explore onp.asarray separately (potentially we could make measurements).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Sounds good!

Copy link
Collaborator

@fllinares fllinares left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side, everything LGTM! Thanks a lot Mathieu!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants