# Maximal Update Parameterization (μP) for RNNs — Student Notebook

In this notebook, you will learn how to apply μP (Maximal Update Parameterization) to Recurrent Neural Networks. You'll discover:

1. How different parameterization schemes affect training dynamics
2. Why RNNs require special consideration for width scaling
3. The relationship between spectral radius and gradient flow
4. How to implement μP correctly for RNNs

## Instructions

- **Code cells** marked with `# TODO` or `# YOUR CODE HERE` require you to fill in the missing code
- **Question cells** ask you to write your answer in the provided space
- Run the verification cells to check your answers
- Some cells have hints — try without them first!

In [None]:
import torch as th
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

th.manual_seed(42)
np.random.seed(42)

---

# Part 1: Understanding Parameterization Schemes

When scaling neural networks to different widths, initialization and forward pass scaling dramatically affect training. There are two schemes we consider:

## 1.1 Standard Parameterization (SP)

For weight matrix $W \in \mathbb{R}^{n_{\text{out}} \times n_{\text{in}}}$:
- **Initialization**: $W_{ij} \sim \mathcal{N}(0, 1/n_{\text{in}})$ (Xavier/He)
- **Forward pass**: $y = Wx$ (no additional scaling)

## 1.2 Maximal Update Parameterization (μP)

- **Initialization**: $W_{ij} \sim \mathcal{N}(0, 1)$ (unit variance)
- **Forward pass scaling**: Depends on layer type
- **Learning rate scaling**: Different multipliers for different layers

---

## Question 1.1: μP Forward Scaling

In μP, different layer types use different forward pass multipliers. Fill in the table:

| Layer Type | Forward Multiplier | LR Multiplier |
|------------|-------------------|---------------|
| Input → Hidden | ??? | 1 |
| Hidden → Hidden | ??? | $1/\text{width\_mult}$ |
| Hidden → Output | ??? | $1/\text{width\_mult}$ |

**Your Answer:**

```
Input → Hidden:   Forward multiplier = _______________
Hidden → Hidden:  Forward multiplier = _______________
Hidden → Output:  Forward multiplier = _______________
```

*Hint: The multipliers involve $n$ (hidden width), $d$ (input dimension), or both. Think about what keeps activations O(1).*

---

# Part 2: The RNN Challenge — Spectral Radius

RNNs have a unique challenge: the hidden-to-hidden weight matrix $W_{hh}$ is applied **repeatedly**:

$$h_t = \phi(W_{xh} x_t + W_{hh} h_{t-1} + b)$$

The **spectral radius** $\rho(W_{hh})$ — the largest absolute eigenvalue — controls whether signals grow or shrink over time.

## 2.1 Random Matrix Theory

For a random matrix $W \in \mathbb{R}^{n \times n}$ with i.i.d. entries $W_{ij} \sim \mathcal{N}(0, \sigma^2)$:

$$\rho(W) \approx \sigma \sqrt{n}$$

---

## Question 2.1: Spectral Radius under Different Parameterizations

Use the formula $\rho(W) \approx \sigma\sqrt{n}$ to fill in the effective spectral radius:

**Standard Parameterization (SP):**
- Init: $W_{ij} \sim \mathcal{N}(0, 1/n)$, so $\sigma = 1/\sqrt{n}$
- Forward: Use $W$ directly (no scaling)
- Effective spectral radius $\rho(W) = $ _______________

**μP with $1/\sqrt{n}$ scaling:**
- Init: $W_{ij} \sim \mathcal{N}(0, 1)$, so $\sigma = 1$
- Forward: Use $\frac{1}{\sqrt{n}}W$
- Effective spectral radius $\rho(\frac{1}{\sqrt{n}}W) = $ _______________

**WRONG scaling with $1/n$:**
- Init: $W_{ij} \sim \mathcal{N}(0, 1)$, so $\sigma = 1$
- Forward: Use $\frac{1}{n}W$
- Effective spectral radius $\rho(\frac{1}{n}W) = $ _______________

---

## Exercise 2.2: Verify Spectral Radius Empirically

Complete the function to compute the spectral radius of the effective recurrence matrix under different parameterizations.

