# Lookahead Optimizer on MNIST

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/lookahead_mnist.ipynb)

This notebook trains a simple Convolution Neural Network (CNN) for hand-written digit recognition (MNIST dataset) using {py:func}`optax.lookahead`.

To run the colab locally you need install the
`tensorflow`, `tensorflow-datasets` packages via `pip`.

In [None]:
from flax import linen as nn
import jax
import jax.numpy as jnp
import optax

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
# @markdown The learning rate for the fast optimizer:
FAST_LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown The learning rate for the slow optimizer:
SLOW_LEARNING_RATE = 0.1 # @param{type:"number"}
# @markdown Number of fast optimizer steps to take before synchronizing parameters:
SYNC_PERIOD = 5 # @param{type:"integer"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 256 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 1 # @param{type:"integer"}

MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using `tensorflow_datasets`, apply min-max normalization to images, shuffle the data in the train set and create batches of size `BATCH_SIZE`.


In [None]:
(train_loader, test_loader), info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
)
NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape

min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)

train_loader_batched = train_loader.shuffle(
    buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)

test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)

The data is ready! Next let's define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple CNN.

In [None]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

In [None]:
net = CNN()

@jax.jit
def predict(params, inputs):
  return net.apply({'params': params}, inputs)


@jax.jit
def loss_accuracy(params, data):
  """Computes loss and accuracy over a mini-batch.

  Args:
    params: parameters of the model.
    data: tuple of (inputs, labels).

  Returns:
    loss: float
  """
  inputs, labels = data
  logits = predict(params, inputs)
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=labels
  ).mean()
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
  return loss, {"accuracy": accuracy}

Next we need to initialize CNN parameters and solver state. We also define a convenience function `dataset_stats` that we'll call once per epoch to collect the loss and accuracy of our solver over the test set. We will be using the Lookahead optimizer.
Its wrapper keeps a pair of slow and fast parameters. To
initialize them, we create a pair of synchronized parameters from the
initial model parameters.


## Understanding the Lookahead Optimizer

The lookahead optimizer is a wrapper that maintains two sets of parameters:
- **Fast parameters**: Updated frequently by an inner "fast" optimizer (e.g., Adam, SGD)
- **Slow parameters**: Updated less frequently by interpolating with the fast parameters

### Key Concepts:

1. **`LookaheadParams`**: A special container holding both fast and slow parameters
   - `params.fast`: The fast-changing parameters used for gradient computation
   - `params.slow`: The slow-changing parameters used for validation/inference

2. **Synchronization Period**: Every `k` steps (e.g., 5 steps), the slow parameters are updated:
   ```
   slow_params = slow_params + slow_step_size * (fast_params - slow_params)
   ```

3. **Initialization**: Always use `optax.LookaheadParams.init_synced(params)` to wrap your initial parameters

4. **Gradient Computation**: Always compute gradients with respect to `params.fast`

5. **Validation/Inference**: Use `params.slow` for final predictions and validation

Let's see this in action!

### Step 1: Initialize the Lookahead Optimizer

First, we create the fast optimizer (Adam in this case) and wrap it with lookahead:

In [None]:
# Step 1: Create the fast optimizer (e.g., Adam)
fast_solver = optax.adam(FAST_LEARNING_RATE)

# Step 2: Wrap it with lookahead
# - sync_period: how often to synchronize fast and slow params (every k steps)
# - slow_step_size: interpolation coefficient for slow parameter updates
solver = optax.lookahead(fast_solver, SYNC_PERIOD, SLOW_LEARNING_RATE)

# Step 3: Initialize model parameters (standard Flax initialization)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)
params = net.init({"params": rng}, dummy_data)["params"]

# Step 4: CRITICAL - Wrap parameters in LookaheadParams
# This creates both fast and slow params, initially synchronized
params = optax.LookaheadParams.init_synced(params)

# Step 5: Initialize optimizer state
solver_state = solver.init(params)

print("✓ Lookahead optimizer initialized!")
print(f"  Fast learning rate: {FAST_LEARNING_RATE}")
print(f"  Slow learning rate: {SLOW_LEARNING_RATE}")
print(f"  Sync period: {SYNC_PERIOD} steps")
print(f"\nParameters structure:")
print(f"  Type: {type(params)}")
print(f"  Has 'fast' attribute: {hasattr(params, 'fast')}")
print(f"  Has 'slow' attribute: {hasattr(params, 'slow')}")


