-
Notifications
You must be signed in to change notification settings - Fork 794
Adds optax update guide. #1774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds optax update guide. #1774
Changes from all commits
1451cb8
e25260f
057cdcd
aee1dba
a3cbbbc
d17c0ae
a0369ae
8d4cf04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,281 @@ | ||||||||
| Upgrading my Codebase to Optax | ||||||||
| ============================== | ||||||||
|
|
||||||||
| We have proposed to replace :py:mod:`flax.optim` with `Optax | ||||||||
| <https://optax.readthedocs.io>`_ in 2021 with `FLIP #1009 | ||||||||
| <https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md>`_ and | ||||||||
| the Flax optimizers are now *effectively deprecated*. This guide is targeted | ||||||||
| towards :py:mod:`flax.optim` users to help them update their code to Optax. | ||||||||
|
|
||||||||
| Code samples below are executable in | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| `Colab <https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/optax_update_guide.ipynb>`_. | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the user journey where someone wants to run this Colab? (as opposed to copying code from the RTD page). I ask because I'd rather avoid having to remember to update two places if we change a variable name in this doc, etc.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or, another version of the question: Should we have colabs for all of our HOWTOs? If so, maybe we should find a way to automatically convert them. Or is this one special? I'm hesitant to introduce this new expectation by adding just one guide with a Colab, unless we want to go in that direction more broadly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (But no need to block merging this PR on this discussion, just wanted to bring it up. We can resolve separately)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the user journey is to provide users with a working example that they can use to experiment with different gradient transformations to replicate their original setup. Of course they could do this on their own starting with the code blocks from the transition guide, but there's still some more code involved and it's annoying having to copy this together and a nicer experience to start with a working Colab. I first tried to have a Colab that is automatically synced, but then realized two problems:
If we have a Sphinx plugin implementing the 2 features mentioned above then I'll gladly use it. But it's too much work to implement, and I'd rather have to sync the two documents for the added functionality. I don't expect us to update this code frequently, so it's more about remembering (and even if we forgot to update a small thing, it wouldn't be too bad).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My two cents: I think it is great you added a separate Colab, and I think we should do this for all of our HOWTOs. It is confusing that currently many of our HOWTOs have code snippets, but they cannot be executed in /copied from a single place. Wrt automation, I agree with Andreas' point that having two separate locations for the code is not great, but given our limited time and priorities probably the best we can do for now.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack. Leaving comment open for better visiblity. |
||||||||
|
|
||||||||
| See also Optax's quick start documentation: | ||||||||
| https://optax.readthedocs.io/en/latest/optax-101.html | ||||||||
|
|
||||||||
| .. testsetup:: | ||||||||
|
|
||||||||
| import flax | ||||||||
| import jax | ||||||||
| import jax.numpy as jnp | ||||||||
| import flax.linen as nn | ||||||||
| import optax | ||||||||
|
|
||||||||
| # Note: this is the minimal code required to make below code run. See in the | ||||||||
| # Colab linked above for a more meaningful definition of datasets etc. | ||||||||
| batch = {'image': jnp.ones([1, 28, 28, 1]), 'label': jnp.array([0])} | ||||||||
| ds_train = [batch] | ||||||||
| get_ds_train = lambda: [batch] | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| model = nn.Dense(1) | ||||||||
| variables = model.init(jax.random.PRNGKey(0), batch['image']) | ||||||||
| learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1. | ||||||||
| loss = lambda params, batch: jnp.array(0.) | ||||||||
|
|
||||||||
| Replacing ``flax.optim`` with ``optax`` | ||||||||
| --------------------------------------- | ||||||||
|
|
||||||||
| Optax has drop-in replacements for all of Flax's optimizers. Refer to Optax's | ||||||||
| documentation `Common Optimizers <https://optax.readthedocs.io/en/latest/api.html>`_ | ||||||||
| for API details. | ||||||||
|
|
||||||||
| The usage is very similar, with the difference that ``optax`` does not keep a | ||||||||
| copy of the ``params``, so they need to be passed around separately. Flax | ||||||||
| provides the utility :py:class:`~flax.training.train_state.TrainState` to store | ||||||||
| optimizer state, parameters, and other associated data in a single dataclass | ||||||||
| (not used in code below). | ||||||||
|
|
||||||||
| .. codediff:: | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| :title_left: ``flax.optim`` | ||||||||
| :title_right: ``optax`` | ||||||||
|
|
||||||||
| @jax.jit | ||||||||
| def train_step(optimizer, batch): | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| grads = jax.grad(loss)(optimizer.target, batch) | ||||||||
|
|
||||||||
|
|
||||||||
| return optimizer.apply_gradient(grads) | ||||||||
|
|
||||||||
| optimizer_def = flax.optim.Momentum( | ||||||||
| learning_rate, momentum) | ||||||||
| optimizer = optimizer_def.create(variables['params']) | ||||||||
|
|
||||||||
| for batch in get_ds_train(): | ||||||||
| optimizer = train_step(optimizer, batch) | ||||||||
|
|
||||||||
| --- | ||||||||
|
|
||||||||
| @jax.jit | ||||||||
| def train_step(params, opt_state, batch): | ||||||||
| grads = jax.grad(loss)(params, batch) | ||||||||
| updates, opt_state = tx.update(grads, opt_state) | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| params = optax.apply_updates(params, updates) | ||||||||
| return params, opt_state | ||||||||
|
|
||||||||
| tx = optax.sgd(learning_rate, momentum) | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| params = variables['params'] | ||||||||
| opt_state = tx.init(params) | ||||||||
|
|
||||||||
| for batch in ds_train: | ||||||||
| params, opt_state = train_step(params, opt_state, batch) | ||||||||
|
|
||||||||
|
|
||||||||
| Composable Gradient Transformations | ||||||||
| ----------------------------------- | ||||||||
|
|
||||||||
| The function |optax.sgd()|_ used in the code snippet above is simply a wrapper | ||||||||
| for the sequential application of two gradient transformations. Instead of using | ||||||||
| this alias, it is common to use |optax.chain()|_ to combine multiple of these | ||||||||
| generic building blocks. | ||||||||
|
|
||||||||
| .. |optax.sgd()| replace:: ``optax.sgd()`` | ||||||||
| .. _optax.sgd(): https://optax.readthedocs.io/en/latest/api.html#optax.sgd | ||||||||
| .. |optax.chain()| replace:: ``optax.chain()`` | ||||||||
| .. _optax.chain(): https://optax.readthedocs.io/en/latest/api.html#chain | ||||||||
|
|
||||||||
| .. codediff:: | ||||||||
| :title_left: Pre-defined alias | ||||||||
| :title_right: Combining transformations | ||||||||
|
|
||||||||
| # Note that the aliases follow the convention to use positive | ||||||||
| # values for the learning rate by default. | ||||||||
| tx = optax.sgd(learning_rate, momentum) | ||||||||
|
|
||||||||
| --- | ||||||||
|
|
||||||||
| # | ||||||||
|
|
||||||||
| tx = optax.chain( | ||||||||
| # 1. Step: keep a trace of past updates and add to gradients. | ||||||||
| optax.trace(decay=momentum), | ||||||||
| # 2. Step: multiply result from step 1 with negative learning rate. | ||||||||
| # Note that `optax.apply_updates()` simply adds the final updates to the | ||||||||
| # parameters, so we must make sure to flip the sign here for gradient | ||||||||
| # descent. | ||||||||
| optax.scale(-learning_rate), | ||||||||
| ) | ||||||||
|
|
||||||||
| Weight Decay | ||||||||
| ------------ | ||||||||
|
|
||||||||
| Some of Flax's optimizers also include a weight decay. In Optax, some optimizers | ||||||||
| also have a weight decay parameter (such as |optax.adamw()|_), and to others the | ||||||||
| weight decay can be added as another "gradient transformation" | ||||||||
| |optax.add_decayed_weights()|_ that adds an update derived from the parameters. | ||||||||
|
|
||||||||
| .. |optax.adamw()| replace:: ``optax.adamw()`` | ||||||||
| .. _optax.adamw(): https://optax.readthedocs.io/en/latest/api.html#optax.adamw | ||||||||
| .. |optax.add_decayed_weights()| replace:: ``optax.add_decayed_weights()`` | ||||||||
| .. _optax.add_decayed_weights(): https://optax.readthedocs.io/en/latest/api.html#optax.add_decayed_weights | ||||||||
|
|
||||||||
| .. codediff:: | ||||||||
| :title_left: ``flax.optim`` | ||||||||
| :title_right: ``optax`` | ||||||||
|
|
||||||||
| optimizer_def = flax.optim.Adam( | ||||||||
| learning_rate, weight_decay=weight_decay) | ||||||||
| optimizer = optimizer_def.create(variables['params']) | ||||||||
|
|
||||||||
| --- | ||||||||
|
|
||||||||
| # (Note that you could also use `optax.adamw()` in this case) | ||||||||
| tx = optax.chain( | ||||||||
| optax.scale_by_adam(), | ||||||||
| optax.add_decayed_weights(weight_decay), | ||||||||
| # params -= learning_rate * (adam(grads) + params * weight_decay) | ||||||||
| optax.scale(-learning_rate), | ||||||||
| ) | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| # Note that you'll need to specify `params` when computing the udpates: | ||||||||
| # tx.update(grads, opt_state, params) | ||||||||
|
|
||||||||
| Gradient Clipping | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| ----------------- | ||||||||
|
|
||||||||
| Training can be stabilized by clipping gradients to a global norm (`Pascanu et | ||||||||
| al, 2012 <https://arxiv.org/abs/1211.5063>`_). In Flax this is often done by | ||||||||
| processing the gradients before passing them to the optimizer. With Optax this | ||||||||
| becomes just another gradient transformation |optax.clip_by_global_norm()|_. | ||||||||
|
|
||||||||
| .. |optax.clip_by_global_norm()| replace:: ``optax.clip_by_global_norm()`` | ||||||||
| .. _optax.clip_by_global_norm(): https://optax.readthedocs.io/en/latest/api.html#optax.clip_by_global_norm | ||||||||
|
|
||||||||
| .. codediff:: | ||||||||
| :title_left: ``flax.optim`` | ||||||||
| :title_right: ``optax`` | ||||||||
|
|
||||||||
| def train_step(optimizer, batch): | ||||||||
| grads = jax.grad(loss)(optimizer.target, batch) | ||||||||
| grads_flat, _ = jax.tree_flatten(grads) | ||||||||
| global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat])) | ||||||||
| g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2) | ||||||||
| grads = jax.tree_map(lambda g: g * g_factor, grads) | ||||||||
| return optimizer.apply_gradient(grads) | ||||||||
|
|
||||||||
| --- | ||||||||
|
|
||||||||
| tx = optax.chain( | ||||||||
| optax.clip_by_global_norm(grad_clip_norm), | ||||||||
| optax.trace(decay=momentum), | ||||||||
| optax.scale(-learning_rate), | ||||||||
| ) | ||||||||
|
|
||||||||
| Learning Rate Schedules | ||||||||
| ----------------------- | ||||||||
|
|
||||||||
| For learning rate schedules, Flax allows overwriting hyper parameters when | ||||||||
| applying the gradients. Optax maintains a step counter and provides this as an | ||||||||
| argument to a function for scaling the updates added with | ||||||||
| |optax.scale_by_schedule()|_. Optax also allows specifying a functions to | ||||||||
| inject arbitrary scalar values for other gradient updates via | ||||||||
| |optax.inject_hyperparams()|_. | ||||||||
|
|
||||||||
| Read more about learning rate schedules in the :doc:`lr_schedule` guide. | ||||||||
|
|
||||||||
| Read more about schedules defined in Optax under `Optimizer Schedules | ||||||||
| <https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules>`_. the | ||||||||
| standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a | ||||||||
| learning rate schedule as a parameter for ``learning_rate``. | ||||||||
|
|
||||||||
|
|
||||||||
| .. |optax.scale_by_schedule()| replace:: ``optax.scale_by_schedule()`` | ||||||||
| .. _optax.scale_by_schedule(): https://optax.readthedocs.io/en/latest/api.html#optax.scale_by_schedule | ||||||||
| .. |optax.inject_hyperparams()| replace:: ``optax.inject_hyperparams()`` | ||||||||
| .. _optax.inject_hyperparams(): https://optax.readthedocs.io/en/latest/api.html#optax.inject_hyperparams | ||||||||
|
|
||||||||
| .. codediff:: | ||||||||
| :title_left: ``flax.optim`` | ||||||||
| :title_right: ``optax`` | ||||||||
|
|
||||||||
| def train_step(step, optimizer, batch): | ||||||||
| grads = jax.grad(loss)(optimizer.target, batch) | ||||||||
| return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step)) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could consider making this line less wide since right now the side by side has a scrollbar on laptops, like this?
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that it's the Also, we already use the same pattern in the other snippets, so I'd leave it as is for consistency. |
||||||||
|
|
||||||||
| --- | ||||||||
|
|
||||||||
| tx = optax.chain( | ||||||||
| optax.trace(decay=momentum), | ||||||||
| # Note that we still want a negative value for scaling the updates! | ||||||||
| optax.scale_by_schedule(lambda step: -schedule(step)), | ||||||||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| ) | ||||||||
|
|
||||||||
| Multiple Optimizers / Updating a Subset of Parameters | ||||||||
| ----------------------------------------------------- | ||||||||
|
|
||||||||
| In Flax, traversals are used to specify which parameters should be updated by an | ||||||||
| optimizer. And you can combine traversals using | ||||||||
| :py:class:`flax.optim.MultiOptimizer` to apply different optimizers on different | ||||||||
| parameters. The equivalent in Optax is |optax.masked()|_ and |optax.chain()|_. | ||||||||
|
|
||||||||
| Note that the example below is using :py:mod:`flax.traverse_util` to create the | ||||||||
| boolean masks required by |optax.masked()|_ - alternatively you could also | ||||||||
| create them manually, or use |optax.multi_transform()|_ that takes a | ||||||||
| multivalent pytree to specify gradient transformations. | ||||||||
|
|
||||||||
| Beware that |optax.masked()|_ flattens the pytree internally and the inner | ||||||||
| gradient transformations will only be called with that partial flattened view of | ||||||||
| the params/gradients. This is not a problem usually, but it makes it hard to | ||||||||
| nest multiple levels of masked gradient transformations (because the inner | ||||||||
| masks will expect the mask to be defined in terms of the partial flattened view | ||||||||
| that is not readily available outside the outer mask). | ||||||||
|
|
||||||||
| .. |optax.masked()| replace:: ``optax.masked()`` | ||||||||
| .. _optax.masked(): https://optax.readthedocs.io/en/latest/api.html#optax.masked | ||||||||
| .. |optax.multi_transform()| replace:: ``optax.multi_transform()`` | ||||||||
| .. _optax.multi_transform(): https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform | ||||||||
|
|
||||||||
| .. codediff:: | ||||||||
| :title_left: ``flax.optim`` | ||||||||
| :title_right: ``optax`` | ||||||||
|
|
||||||||
| kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) | ||||||||
| biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p) | ||||||||
|
|
||||||||
| kernel_opt = flax.optim.Momentum(learning_rate, momentum) | ||||||||
| bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum) | ||||||||
|
|
||||||||
|
|
||||||||
| optimizer = flax.optim.MultiOptimizer( | ||||||||
| (kernels, kernel_opt), | ||||||||
| (biases, bias_opt) | ||||||||
| ).create(variables['params']) | ||||||||
|
|
||||||||
| --- | ||||||||
|
|
||||||||
| kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I'm sorry to have missed this in the first round of review but I'm curious -- why do we need to be using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's plenty of different ways how you can create those masks. I took
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I'll give it a shot. |
||||||||
| biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p) | ||||||||
|
|
||||||||
| all_false = jax.tree_map(lambda _: False, params) | ||||||||
| kernels_mask = kernels.update(lambda _: True, all_false) | ||||||||
| biases_mask = biases.update(lambda _: True, all_false) | ||||||||
|
|
||||||||
| tx = optax.chain( | ||||||||
| optax.trace(decay=momentum), | ||||||||
| optax.masked(optax.scale(-learning_rate), kernels_mask), | ||||||||
| optax.masked(optax.scale(-learning_rate * 0.1), biases_mask), | ||||||||
| ) | ||||||||
|
|
||||||||
| Final Words | ||||||||
| ----------- | ||||||||
|
|
||||||||
| All above patterns can of course also be mixed and Optax makes it possible to | ||||||||
| encapsulate all these transformations into a single place outside the main | ||||||||
| training loop, which makes testing much easier. | ||||||||
Uh oh!
There was an error while loading. Please reload this page.