In [None]:
def compute_spectral_radius(n, parameterization='sp', num_samples=50):
    """
    Compute spectral radius of the effective recurrence matrix.
    
    Args:
        n: Matrix dimension (hidden size)
        parameterization: 'sp', 'mup_correct', or 'mup_wrong'
        num_samples: Number of random matrices to average over
    
    Returns:
        mean_radius, std_radius
    """
    radii = []
    
    for _ in range(num_samples):
        if parameterization == 'sp':
            # SP: W ~ N(0, 1/n), used directly
            # TODO: Initialize W with variance 1/n
            W = th.randn(n, n) / ...  # YOUR CODE HERE
            effective_W = W
            
        elif parameterization == 'mup_correct':
            # μP (CORRECT): W ~ N(0, 1), scaled by 1/√n in forward pass
            W = th.randn(n, n)
            # TODO: Apply correct μP scaling
            effective_W = W / ...  # YOUR CODE HERE
            
        elif parameterization == 'mup_wrong':
            # WRONG: W ~ N(0, 1), scaled by 1/n
            W = th.randn(n, n)
            # TODO: Apply wrong 1/n scaling
            effective_W = W / ...  # YOUR CODE HERE
        
        # Compute spectral radius (max absolute eigenvalue)
        eigs = th.linalg.eigvals(effective_W)
        radii.append(eigs.abs().max().item())
    
    return np.mean(radii), np.std(radii)

In [None]:
# Test your implementation
widths = [32, 64, 128, 256, 512]
print("Effective Spectral Radius:")
print(f"{'Width':>6} | {'SP':>12} | {'μP (1/√n)':>12} | {'WRONG (1/n)':>12} | {'Theory 1/√n':>12}")
print("-" * 70)

for n in widths:
    sp_rho, _ = compute_spectral_radius(n, 'sp')
    mup_correct_rho, _ = compute_spectral_radius(n, 'mup_correct')
    mup_wrong_rho, _ = compute_spectral_radius(n, 'mup_wrong')
    theoretical_wrong = 1 / math.sqrt(n)
    print(f"{n:>6} | {sp_rho:>12.4f} | {mup_correct_rho:>12.4f} | {mup_wrong_rho:>12.4f} | {theoretical_wrong:>12.4f}")

---

## Question 2.3: Interpreting the Results

Based on the table above, answer:

**A)** Which parameterizations give a spectral radius that is approximately constant (≈1) regardless of width?

**Your Answer:** _______________

**B)** For the "WRONG (1/n)" scaling, how does the spectral radius change as width increases?

**Your Answer:** _______________

**C)** If the spectral radius is $\rho$ and we have $T=100$ timesteps, the gradient magnitude scales roughly as $\rho^T$. For width $n=256$ with wrong 1/n scaling, approximately what is the gradient ratio? (Use $\rho \approx 1/\sqrt{256} = 1/16$)

**Your Answer:** $(1/16)^{100} = $ _______________

---

# Part 3: Implementing RNN Layers

Now let's implement RNN layers with different parameterizations and measure their gradient flow.

## Exercise 3.1: Standard Parameterization RNN

Complete the `StandardRNNLayer` class.

