# Learning Rate Sweep with Early Stopping

This notebook demonstrates how to use `torch-batteries` to perform a learning rate sweep with eager early stopping.

**Research Question**: What learning rate achieves the fastest convergence with early stopping?

**Experiment Design**:
- Train models with different learning rates (1e-4, 5e-4, 1e-3, 5e-3, 1e-2)
- Use aggressive early stopping (patience=3) to quickly identify poor LRs
- Track all metrics automatically to Weights & Biases
- Compare convergence speed and final accuracy

**What gets tracked**:
- Training and validation metrics (loss, accuracy)
- Hyperparameters (learning rate, batch size, patience)
- When training stopped
- Whether early stopping was triggered

## Setup

In [2]:
# Install dependencies (only when running on Google Colab)
try:
    import google.colab  # type: ignore
    !pip install torch-batteries[example,wandb]
except:
    pass

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import wandb
from IPython.display import clear_output
from pprint import pprint
from copy import deepcopy

import torch_batteries
from torch_batteries.utils.device import get_device
from torch_batteries import Battery, Event, charge
from torch_batteries.callbacks import EarlyStopping, ExperimentTrackingCallback
from torch_batteries.tracking.wandb import WandbTracker
from torch_batteries.tracking import Run
from torch_batteries.events.core import EventContext

print(f"torch-batteries version: {torch_batteries.__version__}")
print(f"PyTorch version: {torch.__version__}")

print(f"Is CUDA available?: {get_device()}")

torch-batteries version: 0.5.0
PyTorch version: 2.9.1+cu128
Is CUDA available?: cuda


In [14]:
wandb_project_name = input("Enter your wandb project name (default: 'torch-batteries-integration'): ").strip() or "torch-batteries-integration"

# Wandb entity (optional)
wandb_entity = input("Enter your wandb entity (username/team) or press Enter to skip: ").strip() or None

print(f"\nW&B Settings:")
print(f"  Project: {wandb_project_name}")
print(f"  Entity: {wandb_entity if wandb_entity else 'default'}")


W&B Settings:
  Project: torch-batteries-integration
  Entity: default


In [15]:
wandb.login()
clear_output()

## 1. Prepare Data

Load MNIST dataset with train/validation split.

In [16]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

full_train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [0.9, 0.1])

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 54000
Validation samples: 6000


## 2. Define the Model

Simple CNN for MNIST classification with training and validation steps.

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        return self.net(x)

    @charge(Event.TRAIN_STEP)
    def training_step(self, context: EventContext):
        x, y = context["batch"]
        return F.cross_entropy(self(x), y)

    @charge(Event.VALIDATION_STEP)
    def validation_step(self, context: EventContext):
        x, y = context["batch"]
        return F.cross_entropy(self(x), y)


print(f"Number of parameters: {sum(p.numel() for p in MNISTNet().parameters())}")

Number of parameters: 19146


## 3. Define Base Configuration

In [18]:
BASE_CONFIG = {
    "model": "CNN",
    "tags": ["lr-sweep", "cnn", "mnist"],
    "dataset": "MNIST",
    "batch_size": batch_size,
    "optimizer": "Adam",
    "max_epochs": 20,
    # Aggressive early stopping for quick experiments
    "early_stopping_patience": 3,
    "early_stopping_monitor": "val_accuracy",
    "early_stopping_min_delta": 0.1,
    "learning_rate": None,  # To be set per run
}

print("Base Configuration:")
pprint(BASE_CONFIG, sort_dicts=False)

Base Configuration:
{'model': 'CNN',
 'tags': ['lr-sweep', 'cnn', 'mnist'],
 'dataset': 'MNIST',
 'batch_size': 256,
 'optimizer': 'Adam',
 'max_epochs': 20,
 'early_stopping_patience': 3,
 'early_stopping_monitor': 'val_accuracy',
 'early_stopping_min_delta': 0.1,
 'learning_rate': None}


## 4. Run Learning Rate Sweep

Train models with different learning rates where each run will:
1. Initialize a fresh model and a optimizer with a specific learning rate
2. Set up experiment tracking
3. Log all metrics to wandb

