# Learning Rate Sweep with Early Stopping

This notebook demonstrates how to use `torch-batteries` to perform a learning rate sweep with eager early stopping.

**Research Question**: What learning rate achieves the fastest convergence with early stopping?

**Experiment Design**:
- Train models with different learning rates (1e-4, 5e-4, 1e-3, 5e-3, 1e-2)
- Use aggressive early stopping (patience=3) to quickly identify poor LRs
- Track all metrics automatically to Weights & Biases
- Compare convergence speed and final accuracy

**What gets tracked**:
- Training and validation metrics (loss, accuracy)
- Hyperparameters (learning rate, batch size, patience)
- When training stopped
- Whether early stopping was triggered
- Model artifacts (checkpoints)

## Setup

In [1]:
# Install dependencies
# !pip install torch-batteries wandb

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import wandb
from IPython.display import clear_output

from torch_batteries import Battery, Event, charge
from torch_batteries.callbacks import EarlyStopping, ExperimentTrackingCallback
from torch_batteries.tracking import WandbTracker, Project, Experiment, Run
from torch_batteries.events.core import EventContext

print(f"Is CUDA available?: {torch.cuda.is_available()}")

Is CUDA available?: True


In [None]:
project_name = input("Enter your wandb project name (default: 'torch-batteries-integration'): ").strip() or "torch-batteries-integration"

# Wandb entity (optional)
wandb_entity = input("Enter your wandb entity (username/team) or press Enter to skip: ").strip() or None

print(f"\nConfiguration:")
print(f"  Project: {project_name}")
print(f"  Entity: {wandb_entity if wandb_entity else 'default'}")


Configuration:
  Project: torch-batteries-integration
  Entity: default


In [4]:
wandb.login()
clear_output()

## 1. Define the Model

Simple CNN for MNIST classification with training and validation steps.

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        return self.net(x)

    @charge(Event.TRAIN_STEP)
    def training_step(self, context: EventContext):
        x, y = context["batch"]
        return F.cross_entropy(self(x), y)

    @charge(Event.VALIDATION_STEP)
    def validation_step(self, context: EventContext):
        x, y = context["batch"]
        return F.cross_entropy(self(x), y)


print(f"Number of parameters: {sum(p.numel() for p in MNISTNet().parameters())}")

Number of parameters: 19146


## 2. Prepare Data

Load MNIST dataset with train/validation split.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(test_dataset)}")

Training samples: 60000
Validation samples: 10000


## 3. Define Experiment Configuration

Set up the experiment tracking structure:
- **Project**: "lr-sweep-research" (top-level container)
- **Experiment**: "lr-sweep-eager-es" (this specific investigation)
- **Runs**: Multiple runs with different learning rates

In [None]:
project = Project(
    name=project_name,
    description="Learning rate sweep research on MNIST classification"
)

experiment = Experiment(
    name="lr-sweep-eager-es",
    description="Finding optimal learning rate with aggressive early stopping (patience=3)",
    base_config={
        "model": "CNN",
        "dataset": "MNIST",
        "batch_size": batch_size,
        "optimizer": "Adam",
        "max_epochs": 20,
        "early_stopping_patience": 3,
    },
    tags=["lr-sweep", "early-stopping", "cnn", "mnist"]
)

print(f"Project: {project.name}")
print(f"Experiment: {experiment.name}")
print(f"Base config: {experiment.base_config}")

Project: torch-batteries-integration
Experiment: lr-sweep-eager-es
Base config: {'model': 'CNN', 'dataset': 'MNIST', 'batch_size': 256, 'optimizer': 'Adam', 'max_epochs': 20, 'early_stopping_patience': 3}


## 4. Run Learning Rate Sweep

Train models with different learning rates. Each run:
1. Initialize a fresh model with a specific learning rate
2. Set up experiment tracking
3. Train with aggressive early stopping (patience=3, delta=0.1)
4. Log all metrics to wandb

In [None]:
learning_rates = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2]

max_epochs = 20
patience = 3  # Aggressive early stopping