In [None]:
class StandardRNNLayer(nn.Module):
    """
    Standard Parameterization RNN.
    
    - Init: Xavier (1/√fan_in variance)
    - Forward: No additional scaling
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # TODO: Initialize W_xh with Xavier scaling (divide by √input_size)
        self.W_xh = nn.Parameter(th.randn(hidden_size, input_size) / ...)  # YOUR CODE HERE
        
        # TODO: Initialize W_hh with Xavier scaling (divide by √hidden_size)
        self.W_hh = nn.Parameter(th.randn(hidden_size, hidden_size) / ...)  # YOUR CODE HERE
        
        self.b = nn.Parameter(th.zeros(hidden_size))
    
    def forward(self, x):
        B, T, _ = x.shape
        h = th.zeros(B, self.hidden_size, device=x.device)
        h_list = []
        
        for t in range(T):
            # Standard forward pass: no additional scaling
            h = th.tanh(F.linear(x[:, t], self.W_xh) + F.linear(h, self.W_hh) + self.b)
            h_list.append(h)
            h.retain_grad()
        
        self.h_list = h_list
        return th.stack(h_list, dim=1), h

## Exercise 3.2: Correct μP RNN Layer

This is the key exercise! Complete the `MuPRNNLayer` with **correct** μP scaling.

Remember:
- Input → Hidden: scale by $1/\sqrt{d}$ where $d$ is input dimension
- Hidden → Hidden: scale by $1/\sqrt{n}$ where $n$ is hidden size (**NOT** $1/n$!)

In [None]:
class MuPRNNLayer(nn.Module):
    """
    CORRECT μP RNN Layer.
    
    - Init: Unit variance for all weights
    - Forward: 1/√d for input, 1/√n for recurrence
    """
    def __init__(self, input_size, hidden_size, base_hidden_size=64):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.base_hidden_size = base_hidden_size
        self.width_mult = hidden_size / base_hidden_size
        
        # μP: Unit variance initialization (no scaling here)
        self.W_xh = nn.Parameter(th.randn(hidden_size, input_size))
        self.W_hh = nn.Parameter(th.randn(hidden_size, hidden_size))
        self.b = nn.Parameter(th.zeros(hidden_size))
    
    def forward(self, x):
        B, T, d = x.shape
        n = self.hidden_size
        h = th.zeros(B, n, device=x.device)
        h_list = []
        
        for t in range(T):
            # TODO: Complete the μP forward pass
            # - Input projection should be scaled by 1/√d
            # - Recurrence should be scaled by 1/√n (NOT 1/n!)
            h = th.tanh(
                F.linear(x[:, t], self.W_xh) / ... +  # YOUR CODE HERE: input scaling
                F.linear(h, self.W_hh) / ... +        # YOUR CODE HERE: recurrence scaling
                self.b
            )
            h_list.append(h)
            h.retain_grad()
        
        self.h_list = h_list
        return th.stack(h_list, dim=1), h
    
    def get_lr_multipliers(self):
        """μP learning rate multipliers."""
        return {
            'W_xh': 1.0,
            'W_hh': 1.0 / self.width_mult,  # Key μP LR scaling
            'b': 1.0,
        }

## Exercise 3.3: Wrong Scaling RNN (for comparison)

Complete this **incorrectly** parameterized RNN to see what goes wrong with 1/n scaling.

In [None]:
class WrongScalingRNNLayer(nn.Module):
    """
    INCORRECTLY parameterized RNN — DO NOT USE IN PRACTICE!
    
    This uses 1/n scaling for recurrence, which causes vanishing gradients.
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.W_xh = nn.Parameter(th.randn(hidden_size, input_size))
        self.W_hh = nn.Parameter(th.randn(hidden_size, hidden_size))
        self.b = nn.Parameter(th.zeros(hidden_size))
    
    def forward(self, x):
        B, T, d = x.shape
        n = self.hidden_size
        h = th.zeros(B, n, device=x.device)
        h_list = []
        
        for t in range(T):
            # WRONG: 1/n scaling causes vanishing gradients!
            h = th.tanh(
                F.linear(x[:, t], self.W_xh) / math.sqrt(d) +
                F.linear(h, self.W_hh) / n +  # WRONG: should be √n
                self.b
            )
            h_list.append(h)
            h.retain_grad()
        
        self.h_list = h_list
        return th.stack(h_list, dim=1), h

---

## Exercise 3.4: Measuring Gradient Flow

Complete the function to measure how well gradients flow backward through time.

In [None]:
def measure_gradient_flow(rnn_class, hidden_size, seq_len=10, num_trials=20, **kwargs):
    """
    Measure gradient flow through time.
    
    Returns the ratio: ||∇_{h_0} L|| / ||∇_{h_T} L||
    
    A ratio close to 1 means gradients flow well.
    A ratio << 1 means gradients are vanishing.
    """
    grad_ratios = []
    
    for _ in range(num_trials):
        rnn = rnn_class(input_size=8, hidden_size=hidden_size, **kwargs)
        x = th.randn(4, seq_len, 8)
        
        _, last_h = rnn(x)
        
        # TODO: Compute a simple loss at the final timestep
        loss = ...  # YOUR CODE HERE (hint: sum of last_h)
        loss.backward()
        
        # TODO: Get gradient norms at first and last timestep
        grad_first = rnn.h_list[0].grad.norm().item()  # Gradient at t=0
        grad_last = rnn.h_list[-1].grad.norm().item()  # Gradient at t=T
        
        # TODO: Compute the ratio (with numerical stability check)
        if grad_last > 1e-10:
            ratio = ...  # YOUR CODE HERE
            grad_ratios.append(ratio)
    
    return np.mean(grad_ratios), np.std(grad_ratios)

