Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sophia-h optimizer #979

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

evanatyourservice
Copy link

PR to add sophia optimizer. It's mostly based on levanter's implementation with some changes/added features here and there.

One note is that I had to change the contrib common test file a couple times, once to pass the loss_fn out of the parabola and rosenbrock functions (could be useful later for other optimizers that need loss function), and a second time to bypass the check for update arguments to be values (the loss function is not). Please advise if these changes are not ok or the most correct.

@evanatyourservice
Copy link
Author

fixes #968

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much @evanatyourservice! And sorry for the delay.
I left you some comments we can discuss about.

optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved

def update_hessian(key, count, nu, params, obj_fn):
def _do_update(key):
if pmap_axis_name is not None and jax.local_device_count() > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure to understand the logic here.
I suppose the idea is that, by using pmap, we can compute an average of $m$ estimates where $m$ is the number of local devices?
But (i) we need an example on how to use it, in particular, how to set pmap_axis_name, (ii) we need to test it.
Concerning the second point, pmap is free computationally but a priori not free in terms of memory. We need to be able to have $m$ random copies of the parameters which may blow up the memory. Moreover, with an implementation using hvps as done right now, the graph of computation of the function would be replicated $m$ times too, which, I would believe, is infeasible for large models.
Moreover, I don't know how this logic will interact with shard_map or automatic parallelism/sharding with jit.
So either we keep this option but document and test it, or we leave it out for now. As you pointed out, we may keep the code close to the levanter implementation which does not incorporate such a logic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, it performs a separate monte carlo sample on each device by using a device-specific rng_key instead of a replicated rng_key, essentially giving free estimates. Without pmap_axis_name, if using pmap, the exact same sample would be carried out on each device because the same rng_key would be replicated on every device. This could be viewed as wasteful because the only thing needed for each device to perform its own unique sample is a unique rng_key. The params or vectors do not have to be multiplied memory-wise on each device to do this, only the rng_keys.

For example, if we are using pmap with 4 devices and provide pmap_axis_name="batch", 4 unique rng_keys are made in init and all 4 are held on every device. When it comes time to update the hessian diag, each device slices its own of the 4 rng_keys using these lines:

        idx = jax.lax.axis_index(pmap_axis_name)
        key = jax.lax.dynamic_index_in_dim(key, idx, keepdims=False)

Then the hutchinson's sample is carried out with the device's single unique rng_key, the samples are averaged using pmean, and then the rng_keys are all_gathered back together. The costs are the memory of holding 4 rng_keys on each device instead of 1, the pmean for averaging the hessian diagonal estimates across devices, and the all_gather for the 4 rng_keys.

It shouldn't mess with jit at all because if one isn't using pmap the user can leave the argument as None and it will not be used and act as the normal optimizer. I think it's a very useful feature because, if using pmap, performing the exact same monte carlo sample on each device could be seen as wasteful when the only thing needed for them to be unique samples are unique rng_keys for the random vector creation. I originally got this idea from distributed shampoo where in pmap mode computations are split across devices using jax.lax.axis_index, slicing, and all_gathering. I'll document the feature better so users can understand what's going on and make using it easier, and I can add a test for it using xla_force_host_platform_device_count.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for the explanations!
We would mostly need one example to show how to use it.
I would still believe that at one point the m samples need to be created in the memory before being fed to the hvp so a priori we would incur a memory cost of m d (for m the number of devices and d the size of the parameters). I don't see how this can be avoided. In terms of computations, it should indeed be no more costly than not using pmap, so as long as the memory allows it it's a great feature. I would just mention the memory cost (if you agree with me).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok actually I think for the sake of simplicity I will just remove this feature for now haha, but what would be the most concise way to average the hessian diagonals across devices/sharded data axis when using either pmap or jit with sharding, akin to averaging gradients across devices? Would an argument asking for a named axis suffice?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the hutchinson estimators (trace or diag) should be part of another functionality in optax.
As an example matfree has a rather neat api that we could either directly use or take inspiration of.
Concretely, I'm wondering if it would be possible for sophia_h to take as an argument a function hutch_hess_diag_fn given from something like hutch_hess_diag_fn = hutch_hess_diag(fun, **extra_args) with a signature hess_diag = hutch_hess_diag_fn(params) (semantic akin to a jax.grad). I'd suppose that such a function needs also an inner state (to store the key) so maybe more something like hess_diag, hutch_state = hutch_hess_diag_fn(params, hutch_state).
That way we can postpone for now the implementation of a clever hutchinson estimator and such a clever estimator could be used for other purposes. The extra_args above could incoporate pmap axis etc... Also that way it would be easier to make unit tests on the pmap/shard_map behaviors.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent yeah that sounds like a great idea for the long run, and matfree's api does look really nice. I'll add the argument for a custom hvp fn, and have a default one in the same file using the hess_diag, hutch_state = hutch_hess_diag_fn(params, hutch_state) format, just so people can get started using the optimizer and have a good example of how to make their own fn if they want. Then later optax could add hvp fns into their main API if they wanted. I'll also rename the optimizer to sophia instead of sophia-h, because then a user can write their own fn for sophia-g, or any other kind of preconditioner they'd like, it won't be hardcoded to be sophia-h.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds perfect! Thanks for all the work @evanatyourservice !

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I got a good enough setup with this at least for a contrib but let me know your thoughts! Not sure if I've documented it well enough too

optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
@evanatyourservice
Copy link
Author

Hi Vincent, thank you for the notes! They all make perfect sense to me and I'll get to updating the code/answering them tomorrow

@fabianp
Copy link
Member

fabianp commented Jun 27, 2024

@evanatyourservice please ping us whenever you're ready for another round of reviews :-)

@evanatyourservice
Copy link
Author

@fabianp Will do! Sorry been moving but will try to get this going asap

@fabianp
Copy link
Member

fabianp commented Jun 28, 2024

there's no rush, just wanted to make sure you were not waiting on us :-)

@evanatyourservice
Copy link
Author

@vroulet @fabianp Got some updates pushed, let me know if anything needs to be changed! Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants