Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/howtos/lr_schedule.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
Learning Rate Scheduling
=============================

Note: See "Learning Rate Schedules" in :doc:`optax_update_guide` for
implementing learning rate schedules using ``optax``.


The learning rate is considered one of the most important hyperparameters for
training deep neural networks, but choosing it can be quite hard.
To simplify this, one can use a so-called *cyclic learning rate*, which
Expand Down
281 changes: 281 additions & 0 deletions docs/howtos/optax_update_guide.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
Upgrading my Codebase to Optax
==============================

We have proposed to replace :py:mod:`flax.optim` with `Optax
<https://optax.readthedocs.io>`_ in 2021 with `FLIP #1009
<https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md>`_ and
the Flax optimizers are now *effectively deprecated*. This guide is targeted
towards :py:mod:`flax.optim` users to help them update their code to Optax.

Code samples below are executable in
`Colab <https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/optax_update_guide.ipynb>`_.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the user journey where someone wants to run this Colab? (as opposed to copying code from the RTD page). I ask because I'd rather avoid having to remember to update two places if we change a variable name in this doc, etc.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, another version of the question: Should we have colabs for all of our HOWTOs? If so, maybe we should find a way to automatically convert them. Or is this one special? I'm hesitant to introduce this new expectation by adding just one guide with a Colab, unless we want to go in that direction more broadly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But no need to block merging this PR on this discussion, just wanted to bring it up. We can resolve separately)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user journey is to provide users with a working example that they can use to experiment with different gradient transformations to replicate their original setup. Of course they could do this on their own starting with the code blocks from the transition guide, but there's still some more code involved and it's annoying having to copy this together and a nicer experience to start with a working Colab.

I first tried to have a Colab that is automatically synced, but then realized two problems:

  1. I think it's very useful to have the side-by-side in the update guide. This would require us to create a new Sphinx plugin that knows how to take two cells in the Colab, do some pre-processing, and then put them into a diff html table.
  2. The Colab will necessarily contain some different code. For example, I want to provide a more realistic example that can be used int he Colab to play around with optimizer settings in a meaningful way. In RST on the other hand we want to have a cheap setup that does not load anything from tdfs. This would probably be a simpler Sphinx plugin replacing some code block with another code block when transcribing the Colab into RST.

If we have a Sphinx plugin implementing the 2 features mentioned above then I'll gladly use it. But it's too much work to implement, and I'd rather have to sync the two documents for the added functionality. I don't expect us to update this code frequently, so it's more about remembering (and even if we forgot to update a small thing, it wouldn't be too bad).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents: I think it is great you added a separate Colab, and I think we should do this for all of our HOWTOs. It is confusing that currently many of our HOWTOs have code snippets, but they cannot be executed in /copied from a single place.

Wrt automation, I agree with Andreas' point that having two separate locations for the code is not great, but given our limited time and priorities probably the best we can do for now.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Leaving comment open for better visiblity.


See also Optax's quick start documentation:
https://optax.readthedocs.io/en/latest/optax-101.html

.. testsetup::

import flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

# Note: this is the minimal code required to make below code run. See in the
# Colab linked above for a more meaningful definition of datasets etc.
batch = {'image': jnp.ones([1, 28, 28, 1]), 'label': jnp.array([0])}
ds_train = [batch]
get_ds_train = lambda: [batch]
model = nn.Dense(1)
variables = model.init(jax.random.PRNGKey(0), batch['image'])
learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1.
loss = lambda params, batch: jnp.array(0.)

Replacing ``flax.optim`` with ``optax``
---------------------------------------

Optax has drop-in replacements for all of Flax's optimizers. Refer to Optax's
documentation `Common Optimizers <https://optax.readthedocs.io/en/latest/api.html>`_
for API details.

The usage is very similar, with the difference that ``optax`` does not keep a
copy of the ``params``, so they need to be passed around separately. Flax
provides the utility :py:class:`~flax.training.train_state.TrainState` to store
optimizer state, parameters, and other associated data in a single dataclass
(not used in code below).