In [None]:
# Test gradient flow across different widths
widths = [32, 64, 128, 256]
seq_len = 10

print(f"Gradient Flow Ratio (∇h₀/∇h_T) for T={seq_len}:")
print(f"{'Width':>6} | {'SP':>16} | {'μP (correct)':>16} | {'1/n (WRONG)':>16}")
print("-" * 60)

for n in widths:
    sp_mean, sp_std = measure_gradient_flow(StandardRNNLayer, n, seq_len)
    mup_mean, mup_std = measure_gradient_flow(MuPRNNLayer, n, seq_len)
    wrong_mean, wrong_std = measure_gradient_flow(WrongScalingRNNLayer, n, seq_len)
    
    print(f"{n:>6} | {sp_mean:>7.4f} ± {sp_std:.4f} | {mup_mean:>7.4f} ± {mup_std:.4f} | {wrong_mean:>7.4f} ± {wrong_std:.4f}")

---

## Question 3.5: Analyzing Gradient Flow Results

Based on the gradient flow measurements:

**A)** For SP and correct μP, does the gradient ratio change significantly as width increases?

**Your Answer:** _______________

**B)** For the wrong 1/n scaling, what happens to the gradient ratio as width doubles?

**Your Answer:** _______________

**C)** Why is this problematic for training wide RNNs with wrong scaling?

**Your Answer:** _______________

---

# Part 4: Controlling the Spectral Radius with α

We can introduce a learnable parameter $\alpha$ to control the effective spectral radius:

$$h_t = \phi\left(\frac{1}{\sqrt{d}}W_{xh}x_t + \frac{\alpha}{\sqrt{n}}W_{hh}h_{t-1} + b\right)$$

The effective spectral radius becomes $\alpha$ (independent of width!).

## Exercise 4.1: μP RNN with Learnable α

In [None]:
class MuPRNNLayerWithAlpha(nn.Module):
    """
    μP RNN with learnable spectral radius control.
    """
    def __init__(self, input_size, hidden_size, base_hidden_size=64, init_alpha=0.95):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.base_hidden_size = base_hidden_size
        self.width_mult = hidden_size / base_hidden_size
        
        self.W_xh = nn.Parameter(th.randn(hidden_size, input_size))
        self.W_hh = nn.Parameter(th.randn(hidden_size, hidden_size))
        self.b = nn.Parameter(th.zeros(hidden_size))
        
        # TODO: Initialize log_alpha so that alpha = init_alpha
        # Hint: if alpha = exp(log_alpha), what should log_alpha be?
        self.log_alpha = nn.Parameter(th.tensor(...))  # YOUR CODE HERE
    
    @property
    def alpha(self):
        return th.exp(self.log_alpha)
    
    def forward(self, x):
        B, T, d = x.shape
        n = self.hidden_size
        h = th.zeros(B, n, device=x.device)
        h_list = []
        
        for t in range(T):
            # TODO: Include alpha in the recurrence scaling
            # The recurrence should be: alpha * W_hh @ h / √n
            h = th.tanh(
                F.linear(x[:, t], self.W_xh) / math.sqrt(d) +
                ... * F.linear(h, self.W_hh) / math.sqrt(n) +  # YOUR CODE HERE
                self.b
            )
            h_list.append(h)
            h.retain_grad()
        
        self.h_list = h_list
        return th.stack(h_list, dim=1), h
    
    def get_lr_multipliers(self):
        return {
            'W_xh': 1.0,
            'W_hh': 1.0 / self.width_mult,
            'b': 1.0,
            'log_alpha': 1.0,
        }

