Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sophia-h optimizer #979

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ Experimental features and algorithms that don't meet the
schedule_free
schedule_free_eval_params
ScheduleFreeState
sophia_h
SophiaHState
split_real_and_imaginary
SplitRealAndImaginaryState
sophia
scale_by_sophia
SophiaState

Asynchronous-centering-Prop
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -98,3 +97,8 @@ Sharpness aware minimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sam
.. autoclass:: SAMState

Sophia-H
~~~~~~~~
.. autofunction:: sophia_h
.. autoclass:: SophiaHState
2 changes: 2 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@
from optax.contrib._schedule_free import schedule_free
from optax.contrib._schedule_free import schedule_free_eval_params
from optax.contrib._schedule_free import ScheduleFreeState
from optax.contrib._sophia_h import sophia_h
from optax.contrib._sophia_h import SophiaHState
33 changes: 24 additions & 9 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import chex
from functools import partial
import jax
import jax.numpy as jnp
from optax import contrib
Expand All @@ -40,6 +41,7 @@
dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='sophia_h', opt_kwargs=dict(learning_rate=1e-2)),
)


Expand All @@ -48,11 +50,12 @@ def _setup_parabola(dtype):
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)

@jax.value_and_grad
def get_updates(params):
def loss_fn(params):
return jnp.sum(numerics.abs_sq(params - final_params))

return initial_params, final_params, get_updates
get_updates = jax.value_and_grad(loss_fn)

return initial_params, final_params, get_updates, loss_fn


def _setup_rosenbrock(dtype):
Expand All @@ -63,13 +66,14 @@ def _setup_rosenbrock(dtype):
initial_params = jnp.array([0.0, 0.0], dtype=dtype)
final_params = jnp.array([a, a**2], dtype=dtype)

@jax.value_and_grad
def get_updates(params):
def loss_fn(params):
return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
params[1] - params[0] ** 2
)

return initial_params, final_params, get_updates
get_updates = jax.value_and_grad(loss_fn)

return initial_params, final_params, get_updates, loss_fn


class ContribTest(chex.TestCase):
Expand All @@ -81,13 +85,15 @@ class ContribTest(chex.TestCase):
)
def test_optimizers(self, opt_name, opt_kwargs, target, dtype):
opt = getattr(contrib, opt_name)(**opt_kwargs)
initial_params, final_params, get_updates = target(dtype)
initial_params, final_params, get_updates, loss_fn = target(dtype)

@jax.jit
def step(params, state):
value, updates = get_updates(params)
if opt_name in ['momo', 'momo_adam']:
update_kwargs = {'value': value}
elif opt_name == 'sophia_h':
update_kwargs = {'obj_fn': loss_fn}
else:
update_kwargs = {}
updates, state = opt.update(updates, state, params, **update_kwargs)
Expand Down Expand Up @@ -120,18 +126,27 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))]

opt_update = opt.update
opt_inject_update = opt_inject.update
if opt_name in ['momo', 'momo_adam']:
update_kwargs = {'value': jnp.array(1.0)}
elif opt_name == 'sophia_h':
temp_update_kwargs = {
'obj_fn': lambda ps: sum(jnp.sum(p) for p in jax.tree.leaves(ps))
}
opt_update = partial(opt.update, **temp_update_kwargs)
opt_inject_update = partial(opt_inject.update, **temp_update_kwargs)
update_kwargs = {}
else:
update_kwargs = {}

state = self.variant(opt.init)(params)
updates, new_state = self.variant(opt.update)(
updates, new_state = self.variant(opt_update)(
grads, state, params, **update_kwargs
)

state_inject = self.variant(opt_inject.init)(params)
updates_inject, new_state_inject = self.variant(opt_inject.update)(
updates_inject, new_state_inject = self.variant(opt_inject_update)(
grads, state_inject, params, **update_kwargs
)

Expand Down
Loading
Loading