# Nested Learning: Continuum Memory System (CMS) Demo

This notebook demonstrates the **Tier 1** implementation of Nested Learning from the Google paper.

## Key Concepts

The Continuum Memory System implements **multi-timescale learning** where different parameter groups update at different frequencies:

- **Level 1 (Fast)**: Updates every step â†’ captures short-term patterns
- **Level 2 (Medium)**: Updates every 16 steps â†’ captures mid-term patterns  
- **Level 3 (Slow)**: Updates every 256 steps â†’ captures long-term structure

This is achieved through:
1. **Step-aligned gradient accumulation**: Native PyTorch gradient accumulation
2. **Selective gradient zeroing**: Only zero gradients for levels that just updated
3. **Learning rate scaling**: Scale LR by 1/chunk_size to compensate for accumulated gradients

In [None]:
import sys
sys.path.insert(0, '../src')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

from model import NestedModel
from scheduler import ChunkedUpdateScheduler
from utils import setup_optimizers, set_seed, create_dummy_data

# Set random seed
set_seed(42)

print("âœ“ Imports successful")

## 1. Initialize the Model

The `NestedModel` is a standard PyTorch module with parameters organized into three levels.

In [None]:
# Create model
model = NestedModel(input_size=768, hidden_size=3072)
model.print_model_info()

# Verify levels are accessible
print("\nLevel names:", model.get_level_names())

## 2. Initialize the Scheduler

The `ChunkedUpdateScheduler` orchestrates when each level updates using **step-aligned logic**.

In [None]:
# Define chunk sizes (update frequencies)
chunk_sizes = {
    "level1_fast": 1,      # Updates every step
    "level2_medium": 16,   # Updates every 16 steps
    "level3_slow": 256,    # Updates every 256 steps
}

scheduler = ChunkedUpdateScheduler(chunk_sizes)

### Visualize Update Schedule

Let's see when each level updates over the first 300 steps:

In [None]:
# Simulate which levels update at each step
steps = range(1, 301)
update_pattern = {level: [] for level in chunk_sizes.keys()}

for step in steps:
    for level_name in chunk_sizes.keys():
        if scheduler.should_update(level_name, step):
            update_pattern[level_name].append(step)

# Visualize
fig, ax = plt.subplots(figsize=(14, 6))

colors = {'level1_fast': 'red', 'level2_medium': 'blue', 'level3_slow': 'green'}
y_positions = {'level1_fast': 3, 'level2_medium': 2, 'level3_slow': 1}

for level_name, steps_updated in update_pattern.items():
    y = [y_positions[level_name]] * len(steps_updated)
    ax.scatter(steps_updated, y, c=colors[level_name], s=10, alpha=0.6, label=level_name)

ax.set_yticks([1, 2, 3])
ax.set_yticklabels(['Slow (256)', 'Medium (16)', 'Fast (1)'])
ax.set_xlabel('Training Step')
ax.set_ylabel('Learning Level')
ax.set_title('Update Schedule: When Each Level Updates')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Updates in first 300 steps:")
for level_name, steps_updated in update_pattern.items():
    print(f"  {level_name:20s}: {len(steps_updated):3d} updates")

## 3. Setup Optimizers with Scaled Learning Rates

**CRITICAL**: Learning rates must be scaled by `1/chunk_size` to compensate for gradient accumulation.

Why? Gradients accumulate (sum) over multiple steps. Without LR scaling, slower levels would take massive steps.

In [None]:
base_lr = 1e-4

optimizers = setup_optimizers(
    model=model,
    chunk_sizes=chunk_sizes,
    base_lr=base_lr,
    optimizer_type="adam",
    weight_decay=0.0
)

### Effective Batch Size

The learning rate scaling makes each level operate as if it's using a different batch size:

| Level | Chunk Size | Scaled LR | Effective Batch Size |
|-------|------------|-----------|----------------------|
| Fast  | 1          | 1e-4      | 32 (base)           |
| Medium| 16         | 6.25e-6   | 512 (32 Ã— 16)       |
| Slow  | 256        | 3.9e-7    | 8192 (32 Ã— 256)     |

In [None]:
# Visualize effective batch sizes
base_batch_size = 32

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Learning rates
levels = list(chunk_sizes.keys())
lrs = [base_lr / chunk_sizes[level] for level in levels]
ax1.bar(range(len(levels)), lrs, color=['red', 'blue', 'green'], alpha=0.7)
ax1.set_xticks(range(len(levels)))
ax1.set_xticklabels(['Fast', 'Medium', 'Slow'])
ax1.set_ylabel('Learning Rate')
ax1.set_title('Scaled Learning Rates')
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3)

