Skip to content

Commit

Permalink
D-Adaptation and Prodigy implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
adefazio committed Dec 7, 2023
1 parent bf987e1 commit 8cb18fa
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,10 @@ scale_gradient
SAMState
cocob
COCOBState
dadapt_adamw
DAdaptAdamWState
prodigy
ProdigyState


Complex-Valued Optimization
Expand Down
2 changes: 2 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@
from optax.contrib.sam import NormalizeState
from optax.contrib.sam import sam
from optax.contrib.sam import SAMState
from optax.contrib.dadapt_adamw import dadapt_adamw
from optax.contrib.prodigy import prodigy
123 changes: 123 additions & 0 deletions optax/contrib/dadapt_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@

# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""D-Adatation (AdamW variant).
A contributed implementation of the method from "Learning-Rate-Free Learning
by D-Adaptation" (https://arxiv.org/abs/2301.07733) by Aaron Defazio and
Konstantin Mishchenko (ICML 2023 Outstanding Paper award).
"""
import jax.numpy as jnp
from jax.tree_util import tree_map

from typing import NamedTuple, Optional, Tuple
from optax import tree_utils
from optax._src import base
from optax._src import utils

class DAdaptAdamWState(NamedTuple):
exp_avg: base.Updates
exp_avg_sq: base.Updates
grad_sum: base.Updates # exponential moving average of the sum of gradients
estim_lr: float # Distance to solution estimate
numerator_weighted: float
count: float

def dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0,
betas=(0.9, 0.999),
eps=1e-8,
estim_lr0=1e-6,
weight_decay=0) -> base.GradientTransformation:
"""Learning rate free AdamW by D-Adaptation
Adapts the baseline learning rate of AdamW automatically by estimating the
initial distance to solution in the infinity norm.
This method works best when combined with a learning rate schedule that
treats 1.0 as the base (usually max) value.
References:
[Defazio & Mishchenko, 2023](https://arxiv.org/abs/2301.07733)
Args:
learning_rate: Learning rate scheduling parameter. The recommended schedule
is a linear_schedule with init_value=1.0 and end_value=0, combined with a
0-20% learning rate warmup.
betas: Betas for the underlying AdamW Optimizer.
eps: eps for the underlying AdamW Optimizer.
estim_lr0: Initial (under-)estimate of the learning rate.
weight_decay: AdamW style weight-decay. To use Regular Adam decay, chain
with add_decayed_weights.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params: base.Params) -> DAdaptAdamWState:
exp_avg = tree_map(lambda p: jnp.zeros(p.shape), params)
exp_avg_sq = tree_map(lambda p: jnp.zeros(p.shape), params)
grad_sum = tree_map(lambda p: jnp.zeros(p.shape), params)
estim_lr = estim_lr0
numerator_weighted = 0
count = 0
return DAdaptAdamWState(exp_avg, exp_avg_sq,
grad_sum, estim_lr, numerator_weighted, count)

def update_fn(
updates: base.Updates,
state: DAdaptAdamWState,
params: Optional[base.Params] = None,
) -> Tuple[base.Updates, DAdaptAdamWState]:
if params is None:
raise ValueError(base.NO_PARAMS_MSG)

count = state.count
beta1, beta2 = betas
sb2 = beta2**(0.5)
sched = learning_rate(count) if callable(learning_rate) else learning_rate

grad_sum = state.grad_sum
numerator_weighted = state.numerator_weighted

bc = ((1-beta2**(count+1))**0.5)/(1-beta1**(count+1))
dlr = state.estim_lr * sched * bc

s_weighted = tree_map(lambda sk, eas: sk/(jnp.sqrt(eas)+eps),
grad_sum, state.exp_avg_sq)
numerator_acum = tree_utils.tree_vdot(updates, s_weighted)

