-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
D-Adaptation and Prodigy implementation
- Loading branch information
Showing
6 changed files
with
438 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
|
||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""D-Adatation (AdamW variant). | ||
A contributed implementation of the method from "Learning-Rate-Free Learning | ||
by D-Adaptation" (https://arxiv.org/abs/2301.07733) by Aaron Defazio and | ||
Konstantin Mishchenko (ICML 2023 Outstanding Paper award). | ||
""" | ||
import jax.numpy as jnp | ||
from jax.tree_util import tree_map | ||
|
||
from typing import NamedTuple, Optional, Tuple | ||
from optax import tree_utils | ||
from optax._src import base | ||
from optax._src import utils | ||
|
||
class DAdaptAdamWState(NamedTuple): | ||
exp_avg: base.Updates | ||
exp_avg_sq: base.Updates | ||
grad_sum: base.Updates # exponential moving average of the sum of gradients | ||
estim_lr: float # Distance to solution estimate | ||
numerator_weighted: float | ||
count: float | ||
|
||
def dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0, | ||
betas=(0.9, 0.999), | ||
eps=1e-8, | ||
estim_lr0=1e-6, | ||
weight_decay=0) -> base.GradientTransformation: | ||
"""Learning rate free AdamW by D-Adaptation | ||
Adapts the baseline learning rate of AdamW automatically by estimating the | ||
initial distance to solution in the infinity norm. | ||
This method works best when combined with a learning rate schedule that | ||
treats 1.0 as the base (usually max) value. | ||
References: | ||
[Defazio & Mishchenko, 2023](https://arxiv.org/abs/2301.07733) | ||
Args: | ||
learning_rate: Learning rate scheduling parameter. The recommended schedule | ||
is a linear_schedule with init_value=1.0 and end_value=0, combined with a | ||
0-20% learning rate warmup. | ||
betas: Betas for the underlying AdamW Optimizer. | ||
eps: eps for the underlying AdamW Optimizer. | ||
estim_lr0: Initial (under-)estimate of the learning rate. | ||
weight_decay: AdamW style weight-decay. To use Regular Adam decay, chain | ||
with add_decayed_weights. | ||
Returns: | ||
A `GradientTransformation` object. | ||
""" | ||
def init_fn(params: base.Params) -> DAdaptAdamWState: | ||
exp_avg = tree_map(lambda p: jnp.zeros(p.shape), params) | ||
exp_avg_sq = tree_map(lambda p: jnp.zeros(p.shape), params) | ||
grad_sum = tree_map(lambda p: jnp.zeros(p.shape), params) | ||
estim_lr = estim_lr0 | ||
numerator_weighted = 0 | ||
count = 0 | ||
return DAdaptAdamWState(exp_avg, exp_avg_sq, | ||
grad_sum, estim_lr, numerator_weighted, count) | ||
|
||
def update_fn( | ||
updates: base.Updates, | ||
state: DAdaptAdamWState, | ||
params: Optional[base.Params] = None, | ||
) -> Tuple[base.Updates, DAdaptAdamWState]: | ||
if params is None: | ||
raise ValueError(base.NO_PARAMS_MSG) | ||
|
||
count = state.count | ||
beta1, beta2 = betas | ||
sb2 = beta2**(0.5) | ||
sched = learning_rate(count) if callable(learning_rate) else learning_rate | ||
|
||
grad_sum = state.grad_sum | ||
numerator_weighted = state.numerator_weighted | ||
|
||
bc = ((1-beta2**(count+1))**0.5)/(1-beta1**(count+1)) | ||
dlr = state.estim_lr * sched * bc | ||
|
||
s_weighted = tree_map(lambda sk, eas: sk/(jnp.sqrt(eas)+eps), | ||
grad_sum, state.exp_avg_sq) | ||
numerator_acum = tree_utils.tree_vdot(updates, s_weighted) | ||
|
||
exp_avg = tree_map(lambda ea, g: beta1*ea + (1-beta1)*dlr*g, | ||
state.exp_avg, updates) | ||
exp_avg_sq = tree_map(lambda eas, g: beta2*eas + (1-beta2)*g*g, | ||
state.exp_avg_sq, updates) | ||
|
||
grad_sum = tree_map(lambda sk, g: sb2*sk + (1-sb2)*dlr*g, grad_sum, updates) | ||
|
||
grad_sum_l1 = tree_utils.tree_sum(tree_map(jnp.abs, grad_sum)) | ||
|
||
numerator_weighted = sb2*numerator_weighted + (1-sb2)*dlr*numerator_acum | ||
|
||
d_estimate = numerator_weighted/((1-sb2)*grad_sum_l1) | ||
estim_lr = jnp.maximum(state.estim_lr, d_estimate) | ||
|
||
p_update = tree_map(lambda ea, eas, p: | ||
-weight_decay*dlr*p | ||
- ea/(jnp.sqrt(eas)+eps), | ||
exp_avg, exp_avg_sq, params) | ||
|
||
new_state = DAdaptAdamWState(exp_avg, exp_avg_sq, grad_sum, | ||
estim_lr, numerator_weighted, | ||
utils.safe_int32_increment(count)) | ||
return p_update, new_state | ||
|
||
return base.GradientTransformation(init_fn, update_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for `dadapt_adamw.py`.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from optax import contrib | ||
from optax._src import numerics | ||
from optax._src import update | ||
from optax.tree_utils import _state_utils | ||
|
||
|
||
def _setup_parabola(dtype): | ||
"""Quadratic function as an optimization target.""" | ||
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) | ||
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) | ||
|
||
@jax.grad | ||
def get_updates(params): | ||
return jnp.sum(numerics.abs_sq(params - final_params)) | ||
|
||
return initial_params, final_params, get_updates | ||
|
||
|
||
def _setup_rosenbrock(dtype): | ||
"""Rosenbrock function as an optimization target.""" | ||
a = 1.0 | ||
b = 100.0 | ||
|
||
initial_params = jnp.array([0.0, 0.0], dtype=dtype) | ||
final_params = jnp.array([a, a**2], dtype=dtype) | ||
|
||
@jax.grad | ||
def get_updates(params): | ||
return (numerics.abs_sq(a - params[0]) + | ||
b * numerics.abs_sq(params[1] - params[0]**2)) | ||
|
||
return initial_params, final_params, get_updates | ||
|
||
class DAdaptAdamWTest(chex.TestCase): | ||
@parameterized.product( | ||
opt_name=('dadapt_adamw',), | ||
target=(_setup_parabola, _setup_rosenbrock), | ||
dtype=(jnp.float32,), | ||
) | ||
def test_optimization(self, opt_name, target, dtype): | ||
|
||
opt = getattr(contrib, opt_name)() | ||
initial_params, final_params, get_updates = target(dtype) | ||
|
||
@jax.jit | ||
def step(params, state): | ||
updates = get_updates(params) | ||
updates, state = opt.update(updates, state, params) | ||
params = update.apply_updates(params, updates) | ||
return params, state | ||
|
||
params = initial_params | ||
state = opt.init(params) | ||
# A no-op change, to verify that tree map works. | ||
state = _state_utils.tree_map_params(opt, lambda v: v, state) | ||
|
||
for _ in range(15000): | ||
params, state = step(params, state) | ||
|
||
chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) | ||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
|
||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Prodigy Optimizer. | ||
A contributed implementation of the method from "Prodigy: An Expeditiously | ||
Adaptive Parameter-Free Learner" (https://arxiv.org/abs/2306.06101) by | ||
Konstantin Mishchenko and Aaron Defazio. A new variant of D-Adapt Adam that | ||
adapts the learning rate faster. | ||
""" | ||
import jax.numpy as jnp | ||
from jax.tree_util import tree_map | ||
|
||
from typing import NamedTuple, Optional, Tuple | ||
from optax import tree_utils | ||
from optax._src import base | ||
from optax._src import utils | ||
|
||
class ProdigyState(NamedTuple): | ||
exp_avg: base.Updates | ||
exp_avg_sq: base.Updates | ||
grad_sum: base.Updates # exponential moving average of the sum of gradients | ||
params0: base.Updates # Initial point | ||
estim_lr: float # Distance to solution estimate | ||
numerator_weighted: float | ||
count: float | ||
|
||
def prodigy(learning_rate: base.ScalarOrSchedule = 0.1, | ||
betas=(0.9, 0.999), | ||
beta3=None, | ||
eps=1e-8, | ||
estim_lr0=1e-6, | ||
estim_lr_coef=1.0, | ||
weight_decay=0) -> base.GradientTransformation: | ||
"""Learning rate free AdamW with Prodigy | ||
Implementation of the Prodigy method from "Prodigy: An Expeditiously | ||
Adaptive Parameter-Free Learner", a version of D-Adapt AdamW that adapts the | ||
baseline learning rate faster by using a weighting of the gradients that | ||
places higher weights on more recent gradients. | ||
This method works best when combined with a learning rate schedule that | ||
treats 1.0 as the base (usually max) value. | ||
References: | ||
[Mishchenko & Defazio, 2023](https://arxiv.org/abs/2306.06101) | ||
Args: | ||
learning_rate: Learning rate scheduling parameter. The recommended schedule | ||
is a linear_schedule with init_value=1.0 and end_value=0, combined with a | ||
0-20% learning rate warmup. | ||
betas: Betas for the underlying AdamW Optimizer. | ||
beta3: Optional momentum parameter for estimation of D. | ||
eps: eps for the underlying AdamW Optimizer. | ||
estim_lr0: Initial (under-)estimate of the learning rate. | ||
estim_lr_coef: LR estimates are multiplied by this parameter. | ||
weight_decay: AdamW style weight-decay. To use Regular Adam decay, chain | ||
with add_decayed_weights. | ||
Returns: | ||
A `GradientTransformation` object. | ||
""" | ||
beta1, beta2 = betas | ||
if beta3 is None: | ||
beta3 = beta2**0.5 | ||
|
||
def init_fn(params: base.Params) -> ProdigyState: | ||
exp_avg = tree_map(lambda p: jnp.zeros(p.shape), params) | ||
exp_avg_sq = tree_map(lambda p: jnp.zeros(p.shape), params) | ||
grad_sum = tree_map(lambda p: jnp.zeros(p.shape), params) | ||
params0 = params | ||
estim_lr = estim_lr0 | ||
numerator_weighted = 0 | ||
count = 0 | ||
return ProdigyState(exp_avg, exp_avg_sq, | ||
grad_sum, params0, estim_lr, | ||
numerator_weighted, count) | ||
|
||
def update_fn(updates: base.Updates, | ||
state: ProdigyState, | ||
params: Optional[base.Params] = None, | ||
) -> Tuple[base.Updates, ProdigyState]: | ||
if params is None: | ||
raise ValueError(base.NO_PARAMS_MSG) | ||
|
||
count = state.count | ||
sched = learning_rate(count) if callable(learning_rate) else learning_rate | ||
grad_sum = state.grad_sum | ||
params0 = state.params0 | ||
estim_lr = state.estim_lr | ||
numerator_weighted = state.numerator_weighted | ||
|
||
bc = ((1-beta2**(count+1))**0.5)/(1-beta1**(count+1)) | ||
dlr = estim_lr * sched * bc | ||
dg = tree_map(lambda g: estim_lr * g, updates) | ||
|
||
param_diff = tree_map(lambda p0, p: p0-p, params0, params) | ||
numerator_acum = tree_utils.tree_vdot(updates, param_diff) | ||
|
||
exp_avg = tree_map(lambda ea, dgk: beta1*ea + (1-beta1)*dgk, | ||
state.exp_avg, dg) | ||
exp_avg_sq = tree_map(lambda eas, dgk: beta2*eas + (1-beta2)*dgk*dgk, | ||
state.exp_avg_sq, dg) | ||
|
||
grad_sum = tree_map(lambda sk, dgk: beta3*sk + dlr*dgk/estim_lr0, | ||
grad_sum, dg) | ||
|
||
numerator_weighted = beta3*numerator_weighted | ||
numerator_weighted += (estim_lr/estim_lr0)*dlr*numerator_acum | ||
|
||
denominator = tree_utils.tree_sum(tree_map(jnp.abs, grad_sum)) | ||
|
||
lr_estimate = estim_lr_coef*numerator_weighted/denominator | ||
estim_lr = jnp.maximum(state.estim_lr, lr_estimate) | ||
|
||
p_update = tree_map(lambda ea, eas, p: | ||
-weight_decay*dlr*p | ||
- dlr*ea/(jnp.sqrt(eas) + estim_lr*eps), | ||
exp_avg, exp_avg_sq, params) | ||
|
||
new_state = ProdigyState(exp_avg, exp_avg_sq, grad_sum, params0, | ||
estim_lr, numerator_weighted, | ||
utils.safe_int32_increment(count)) | ||
return p_update, new_state | ||
|
||
return base.GradientTransformation(init_fn, update_fn) |
Oops, something went wrong.