def accuracy(predictions, targets):
    """Calculate accuracy."""
    pred_labels = predictions.argmax(dim=1)
    return (pred_labels == targets).float().mean().item()

metrics = {"accuracy": accuracy}

In [14]:
results = []

for lr in learning_rates:
    print(f"\n{'='*60}")
    print(f"Starting run with learning_rate={lr:.0e}")
    print(f"{'='*60}\n")
    
    model = MNISTNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    run = Run(
        name=f"lr-{lr:.0e}",
        config={
            "learning_rate": lr,
            "patience": patience,
        },
        job_type="train",
    )
    
    tracker = WandbTracker(entity=wandb_entity)
    
    callbacks = [
        ExperimentTrackingCallback(
            tracker=tracker,
            project=project,
            experiment=experiment,
            run=run,
        ),
        EarlyStopping(
            stage="val",
            metric="loss",
            patience=patience,
            min_delta=0.1,
            mode="min",
            verbose=True,
            restore_best_weights=True,
        )
    ]
    
    battery = Battery(
        model=model,
        optimizer=optimizer,
        metrics=metrics,
        callbacks=callbacks,
    )
    
    result = battery.train(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=max_epochs,
    )
    
    results.append({
        'lr': lr,
        'train_loss': result["train_loss"][-1],
        'val_loss': result["val_loss"][-1],
        'val_accuracy': result["val_metrics"]['accuracy'][-1],
    })
    
    print(f"\nRun completed:")
    print(f"  Learning rate: {lr:.0e}")
    print(f"  Final train loss: {result['train_loss'][-1]:.4f}")
    print(f"  Final val loss: {result['val_loss'][-1]:.4f}")
    print(f"  Final val accuracy: {result['val_metrics']['accuracy'][-1]:.4f}")


Starting run with learning_rate=1e-04



Epoch 1/20 [Train]: 100%|██████████| 235/235 [00:26<00:00,  8.72it/s, Loss=2.0636, Accuracy=0.4098]
Epoch 1/20 [Validation]: 100%|██████████| 40/40 [00:02<00:00, 13.51it/s, Loss=1.5402, Accuracy=0.6696]
Epoch 2/20 [Train]: 100%|██████████| 235/235 [00:26<00:00,  8.95it/s, Loss=1.0243, Accuracy=0.7623]
Epoch 2/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 12.23it/s, Loss=0.6638, Accuracy=0.8298]
Epoch 3/20 [Train]: 100%|██████████| 235/235 [00:27<00:00,  8.50it/s, Loss=0.5452, Accuracy=0.8531]
Epoch 3/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 12.46it/s, Loss=0.4236, Accuracy=0.8876]
Epoch 4/20 [Train]: 100%|██████████| 235/235 [00:29<00:00,  8.09it/s, Loss=0.3910, Accuracy=0.8910]
Epoch 4/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.68it/s, Loss=0.3236, Accuracy=0.9115]
Epoch 5/20 [Train]: 100%|██████████| 235/235 [00:30<00:00,  7.74it/s, Loss=0.3154, Accuracy=0.9110]
Epoch 5/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 12.59it/s, Loss=0.2636, Ac

0,1
train/accuracy,▁▁▄▆▅▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██████████████
train/loss,█████▆▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▅▆▇████
val/loss,█▃▂▂▁▁▁▁

0,1
total_epochs,7.0
total_steps,1880.0
train/accuracy,0.97917
train/loss,0.08572
val/accuracy,0.9472
val/loss,0.18911



Run completed:
  Learning rate: 1e-04
  Final train loss: 0.2202
  Final val loss: 0.1891
  Final val accuracy: 0.9472

Starting run with learning_rate=5e-04



