Skip to content

Commit

Permalink
separate out hessian diagonal fn, make customizable
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Jun 28, 2024
1 parent 33f4525 commit 9d73e9f
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 104 deletions.
10 changes: 5 additions & 5 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ Experimental features and algorithms that don't meet the
schedule_free
schedule_free_eval_params
ScheduleFreeState
sophia_h
SophiaHState
sophia
SophiaState
split_real_and_imaginary
SplitRealAndImaginaryState

Expand Down Expand Up @@ -98,7 +98,7 @@ Sharpness aware minimization
.. autofunction:: sam
.. autoclass:: SAMState

Sophia-H
Sophia
~~~~~~~~
.. autofunction:: sophia_h
.. autoclass:: SophiaHState
.. autofunction:: sophia
.. autoclass:: SophiaState
4 changes: 2 additions & 2 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
from optax.contrib._schedule_free import schedule_free
from optax.contrib._schedule_free import schedule_free_eval_params
from optax.contrib._schedule_free import ScheduleFreeState
from optax.contrib._sophia_h import sophia_h
from optax.contrib._sophia_h import SophiaHState
from optax.contrib._sophia import sophia
from optax.contrib._sophia import SophiaState
6 changes: 3 additions & 3 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='sophia_h', opt_kwargs=dict(learning_rate=1e-2)),
dict(opt_name='sophia', opt_kwargs=dict(learning_rate=1e-2)),
)


Expand Down Expand Up @@ -92,7 +92,7 @@ def step(params, state):
value, updates = get_updates(params)
if opt_name in ['momo', 'momo_adam']:
update_kwargs = {'value': value}
elif opt_name == 'sophia_h':
elif opt_name == 'sophia':
update_kwargs = {'obj_fn': loss_fn}
else:
update_kwargs = {}
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
opt_inject_update = opt_inject.update
if opt_name in ['momo', 'momo_adam']:
update_kwargs = {'value': jnp.array(1.0)}
elif opt_name == 'sophia_h':
elif opt_name == 'sophia':
temp_update_kwargs = {
'obj_fn': lambda ps: sum(jnp.sum(p) for p in jax.tree.leaves(ps))
}
Expand Down
Loading

0 comments on commit 9d73e9f

Please sign in to comment.