# ROC Curve Analysis for CIFAR-10 CNN Experiments

This notebook computes ROC curves and AUC scores for CIFAR-10 CNN classification experiments.
It retrieves predictions with probability scores from the DerivaML catalog and compares them
against ground truth labels.

The notebook uses:
1. **Confidence scores** from the `Image_Classification` feature table in the catalog
2. **Full probability distributions** from `prediction_probabilities.csv` execution asset (if available)

In [None]:
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, RocCurveDisplay
from sklearn.preprocessing import label_binarize

from deriva_ml import DerivaML, MLVocab, RID, DerivaMLConfig
from deriva_ml.dataset import DatasetConfigList
from deriva_ml.execution import ExecutionConfiguration, Execution

from hydra_zen import launch, zen, builds, store

## Parameters

Configure the execution RID to analyze. This should be the RID of a completed CIFAR-10 CNN execution.

In [None]:
# Parameters cell - these can be overridden by papermill
dry_run: bool = False
execution_rid: str = ""  # RID of the execution to analyze (leave empty to use latest)

In [None]:
overrides = [f"cfg.dry_run={dry_run}"]

## Load Configuration

Initialize the Hydra configuration store with DerivaML settings.

In [None]:
# Load the configuration store
import configs.deriva

In [None]:
@dataclass
class notebook_config:
    deriva_ml: DerivaMLConfig
    dry_run: bool

notebook_defaults = [
    "_self_",
    {"deriva_ml": "local"},
]

NotebookConfig = builds(
    notebook_config,
    populate_full_signature=True,
    dry_run=False,
    hydra_defaults=notebook_defaults
)

store(NotebookConfig, name="roc_analysis_config")
store.add_to_hydra_store(overwrite_ok=True)

config = launch(
    NotebookConfig,
    zen(NotebookConfig),
    version_base="1.3",
    config_name="roc_analysis_config",
    job_name="ROCAnalysis",
    overrides=overrides
).return_value

In [None]:
# Initialize DerivaML connection
ml = DerivaML.instantiate(config.deriva_ml)
print(f"Connected to {ml.host_name}, catalog {ml.catalog_id}")

## Find Execution to Analyze

Either use the specified execution RID or find the most recent CIFAR-10 CNN execution.

In [None]:
# Get the execution to analyze
if execution_rid:
    print(f"Analyzing specified execution: {execution_rid}")
else:
    # Find the most recent completed execution with predictions
    executions = list(ml.get_table_as_dict("Execution"))
    # Filter for completed executions
    completed = [e for e in executions if e.get("Status") == "Completed"]
    if completed:
        # Sort by RID (most recent last) and take the last one
        execution_rid = completed[-1]["RID"]
        print(f"Using most recent completed execution: {execution_rid}")
    else:
        raise ValueError("No completed executions found")

print(f"Execution RID: {execution_rid}")

## Load Probability Data

Try to load full probability distributions from the CSV asset. If not available, fall back to confidence scores from the catalog.

In [None]:
# Check for prediction_probabilities.csv asset
pb = ml.pathBuilder()
exec_assets = pb.schemas[ml.ml_schema].tables["Execution_Asset"]

assets = list(
    exec_assets
    .filter(exec_assets.Execution == execution_rid)
    .entities()
    .fetch()
)

prob_csv_asset = None
for asset in assets:
    if asset.get("Filename") == "prediction_probabilities.csv":
        prob_csv_asset = asset
        break

if prob_csv_asset:
    print(f"Found prediction_probabilities.csv asset (RID: {prob_csv_asset['RID']})")
    use_full_probs = True
else:
    print("No prediction_probabilities.csv found, will use confidence scores from catalog")
    use_full_probs = False

In [None]:
# Load predictions and ground truth from catalog
prediction_table = pb.schemas[ml.ml_schema].tables["Execution_Image_Image_Classification"]

# Query predictions for our execution
predictions = list(
    prediction_table
    .filter(prediction_table.Execution == execution_rid)
    .entities()
    .fetch()
)

print(f"Found {len(predictions)} predictions for execution {execution_rid}")

if not predictions:
    raise ValueError(f"No predictions found for execution {execution_rid}")

# Check if confidence scores are available
has_confidence = any(p.get("Confidence") is not None for p in predictions)
print(f"Confidence scores available: {has_confidence}")

In [None]:
# Get ground truth labels from the Image_Classification feature table
ground_truth_table = pb.schemas[ml.domain_schema].tables["Image_Image_Classification"]

# Get all ground truth labels
ground_truth = list(ground_truth_table.entities().fetch())
print(f"Found {len(ground_truth)} ground truth labels")

# Create lookup dict: Image RID -> ground truth class
gt_lookup = {gt["Image"]: gt["Image_Class"] for gt in ground_truth}

In [None]:
# Build aligned arrays of predictions and ground truth
y_true = []
y_pred = []
y_confidence = []  # Confidence of predicted class
image_rids = []

for pred in predictions:
    image_rid = pred["Image"]
    if image_rid in gt_lookup:
        y_true.append(gt_lookup[image_rid])
        y_pred.append(pred["Image_Class"])
        y_confidence.append(pred.get("Confidence", 1.0))  # Default to 1.0 if no confidence
        image_rids.append(image_rid)

print(f"Matched {len(y_true)} predictions with ground truth")

# Get unique class names
class_names = sorted(set(y_true))
n_classes = len(class_names)
print(f"Classes ({n_classes}): {class_names}")

## Load Full Probability Distributions (if available)