Epoch 1/20 [Train]: 100%|██████████| 235/235 [00:28<00:00,  8.27it/s, Loss=0.9582, Accuracy=0.7451]
Epoch 1/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 12.05it/s, Loss=0.2942, Accuracy=0.9157]
Epoch 2/20 [Train]: 100%|██████████| 235/235 [00:27<00:00,  8.50it/s, Loss=0.2515, Accuracy=0.9299]
Epoch 2/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.91it/s, Loss=0.1740, Accuracy=0.9516]
Epoch 3/20 [Train]: 100%|██████████| 235/235 [00:28<00:00,  8.24it/s, Loss=0.1805, Accuracy=0.9484]
Epoch 3/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.17it/s, Loss=0.1432, Accuracy=0.9586]
Epoch 4/20 [Train]: 100%|██████████| 235/235 [00:26<00:00,  8.71it/s, Loss=0.1454, Accuracy=0.9576]
Epoch 4/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.44it/s, Loss=0.1233, Accuracy=0.9630]
Epoch 5/20 [Train]: 100%|██████████| 235/235 [00:24<00:00,  9.51it/s, Loss=0.1264, Accuracy=0.9635]
Epoch 5/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 12.89it/s, Loss=0.1074, Ac

0,1
train/accuracy,▁▆▆▆▇▇▇▇▇▇▇██▇██████████████████████████
train/loss,██▅▃▃▂▂▂▂▂▂▂▂▂▁▁▁▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▆▇▇█
val/loss,█▃▂▂▁

0,1
total_epochs,4.0
total_steps,1175.0
train/accuracy,0.96875
train/loss,0.09536
val/accuracy,0.968
val/loss,0.10743



Run completed:
  Learning rate: 5e-04
  Final train loss: 0.1264
  Final val loss: 0.1074
  Final val accuracy: 0.9680

Starting run with learning_rate=1e-03



Epoch 1/20 [Train]: 100%|██████████| 235/235 [00:29<00:00,  8.00it/s, Loss=0.7549, Accuracy=0.7771]
Epoch 1/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 12.45it/s, Loss=0.2295, Accuracy=0.9319]
Epoch 2/20 [Train]: 100%|██████████| 235/235 [00:28<00:00,  8.15it/s, Loss=0.2114, Accuracy=0.9394]
Epoch 2/20 [Validation]: 100%|██████████| 40/40 [00:02<00:00, 14.13it/s, Loss=0.1697, Accuracy=0.9486]
Epoch 3/20 [Train]: 100%|██████████| 235/235 [00:27<00:00,  8.61it/s, Loss=0.1540, Accuracy=0.9565]
Epoch 3/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.06it/s, Loss=0.1211, Accuracy=0.9633]
Epoch 4/20 [Train]: 100%|██████████| 235/235 [00:26<00:00,  8.98it/s, Loss=0.1279, Accuracy=0.9636]
Epoch 4/20 [Validation]: 100%|██████████| 40/40 [00:02<00:00, 15.08it/s, Loss=0.1005, Accuracy=0.9703]
Epoch 5/20 [Train]: 100%|██████████| 235/235 [00:25<00:00,  9.26it/s, Loss=0.1109, Accuracy=0.9681]
Epoch 5/20 [Validation]: 100%|██████████| 40/40 [00:02<00:00, 14.17it/s, Loss=0.0988, Ac

0,1
train/accuracy,▁▃▄▆▆▇▇▇█▇▇▇▇▇█▇▇████▇██████████▇███████
train/loss,▆█▇█▇▅▅▄▅▆▄▅▆▄▃▄▃▂▃▃▃▄▃▂▃▃▃▃▂▁▃▂▁▁▃▂▁▂▂▃
val/accuracy,▁▄▆▇▇█
val/loss,█▅▃▂▂▁

0,1
total_epochs,5.0
total_steps,1410.0
train/accuracy,0.95833
train/loss,0.11665
val/accuracy,0.9753
val/loss,0.08345



Run completed:
  Learning rate: 1e-03
  Final train loss: 0.0975
  Final val loss: 0.0834
  Final val accuracy: 0.9753

Starting run with learning_rate=5e-03



Epoch 1/20 [Train]: 100%|██████████| 235/235 [00:28<00:00,  8.39it/s, Loss=0.3785, Accuracy=0.8902]
Epoch 1/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.16it/s, Loss=0.1011, Accuracy=0.9691]
Epoch 2/20 [Train]: 100%|██████████| 235/235 [00:27<00:00,  8.48it/s, Loss=0.0995, Accuracy=0.9726]
Epoch 2/20 [Validation]: 100%|██████████| 40/40 [00:02<00:00, 13.73it/s, Loss=0.0738, Accuracy=0.9777]
Epoch 3/20 [Train]: 100%|██████████| 235/235 [00:28<00:00,  8.33it/s, Loss=0.0714, Accuracy=0.9810]
Epoch 3/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 11.53it/s, Loss=0.0511, Accuracy=0.9839]
Epoch 4/20 [Train]: 100%|██████████| 235/235 [00:33<00:00,  7.03it/s, Loss=0.0578, Accuracy=0.9853]
Epoch 4/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 10.83it/s, Loss=0.0550, Accuracy=0.9832]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
train/accuracy,▃▁▄▃▇▅▆█▆▇▇▇▇▇▇██▇▇▇██▇▆▇▆▇█▆▇▇▇▆██▇██▇▇
train/loss,█▅▆▅▃▃▃▃▂▂▂▁▂▂▂▁▁▂▁▂▁▂▂▁▁▂▂▁▁▂▁▁▁▁▂▁▁▁▁▁
val/accuracy,▁▅██
val/loss,█▄▁▂