In [19]:
learning_rates = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2]

def accuracy(predictions, targets):
    """Calculate accuracy."""
    pred_labels = predictions.argmax(dim=1)
    return (pred_labels == targets).float().mean().item()

metrics = {"accuracy": accuracy}

In [20]:
results = []

for lr in learning_rates:
    print(f"\n{'='*60}")
    print(f"Starting run with learning_rate={lr:.0e}")
    print(f"{'='*60}\n")

    config = deepcopy(BASE_CONFIG)
    config["learning_rate"] = lr
    
    model = MNISTNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    run = Run(
        name=f"lr-{lr:.0e}",
        group="lr-sweep",
        description="Learning rate sweep with early stopping on MNIST dataset using a CNN model",
        config=config,
        job_type="training",
    )
    
    tracker = WandbTracker(project=wandb_project_name, entity=wandb_entity)
    
    callbacks = [
        ExperimentTrackingCallback(
            tracker=tracker,
            run=run,
        ),
        EarlyStopping(
            stage=config["early_stopping_monitor"].split("_")[0],
            metric=config["early_stopping_monitor"].split("_")[1],
            patience=config["early_stopping_patience"],
            min_delta=config["early_stopping_min_delta"],
            mode="max",
            verbose=True,
            restore_best_weights=True,
        )
    ]
    
    battery = Battery(
        model=model,
        optimizer=optimizer,
        metrics=metrics,
        callbacks=callbacks,
    )
    
    result = battery.train(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=config["max_epochs"],
        verbose=0,  # We can see progress in W&B
    )
    
    results.append({
        'lr': lr,
        'train_loss': result["train_loss"][-1],
        'val_loss': result["val_loss"][-1],
        'val_accuracy': result["val_metrics"]['accuracy'][-1],
        'epochs_trained': len(result["train_loss"]),
    })
    
    print(f"\nRun completed:")
    print(f"  Epochs trained: {len(result['train_loss'])}")
    print(f"  Learning rate: {lr:.0e}")
    print(f"  Final train loss: {result['train_loss'][-1]:.4f}")
    print(f"  Final val loss: {result['val_loss'][-1]:.4f}")
    print(f"  Final val accuracy: {result['val_metrics']['accuracy'][-1]:.4f}")


Starting run with learning_rate=1e-04



0,1
train/accuracy,▁▁▁▁▁▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████▇███████
train/epoch,▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆██████
train/loss,█▇▇▇▆▆▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁
val/accuracy,▁▆▇██
val/epoch,▁▃▅▆█
val/loss,█▃▂▁▁

0,1
total_epochs,4.0
total_steps,1055.0
train/accuracy,0.9
train/epoch,4.0
train/loss,0.37273
val/accuracy,0.908
val/epoch,4.0
val/loss,0.31756



Run completed:
  Epochs trained: 5
  Learning rate: 1e-04
  Final train loss: 0.3505
  Final val loss: 0.3176
  Final val accuracy: 0.9080

Starting run with learning_rate=5e-04



0,1
train/accuracy,▁▆▅▆▇▇▇▇▇▇█▇█▇▇▇▇▇████████▇███████████▇█
train/epoch,▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆▆▆███████████
train/loss,█▇▆▄▃▂▂▂▂▁▂▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▆██
val/epoch,▁▃▆█
val/loss,█▃▂▁

0,1
total_epochs,3.0
total_steps,844.0
train/accuracy,0.925
train/epoch,3.0
train/loss,0.23917
val/accuracy,0.9475
val/epoch,3.0
val/loss,0.16722



Run completed:
  Epochs trained: 4
  Learning rate: 5e-04
  Final train loss: 0.1785
  Final val loss: 0.1672
  Final val accuracy: 0.9475

Starting run with learning_rate=1e-03



0,1
train/accuracy,▁▄▅▇▇▇▇▇▇████████████████████████████▇██
train/epoch,▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆███████
train/loss,█▇▆▂▂▂▂▂▂▂▁▂▁▂▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▃▆█
val/epoch,▁▃▆█
val/loss,█▅▃▁

