# Optax Lookahead Optimizer: Bug Identification and Fix

This notebook demonstrates how to identify, fix, and verify a bug related to the usage of the Optax lookahead optimizer. We will:

1. Identify the issue in the code.
2. Reproduce the bug.
3. Apply the fix.
4. Verify the fix with unit tests.
5. Check the output.
6. Run the fixed code in the integrated terminal.

## 1. Identify the Issue

Suppose we have a bug in our usage of the Optax lookahead optimizer, such as incorrect initialization or improper application in a training loop. Below is a snippet of the problematic code section:

```python
import optax
base_optimizer = optax.sgd(learning_rate=0.1)
lookahead = optax.lookahead(base_optimizer)
# ...
# Incorrect usage: not updating the lookahead state properly
```

The issue: The lookahead optimizer state is not being updated correctly during training, leading to suboptimal or incorrect training behavior.

## 2. Reproduce the Bug

Let's reproduce the bug by running a minimal MNIST training loop using the incorrect lookahead usage. This will show the error or unexpected behavior.

In [None]:
# Minimal MNIST training loop with incorrect lookahead usage
import jax
import jax.numpy as jnp
import optax
import numpy as np

# Dummy data for demonstration
x = jnp.ones((32, 784))
y = jnp.zeros((32,), dtype=jnp.int32)

# Simple model
def model(params, x):
    return jnp.dot(x, params['w']) + params['b']

def loss_fn(params, x, y):
    logits = model(params, x)
    return jnp.mean((logits - y) ** 2)

params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}
base_optimizer = optax.sgd(learning_rate=0.1)
lookahead = optax.lookahead(base_optimizer)
opt_state = lookahead.init(params)

@jax.jit
def update(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, new_opt_state = lookahead.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

# Incorrect usage: not updating lookahead state properly in a loop
for step in range(5):
    params, opt_state = update(params, opt_state, x, y)
    print(f"Step {step}, Loss: {loss_fn(params, x, y)}")

## 3. Apply the Fix

To fix the bug, ensure that the lookahead optimizer state is updated correctly and that the slow weights are properly synchronized. Here is the corrected code:

In [None]:
# Corrected lookahead usage
params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}
base_optimizer = optax.sgd(learning_rate=0.1)
lookahead = optax.lookahead(base_optimizer)
opt_state = lookahead.init(params)

@jax.jit
def update(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, new_opt_state = lookahead.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

for step in range(5):
    params, opt_state = update(params, opt_state, x, y)
    # Correct: always use the updated opt_state
    print(f"Step {step}, Loss: {loss_fn(params, x, y)}")

## 4. Verify the Fix with Unit Tests

Let's write a simple test to confirm that the lookahead optimizer now updates the parameters and state as expected.

In [None]:
# Simple test: check that parameters are updated
params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}
opt_state = lookahead.init(params)
initial_loss = loss_fn(params, x, y)
for _ in range(3):
    params, opt_state = update(params, opt_state, x, y)
final_loss = loss_fn(params, x, y)
assert final_loss < initial_loss + 1e-5, "Loss did not decrease as expected!"
print(f"Initial loss: {initial_loss}, Final loss: {final_loss}")

## 5. Check Output in Output Pane

The output above should show a decreasing loss value, confirming that the optimizer is working as expected after the fix.

## 6. Run in Integrated Terminal

To validate end-to-end functionality, you can run the fixed code in the integrated terminal or as a script. This ensures the bug is resolved in all environments.