.. codediff::
:title_left: ``flax.optim``
:title_right: ``optax``

@jax.jit
def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)


return optimizer.apply_gradient(grads)

optimizer_def = flax.optim.Momentum(
learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])

for batch in get_ds_train():
optimizer = train_step(optimizer, batch)

---

@jax.jit
def train_step(params, opt_state, batch):
grads = jax.grad(loss)(params, batch)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state

tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

for batch in ds_train:
params, opt_state = train_step(params, opt_state, batch)


Composable Gradient Transformations
-----------------------------------

The function |optax.sgd()|_ used in the code snippet above is simply a wrapper
for the sequential application of two gradient transformations. Instead of using
this alias, it is common to use |optax.chain()|_ to combine multiple of these
generic building blocks.

.. |optax.sgd()| replace:: ``optax.sgd()``
.. _optax.sgd(): https://optax.readthedocs.io/en/latest/api.html#optax.sgd
.. |optax.chain()| replace:: ``optax.chain()``
.. _optax.chain(): https://optax.readthedocs.io/en/latest/api.html#chain

.. codediff::
:title_left: Pre-defined alias
:title_right: Combining transformations

# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)

---

#

tx = optax.chain(
# 1. Step: keep a trace of past updates and add to gradients.
optax.trace(decay=momentum),
# 2. Step: multiply result from step 1 with negative learning rate.
# Note that `optax.apply_updates()` simply adds the final updates to the
# parameters, so we must make sure to flip the sign here for gradient
# descent.
optax.scale(-learning_rate),
)

Weight Decay
------------

Some of Flax's optimizers also include a weight decay. In Optax, some optimizers
also have a weight decay parameter (such as |optax.adamw()|_), and to others the
weight decay can be added as another "gradient transformation"
|optax.add_decayed_weights()|_ that adds an update derived from the parameters.

.. |optax.adamw()| replace:: ``optax.adamw()``
.. _optax.adamw(): https://optax.readthedocs.io/en/latest/api.html#optax.adamw
.. |optax.add_decayed_weights()| replace:: ``optax.add_decayed_weights()``
.. _optax.add_decayed_weights(): https://optax.readthedocs.io/en/latest/api.html#optax.add_decayed_weights

.. codediff::
:title_left: ``flax.optim``
:title_right: ``optax``

optimizer_def = flax.optim.Adam(
learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])

---

# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
optax.scale_by_adam(),
optax.add_decayed_weights(weight_decay),
# params -= learning_rate * (adam(grads) + params * weight_decay)
optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)

Gradient Clipping
-----------------

Training can be stabilized by clipping gradients to a global norm (`Pascanu et
al, 2012 <https://arxiv.org/abs/1211.5063>`_). In Flax this is often done by
processing the gradients before passing them to the optimizer. With Optax this
becomes just another gradient transformation |optax.clip_by_global_norm()|_.

.. |optax.clip_by_global_norm()| replace:: ``optax.clip_by_global_norm()``
.. _optax.clip_by_global_norm(): https://optax.readthedocs.io/en/latest/api.html#optax.clip_by_global_norm

.. codediff::
:title_left: ``flax.optim``
:title_right: ``optax``

def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
grads_flat, _ = jax.tree_flatten(grads)
global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
grads = jax.tree_map(lambda g: g * g_factor, grads)
return optimizer.apply_gradient(grads)

---

tx = optax.chain(
optax.clip_by_global_norm(grad_clip_norm),
optax.trace(decay=momentum),
optax.scale(-learning_rate),
)

Learning Rate Schedules
-----------------------

For learning rate schedules, Flax allows overwriting hyper parameters when
applying the gradients. Optax maintains a step counter and provides this as an
argument to a function for scaling the updates added with
|optax.scale_by_schedule()|_. Optax also allows specifying a functions to
inject arbitrary scalar values for other gradient updates via
|optax.inject_hyperparams()|_.

