In [None]:
# Cell 1: Imports and Setup
import logging
from pathlib import Path
import sys
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from config.experiment_config import create_experiment_config
from experiments.traditional import TraditionalExperiment
from models.data import get_dataset
from models.factory import get_model
from utils.logging import setup_logging, get_logger

logger = get_logger(__name__)
setup_logging()

# Enable CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
# Cell 2: Configuration
CHECKPOINT_PATH = "checkpoints/wideresnet/wideresnet_best.pt"
EPOCHS = 200
BATCH_SIZE = 128
NUM_WORKERS = 8

# Create checkpoint directory
Path(CHECKPOINT_PATH).parent.mkdir(parents=True, exist_ok=True)

# Training transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    )
])

# Test transforms
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    )
])

In [None]:
# Cell 3: Training Visualizer
class TrainingVisualizer:
    def __init__(self, epochs):
        self.epochs = epochs
        self.train_losses = []
        self.train_accs = []
        self.val_accs = []
        self.lrs = []
        
        # Create the figure and axes
        plt.ion()
        self.fig, (self.ax1, self.ax2, self.ax3) = plt.subplots(1, 3, figsize=(15, 5))
        self.setup_plots()
        
    def setup_plots(self):
        # Loss plot
        self.ax1.set_title('Training Loss')
        self.ax1.set_xlabel('Epoch')
        self.ax1.set_ylabel('Loss')
        self.ax1.grid(True)
        
        # Accuracy plot
        self.ax2.set_title('Accuracy')
        self.ax2.set_xlabel('Epoch')
        self.ax2.set_ylabel('Accuracy (%)')
        self.ax2.grid(True)
        
        # Learning rate plot
        self.ax3.set_title('Learning Rate')
        self.ax3.set_xlabel('Epoch')
        self.ax3.set_ylabel('Learning Rate')
        self.ax3.grid(True)
        
        plt.tight_layout()
    
    def update(self, epoch, train_loss, train_acc, val_acc, lr):
        self.train_losses.append(train_loss)
        self.train_accs.append(train_acc)
        self.val_accs.append(val_acc)
        self.lrs.append(lr)
        
        epochs = list(range(1, len(self.train_losses) + 1))
        
        # Update loss plot
        self.ax1.clear()
        self.ax1.plot(epochs, self.train_losses)
        self.ax1.set_title('Training Loss')
        self.ax1.set_xlabel('Epoch')
        self.ax1.set_ylabel('Loss')
        self.ax1.grid(True)
        
        # Update accuracy plot
        self.ax2.clear()
        self.ax2.plot(epochs, self.train_accs, label='Train')
        self.ax2.plot(epochs, self.val_accs, label='Validation')
        self.ax2.set_title('Accuracy')
        self.ax2.set_xlabel('Epoch')
        self.ax2.set_ylabel('Accuracy (%)')
        self.ax2.legend()
        self.ax2.grid(True)
        
        # Update learning rate plot
        self.ax3.clear()
        self.ax3.plot(epochs, self.lrs)
        self.ax3.set_title('Learning Rate')
        self.ax3.set_xlabel('Epoch')
        self.ax3.set_ylabel('Learning Rate')
        self.ax3.set_yscale('log')
        self.ax3.grid(True)
        
        plt.tight_layout()
        plt.draw()
        plt.pause(0.1)
    
    def save(self, path):
        plt.savefig(path)
        
    def close(self):
        plt.close()