In [None]:
# Verify that α controls gradient flow independently of width
def analyze_alpha_effect(widths, alphas, seq_len=10):
    results = {}
    
    for alpha in alphas:
        results[alpha] = []
        for n in widths:
            mean, _ = measure_gradient_flow(
                MuPRNNLayerWithAlpha, n, seq_len, 
                num_trials=15, init_alpha=alpha
            )
            results[alpha].append(mean)
    
    return results

widths = [32, 64, 128, 256]
alphas = [0.8, 0.9, 0.95, 1.0]
seq_len = 10

results = analyze_alpha_effect(widths, alphas, seq_len)

print(f"Gradient Flow Ratio for Different α Values (T={seq_len}):")
print(f"{'Width':>6}", end='')
for a in alphas:
    print(f" | {'α='+str(a):>12}", end='')
print()
print("-" * (8 + 15 * len(alphas)))

for i, n in enumerate(widths):
    print(f"{n:>6}", end='')
    for a in alphas:
        print(f" | {results[a][i]:>12.4f}", end='')
    print()

print(f"\nTheoretical α^T:")
for a in alphas:
    print(f"  α={a}: {a**seq_len:.4f}")

---

## Question 4.2: Understanding α

**A)** For a fixed α, does the gradient ratio depend on width?

**Your Answer:** _______________

**B)** How does the empirical gradient ratio compare to the theoretical prediction $\alpha^T$?

**Your Answer:** _______________

**C)** If you need gradients at $t=0$ to be at least 10% of gradients at $t=T$ for a sequence of length $T=50$, what is the minimum α you need?

*Hint: Solve $\alpha^{50} \geq 0.1$*

**Your Answer:** $\alpha \geq $ _______________

---

# Part 5: Width × Time Interaction

With correct μP, width and time become **orthogonal** concerns:
- Width affects optimization (LR transfer)
- Time affects gradient flow (via $\alpha^T$)

## Exercise 5.1: Verify Orthogonality

In [None]:
def analyze_width_time_interaction(widths, seq_lens, alpha=0.95, num_trials=10):
    """
    Measure gradient flow for different (width, seq_len) combinations.
    
    Expected behavior with correct μP:
    - Rows (fixed T) should be constant (width-independent)
    - Columns (fixed n) should decay as α^T
    """
    results = np.zeros((len(widths), len(seq_lens)))
    
    for i, n in enumerate(widths):
        for j, T in enumerate(seq_lens):
            mean, _ = measure_gradient_flow(
                MuPRNNLayerWithAlpha, n, T, 
                num_trials=num_trials, init_alpha=alpha
            )
            results[i, j] = mean
    
    return results

widths = [32, 64, 128, 256]
seq_lens = [2, 5, 10, 15, 20]
alpha = 0.95

results = analyze_width_time_interaction(widths, seq_lens, alpha)

print(f"Gradient Flow Matrix (α={alpha}):")
print(f"{'Width':>6}", end='')
for T in seq_lens:
    print(f" | T={T:>3}", end='')
print()
print("-" * (8 + 8 * len(seq_lens)))

for i, n in enumerate(widths):
    print(f"{n:>6}", end='')
    for j in range(len(seq_lens)):
        print(f" | {results[i,j]:>5.3f}", end='')
    print()

In [None]:
# Visualize the orthogonality
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
ax = axes[0]
im = ax.imshow(results, cmap='viridis', aspect='auto')
ax.set_xticks(range(len(seq_lens)))
ax.set_xticklabels(seq_lens)
ax.set_yticks(range(len(widths)))
ax.set_yticklabels(widths)
ax.set_xlabel('Sequence Length T')
ax.set_ylabel('Hidden Width n')
ax.set_title(f'Gradient Flow Ratio (α={alpha})')
plt.colorbar(im, ax=ax, label='Gradient Ratio')

for i in range(len(widths)):
    for j in range(len(seq_lens)):
        ax.text(j, i, f'{results[i,j]:.2f}', ha='center', va='center', 
                color='white', fontsize=10, fontweight='bold')