If the CSV asset exists, download and parse it for more accurate ROC curves.

In [None]:
# Load full probability distributions if available
prob_matrix = None

if use_full_probs and prob_csv_asset:
    # Download the CSV file
    import tempfile
    import requests
    
    # Get the hatrac URL and download
    url = prob_csv_asset.get("URL")
    if url:
        # Construct full URL
        full_url = f"https://{ml.host_name}{url}"
        
        # Download using DerivaML's session for authentication
        response = ml.catalog.get(url)
        
        # Parse CSV content
        import io
        prob_df = pd.read_csv(io.StringIO(response.text))
        print(f"Loaded probability CSV with {len(prob_df)} rows")
        print(f"Columns: {list(prob_df.columns)}")
        
        # Extract probability columns
        prob_cols = [f"prob_{c}" for c in class_names]
        if all(col in prob_df.columns for col in prob_cols):
            # Create RID -> probability array mapping
            rid_to_probs = {}
            for _, row in prob_df.iterrows():
                rid = row["Image_RID"]
                probs = [row[col] for col in prob_cols]
                rid_to_probs[rid] = probs
            
            # Build probability matrix aligned with y_true
            prob_matrix = np.array([rid_to_probs.get(rid, [1/n_classes]*n_classes) for rid in image_rids])
            print(f"Built probability matrix: {prob_matrix.shape}")
        else:
            print(f"Missing probability columns, expected: {prob_cols}")
            use_full_probs = False
else:
    use_full_probs = False

print(f"\nUsing full probability distributions: {use_full_probs}")

## Compute ROC Curves

For multi-class classification, we compute ROC curves using a one-vs-rest approach.
When full probabilities are available, we get smooth ROC curves. Otherwise, we use
confidence scores for a stepped approximation.

In [None]:
# Convert class names to numeric indices
class_to_idx = {name: idx for idx, name in enumerate(class_names)}
y_true_idx = np.array([class_to_idx[c] for c in y_true])
y_pred_idx = np.array([class_to_idx[c] for c in y_pred])

# Binarize the labels for ROC computation (one-vs-rest)
y_true_bin = label_binarize(y_true_idx, classes=range(n_classes))

# Create score matrix for ROC computation
if use_full_probs and prob_matrix is not None:
    # Use full probability distributions
    y_score = prob_matrix
    print("Using full probability distributions for ROC curves")
elif has_confidence:
    # Use confidence scores: predicted class gets confidence, others get (1-conf)/(n-1)
    y_score = np.zeros((len(y_pred_idx), n_classes))
    for i, (pred_idx, conf) in enumerate(zip(y_pred_idx, y_confidence)):
        remaining = (1 - conf) / (n_classes - 1) if n_classes > 1 else 0
        y_score[i, :] = remaining
        y_score[i, pred_idx] = conf
    print("Using confidence scores for ROC curves")
else:
    # Fall back to binary predictions (stepped ROC curves)
    y_score = label_binarize(y_pred_idx, classes=range(n_classes)).astype(float)
    print("Using binary predictions for ROC curves (no probability data available)")

print(f"Score matrix shape: {y_score.shape}")

In [None]:
# Compute ROC curve and AUC for each class
fpr = {}
tpr = {}
roc_auc = {}

for i, class_name in enumerate(class_names):
    fpr[class_name], tpr[class_name], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
    roc_auc[class_name] = auc(fpr[class_name], tpr[class_name])

# Compute micro-average ROC curve
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average AUC
roc_auc["macro"] = np.mean([roc_auc[c] for c in class_names])

print("\nAUC scores per class:")
for class_name in class_names:
    print(f"  {class_name}: {roc_auc[class_name]:.4f}")
print(f"\nMicro-average AUC: {roc_auc['micro']:.4f}")
print(f"Macro-average AUC: {roc_auc['macro']:.4f}")

## Plot ROC Curves

In [None]:
# Plot ROC curves for all classes
fig, ax = plt.subplots(figsize=(10, 8))

# Plot micro-average
ax.plot(
    fpr["micro"], tpr["micro"],
    label=f"Micro-average (AUC = {roc_auc['micro']:.2f})",
    color="deeppink", linestyle=":", linewidth=3
)

# Plot each class
colors = plt.cm.tab10(np.linspace(0, 1, n_classes))
for i, class_name in enumerate(class_names):
    ax.plot(
        fpr[class_name], tpr[class_name],
        color=colors[i],
        label=f"{class_name} (AUC = {roc_auc[class_name]:.2f})"
    )

# Plot diagonal (random classifier)
ax.plot([0, 1], [0, 1], "k--", label="Random (AUC = 0.50)")

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title(f"ROC Curves - CIFAR-10 CNN (Execution {execution_rid})")
ax.legend(loc="lower right", fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary Statistics

In [None]:
# Compute overall accuracy
correct = sum(1 for t, p in zip(y_true, y_pred) if t == p)
accuracy = correct / len(y_true) * 100

print(f"\n{'='*50}")
print(f"CIFAR-10 CNN ROC Analysis Summary")
print(f"{'='*50}")
print(f"Execution RID: {execution_rid}")
print(f"Total predictions: {len(y_true)}")
print(f"Overall accuracy: {accuracy:.2f}%")
print(f"Data source: {'Full probabilities' if use_full_probs else ('Confidence scores' if has_confidence else 'Binary predictions')}")
print(f"Micro-average AUC: {roc_auc['micro']:.4f}")
print(f"Macro-average AUC: {roc_auc['macro']:.4f}")
print(f"{'='*50}")