# Training Results Visualization

This notebook demonstrates how to use the `plot_curves` utilities to visualize training results from the ResNet reproduction experiments.

In [None]:
import matplotlib.pyplot as plt
from utils.plot_curves import (
    load_training_log,
    plot_training_curves,
    plot_loss_curves,
    plot_single_model_curves,
    compare_multiple_models,
)

%matplotlib inline

## 1. Load Training Logs

First, let's load the training logs for PlainNet-20 and ResNet-20.

In [None]:
# Load logs
plain20_log = load_training_log('results/plain20/logs/PlainNet20_20251106-161914.json')
resnet20_log = load_training_log('results/logs/ResNet20_20251106-024420.json')

print(f"PlainNet-20: {len(plain20_log['train'])} epochs")
print(f"ResNet-20: {len(resnet20_log['train'])} epochs")

# Show sample data
print("\nSample epoch data (ResNet-20, epoch 1):")
print(f"  Train: {resnet20_log['train'][0]}")
print(f"  Test:  {resnet20_log['test'][0]}")

## 2. Compare PlainNet vs ResNet (Error Rate)

This reproduces Figure 6 from the original ResNet paper, showing the training degradation problem in plain networks.

In [None]:
fig = plot_training_curves(
    plain20_log,
    resnet20_log,
    title='CIFAR-10: PlainNet-20 vs ResNet-20 (Error Rate)',
    plot_error=True,
    figsize=(14, 5)
)
plt.show()

## 3. Compare PlainNet vs ResNet (Accuracy)

Same comparison but showing accuracy instead of error rate.

In [None]:
fig = plot_training_curves(
    plain20_log,
    resnet20_log,
    title='CIFAR-10: PlainNet-20 vs ResNet-20 (Accuracy)',
    plot_error=False,
    figsize=(14, 5)
)
plt.show()

## 4. Loss Curves

Comparing training and test loss curves.

In [None]:
fig = plot_loss_curves(
    plain20_log,
    resnet20_log,
    title='CIFAR-10: Training and Test Loss',
    figsize=(14, 5)
)
plt.show()

## 5. Detailed Single Model Analysis

View all metrics for a single model in one figure.

In [None]:
fig = plot_single_model_curves(
    resnet20_log,
    model_name='ResNet-20',
    figsize=(14, 8)
)
plt.show()

## 6. Compare Multiple Models

Compare test accuracy across all available models.

In [None]:
from pathlib import Path

# Collect all available models
log_paths = {
    'PlainNet-20': 'results/plain20/logs/PlainNet20_20251106-161914.json',
    'ResNet-20': 'results/logs/ResNet20_20251106-024420.json',
}

# Check for 32-layer models
plain32 = Path('results/plain32/logs/PlainNet32_20251106-203115.json')
resnet32 = Path('results/resnet32/logs/ResNet32_20251106-215354.json')

if plain32.exists():
    log_paths['PlainNet-32'] = str(plain32)
if resnet32.exists():
    log_paths['ResNet-32'] = str(resnet32)

fig = compare_multiple_models(
    log_paths,
    metric='accuracy',
    split='test',
    title='Test Accuracy Comparison',
    figsize=(12, 7)
)
plt.show()

## 7. Key Observations

From the plots above, we can observe:

1. **Training Degradation**: PlainNet shows higher training error than ResNet, demonstrating the degradation problem
2. **Test Performance**: ResNet achieves better test accuracy than PlainNet
3. **Convergence**: ResNet converges faster and more smoothly
4. **Depth**: Deeper networks (32-layer) show more pronounced differences

These results confirm the key findings from the original ResNet paper.

## 8. Custom Analysis

You can also extract and analyze specific metrics:

In [None]:
from utils.plot_curves import extract_metrics
import numpy as np

# Extract final epoch results
plain_train_acc, plain_test_acc = extract_metrics(plain20_log, 'accuracy')
resnet_train_acc, resnet_test_acc = extract_metrics(resnet20_log, 'accuracy')

print("Final Results (last epoch):")
print("=" * 50)
print(f"PlainNet-20:")
print(f"  Train Accuracy: {plain_train_acc[-1]*100:.2f}%")
print(f"  Test Accuracy:  {plain_test_acc[-1]*100:.2f}%")
print(f"\nResNet-20:")
print(f"  Train Accuracy: {resnet_train_acc[-1]*100:.2f}%")
print(f"  Test Accuracy:  {resnet_test_acc[-1]*100:.2f}%")
print(f"\nImprovement:")
print(f"  Train: +{(resnet_train_acc[-1] - plain_train_acc[-1])*100:.2f}%")
print(f"  Test:  +{(resnet_test_acc[-1] - plain_test_acc[-1])*100:.2f}%")

# Best test accuracy
print(f"\nBest Test Accuracy:")
print(f"  PlainNet-20: {plain_test_acc.max()*100:.2f}% (epoch {plain_test_acc.argmax()+1})")
print(f"  ResNet-20:   {resnet_test_acc.max()*100:.2f}% (epoch {resnet_test_acc.argmax()+1})")