# ResNet18 Finetuning on CIFAR-10

This notebook:
- Takes a ResNet18 network pretrained on ImageNet as base point, then finetune on CIFAR-10
- Uses different finetuning hyperparameters to obtain different model checkpoints
- Follows heDeepResidualLearning2016 training configuration

## Setup Environment

In [1]:
LOCAL = True

# if run locally:
if LOCAL:
    ROOT_DIR = "/Users/Yang/Desktop/research-model-merge/playground/merge_soup-resnet18-cifar10"
    DATA_DIR = "/Users/Yang/Desktop/research-model-merge/datasets"
    PROJECT_ROOT = "/Users/Yang/Desktop/research-model-merge"
else:
    # on Colab
    ROOT_DIR = "/content"
    DATA_DIR = "/content/datasets"
    PROJECT_ROOT = "/content"
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = "drive/MyDrive/research-model_merge"

In [2]:
import os
import sys
import time
from typing import Dict, Any

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights
import numpy as np
from tqdm import tqdm

# Add project root to path
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# Add utils to path
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)

from datasets.cifar10 import CIFAR10

from datasets.cifar10 import CIFAR10

In [3]:
# Check GPU availability and system info
import subprocess

print("🔍 System Information:")
print(f"Python version: {subprocess.check_output(['python', '--version']).decode().strip()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA version: {torch.version.cuda}")
    DEVICE = torch.device("cuda")
else:
    print("⚠️ No GPU available! Training will be slow on CPU.")
    DEVICE = torch.device("cpu")

🔍 System Information:
Python version: Python 3.11.5
PyTorch version: 2.1.2
CUDA available: False
⚠️ No GPU available! Training will be slow on CPU.


## Dataset Preparation

Using the shared CIFAR10 dataset class from `datasets/cifar10.py`:
- Training: 98% of original training set (49,000 images)
- Validation: 2% of original training set (1,000 images)  
- Test: Official CIFAR-10 test set (10,000 images)
- Persistent indices ensure consistent splits across all experiments

In [4]:
# Create CIFAR-10 dataset using shared dataset class
# This uses persistent indices for reproducible splits
dataset = CIFAR10(
    data_location=DATA_DIR,
    batch_size=256,
    num_workers=2
)

train_loader = dataset.train_loader
val_loader = dataset.val_loader
test_loader = dataset.test_loader

print(f"✅ Dataset loaded:")
print(f"   Train samples: {len(dataset.train_sampler)}")
print(f"   Val samples: {len(dataset.val_sampler)}")
print(f"   Test samples: {len(dataset.test_dataset)}")
print(f"   Classnames: {dataset.classnames}")

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
✅ Dataset loaded:
   Train samples: 49000
   Val samples: 1000
   Test samples: 10000
   Classnames: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


## Finetuning Function

In [5]:
def cosine_lr_schedule(optimizer, epoch, total_epochs, warmup_epochs, base_lr):
    """
    Cosine learning rate schedule with linear warmup.
    Following Git Re-Basin configuration.
    """
    if epoch < warmup_epochs:
        # Linear warmup from 1e-6 to base_lr
        lr = 1e-6 + (base_lr - 1e-6) * epoch / warmup_epochs
    else:
        # Cosine decay from base_lr to 0
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        lr = base_lr * 0.5 * (1 + np.cos(np.pi * progress))
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    return lr

In [6]:
def step_lr_schedule(optimizer, epoch, total_epochs, warmup_epochs, base_lr):
    """
    Following heDeepResidualLearning2016. 
    """
    if epoch < warmup_epochs:
        lr = base_lr
    else:
        lr = base_lr * 0.1
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    return lr

In [7]:
def finetune_resnet(
    train_loader: DataLoader,
    val_loader: DataLoader,
    model_save_location: str = '.',
    batch_size: int = 256,
    epochs: int = 10,
    warmup_epochs: int = 5,
    lr: float = 0.1,
    wd: float = 1e-4,
    momentum: float = 0.9,
    name: str = 'config1',
    log_interval: int = 20,
) -> Dict[str, Any]:
    """
    Finetune ResNet18 (pretrained on ImageNet) on CIFAR-10.
    
    Following He 2016 training configuration:
    - SGD optimizer with momentum=0.9
    - Weight decay (default 1e-4)
    - Step LR schedule with 5-epoch warmup
    - Warmup: 1e-6 -> lr over 5 epochs
    - Step decay: lr -> 0.1*lr after warmup
    """
    os.makedirs(model_save_location, exist_ok=True)
    
    # Load pretrained ResNet18 and modify for CIFAR-10
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    
    # Modify first conv layer for 32x32 input (CIFAR-10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    
    # Remove maxpool layer (too aggressive for 32x32 images)
    model.maxpool = nn.Identity()
    
    # Replace final FC layer for CIFAR-10 (10 classes)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    
    model = model.to(DEVICE)
    
    # Optimizer: SGD with momentum
    optimizer = optim.SGD(
        model.parameters(), 
        lr=lr,
        momentum=momentum,
        weight_decay=wd
    )
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_acc': [],
        'lr': []
    }
    
    print(f"\n{'='*80}")
    print(f"Starting training: {name}")
    print(f"Config: lr={lr}, wd={wd}, epochs={epochs}, batch_size={batch_size}")
    print(f"{'='*80}\n")
    
    # Training loop
    for epoch in range(epochs):
        # Update learning rate
        current_lr = step_lr_schedule(optimizer, epoch, epochs, warmup_epochs, lr)
        history['lr'].append(current_lr)
        
        # Training phase
        model.train()
        train_loss_accum = 0.0
        train_batches = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for i, (inputs, labels) in enumerate(pbar):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss_accum += loss.item()
            train_batches += 1
            
            if i % log_interval == 0:
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'lr': f'{current_lr:.6f}'
                })
        
        train_loss = train_loss_accum / train_batches
        history['train_loss'].append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss_accum = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
            for inputs, labels in pbar:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss_accum += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.*correct/total:.2f}%'
                })
        
        val_loss = val_loss_accum / len(val_loader)
        val_acc = correct / total
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"\nEpoch {epoch+1}/{epochs} Summary:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Val Acc:    {100*val_acc:.2f}%")
        print(f"  LR:         {current_lr:.6f}\n")
        
        # Save checkpoint after each epoch
        checkpoint_path = os.path.join(model_save_location, f'{name}_epoch{epoch+1}.pt')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"✅ Saved checkpoint: {checkpoint_path}")
    
    result = {
        'history': history,
        'config': {
            'model_save_location': model_save_location,
            'batch_size': batch_size,
            'epochs': epochs,
            'warmup_epochs': warmup_epochs,
            'lr': lr,
            'wd': wd,
            'momentum': momentum,
            'name': name,
        },
    }
    
    return result

