### Vision Model Comparison

This notebook aggregates and compares results from different vision model experiments:
- ResNet-50 (baseline)
- EfficientNet-B0
- DenseNet-121
- ViT-B/16

All experiments use the same configuration except for the vision model architecture.

In [1]:
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

# Initialize wandb API
api = wandb.Api()

In [2]:
# Wandb configuration
WANDB_PROJECT = "APCOMP215"

# Experiment run names (should match the config filenames)
EXPERIMENTS = {
    "ResNet-50": "exp_001_baseline",
    "EfficientNet-B0": "exp_002_efficientnet_b0",
    "DenseNet-121": "exp_003_densenet121",
    "ViT-B/16": "exp_004_vit_b16"
}


In [None]:
def fetch_run_data(model_name, run_name):
    """Fetch the latest run data for a given experiment name"""
    try:
        runs = api.runs(WANDB_PROJECT, filters={"display_name": run_name}, order="-created_at")
        if len(runs) == 0:
            print(f"⚠️  No runs found for {run_name}")
            return None
        # Get the most recent run
        run = runs[0]
        print(f"✓ Found run: {run.name} ({run.state}) for {model_name}")
        history = run.history()
        summary = run.summary
        
        return {
            "model_name": model_name,
            "run_name": run_name,
            "run_id": run.id,
            "state": run.state,
            "history": history,
            "summary": summary,
            "config": run.config
        }
    except Exception as e:
        print(f"❌ Error fetching {run_name}: {e}")
        return None

# Fetch all experiment data
experiment_data = {}
for model_name, run_name in EXPERIMENTS.items():
    data = fetch_run_data(model_name, run_name)
    if data:
        experiment_data[model_name] = data


In [None]:
# Extract final metrics for comparison
metrics_to_compare = [
    "val/f1",  # Macro F1 (main metric)
    "val/acc",  # Validation accuracy
    "val/loss",  # Validation loss
    "train/f1",  # Training macro F1
    "train/acc",  # Training accuracy
    "train/loss"  # Training loss
]

summary_data = []
for model_name, data in experiment_data.items():
    summary = data["summary"]
    row = {"Model": model_name}
    
    for metric in metrics_to_compare:
        # Try different metric name variations
        metric_key = None
        for key in summary.keys():
            if metric in key or key.endswith(metric.split("/")[-1]):
                metric_key = key
                break
        
        if metric_key and metric_key in summary:
            value = summary[metric_key]
            # Handle wandb artifacts
            if hasattr(value, 'value'):
                value = value.value
            row[metric] = value
        else:
            row[metric] = None
    
    summary_data.append(row)

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.set_index("Model")

# Display summary table
print("=" * 80)
print("FINAL METRICS COMPARISON")
print("=" * 80)
print(summary_df.round(4))
print("=" * 80)


In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

metrics_to_plot = [
    ("val/f1", "Validation Macro F1", "Higher is better"),
    ("val/acc", "Validation Accuracy (%)", "Higher is better"),
    ("val/loss", "Validation Loss", "Lower is better"),
    ("train/f1", "Training Macro F1", "Higher is better"),
    ("train/acc", "Training Accuracy (%)", "Higher is better"),
    ("train/loss", "Training Loss", "Lower is better")
]

for idx, (metric_key, title, note) in enumerate(metrics_to_plot):
    ax = axes[idx]
    
    for model_name, data in experiment_data.items():
        history = data["history"]
        
        # Find the metric column (handle variations)
        metric_col = None
        for col in history.columns:
            if metric_key in col or col.endswith(metric_key.split("/")[-1]):
                metric_col = col
                break
        
        if metric_col and metric_col in history.columns:
            epochs = history.get("_step", range(len(history)))
            values = history[metric_col]
            
            # Remove NaN values
            mask = ~values.isna()
            epochs_clean = epochs[mask] if hasattr(epochs, '__getitem__') else np.array(epochs)[mask]
            values_clean = values[mask]
            
            ax.plot(epochs_clean, values_clean, label=model_name, linewidth=2, alpha=0.8)
    
    ax.set_xlabel("Epoch", fontsize=11)
    ax.set_ylabel(title.split("(")[0].strip(), fontsize=11)
    ax.set_title(f"{title}\n({note})", fontsize=12, fontweight='bold')
    ax.legend(loc='best')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
