[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/optax_update_guide.ipynb)
[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/optax_update_guide.ipynb)

Colab for
https://flax.readthedocs.io/en/latest/guides/optax_update_guide.html

### Setup

In [1]:
# flax.optim was deprecated after 0.5.3
!pip install -q --force-reinstall flax==0.5.3 optax

In [2]:
from typing import Sequence

import flax
from  flax.training import train_state
import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.optim
import optax

In [3]:
batch = {
    'image': jnp.ones([1, 28, 28, 1]),
    'label': jnp.array([0]),
}



In [4]:
class Perceptron(nn.Module):
  units: Sequence[int]
  @nn.compact
  def __call__(self, x):
    x = x.reshape([x.shape[0], -1]) / 255.
    x = nn.Dense(50)(x)
    x = nn.relu(x)
    return nn.Dense(10)(x)

def loss(params, batch):
  logits = model.apply({'params': params}, batch['image'])
  one_hot = jax.nn.one_hot(batch['label'], 10)
  return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))

model = Perceptron([50, 10])
variables = model.init(jax.random.PRNGKey(0), batch['image'])

jax.tree_util.tree_map(jnp.shape, variables)

FrozenDict({
    params: {
        Dense_0: {
            bias: (50,),
            kernel: (784, 50),
        },
        Dense_1: {
            bias: (10,),
            kernel: (50, 10),
        },
    },
})

In [5]:
import tensorflow_datasets as tfds

builder = tfds.builder('mnist')
builder.download_and_prepare()
ds_test = jax.tree_util.tree_map(jnp.array, builder.as_dataset('test', batch_size=-1))
get_ds_train = lambda: (
    jax.tree_util.tree_map(jnp.array, x)
    for x in builder.as_dataset('train').batch(128))
batch = next(get_ds_train())
jax.tree_util.tree_map(jnp.shape, batch)

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


{'image': (128, 28, 28, 1), 'label': (128,)}

In [6]:
@jax.jit
def eval(params):
  logits = model.apply({'params': params}, ds_test['image'])
  return (logits.argmax(axis=-1) == ds_test['label']).mean()

eval(variables['params'])

DeviceArray(0.103, dtype=float32)

In [7]:
learning_rate, momentum = 0.01, 0.9

### Replacing `flax.optim` with `optax`

In [8]:
@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return optimizer.apply_gradient(grads)

optimizer = flax.optim.Momentum(learning_rate, momentum).create(
    variables['params'])
for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)

eval(optimizer.target)

DeviceArray(0.9165, dtype=float32)

In [9]:
tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.9165, dtype=float32)

In [10]:
@jax.jit
def train_step(state, batch):
  def loss(params):
    logits = state.apply_fn({'params': params}, batch['image'])
    one_hot = jax.nn.one_hot(batch['label'], 10)
    return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
  grads = jax.grad(loss)(state.params)
  return state.apply_gradients(grads=grads)

tx = optax.sgd(learning_rate, momentum)
state = train_state.TrainState.create(
    apply_fn=model.apply, tx=tx, params=variables['params'],
)
opt_state = tx.init(params)

for batch in get_ds_train():
  state = train_step(state, batch)

eval(params)

DeviceArray(0.9165, dtype=float32)

### Composable Gradient Transformations

In [11]:
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.scale(-learning_rate),
)
params = variables['params']
opt_state = tx.init(params)

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.9165, dtype=float32)

### Weight Decay

In [12]:
weight_decay = 1e-5

In [13]:
@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return optimizer.apply_gradient(grads)

optimizer = flax.optim.Adam(learning_rate, weight_decay=weight_decay).create(
    variables['params'])
for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)

eval(optimizer.target)

DeviceArray(0.95129997, dtype=float32)

In [14]:
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    optax.scale(-learning_rate),
)
params = variables['params']
opt_state = tx.init(params)

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.9517, dtype=float32)

### Gradient Clipping

In [15]:
grad_clip_norm = 1.0

In [16]:
@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  grads_flat, _ = jax.tree_util.tree_flatten(grads)
  global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
  g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
  grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
  return optimizer.apply_gradient(grads)

optimizer = flax.optim.Momentum(learning_rate, momentum).create(
    variables['params'])
for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)

eval(optimizer.target)

DeviceArray(0.91679996, dtype=float32)

In [17]:
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.chain(
    optax.clip_by_global_norm(grad_clip_norm),
    optax.trace(decay=momentum),
    optax.scale(-learning_rate),
)
params = variables['params']
opt_state = tx.init(params)

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.91679996, dtype=float32)

### Learning Rate Schedules

In [18]:
schedule = lambda step: learning_rate * jnp.exp(step * 1e-3)

In [19]:
@jax.jit
def train_step(step, optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))

optimizer = flax.optim.Momentum(learning_rate, momentum).create(
    variables['params'])
step = jnp.array(0)
for batch in get_ds_train():
  step, optimizer = train_step(step, optimizer, batch)

eval(optimizer.target)

DeviceArray(0.9201, dtype=float32)

In [20]:
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.scale_by_schedule(lambda step: -schedule(step)),
)
params = variables['params']
opt_state = tx.init(params)

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.9201, dtype=float32)

### Multiple Optimizers

In [21]:
@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return optimizer.apply_gradient(grads)

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)
kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)
optimizer = flax.optim.MultiOptimizer(
    (kernels, kernel_opt),
    (biases, bias_opt)
).create(variables['params'])

for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)

eval(optimizer.target)

DeviceArray(0.91679996, dtype=float32)

In [22]:
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.masked(optax.scale(-learning_rate), kernels_mask),
    optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)
params = variables['params']
opt_state = tx.init(params)

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.91679996, dtype=float32)

In [23]:
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.multi_transform({
      'kernels': optax.scale(-learning_rate),
      'biases': optax.scale(-learning_rate * 0.1),
  }, kernels.update(lambda _: 'kernels',
                    biases.update(lambda _: 'biases', params))),
)
params = variables['params']
opt_state = tx.init(params)

for batch in get_ds_train():
  params, opt_state = train_step(params, opt_state, batch)

eval(params)

DeviceArray(0.91679996, dtype=float32)