0,1
total_epochs,3.0
total_steps,940.0
train/accuracy,1.0
train/loss,0.00817
val/accuracy,0.9832
val/loss,0.055



Run completed:
  Learning rate: 5e-03
  Final train loss: 0.0578
  Final val loss: 0.0550
  Final val accuracy: 0.9832

Starting run with learning_rate=1e-02



Epoch 1/20 [Train]: 100%|██████████| 235/235 [00:26<00:00,  8.77it/s, Loss=0.4187, Accuracy=0.8694]
Epoch 1/20 [Validation]: 100%|██████████| 40/40 [00:03<00:00, 13.21it/s, Loss=0.1233, Accuracy=0.9617]
Epoch 2/20 [Train]:  68%|██████▊   | 159/235 [00:18<00:09,  7.63it/s, Loss=0.1078, Accuracy=0.9700]

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f847b3f7df0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f84b06b2810, execution_count=14 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7f84bf6bf9d0, raw_cell="results = []

for lr in learning_rates:
    print(.." transformed_cell="results = []

for lr in learning_rates:
    print(.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu/home/arkadiusz/coding/university/isi2/torch-batteries/notebooks/lr_sweep_early_stopping.ipynb#X20sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost

## 5. Summary of Results

Quick overview of all runs to identify the best learning rate.

In [None]:
# Print summary table
print("\n" + "="*80)
print("LEARNING RATE SWEEP SUMMARY")
print("="*80)
print(f"{'LR':<12} {'Val Loss':<12} {'Val Acc':<12} {'Epochs':<12} {'Status'}")
print("-"*80)

best_lr = None
best_accuracy = 0

for r in results:
    status = "Early Stopped" if r['epochs_trained'] < max_epochs else "Completed"
    print(f"{r['lr']:<12.0e} {r['val_loss']:<12.4f} {r['val_accuracy']:<12.4f} {r['epochs_trained']:<12} {status}")
    
    if r['val_accuracy'] > best_accuracy:
        best_accuracy = r['val_accuracy']
        best_lr = r['lr']

print("="*80)
print(f"\nBest learning rate: {best_lr:.0e}")
print(f"Best validation accuracy: {best_accuracy:.4f}")

## 6. View Results in Weights & Biases

After running the experiments, you can:

1. **View the dashboard**: Go to your wandb project page
2. **Compare runs**: See all runs grouped by experiment
3. **Analyze metrics**: 
   - Training curves showing convergence speed
   - Comparison of final accuracies across learning rates
   - Early stopping events
4. **Download artifacts**: Access saved model checkpoints

**Key insights to look for**:
- Which learning rates converged quickly?
- Which learning rates were stopped early due to poor performance?
- What's the optimal learning rate for this problem?
- Trade-off between convergence speed and final accuracy