# Effective batch sizes
eff_batch_sizes = [base_batch_size * chunk_sizes[level] for level in levels]
ax2.bar(range(len(levels)), eff_batch_sizes, color=['red', 'blue', 'green'], alpha=0.7)
ax2.set_xticks(range(len(levels)))
ax2.set_xticklabels(['Fast', 'Medium', 'Slow'])
ax2.set_ylabel('Effective Batch Size')
ax2.set_title('Effective Batch Sizes (via Gradient Accumulation)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Training Loop Demo

Let's run a small training loop to see the system in action.

In [None]:
# Setup
device = torch.device('cpu')
model.to(device)
criterion = nn.MSELoss()

# Track gradient norms and updates
grad_norms = defaultdict(list)
update_steps = defaultdict(list)

num_steps = 300
print(f"Running {num_steps} training steps...\n")

for global_step in range(1, num_steps + 1):
    # Generate dummy data
    data, targets = create_dummy_data(
        batch_size=32,
        seq_length=128,
        input_size=768,
        device=device
    )
    
    # Forward & Backward
    output = model(data)
    loss = criterion(output, data)
    loss.backward()
    
    # Track gradient norms BEFORE potential zeroing
    for level_name, module in model.levels.items():
        grad_norm = 0.0
        for p in module.parameters():
            if p.grad is not None:
                grad_norm += p.grad.norm().item() ** 2
        grad_norms[level_name].append(np.sqrt(grad_norm))
    
    # Selective Update & Gradient Zeroing
    for level_name, module in model.levels.items():
        if scheduler.should_update(level_name, global_step):
            optimizers[level_name].step()
            scheduler.mark_updated(level_name, global_step)
            update_steps[level_name].append(global_step)
            
            # Zero only this level's gradients
            for p in module.parameters():
                if p.grad is not None:
                    p.grad.zero_()
    
    if global_step % 100 == 0:
        print(f"Step {global_step}/{num_steps} - Loss: {loss.item():.4f}")

print("\nâœ“ Training complete!")

## 5. Visualize Gradient Accumulation

The key insight: gradient norms grow for slower levels as they accumulate, then reset after updates.

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True)

level_info = [
    ('level1_fast', 'Fast (updates every step)', 'red', axes[0]),
    ('level2_medium', 'Medium (updates every 16 steps)', 'blue', axes[1]),
    ('level3_slow', 'Slow (updates every 256 steps)', 'green', axes[2]),
]

for level_name, title, color, ax in level_info:
    steps = range(1, len(grad_norms[level_name]) + 1)
    
    # Plot gradient norms
    ax.plot(steps, grad_norms[level_name], color=color, alpha=0.7, linewidth=1)
    
    # Mark update points
    if level_name in update_steps:
        update_y = [grad_norms[level_name][s-1] for s in update_steps[level_name]]
        ax.scatter(update_steps[level_name], update_y, 
                  c='black', s=30, marker='v', zorder=5, 
                  label='Update & Zero Grad')
    
    ax.set_ylabel('Gradient Norm')
    ax.set_title(f'{title}')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel('Training Step')
fig.suptitle('Gradient Accumulation Patterns Across Levels', fontsize=14, y=1.00)
plt.tight_layout()
plt.show()

print("Notice how:")
print("- Fast level: Gradients stay small (zeroed every step)")
print("- Medium level: Gradients accumulate for 16 steps, then reset")
print("- Slow level: Gradients accumulate for 256 steps, then reset")

## 6. Update Statistics

Let's verify the scheduler tracked updates correctly.

In [None]:
scheduler.print_stats(num_steps)

# Verify the math
print("\nVerification:")
for level_name, chunk_size in chunk_sizes.items():
    expected = num_steps // chunk_size
    actual = scheduler.get_update_count(level_name)
    status = "âœ“" if expected == actual else "âœ—"
    print(f"{status} {level_name:20s}: Expected {expected}, Got {actual}")

## 7. Memory Efficiency

**Key Advantage**: We don't store separate gradient buffers for each level.

- Single forward/backward pass computes gradients for ALL parameters
- Gradients naturally accumulate in `.grad` attributes  
- Selective zeroing enables different timescales
- Memory overhead: **~0%** compared to standard training

This is why CMS is practical even for large models!

## Summary: Key Concepts

### âœ“ What We Implemented (Tier 1: CMS)

1. **NestedModel**: Standard PyTorch module with parameters grouped into levels
2. **ChunkedUpdateScheduler**: Step-aligned update logic
3. **Scaled Learning Rates**: LR Ã— (1/chunk_size) for each level
4. **Selective Gradient Zeroing**: Only zero gradients for levels that updated
5. **Memory Efficient**: Uses native PyTorch gradient accumulation

### ðŸŽ¯ Core Principles

- **Multi-timescale learning**: Different parameters learn at different speeds
- **Step-aligned updates**: Updates at specific multiples (16, 32, 48...)
- **No computational overhead**: Single forward/backward pass
- **Faithful to paper**: Mathematically correct gradient accumulation

### ðŸš€ Next Steps: Tier 2 & 3

- **Tier 2**: Add auxiliary losses on intermediate activations
- **Tier 3**: Implement GABAL (learned learning rates)

The current architecture is designed for easy extension!