-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sophia-h optimizer #979
base: main
Are you sure you want to change the base?
Add sophia-h optimizer #979
Conversation
fixes #968 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much @evanatyourservice! And sorry for the delay.
I left you some comments we can discuss about.
optax/contrib/_sophia_h.py
Outdated
|
||
def update_hessian(key, count, nu, params, obj_fn): | ||
def _do_update(key): | ||
if pmap_axis_name is not None and jax.local_device_count() > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure to understand the logic here.
I suppose the idea is that, by using pmap, we can compute an average of
But (i) we need an example on how to use it, in particular, how to set pmap_axis_name, (ii) we need to test it.
Concerning the second point, pmap is free computationally but a priori not free in terms of memory. We need to be able to have
Moreover, I don't know how this logic will interact with shard_map or automatic parallelism/sharding with jit.
So either we keep this option but document and test it, or we leave it out for now. As you pointed out, we may keep the code close to the levanter implementation which does not incorporate such a logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly, it performs a separate monte carlo sample on each device by using a device-specific rng_key instead of a replicated rng_key, essentially giving free estimates. Without pmap_axis_name, if using pmap, the exact same sample would be carried out on each device because the same rng_key would be replicated on every device. This could be viewed as wasteful because the only thing needed for each device to perform its own unique sample is a unique rng_key. The params or vectors do not have to be multiplied memory-wise on each device to do this, only the rng_keys.
For example, if we are using pmap with 4 devices and provide pmap_axis_name="batch", 4 unique rng_keys are made in init and all 4 are held on every device. When it comes time to update the hessian diag, each device slices its own of the 4 rng_keys using these lines:
idx = jax.lax.axis_index(pmap_axis_name)
key = jax.lax.dynamic_index_in_dim(key, idx, keepdims=False)
Then the hutchinson's sample is carried out with the device's single unique rng_key, the samples are averaged using pmean, and then the rng_keys are all_gathered back together. The costs are the memory of holding 4 rng_keys on each device instead of 1, the pmean for averaging the hessian diagonal estimates across devices, and the all_gather for the 4 rng_keys.
It shouldn't mess with jit at all because if one isn't using pmap the user can leave the argument as None and it will not be used and act as the normal optimizer. I think it's a very useful feature because, if using pmap, performing the exact same monte carlo sample on each device could be seen as wasteful when the only thing needed for them to be unique samples are unique rng_keys for the random vector creation. I originally got this idea from distributed shampoo where in pmap mode computations are split across devices using jax.lax.axis_index, slicing, and all_gathering. I'll document the feature better so users can understand what's going on and make using it easier, and I can add a test for it using xla_force_host_platform_device_count.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks for the explanations!
We would mostly need one example to show how to use it.
I would still believe that at one point the m samples need to be created in the memory before being fed to the hvp so a priori we would incur a memory cost of m d (for m the number of devices and d the size of the parameters). I don't see how this can be avoided. In terms of computations, it should indeed be no more costly than not using pmap, so as long as the memory allows it it's a great feature. I would just mention the memory cost (if you agree with me).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok actually I think for the sake of simplicity I will just remove this feature for now haha, but what would be the most concise way to average the hessian diagonals across devices/sharded data axis when using either pmap or jit with sharding, akin to averaging gradients across devices? Would an argument asking for a named axis suffice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the hutchinson estimators (trace or diag) should be part of another functionality in optax.
As an example matfree has a rather neat api that we could either directly use or take inspiration of.
Concretely, I'm wondering if it would be possible for sophia_h to take as an argument a function hutch_hess_diag_fn
given from something like hutch_hess_diag_fn = hutch_hess_diag(fun, **extra_args)
with a signature hess_diag = hutch_hess_diag_fn(params)
(semantic akin to a jax.grad
). I'd suppose that such a function needs also an inner state (to store the key) so maybe more something like hess_diag, hutch_state = hutch_hess_diag_fn(params, hutch_state)
.
That way we can postpone for now the implementation of a clever hutchinson estimator and such a clever estimator could be used for other purposes. The extra_args
above could incoporate pmap axis etc... Also that way it would be easier to make unit tests on the pmap/shard_map behaviors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent yeah that sounds like a great idea for the long run, and matfree's api does look really nice. I'll add the argument for a custom hvp fn, and have a default one in the same file using the hess_diag, hutch_state = hutch_hess_diag_fn(params, hutch_state)
format, just so people can get started using the optimizer and have a good example of how to make their own fn if they want. Then later optax could add hvp fns into their main API if they wanted. I'll also rename the optimizer to sophia instead of sophia-h, because then a user can write their own fn for sophia-g, or any other kind of preconditioner they'd like, it won't be hardcoded to be sophia-h.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds perfect! Thanks for all the work @evanatyourservice !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I got a good enough setup with this at least for a contrib but let me know your thoughts! Not sure if I've documented it well enough too
Hi Vincent, thank you for the notes! They all make perfect sense to me and I'll get to updating the code/answering them tomorrow |
@evanatyourservice please ping us whenever you're ready for another round of reviews :-) |
@fabianp Will do! Sorry been moving but will try to get this going asap |
there's no rush, just wanted to make sure you were not waiting on us :-) |
PR to add sophia optimizer. It's mostly based on levanter's implementation with some changes/added features here and there.
One note is that I had to change the contrib common test file a couple times, once to pass the loss_fn out of the parabola and rosenbrock functions (could be useful later for other optimizers that need loss function), and a second time to bypass the check for update arguments to be values (the loss function is not). Please advise if these changes are not ok or the most correct.