# Upgrading my codebase to Optax from `flax.optim`

In 2021, [FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md) proposed to replace `flax.optim` with [Optax](https://optax.readthedocs.io). And since Flax v0.6.0, Optax has been the default Flax optimizer library. This guide shows how to update your :py:mod:`flax.optim` code to Optax.

You can also refer to the [Optax 101 (quick start)](https://optax.readthedocs.io/en/latest/optax-101.html).

## Setup and imports

Install/upgrade Flax in Colab (the Flax package comes with Optax):

In [None]:
!pip install --upgrade -q pip jax jaxlib flax

In [None]:
# Import the necessary libraries

import flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

# Note: this is the minimal code required to make below code run.

batch = {'image': jnp.ones([1, 28, 28, 1]), 'label': jnp.array([0])}
ds_train = [batch]
get_ds_train = lambda: [batch]
model = nn.Dense(1)
variables = model.init(jax.random.key(0), batch['image'])
learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1.
loss = lambda params, batch: jnp.array(0.)


## Replacing ``flax.optim`` with ``optax``

In [None]:
Optax has drop-in replacements for all of Flax's optimizers. Refer to the Optax [Common Optimizers API docs](https://optax.readthedocs.io/en/latest/api.html) for more details.

The usage is very similar, with some differences that include:

- Optax (`optax`) does not keep a copy of the `params`, so they need to be passed around separately.
- Flax provides the utility `flax.training.train_state.TrainState` to store the optimizer state, parameters, and other associated data in a single dataclass (not used in the code example below).

In [None]:
# # Code with `flax.optim`
# 
# @jax.jit
# def train_step(optimizer, batch):
#   grads = jax.grad(loss)(optimizer.target, batch)
#     return optimizer.apply_gradient(grads)
# 
# optimizer_def = flax.optim.Momentum(
#     learning_rate, momentum)
# optimizer = optimizer_def.create(variables['params'])
# 
# for batch in get_ds_train():
#   optimizer = train_step(optimizer, batch)

In [None]:
# Code with Optax

@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

for batch in ds_train:
  params, opt_state = train_step(params, opt_state, batch)

Composable gradient transformations
-----------------------------------

The function |optax.sgd()|_ used in the code snippet above is simply a wrapper for the sequential application of two gradient transformations. Instead of using this alias, it is common to use |optax.chain()|_ to combine multiple of these generic building blocks.

In [None]:
# # Code with Optax (pre-defined alias)
#
# # Note that the aliases follow the convention to use positive
# # values for the learning rate by default.
# tx = optax.sgd(learning_rate, momentum)

In [None]:
# Code with Optax (combining transformations)

tx = optax.chain(
    # 1. Step: keep a trace of past updates and add to gradients.
    optax.trace(decay=momentum),
    # 2. Step: multiply result from step 1 with negative learning rate.
    # Note that `optax.apply_updates()` simply adds the final updates to the
    # parameters, so we must make sure to flip the sign here for gradient
    # descent.
    optax.scale(-learning_rate),
)

## Weight decay

- Some of `flax.optim` optimizers include the weight decay parameter.
- In Optax, some optimizers also have a weight decay parameter (such as `optax.adamw()`). For other optimizers that don't have it by default, the weight decay can be added as another "gradient transformation" `optax.add_decayed_weights()` that adds an update derived from the parameters.

In [None]:
# # Code with `flax.optim`

optimizer_def = flax.optim.Adam(
    learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])

In [None]:
# Code with Optax

# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    # params -= learning_rate * (adam(grads) + params * weight_decay)
    optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)

## Gradient clipping

Training can be stabilized by clipping gradients to a global norm ([Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)).

- In Flax this is often done by processing the gradients before passing them to the optimizer.
- In Optax this becomes just another gradient transformation `optax.clip_by_global_norm()`.

In [None]:
# # Code with `flax.optim`
#
# def train_step(optimizer, batch):
#   grads = jax.grad(loss)(optimizer.target, batch)
#   grads_flat, _ = jax.tree_util.tree_flatten(grads)
#   global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
#   g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
#   grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
#   return optimizer.apply_gradient(grads)

In [None]:
# Code with Optax

tx = optax.chain(
    optax.clip_by_global_norm(grad_clip_norm),
    optax.trace(decay=momentum),
    optax.scale(-learning_rate),
)

## Learning rate schedules

For learning rate schedules: 

- Flax allows overwriting hyper parameters when applying the gradients.
- Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with `optax.scale_by_schedule()`. Optax also allows specifying functions to inject arbitrary scalar values for other gradient updates via `optax.inject_hyperparams()`.

You can learn more in the [Learning rate scheduling](https://flax.readthedocs.io/en/latest/guides/training_techniques/lr_schedule.html) guide and the Optax [Optimizer schedules](https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules) API docs. Note that the
standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a
learning rate schedule as a parameter for ``learning_rate``.

In [None]:
# # Code with `flax.optim`
#
# def train_step(step, optimizer, batch):
#   grads = jax.grad(loss)(optimizer.target, batch)
#   return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))

In [None]:
# Code with Optax

tx = optax.chain(
    optax.trace(decay=momentum),
    # Note that we still want a negative value for scaling the updates!
    optax.scale_by_schedule(lambda step: -schedule(step)),
)

## Multiple optimizers / Updating a subset of parameters

In Flax, traversals are used to specify which parameters should be updated by an
optimizer.
- Combining traversals using `flax.optim` was accomplished with `flax.optim.MultiOptimizer` for applying apply different optimizers on different parameters.
- In Optax, the equivalent methods are `optax.masked()` and `optax.chain()`.

Note that the example below is using `flax.traverse_util` to create the boolean masks required by `optax.masked()`. Alternatively, you could also create them manually, or use `optax.multi_transform()` that takes a multivalent pytree to specify gradient transformations.

Beware that `optax.masked()` flattens the pytree internally, and the inner
gradient transformations will only be called with that partial flattened view of
the params/gradients. This is not a problem usually, but it makes it hard to
nest multiple levels of masked gradient transformations (because the inner
masks will expect the mask to be defined in terms of the partial flattened view
that is not readily available outside the outer mask).

In [None]:
# # Code with `flax.optim`
# 
# kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
# biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)
# 
# kernel_opt = flax.optim.Momentum(learning_rate, momentum)
# bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)
# 
# optimizer = flax.optim.MultiOptimizer(
#     (kernels, kernel_opt),
#     (biases, bias_opt)
# ).create(variables['params'])

In [None]:
# Code with Optax

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.masked(optax.scale(-learning_rate), kernels_mask),
    optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)

## Final words

The patterns described in this guide can be mixed together, and Optax makes it possible to
encapsulate the transformations mentioned here into a single place outside of the main
training loop, which makes testing much easier.