# CIFAR-100 Training Notebook
This notebook runs the CIFAR-100 training with detailed logging and visualization.

In [None]:
import sys
sys.path.append('.')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import yaml
import logging

from models import get_model, get_dataset
from experiment.experiment import Trainer
from utils.device import get_device

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
# Load configuration
with open('experiments/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Get CIFAR-100 specific config
cifar_config = {**config['defaults'], **config['dataset_defaults']['cifar100']}
print("Training configuration:")
for k, v in cifar_config.items():
    print(f"{k}: {v}")

In [None]:
# Setup device and model
device = get_device()
print(f"Using device: {device}")

model = get_model('cifar100').to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=cifar_config['label_smoothing'])

# Setup optimizer
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=cifar_config['learning_rate'],
    momentum=cifar_config['momentum'],
    weight_decay=cifar_config['weight_decay']
)

# Create trainer
trainer = Trainer(model, criterion, optimizer, device, cifar_config)

In [None]:
# Get datasets and create dataloaders
train_dataset = get_dataset('cifar100', train=True)
test_dataset = get_dataset('cifar100', train=False)

train_loader = DataLoader(
    train_dataset,
    batch_size=cifar_config['batch_size'],
    shuffle=True,
    num_workers=cifar_config['num_workers'],
    pin_memory=cifar_config['pin_memory']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=cifar_config['batch_size'],
    shuffle=False,
    num_workers=cifar_config['num_workers'],
    pin_memory=cifar_config['pin_memory']
)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Training loop with progress bar
num_epochs = 10  # Start with 10 epochs for testing
train_losses = []
train_accs = []
test_losses = []
test_accs = []

for epoch in tqdm(range(1, num_epochs + 1), desc='Training Progress'):
    # Train
    trainer.train_epoch(train_loader, epoch)
    train_loss = trainer.metrics['train_losses'][-1]
    train_acc = trainer.metrics['train_accs'][-1]
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Evaluate
    test_metrics = trainer.evaluate(test_loader, epoch)
    test_losses.append(test_metrics['loss'])
    test_accs.append(test_metrics['accuracy'])
    
    # Print metrics
    print(f"\nEpoch {epoch}/{num_epochs}:")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Test Loss: {test_metrics['loss']:.4f} | Test Acc: {test_metrics['accuracy']:.2f}%")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Early stopping check
    if trainer.should_stop():
        print("Early stopping triggered")
        break

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Test Accuracy')
plt.legend()

plt.tight_layout()
plt.show()