# Line plot
ax = axes[1]
for i, n in enumerate(widths):
    ax.semilogy(seq_lens, results[i, :], 'o-', label=f'n={n}', linewidth=2, markersize=8)

theoretical = [alpha**T for T in seq_lens]
ax.semilogy(seq_lens, theoretical, 'k--', linewidth=3, label=f'Theory: {alpha}^T')

ax.set_xlabel('Sequence Length T')
ax.set_ylabel('Gradient Flow Ratio (log scale)')
ax.set_title('All widths follow the same α^T decay')
ax.legend()

plt.tight_layout()
plt.show()

---

## Question 5.2: Interpreting the Heatmap

**A)** Looking across each row (fixed T, varying n), what do you observe about the gradient ratios?

**Your Answer:** _______________

**B)** Looking down each column (fixed n, varying T), what pattern do you see?

**Your Answer:** _______________

**C)** Why is this orthogonality useful for hyperparameter tuning?

**Your Answer:** _______________

# Part 6: Learning Rate Transfer — The Core Benefit of μP

The fundamental promise of μP is **zero-shot hyperparameter transfer**: tune your learning rate on a small model, then use the *same* LR on a large model.

## Why SP Fails at LR Transfer

With Standard Parameterization:
- Optimal learning rate **changes** with width
- Typically, optimal LR $\propto 1/\sqrt{n}$ or smaller
- Must re-tune hyperparameters for each model size

## Why μP Succeeds

With μP's adaptive learning rates:
- Hidden→hidden weights get LR multiplier: $1/\text{width\_mult}$
- This compensates for the larger number of neurons
- Result: **Same effective learning rate works across all widths**

Let's empirically verify this with learning rate sweeps.

First, lets look at what MuP tells us the learning rate multipliers should be. Fill in the blanks

| Layer Type | LR Multiplier |
|------------|-------------------|
| Input → Hidden | TODO |
| Hidden → Hidden | TODO |
| Hidden → Output | TODO |

*Task:* Implement the learning rate scaler

In [None]:
 def get_lr_multipliers(self):
    """
    μP learning rate multipliers.
    
    For hidden→hidden weights, we scale LR by 1/width_mult
    to ensure relative updates are O(1).
    """
    # TODO: Complete the learning rate multipliers
    return {
        'W_xh': ...,                    # Input weights: no scaling
        'W_hh': ...,                    # Hidden weights: scale down
        'b': ...,                       # Biases: no scaling
    }

MuPRNNLayer.get_lr_multipliers = get_lr_multipliers

In [None]:
def simple_sequence_task(batch_size=32, seq_len=20, input_dim=8):
    """
    Generate a simple sequence task: predict cumulative sum.
    Input: random noise, Target: running average
    """
    x = th.randn(batch_size, seq_len, input_dim)
    # Target is cumulative mean at each timestep
    cumsum = th.cumsum(x.mean(dim=2, keepdim=True), dim=1)
    targets = cumsum / th.arange(1, seq_len + 1).view(1, -1, 1)
    return x, targets


