Skip to content

Commit

Permalink
Add schedule possibility to cocob
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627766391
  • Loading branch information
vroulet authored and OptaxDev committed Apr 24, 2024
1 parent 3301688 commit d9c2bbe
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
3 changes: 3 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# ==============================================================================
"""Contributed optimizers in Optax."""

# pylint: disable=g-importing-member

from optax.contrib._cocob import cocob
from optax.contrib._cocob import COCOBState
from optax.contrib._cocob import scale_by_cocob
from optax.contrib._complex_valued import split_real_and_imaginary
from optax.contrib._complex_valued import SplitRealAndImaginaryState
from optax.contrib._dadapt_adamw import dadapt_adamw
Expand Down
50 changes: 39 additions & 11 deletions optax/contrib/_cocob.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
Networks without Learning Rates Through Coin Betting" by Francesco Orabona and
Tatiana Tommasi.
"""
from typing import NamedTuple
from typing import Any, Callable, NamedTuple, Optional, Union

import jax.numpy as jnp
import jax.tree_util as jtu
from optax._src import base
from optax._src import combine
from optax._src import transform


class COCOBState(NamedTuple):
Expand All @@ -35,10 +37,8 @@ class COCOBState(NamedTuple):
reward: base.Updates


def cocob(
alpha: float = 100,
eps: float = 1e-8,
weight_decay: float = 0
def scale_by_cocob(
alpha: float = 100, eps: float = 1e-8
) -> base.GradientTransformation:
"""Rescale updates according to the COntinuous COin Betting algorithm.
Expand All @@ -47,12 +47,11 @@ def cocob(
subgradients. All we need is a good gambling strategy. See Algorithm 2 of:
References:
[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)
[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)
Args:
alpha: fraction to bet parameter of the COCOB optimizer
eps: jitter term to avoid dividing by 0
weight_decay: L2 penalty
Returns:
A `GradientTransformation` object.
Expand All @@ -72,10 +71,6 @@ def init_fn(params):
def update_fn(updates, state, params):
init_particles, cumulative_grads, scale, subgradients, reward = state

updates = jtu.tree_map(
lambda c, p: c + weight_decay * p, updates, params,
)

scale = jtu.tree_map(
lambda L, c: jnp.maximum(L, jnp.abs(c)), scale, updates
)
Expand Down Expand Up @@ -115,3 +110,36 @@ def update_fn(updates, state, params):
return new_updates, new_state

return base.GradientTransformation(init_fn, update_fn)


def cocob(
learning_rate: base.ScalarOrSchedule = 1.,
alpha: float = 100,
eps: float = 1e-8,
weight_decay: float = 0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the COntinuous COin Betting algorithm.
Algorithm for stochastic subgradient descent. Uses a gambling algorithm to
find the minimizer of a non-smooth objective function by accessing its
subgradients. All we need is a good gambling strategy. See Algorithm 2 of:
References:
[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)
Args:
learning_rate: optional learning rate to e.g. inject some scheduler
alpha: fraction to bet parameter of the COCOB optimizer
eps: jitter term to avoid dividing by 0
weight_decay: L2 penalty
mask: mask for weight decay
Returns:
A `GradientTransformation` object.
"""
return combine.chain(
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate, flip_sign=False),
scale_by_cocob(alpha, eps),
)

0 comments on commit d9c2bbe

Please sign in to comment.