diff --git a/optax/_src/mechanic.py b/optax/_src/mechanic.py new file mode 100644 index 000000000..b0e57dcb9 --- /dev/null +++ b/optax/_src/mechanic.py @@ -0,0 +1,192 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mechanic wrapper for automatic black box learning rate tuning.""" + + +import functools +import operator +from typing import NamedTuple, Tuple + +import chex +import jax +import jax.numpy as jnp +import optax + + +_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST) + + +def _vdot_safe(a, b): + cvdot = _vdot(jnp.asarray(a), jnp.asarray(b)) + return cvdot + + +@jax.jit +def tree_vdot(tree_x, tree_y): + """Compute the inner product .""" + vdots = jax.tree_util.tree_map(_vdot_safe, tree_x, tree_y) + return jax.tree_util.tree_reduce(operator.add, vdots) + + +@jax.jit +def tree_sum(tree): + return jax.tree_util.tree_reduce(lambda x, y: x + y, tree, initializer=0) + + +@jax.jit +def tree_norm(tree): + return jnp.sqrt(tree_sum(jax.tree_map(lambda x: jnp.sum(x**2), tree))) + + +class MechanicState(NamedTuple): + """State of the `GradientTransformation` returned by `mechanize`.""" + base_optimizer_state: optax.OptState + count: chex.Array # shape=(), dtype=jnp.int32. + r: optax.Updates + m: optax.Updates + v: optax.Updates + s: optax.Updates + x0: optax.Updates + + +def mechanize( + base_optimizer: optax.GradientTransformation, + weight_decay: float = 1e-2, + eps: float = 1e-10, + s_init: float = 1e-8 +) -> optax.GradientTransformation: + """Mechanic - black box learning rate tuner/optimizer. + + Accumulates updates returned by the base_optimizer and learns the scale of + the updates (also know as learning rate or step size) to apply on a per + iteration basis. + + Note that Mechanic does NOT eschew a need for a learning rate schedule, + you are free to apply a learning rate schedule with base learning rate set to + 1.0 (or any other constant) and Mechanic will learn the right scale factor + automatically. + + As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as + inner optimizers but we expect it to work with almost any first-order + optimizer. + + References: + [Cutkosky et al, 2023](https://arxiv.org/pdf/2306.00144.pdf) + + Args: + base_optimizer: Base optimizer to compute updates from. + weight_decay: A scalar weight decay rate. + eps: epsilon for mechanic. + s_init: initial scale factor. Default should work almost all the time. + + Returns: + A `GradientTransformation` with init and update functions. + """ + + def init_fn(params: optax.Params) -> MechanicState: + x0 = jax.tree_util.tree_map(lambda t: t.astype(jnp.float32), params) + num_betas = 6 + r = jnp.zeros([num_betas,], jnp.float32) + v = jnp.zeros([num_betas,], jnp.float32) + m = jnp.zeros([num_betas,], jnp.float32) + s = jnp.ones([num_betas,], jnp.float32) * s_init + return MechanicState( + base_optimizer_state=base_optimizer.init(params), + count=jnp.zeros([], jnp.int32), + r=r, + m=m, + v=v, + s=s, + x0=x0, + ) + + def update_fn( + updates: optax.Updates, state: MechanicState, params: optax.Params + ) -> Tuple[optax.Params, MechanicState]: + count_inc = optax.safe_int32_increment(state.count) + new_neg_updates, base_optimizer_state = base_optimizer.update( + updates, state.base_optimizer_state, params + ) + # Since a lot of training loops unfreezes weights to replace it with + # pre-trained weights, we want to make sure we start from actually used + # weights instead of what they were initialized with. + x0 = jax.lax.cond(state.count == 0, lambda: params, lambda: state.x0) + + # Add weight decay to raw gradients, note that this is othogonal to any + # weight decay applied to inner_optimizer updates. + s_sum = jnp.sum(state.s) + grad_norm = tree_norm(updates) + param_norm = tree_norm(params) + + def add_weight_decay(gi, pi): + return gi + weight_decay * s_sum * grad_norm / (param_norm + eps) * pi + + updates = jax.tree_util.tree_map( + add_weight_decay, + updates, + params, + ) + + # We use the memory efficient version of Mechanic where we re-compute + # \Delta every iteration. + delta_prev = jax.tree_util.tree_map( + lambda xti, x0i: (x0i - xti) / (s_sum + eps), + params, + x0) + + # We actually want to add the updates, but since optax by default flips + # signs when applying the learning rate, we substract instead. + delta = jax.tree_util.tree_map( + lambda si, ui: si - ui, delta_prev, new_neg_updates + ) + + # Now we are ready to run the actual Mechanic algorithm. + h = tree_vdot(updates, delta_prev) + betas = jnp.array([ + 0.9, + 0.99, + 0.999, + 0.9999, + 0.99999, + 0.999999, + ]) + + m = jnp.maximum(betas * state.m, jnp.abs(h) + eps) + v = (betas**2) * state.v + h**2 + r = betas * state.r + h * state.s + rc = jnp.maximum(0.0, r) + wealth = (s_init / jnp.size(betas)) * m + rc + s = wealth / (jnp.sqrt(v) + eps) + + # Once we have the scale factor s, we produce new params with it. + new_x0 = x0 + new_params = jax.tree_util.tree_map( + lambda x0, deltai: x0 - jnp.sum(s) * deltai, + new_x0, + delta) + new_neg_updates = jax.tree_util.tree_map( + lambda np, op: np - op, new_params, params) + + return new_neg_updates, MechanicState( + base_optimizer_state=base_optimizer_state, + count=count_inc, + r=r, + m=m, + v=v, + s=s, + x0=new_x0, + ) + + return optax.GradientTransformation(init_fn, update_fn) diff --git a/optax/_src/mechanic_test.py b/optax/_src/mechanic_test.py new file mode 100644 index 000000000..f4fcfa025 --- /dev/null +++ b/optax/_src/mechanic_test.py @@ -0,0 +1,54 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `mechanic.py`.""" + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp +import optax + +from optax._src import mechanic + + +class MechanicTest(chex.TestCase): + + def setUp(self): + super().setUp() + self.init_params = ( + jnp.array([[0.5, 0.5], [0.5, 0.5]])) + self.per_step_updates = (jnp.array([[0.1, -0.1], [0.01, 0.01]])) + + @chex.all_variants(with_pmap=False) + def test_mechanized_adam(self): + params = self.init_params + + adamw = optax.adamw(0.1, 0.9, 0.999) + optim = mechanic.mechanize(adamw) + init_fn = self.variant(optim.init) + transform_fn = self.variant(optim.update) + + def _update(unused_batch): + return transform_fn(self.per_step_updates, state, params) + state = init_fn(params) + chex.assert_tree_all_finite(state) + pmap_fn = jax.pmap(_update, axis_name='batch') + + updates, state = pmap_fn(jnp.array([1.0])) + chex.assert_tree_all_finite((params, updates, state)) + + +if __name__ == '__main__': + absltest.main()