def train_rnn_with_lr(rnn_class, hidden_size, lr, num_steps=100, 
                      base_hidden_size=64, use_mup_lr_scaling=False):
    """
    Train an RNN with a specific learning rate.
    
    Args:
        rnn_class: RNNLayer class to instantiate
        hidden_size: Hidden dimension
        lr: Base learning rate
        num_steps: Number of training steps
        base_hidden_size: Base width for μP scaling
        use_mup_lr_scaling: If True, apply μP LR multipliers to different params
    
    Returns:
        final_loss: Loss after training
        loss_curve: List of losses during training
    """
    if rnn_class == MuPRNNLayer:
        rnn = rnn_class(input_size=8, hidden_size=hidden_size, 
                       base_hidden_size=base_hidden_size)
    else:
        rnn = rnn_class(input_size=8, hidden_size=hidden_size)
    
    readout = nn.Linear(hidden_size, 1)
    
    # Setup optimizer with μP LR scaling if requested
    if use_mup_lr_scaling and hasattr(rnn, 'get_lr_multipliers'):
        lr_mults = rnn.get_lr_multipliers()
        param_groups = []
        
        for name, param in rnn.named_parameters():
            mult = lr_mults.get(name, 1.0)
            param_groups.append({'params': [param], 'lr': lr * mult})
        
        # Readout layer: μP uses 1/width_mult for output layers
        width_mult = hidden_size / base_hidden_size
        param_groups.append({'params': readout.parameters(), 'lr': lr / width_mult})
        
        optimizer = th.optim.Adam(param_groups)
    else:
        # Standard: same LR for all parameters
        optimizer = th.optim.Adam(list(rnn.parameters()) + list(readout.parameters()), lr=lr)
    
    losses = []
    
    for step in range(num_steps):
        x, y = simple_sequence_task(batch_size=32, seq_len=20, input_dim=8)
        
        h_seq, _ = rnn(x)
        pred = readout(h_seq)
        
        loss = F.mse_loss(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
    
    return losses[-1], losses


def run_lr_sweep(rnn_class, widths, learning_rates, num_steps=100, 
                 use_mup_lr_scaling=False, base_hidden_size=64):
    """
    Run learning rate sweep for different widths.
    
    Returns:
        results: Dict[width][lr] = (final_loss, loss_curve)
    """
    results = {}
    
    for width in widths:
        print(f"  Width {width}...", end='', flush=True)
        results[width] = {}
        
        for lr in learning_rates:
            th.manual_seed(42)  # Fixed seed for fair comparison
            np.random.seed(42)
            
            final_loss, loss_curve = train_rnn_with_lr(
                rnn_class, width, lr, num_steps,
                base_hidden_size=base_hidden_size,
                use_mup_lr_scaling=use_mup_lr_scaling
            )
            results[width][lr] = (final_loss, loss_curve)
        
        print(" done")
    
    return results


# Run experiments
widths = [32, 64, 128, 256, 512]
learning_rates = [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1]
num_steps = 150
base_hidden_size = 64

print("Running LR sweep for Standard Parameterization...")
sp_results = run_lr_sweep(
    StandardRNNLayer, widths, learning_rates, num_steps,
    use_mup_lr_scaling=False
)

print("\nRunning LR sweep for μP (with adaptive LR scaling)...")
mup_results = run_lr_sweep(
    MuPRNNLayer, widths, learning_rates, num_steps,
    use_mup_lr_scaling=True, base_hidden_size=base_hidden_size
)

print("\nDone!")

In [None]:
# Visualize LR sweep results
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)

# Color map for widths
colors = plt.cm.viridis(np.linspace(0, 0.8, len(widths)))

# ===== Top Row: Loss vs LR curves =====
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])