exp_avg = tree_map(lambda ea, g: beta1*ea + (1-beta1)*dlr*g,
state.exp_avg, updates)
exp_avg_sq = tree_map(lambda eas, g: beta2*eas + (1-beta2)*g*g,
state.exp_avg_sq, updates)

grad_sum = tree_map(lambda sk, g: sb2*sk + (1-sb2)*dlr*g, grad_sum, updates)

grad_sum_l1 = tree_utils.tree_sum(tree_map(jnp.abs, grad_sum))

numerator_weighted = sb2*numerator_weighted + (1-sb2)*dlr*numerator_acum

d_estimate = numerator_weighted/((1-sb2)*grad_sum_l1)
estim_lr = jnp.maximum(state.estim_lr, d_estimate)

p_update = tree_map(lambda ea, eas, p:
-weight_decay*dlr*p
- ea/(jnp.sqrt(eas)+eps),
exp_avg, exp_avg_sq, params)

new_state = DAdaptAdamWState(exp_avg, exp_avg_sq, grad_sum,
estim_lr, numerator_weighted,
utils.safe_int32_increment(count))
return p_update, new_state

return base.GradientTransformation(init_fn, update_fn)
86 changes: 86 additions & 0 deletions optax/contrib/dadapt_adamw_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `dadapt_adamw.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
import jax.numpy as jnp

from optax import contrib
from optax._src import numerics
from optax._src import update
from optax.tree_utils import _state_utils


def _setup_parabola(dtype):
"""Quadratic function as an optimization target."""
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)

@jax.grad
def get_updates(params):
return jnp.sum(numerics.abs_sq(params - final_params))

return initial_params, final_params, get_updates


def _setup_rosenbrock(dtype):
"""Rosenbrock function as an optimization target."""
a = 1.0
b = 100.0

initial_params = jnp.array([0.0, 0.0], dtype=dtype)
final_params = jnp.array([a, a**2], dtype=dtype)

@jax.grad
def get_updates(params):
return (numerics.abs_sq(a - params[0]) +
b * numerics.abs_sq(params[1] - params[0]**2))

return initial_params, final_params, get_updates

class DAdaptAdamWTest(chex.TestCase):
@parameterized.product(
opt_name=('dadapt_adamw',),
target=(_setup_parabola, _setup_rosenbrock),
dtype=(jnp.float32,),
)
def test_optimization(self, opt_name, target, dtype):

opt = getattr(contrib, opt_name)()
initial_params, final_params, get_updates = target(dtype)

@jax.jit
def step(params, state):
updates = get_updates(params)
updates, state = opt.update(updates, state, params)
params = update.apply_updates(params, updates)
return params, state

params = initial_params
state = opt.init(params)
# A no-op change, to verify that tree map works.
state = _state_utils.tree_map_params(opt, lambda v: v, state)

for _ in range(15000):
params, state = step(params, state)

chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)

if __name__ == '__main__':
absltest.main()
137 changes: 137 additions & 0 deletions optax/contrib/prodigy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Prodigy Optimizer.
A contributed implementation of the method from "Prodigy: An Expeditiously
Adaptive Parameter-Free Learner" (https://arxiv.org/abs/2306.06101) by
Konstantin Mishchenko and Aaron Defazio. A new variant of D-Adapt Adam that
adapts the learning rate faster.
"""
import jax.numpy as jnp
from jax.tree_util import tree_map

from typing import NamedTuple, Optional, Tuple
from optax import tree_utils
from optax._src import base
from optax._src import utils

class ProdigyState(NamedTuple):
exp_avg: base.Updates
exp_avg_sq: base.Updates
grad_sum: base.Updates # exponential moving average of the sum of gradients
params0: base.Updates # Initial point
estim_lr: float # Distance to solution estimate
numerator_weighted: float
count: float

def prodigy(learning_rate: base.ScalarOrSchedule = 0.1,
betas=(0.9, 0.999),
beta3=None,
eps=1e-8,
estim_lr0=1e-6,
estim_lr_coef=1.0,
weight_decay=0) -> base.GradientTransformation:
"""Learning rate free AdamW with Prodigy
Implementation of the Prodigy method from "Prodigy: An Expeditiously
Adaptive Parameter-Free Learner", a version of D-Adapt AdamW that adapts the
baseline learning rate faster by using a weighting of the gradients that
places higher weights on more recent gradients.
This method works best when combined with a learning rate schedule that
treats 1.0 as the base (usually max) value.
References:
[Mishchenko & Defazio, 2023](https://arxiv.org/abs/2306.06101)
Args:
learning_rate: Learning rate scheduling parameter. The recommended schedule
is a linear_schedule with init_value=1.0 and end_value=0, combined with a
0-20% learning rate warmup.
betas: Betas for the underlying AdamW Optimizer.
beta3: Optional momentum parameter for estimation of D.
eps: eps for the underlying AdamW Optimizer.
estim_lr0: Initial (under-)estimate of the learning rate.
estim_lr_coef: LR estimates are multiplied by this parameter.
weight_decay: AdamW style weight-decay. To use Regular Adam decay, chain
with add_decayed_weights.
Returns:
A `GradientTransformation` object.
"""
beta1, beta2 = betas
if beta3 is None:
beta3 = beta2**0.5

def init_fn(params: base.Params) -> ProdigyState:
exp_avg = tree_map(lambda p: jnp.zeros(p.shape), params)
exp_avg_sq = tree_map(lambda p: jnp.zeros(p.shape), params)
grad_sum = tree_map(lambda p: jnp.zeros(p.shape), params)
params0 = params
estim_lr = estim_lr0
numerator_weighted = 0
count = 0
return ProdigyState(exp_avg, exp_avg_sq,
grad_sum, params0, estim_lr,
numerator_weighted, count)

def update_fn(updates: base.Updates,
state: ProdigyState,
params: Optional[base.Params] = None,
) -> Tuple[base.Updates, ProdigyState]:
if params is None:
raise ValueError(base.NO_PARAMS_MSG)

count = state.count
sched = learning_rate(count) if callable(learning_rate) else learning_rate
grad_sum = state.grad_sum
params0 = state.params0
estim_lr = state.estim_lr
numerator_weighted = state.numerator_weighted

bc = ((1-beta2**(count+1))**0.5)/(1-beta1**(count+1))
dlr = estim_lr * sched * bc
dg = tree_map(lambda g: estim_lr * g, updates)

param_diff = tree_map(lambda p0, p: p0-p, params0, params)
numerator_acum = tree_utils.tree_vdot(updates, param_diff)

exp_avg = tree_map(lambda ea, dgk: beta1*ea + (1-beta1)*dgk,
state.exp_avg, dg)
exp_avg_sq = tree_map(lambda eas, dgk: beta2*eas + (1-beta2)*dgk*dgk,
state.exp_avg_sq, dg)

grad_sum = tree_map(lambda sk, dgk: beta3*sk + dlr*dgk/estim_lr0,
grad_sum, dg)

numerator_weighted = beta3*numerator_weighted
numerator_weighted += (estim_lr/estim_lr0)*dlr*numerator_acum

denominator = tree_utils.tree_sum(tree_map(jnp.abs, grad_sum))

lr_estimate = estim_lr_coef*numerator_weighted/denominator
estim_lr = jnp.maximum(state.estim_lr, lr_estimate)

p_update = tree_map(lambda ea, eas, p:
-weight_decay*dlr*p
- dlr*ea/(jnp.sqrt(eas) + estim_lr*eps),
exp_avg, exp_avg_sq, params)

new_state = ProdigyState(exp_avg, exp_avg_sq, grad_sum, params0,
estim_lr, numerator_weighted,
utils.safe_int32_increment(count))
return p_update, new_state

return base.GradientTransformation(init_fn, update_fn)
Loading

0 comments on commit 8cb18fa

Please sign in to comment.