# Analyze Training Results

This notebook loads and analyzes the training results from your GCP workflow run.

**Prerequisites:** You should have:
1. Completed training on a GCE VM
2. Downloaded results from GCS to `./results/<RUN_ID>/`

## 1. Load Metrics

First, let's find your training run and load the metrics.

In [None]:
import json
import os
import matplotlib.pyplot as plt

# Find the most recent run (or specify RUN_ID manually)
results_dir = "./results"
runs = sorted(os.listdir(results_dir))
run_id = runs[-1]  # Most recent run
run_path = os.path.join(results_dir, run_id)
print(f"Analyzing run: {run_id}")
print(f"Run path: {run_path}")

# Load metrics
with open(os.path.join(run_path, "metrics.json")) as f:
    metrics = json.load(f)

print(f"\nTraining completed {len(metrics['epoch'])} epochs")
print(f"Final accuracy: {metrics['test_accuracy'][-1]:.2f}%")
print(f"Final loss: {metrics['test_loss'][-1]:.4f}")

## 2. Plot Training Curves

Visualize how loss and accuracy evolved during training.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Plot loss
ax1.plot(metrics['epoch'], metrics['test_loss'], 'b-o', linewidth=2, markersize=6)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Test Loss over Training')
ax1.grid(True, alpha=0.3)

# Plot accuracy
ax2.plot(metrics['epoch'], metrics['test_accuracy'], 'g-o', linewidth=2, markersize=6)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Test Accuracy over Training')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(run_path, 'training_curves.png'), dpi=150)
plt.show()

print(f"\nTraining curves saved to: {os.path.join(run_path, 'training_curves.png')}")

## 3. Inspect the Model

Load the saved model and examine its architecture.

In [None]:
import torch

model_path = os.path.join(run_path, 'model.pt')
model_state = torch.load(model_path, map_location='cpu')

print("Model layers:")
for name in model_state.keys():
    shape = model_state[name].shape
    print(f"  {name}: {list(shape)}")

total_params = sum(p.numel() for p in model_state.values())
print(f"\nTotal parameters: {total_params:,}")

## 4. Summary

Display a summary of your training run.

In [None]:
print("=" * 50)
print("TRAINING RUN SUMMARY")
print("=" * 50)
print(f"Run ID:          {run_id}")
print(f"Epochs:          {len(metrics['epoch'])}")
print(f"Final Accuracy:  {metrics['test_accuracy'][-1]:.2f}%")
print(f"Final Loss:      {metrics['test_loss'][-1]:.4f}")
print(f"Best Accuracy:   {max(metrics['test_accuracy']):.2f}% (epoch {metrics['test_accuracy'].index(max(metrics['test_accuracy'])) + 1})")
print(f"Parameters:      {total_params:,}")
print("=" * 50)
print(f"\nFiles in {run_path}:")
for f in os.listdir(run_path):
    print(f"  - {f}")