for i, width in enumerate(widths):
    sp_losses = [sp_results[width][lr][0] for lr in learning_rates]
    mup_losses = [mup_results[width][lr][0] for lr in learning_rates]
    
    ax1.plot(learning_rates, sp_losses, 'o-', color=colors[i], 
             label=f'n={width}', linewidth=2, markersize=6)
    ax2.plot(learning_rates, mup_losses, 's-', color=colors[i], 
             label=f'n={width}', linewidth=2, markersize=6)

ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_xlabel('Learning Rate', fontsize=11)
ax1.set_ylabel('Final Loss', fontsize=11)
ax1.set_title('Standard Parameterization\n(Optimal LR shifts with width)', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_xlabel('Learning Rate', fontsize=11)
ax2.set_ylabel('Final Loss', fontsize=11)
ax2.set_title('μP with Adaptive LR Scaling\n(Optimal LR constant across width)', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

# ===== Middle Row: Optimal LR vs Width =====
ax3 = fig.add_subplot(gs[1, :])

# Find optimal LR for each width
sp_optimal_lrs = []
mup_optimal_lrs = []

for width in widths:
    sp_losses = [sp_results[width][lr][0] for lr in learning_rates]
    mup_losses = [mup_results[width][lr][0] for lr in learning_rates]
    
    sp_optimal_lrs.append(learning_rates[np.argmin(sp_losses)])
    mup_optimal_lrs.append(learning_rates[np.argmin(mup_losses)])

ax3.plot(widths, sp_optimal_lrs, 'o-', color='#e74c3c', 
         label='Standard Param', linewidth=3, markersize=10)
ax3.plot(widths, mup_optimal_lrs, 's-', color='#27ae60', 
         label='μP (adaptive LR)', linewidth=3, markersize=10)

# Add theoretical SP line (optimal LR ∝ 1/√n)
sp_theory = [sp_optimal_lrs[0] * np.sqrt(widths[0] / w) for w in widths]
ax3.plot(widths, sp_theory, '--', color='#e74c3c', alpha=0.5, 
         linewidth=2, label='SP Theory (∝ 1/√n)')

ax3.axhline(y=mup_optimal_lrs[0], color='#27ae60', linestyle='--', 
            alpha=0.5, linewidth=2, label='μP Theory (constant)')

ax3.set_xscale('log', base=2)
ax3.set_yscale('log')
ax3.set_xlabel('Hidden Width n', fontsize=12)
ax3.set_ylabel('Optimal Learning Rate', fontsize=12)
ax3.set_title('Optimal LR vs Width: SP Requires Retuning, μP Transfers', 
              fontsize=13, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)

# ===== Bottom Row: Training curves for specific widths =====
ax4 = fig.add_subplot(gs[2, 0])
ax5 = fig.add_subplot(gs[2, 1])

# Show training curves for smallest and largest width at their optimal LRs
small_width, large_width = widths[0], widths[-1]
sp_small_lr = sp_optimal_lrs[0]
sp_large_lr = sp_optimal_lrs[-1]
mup_optimal_lr = mup_optimal_lrs[0]  # Same for all widths!

_, sp_small_curve = sp_results[small_width][sp_small_lr]
_, sp_large_curve = sp_results[large_width][sp_large_lr]

ax4.semilogy(sp_small_curve, color='#3498db', linewidth=2, 
             label=f'n={small_width}, LR={sp_small_lr:.4f}')
ax4.semilogy(sp_large_curve, color='#e74c3c', linewidth=2, 
             label=f'n={large_width}, LR={sp_large_lr:.4f}')
ax4.set_xlabel('Training Step', fontsize=11)
ax4.set_ylabel('Loss', fontsize=11)
ax4.set_title('Standard Param: Different LRs per Width', fontsize=12, fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3)

_, mup_small_curve = mup_results[small_width][mup_optimal_lr]
_, mup_large_curve = mup_results[large_width][mup_optimal_lr]

ax5.semilogy(mup_small_curve, color='#3498db', linewidth=2, 
             label=f'n={small_width}, LR={mup_optimal_lr:.4f}')
ax5.semilogy(mup_large_curve, color='#27ae60', linewidth=2, 
             label=f'n={large_width}, LR={mup_optimal_lr:.4f}')
ax5.set_xlabel('Training Step', fontsize=11)
ax5.set_ylabel('Loss', fontsize=11)
ax5.set_title('μP: Same LR Works for All Widths!', fontsize=12, fontweight='bold')
ax5.legend()
ax5.grid(True, alpha=0.3)

plt.show()

# Print summary statistics
print("=" * 60)
print("LEARNING RATE TRANSFER SUMMARY")
print("=" * 60)
print(f"\n{'Width':>6} | {'SP Optimal LR':>15} | {'μP Optimal LR':>15}")
print("-" * 60)
for i, width in enumerate(widths):
    print(f"{width:>6} | {sp_optimal_lrs[i]:>15.6f} | {mup_optimal_lrs[i]:>15.6f}")

print(f"\n{'Metric':>30} | {'SP':>12} | {'μP':>12}")
print("-" * 60)
print(f"{'LR range (max/min)':>30} | {max(sp_optimal_lrs)/min(sp_optimal_lrs):>12.2f}x | {max(mup_optimal_lrs)/min(mup_optimal_lrs):>12.2f}x")
print(f"{'LR std deviation':>30} | {np.std(sp_optimal_lrs):>12.6f} | {np.std(mup_optimal_lrs):>12.6f}")

print("\n" + "=" * 60)
print("KEY INSIGHT: μP's adaptive LR scaling enables zero-shot transfer!")
print("The same base LR works across all widths with μP.")
print("=" * 60)

We should see that MuP leads to better learning rate transfer!