Read more about learning rate schedules in the :doc:`lr_schedule` guide.

Read more about schedules defined in Optax under `Optimizer Schedules
<https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules>`_. the
standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a
learning rate schedule as a parameter for ``learning_rate``.


.. |optax.scale_by_schedule()| replace:: ``optax.scale_by_schedule()``
.. _optax.scale_by_schedule(): https://optax.readthedocs.io/en/latest/api.html#optax.scale_by_schedule
.. |optax.inject_hyperparams()| replace:: ``optax.inject_hyperparams()``
.. _optax.inject_hyperparams(): https://optax.readthedocs.io/en/latest/api.html#optax.inject_hyperparams

.. codediff::
:title_left: ``flax.optim``
:title_right: ``optax``

def train_step(step, optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could consider making this line less wide since right now the side by side has a scrollbar on laptops, like this?

Suggested change
return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
grad = optimizer.apply_gradient(grads, learning_rate=schedule(step))
return step+1, grad

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it's the optimizer.

Also, we already use the same pattern in the other snippets, so I'd leave it as is for consistency.


---

tx = optax.chain(
optax.trace(decay=momentum),
# Note that we still want a negative value for scaling the updates!
optax.scale_by_schedule(lambda step: -schedule(step)),
)

Multiple Optimizers / Updating a Subset of Parameters
-----------------------------------------------------

In Flax, traversals are used to specify which parameters should be updated by an
optimizer. And you can combine traversals using
:py:class:`flax.optim.MultiOptimizer` to apply different optimizers on different
parameters. The equivalent in Optax is |optax.masked()|_ and |optax.chain()|_.

Note that the example below is using :py:mod:`flax.traverse_util` to create the
boolean masks required by |optax.masked()|_ - alternatively you could also
create them manually, or use |optax.multi_transform()|_ that takes a
multivalent pytree to specify gradient transformations.

Beware that |optax.masked()|_ flattens the pytree internally and the inner
gradient transformations will only be called with that partial flattened view of
the params/gradients. This is not a problem usually, but it makes it hard to
nest multiple levels of masked gradient transformations (because the inner
masks will expect the mask to be defined in terms of the partial flattened view
that is not readily available outside the outer mask).

.. |optax.masked()| replace:: ``optax.masked()``
.. _optax.masked(): https://optax.readthedocs.io/en/latest/api.html#optax.masked
.. |optax.multi_transform()| replace:: ``optax.multi_transform()``
.. _optax.multi_transform(): https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform

.. codediff::
:title_left: ``flax.optim``
:title_right: ``optax``

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)


optimizer = flax.optim.MultiOptimizer(
(kernels, kernel_opt),
(biases, bias_opt)
).create(variables['params'])

---

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'm sorry to have missed this in the first round of review but I'm curious -- why do we need to be using ModelParamTraversal at all? Presumably other users of Optax, e.g. those using Haiku, have some other pattern? I'd love to eventually deprecate ModelParamTraversal as I don't think it's needed anymore, and if needed replace it with something pretty much like a tree_map that gives you the full path of subkeys rather than just the values.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's plenty of different ways how you can create those masks.

I took ModelParamTraversal because it's the most straightforward transition from existing code, but feel free to update the example to use traverse_utils.flatten_dict() instead.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll give it a shot.

biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
optax.trace(decay=momentum),
optax.masked(optax.scale(-learning_rate), kernels_mask),
optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)

Final Words
-----------

All above patterns can of course also be mixed and Optax makes it possible to
encapsulate all these transformations into a single place outside the main
training loop, which makes testing much easier.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ For a quick introduction and short example snippets, see our `README
howtos/lr_schedule
howtos/extracting_intermediates
howtos/model_surgery
howtos/optax_update_guide

.. toctree::
:maxdepth: 1
Expand Down
Loading