## Run Training with Multiple Configurations

We train 5 different configurations with varying learning rates and weight decay values:

1. **Config 1**: lr=0.1, wd=1e-4 (He 2016 baseline)
2. **Config 2**: lr=0.05, wd=1e-4
3. **Config 3**: lr=0.01, wd=1e-4
4. **Config 4**: lr=0.1, wd=1e-3
5. **Config 5**: lr=0.1, wd=1e-5

In [8]:
# Checkpoint directory
if LOCAL:
    checkpoint_dir = f"{ROOT_DIR}/checkpoints"
else:
    checkpoint_dir = f"{DRIVE_DIR}/checkpoints"

os.makedirs(checkpoint_dir, exist_ok=True)

# Define configurations
configs = [
    dict(lr=0.1, wd=1e-4, name='config1'),
    dict(lr=0.05, wd=1e-4, name='config2'),
    dict(lr=0.01, wd=1e-4, name='config3'),
    dict(lr=0.1, wd=1e-3, name='config4'),
    dict(lr=0.1, wd=1e-5, name='config5'),
]

# Common parameters
common = dict(
    train_loader=train_loader,
    val_loader=val_loader,
    model_save_location=checkpoint_dir,
    batch_size=258,
    epochs=10,
    warmup_epochs=5,
    momentum=0.9,
)

In [None]:
# Run all configurations
results = []

for config in configs:
    run_config = {**common, **config}
    print(f"\n{'#'*80}")
    print(f"Running configuration: {config['name']}")
    print(f"  LR: {config['lr']}, WD: {config['wd']}")
    print(f"{'#'*80}\n")
    
    result = finetune_resnet(**run_config)
    results.append(result)
    
    print(f"\n✅ {config['name']} completed!\n")

print("\n" + "="*80)
print("All configurations completed!")
print("="*80)


################################################################################
Running configuration: config1
  LR: 0.1, WD: 0.0001
################################################################################


Starting training: config1
Config: lr=0.1, wd=0.0001, epochs=10, batch_size=258



Epoch 1/10 [Train]:   2%|▏         | 3/192 [00:21<21:21,  6.78s/it, loss=2.5593, lr=0.100000]

## Summary of Results

In [None]:
import pandas as pd

# Create summary table
summary = []
for r in results:
    cfg = r['config']
    hist = r['history']
    summary.append({
        'name': cfg['name'],
        'lr': cfg['lr'],
        'wd': cfg['wd'],
        'final_train_loss': hist['train_loss'][-1],
        'final_val_loss': hist['val_loss'][-1],
        'final_val_acc': f"{100*hist['val_acc'][-1]:.2f}%",
        'best_val_acc': f"{100*max(hist['val_acc']):.2f}%",
    })

df = pd.DataFrame(summary)
print("\n" + "="*80)
print("Training Summary")
print("="*80)
print(df.to_string(index=False))
print("="*80)

# Save summary
df.to_csv(f"{checkpoint_dir}/training_summary.csv", index=False)
print(f"\n✅ Summary saved to {checkpoint_dir}/training_summary.csv")

## Plot Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Training Loss
for r in results:
    axes[0, 0].plot(r['history']['train_loss'], label=r['config']['name'])
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Train Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Plot 2: Validation Loss
for r in results:
    axes[0, 1].plot(r['history']['val_loss'], label=r['config']['name'])
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Validation Loss')
axes[0, 1].set_title('Validation Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Plot 3: Validation Accuracy
for r in results:
    axes[1, 0].plot([100*x for x in r['history']['val_acc']], label=r['config']['name'])
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Validation Accuracy (%)')
axes[1, 0].set_title('Validation Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Plot 4: Learning Rate
for r in results:
    axes[1, 1].plot(r['history']['lr'], label=r['config']['name'])
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].legend()
axes[1, 1].grid(True)
axes[1, 1].set_yscale('log')

plt.tight_layout()
plt.savefig(f"{checkpoint_dir}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to {checkpoint_dir}/training_curves.png")