0,1
total_epochs,3.0
total_steps,844.0
train/accuracy,0.97083
train/epoch,3.0
train/loss,0.1021
val/accuracy,0.96617
val/epoch,3.0
val/loss,0.10686



Run completed:
  Epochs trained: 4
  Learning rate: 1e-03
  Final train loss: 0.1340
  Final val loss: 0.1069
  Final val accuracy: 0.9662

Starting run with learning_rate=5e-03



0,1
train/accuracy,▁▅▇▇▇▇▇█▇███████████████████████████████
train/epoch,▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆▆▆█████
train/loss,█▄▃▅▃▃▄▂▃▂▂▂▃▂▃▃▃▂▃▁▂▂▂▂▂▁▂▂▂▁▂▁▁▂▂▂▂▂▁▂
val/accuracy,▁▇▇█
val/epoch,▁▃▆█
val/loss,█▂▃▁

0,1
total_epochs,3.0
total_steps,844.0
train/accuracy,0.99583
train/epoch,3.0
train/loss,0.03016
val/accuracy,0.96967
val/epoch,3.0
val/loss,0.0885



Run completed:
  Epochs trained: 4
  Learning rate: 5e-03
  Final train loss: 0.0836
  Final val loss: 0.0885
  Final val accuracy: 0.9697

Starting run with learning_rate=1e-02



0,1
train/accuracy,▁▃▃▆▇▇█████▇████████████████████████████
train/epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▃▆▆█████████████
train/loss,██▄▂▂▁▂▁▁▁▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▇██
val/epoch,▁▃▆█
val/loss,█▂▁▁

0,1
total_epochs,3.0
total_steps,844.0
train/accuracy,0.99167
train/epoch,3.0
train/loss,0.05733
val/accuracy,0.96683
val/epoch,3.0
val/loss,0.09838



Run completed:
  Epochs trained: 4
  Learning rate: 1e-02
  Final train loss: 0.1077
  Final val loss: 0.0984
  Final val accuracy: 0.9668


## 5. Summary of Results

Quick overview of all runs to identify the best learning rate.

In [21]:
print("\n" + "="*80)
print("LEARNING RATE SWEEP SUMMARY")
print("="*80)
print(f"{'LR':<12} {'Val Loss':<12} {'Val Acc':<12} {'Epochs':<12} {'Status'}")
print("-"*80)

best_lr = None
best_accuracy = 0

for r in results:
    status = "Early Stopped" if r['epochs_trained'] < BASE_CONFIG["max_epochs"] else "Completed"
    print(f"{r['lr']:<12.0e} {r['val_loss']:<12.4f} {r['val_accuracy']:<12.4f} {r['epochs_trained']:<12} {status}")
    
    if r['val_accuracy'] > best_accuracy:
        best_accuracy = r['val_accuracy']
        best_lr = r['lr']

print("="*80)
print(f"\nBest learning rate: {best_lr:.0e}")
print(f"Best validation accuracy: {best_accuracy:.4f}")


LEARNING RATE SWEEP SUMMARY
LR           Val Loss     Val Acc      Epochs       Status
--------------------------------------------------------------------------------
1e-04        0.3176       0.9080       5            Early Stopped
5e-04        0.1672       0.9475       4            Early Stopped
1e-03        0.1069       0.9662       4            Early Stopped
5e-03        0.0885       0.9697       4            Early Stopped
1e-02        0.0984       0.9668       4            Early Stopped

Best learning rate: 5e-03
Best validation accuracy: 0.9697


## 6. View Results in Weights & Biases

After running the experiments, you can:

1. **View the dashboard**: Go to your wandb project page
2. **Compare runs**: See all grouped runs
3. **Analyze metrics**: 
   - Training curves showing convergence speed
   - Comparison of final accuracies across learning rates

**Key insights to look for**:
- Which learning rates converged quickly?
- Which learning rates were stopped early due to poor performance?
- Trade-off between convergence speed and final accuracy