def dataset_stats(params, data_loader):
  """Computes loss and accuracy over the dataset `data_loader`."""
  all_accuracy = []
  all_loss = []
  for batch in data_loader.as_numpy_iterator():
    batch_loss, batch_aux = loss_accuracy(params, batch)
    all_loss.append(batch_loss)
    all_accuracy.append(batch_aux["accuracy"])
  return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}

### Step 2: Training Loop

Now we train the model. The key points to remember:

1. **Compute gradients on `params.fast`**: The fast parameters are updated every step
2. **Pass full `params` to `solver.update()`**: The optimizer needs both fast and slow params
3. **Use `params.slow` for validation**: The slow parameters typically generalize better

Finally, we do the actual training. The next cell train the model for  `N_EPOCHS`. Within each epoch we iterate over the batched loader `train_loader_batched`, and once per epoch we also compute the test set accuracy and loss.

In [None]:
train_accuracy = []
train_losses = []

# Computes test set accuracy at initialization.
# NOTE: We use params.slow for validation!
test_stats = dataset_stats(params.slow, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]


@jax.jit
def train_step(params, solver_state, batch):
  """Performs a single training step with lookahead optimizer.
  
  Key points:
  1. Compute loss and gradients on params.fast (the fast-changing parameters)
  2. Pass gradients and full params to solver.update()
  3. Apply updates to get new params (both fast and slow will be updated)
  """
  # IMPORTANT: Compute gradients with respect to params.fast
  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
      params.fast, batch
  )
  # Update both fast and slow params (slow updates happen every sync_period steps)
  updates, solver_state = solver.update(grad, solver_state, params)
  # Apply updates to get new synchronized params
  params = optax.apply_updates(params, updates)
  return params, solver_state, loss, aux


for epoch in range(N_EPOCHS):
  train_accuracy_epoch = []
  train_losses_epoch = []

  for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):
    params, solver_state, train_loss, train_aux = train_step(
        params, solver_state, train_batch
    )
    train_accuracy_epoch.append(train_aux["accuracy"])
    train_losses_epoch.append(train_loss)
    if step % 20 == 0:
      print(
          f"step {step}, train loss: {train_loss:.2e}, train accuracy:"
          f" {train_aux['accuracy']:.2f}"
      )

  # IMPORTANT: Validation is done on the slow lookahead parameters
  # The slow parameters typically provide better generalization
  test_stats = dataset_stats(params.slow, test_loader_batched)
  test_accuracy.append(test_stats["accuracy"])
  test_losses.append(test_stats["loss"])
  train_accuracy.append(np.mean(train_accuracy_epoch))
  train_losses.append(np.mean(train_losses_epoch))

In [None]:
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"

## Summary: Correct Lookahead Usage Pattern

Here's the essential pattern for using the lookahead optimizer:

```python
# 1. Create fast optimizer and wrap with lookahead
fast_optimizer = optax.adam(learning_rate)
optimizer = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)

# 2. Initialize your model parameters (any framework)
params = initialize_your_model()

# 3. CRITICAL: Wrap params in LookaheadParams
params = optax.LookaheadParams.init_synced(params)

# 4. Initialize optimizer state
opt_state = optimizer.init(params)

# 5. Training loop
for step in range(num_steps):
    # Compute gradients on params.fast (the fast parameters)
    grads = compute_gradients(params.fast, batch)
    
    # Update optimizer (pass full params object)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    
    # Apply updates (updates both fast and slow params)
    params = optax.apply_updates(params, updates)

# 6. For inference/validation, use params.slow
final_predictions = model(params.slow, test_data)
```

### Common Mistakes to Avoid:

❌ **Don't** forget to wrap params: `params = optax.LookaheadParams.init_synced(params)`  
❌ **Don't** compute gradients on `params` directly  
❌ **Don't** use `params.fast` for final validation/inference  

✅ **Do** compute gradients on `params.fast`  
✅ **Do** pass the full `params` object to `optimizer.update()`  
✅ **Do** use `params.slow` for validation and inference  

The lookahead optimizer can be used as a drop-in replacement for other optimizers by simply following this pattern!