In [None]:
# Cell 4: Training Function
def train_model_interactive(model, train_loader, val_loader, epochs=200, resume=True):
    """Train model with real-time visualization."""
    logger.info("Starting interactive training...")
    
    # Enable cuDNN benchmarking
    cudnn.benchmark = True
    
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
        nesterov=True
    )
    
    milestones = [60, 120, 160]
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=milestones,
        gamma=0.2
    )
    
    # Try to load checkpoint
    start_epoch = 0
    best_acc = 0.0
    if resume and Path(CHECKPOINT_PATH).exists():
        logger.info(f"Resuming from checkpoint: {CHECKPOINT_PATH}")
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=True)
        if isinstance(checkpoint, dict) and 'epoch' in checkpoint:
            start_epoch = checkpoint['epoch'] + 1
            best_acc = checkpoint.get('acc', 0.0)
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            elif 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                model.load_state_dict(checkpoint)
            if 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if 'scheduler_state_dict' in checkpoint:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            logger.info(f"Resuming from epoch {start_epoch} with best accuracy: {best_acc:.2f}%")
    
    model = model.to(device)
    model = torch.nn.DataParallel(model)
    
    scaler = torch.cuda.amp.GradScaler()
    visualizer = TrainingVisualizer(epochs)
    
    try:
        for epoch in range(start_epoch, epochs):
            # Training
            model.train()
            train_loss = 0
            correct = 0
            total = 0
            
            pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
            for inputs, targets in pbar:
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                
                optimizer.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                current_lr = optimizer.param_groups[0]['lr']
                pbar.set_postfix({
                    'loss': f'{train_loss/total:.3f}',
                    'acc': f'{100.*correct/total:.2f}%',
                    'lr': f'{current_lr:.3e}'
                })
            
            train_acc = 100.*correct/total
            
            # Validation
            model.eval()
            correct = 0
            total = 0
            
            with torch.no_grad(), torch.cuda.amp.autocast():
                for inputs, targets in val_loader:
                    inputs = inputs.to(device, non_blocking=True)
                    targets = targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    
                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()
            
            val_acc = 100.*correct/total
            logger.info(f'Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}% (lr={current_lr:.3e})')
            
            # Update visualization
            visualizer.update(epoch+1, train_loss/len(train_loader), train_acc, val_acc, current_lr)
            
            # Save checkpoint if best accuracy
            if val_acc > best_acc:
                logger.info(f'Saving checkpoint... ({val_acc:.2f}%)')
                best_acc = val_acc
                state = {
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'acc': val_acc,
                    'epoch': epoch,
                }
                torch.save(state, CHECKPOINT_PATH)
            
            scheduler.step()
        
        # Save final plot
        visualizer.save('training_curves.png')
        
    except KeyboardInterrupt:
        logger.info('Training interrupted by user')
    finally:
        visualizer.close()
    
    logger.info(f'Training completed. Best accuracy: {best_acc:.2f}%')
    return best_acc > 50


In [None]:
# Cell 5: Setup Function
def setup_training(resume=True, subset_size=None):
    """Set up datasets and model for training."""
    # Get datasets
    train_dataset = get_dataset("cifar100", train=True, transform=transform_train)
    val_dataset = get_dataset("cifar100", train=False, transform=transform_test)
    
    # Apply subset if specified
    if subset_size:
        train_dataset = torch.utils.data.Subset(
            train_dataset, 
            range(min(subset_size, len(train_dataset)))
        )
        val_dataset = torch.utils.data.Subset(
            val_dataset,
            range(min(subset_size // 5, len(val_dataset)))
        )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
        persistent_workers=True,
        prefetch_factor=3
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE*2,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=3
    )
    
    # Create model
    model = get_model("cifar100", "wrn-28-10")
    
    return model, train_loader, val_loader


In [None]:
# Cell 6: Run Training
# Choose your scenario:

# Scenario 1: Resume training from checkpoint
model, train_loader, val_loader = setup_training(resume=True)
train_model_interactive(model, train_loader, val_loader, resume=True)

# Scenario 2: Train from scratch
# model, train_loader, val_loader = setup_training(resume=False)
# train_model_interactive(model, train_loader, val_loader, resume=False)

# Scenario 3: Train with subset
# model, train_loader, val_loader = setup_training(resume=True, subset_size=10000)
# train_model_interactive(model, train_loader, val_loader, resume=True)