# COSMOS Visual Zoobot Evaluation

This notebook compares Zoobot predictions against the saved COSMOS-Web visual labels. Place both `test_set_*.csv` (model probabilities) and `test_catalog.csv` (ground-truth labels) on your local machine and set their paths below.

## 1. Configure file locations

Set the local paths to the exported prediction CSV (`test_set_*.csv`) and the `test_catalog.csv` containing the ground-truth labels. No remote catalogs or stamps are required.

In [None]:
from pathlib import Path

PREDICTIONS_CSV = Path("/Users/marchuertascompany/Documents/data/COSMOS-Web/zoobot/ilbert/test_set_nano.csv")
GROUND_TRUTH_CSV = Path("/Users/marchuertascompany/Documents/data/COSMOS-Web/zoobot/ilbert/test_catalog.csv")


In [None]:
import sys
from pathlib import Path

repo_root = Path("/Users/marchuertascompany/Documents/python_scripts/cosmosweb_lowQ")
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from moprhology.zoobot import train_on_cosmos_visual as cosmos
from IPython.display import display

In [None]:
predictions = pd.read_csv(PREDICTIONS_CSV)
prob_cols = [col for col in predictions.columns if col.startswith('p_')]
if not prob_cols:
    raise ValueError("No probability columns (prefixed with 'p_') found in predictions file.")

truth = pd.read_csv(GROUND_TRUTH_CSV)
if 'id_str' not in truth.columns:
    raise ValueError("Ground-truth catalog must contain an 'id_str' column.")

predictions.head()


## 2. Attach class predictions

Derive the most probable class for each galaxy and the associated confidence.

In [None]:
label_names = [col[2:] for col in prob_cols]
prob_matrix = predictions[prob_cols].to_numpy()
pred_indices = prob_matrix.argmax(axis=1)
pred_labels = [label_names[i] for i in pred_indices]
pred_confidence = prob_matrix.max(axis=1)

predictions['pred_label'] = pred_labels
predictions['pred_idx'] = pred_indices
predictions['pred_confidence'] = pred_confidence

predictions[['id_str', 'pred_label', 'pred_confidence'] + prob_cols].head()


## 3. Merge with ground truth (if available)

Metrics are computed only when true labels are provided. Set `GROUND_TRUTH_CSV` or enable the rebuild block above.

In [None]:
truth_cols = ['id_str', 'label'] if 'label' in truth.columns else ['id_str', 'true_label']
merged = predictions.merge(truth[truth_cols], on='id_str', how='inner', suffixes=('', '_true'))
label_to_index = {name: idx for idx, name in enumerate(label_names)}

if 'label' in merged.columns:
    if merged['label'].dtype == object and merged['label'].isin(label_to_index.keys()).all():
        merged['true_label'] = merged['label']
        merged['true_idx'] = merged['true_label'].map(label_to_index)
    else:
        merged['true_idx'] = merged['label'].astype(int)
        merged['true_label'] = merged['true_idx'].map(lambda idx: label_names[idx])
elif 'true_label' in merged.columns:
    merged['true_idx'] = merged['true_label'].map(label_to_index)
else:
    raise ValueError("Ground-truth data must contain a 'label' (int) or 'true_label' (string) column.")

print(f"Merged {len(merged)} rows with ground truth out of {len(predictions)} predictions.")
merged.head()

## 4. Classification metrics

In [None]:
if 'merged' not in globals():
    raise RuntimeError("Run the merge cell above before computing metrics.")

acc = accuracy_score(merged['true_idx'], merged['pred_idx'])
print(f"Overall accuracy: {acc:.4f}")
report = classification_report(
    merged['true_idx'],
    merged['pred_idx'],
    target_names=label_names,
    zero_division=0
)
print(report)

## 5. Confusion matrix

In [None]:
if 'merged' not in globals():
    raise RuntimeError("Run the merge cell above before plotting the confusion matrix.")

cm = confusion_matrix(
    merged['true_idx'],
    merged['pred_idx'],
    labels=list(range(len(label_names)))
)
cm_norm = cm / cm.sum(axis=1, keepdims=True)

fig, ax = plt.subplots(1, 2, figsize=(16, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax[0], xticklabels=label_names, yticklabels=label_names)
ax[0].set_title('Confusion matrix (counts)')
ax[0].set_xlabel('Predicted')
ax[0].set_ylabel('True')

sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', ax=ax[1], xticklabels=label_names, yticklabels=label_names)
ax[1].set_title('Confusion matrix (row-normalized)')
ax[1].set_xlabel('Predicted')
ax[1].set_ylabel('True')
plt.tight_layout()

## 6. Per-class confidence summary

In [None]:
per_class = (
    merged.groupby('true_label')
    .agg(
        support=('true_label', 'size'),
        accuracy=('pred_idx', lambda idx: np.mean(idx == merged.loc[idx.index, 'true_idx'])),
        mean_confidence=('pred_confidence', 'mean'),
        median_confidence=('pred_confidence', 'median')
    )
    .sort_values('support', ascending=False)
)
display(per_class)


## 7. Inspect largest errors

In [None]:
prob_array = merged[prob_cols].to_numpy()
true_probs = prob_array[np.arange(len(merged)), merged['true_idx']]
errors = merged[merged['pred_label'] != merged['true_label']].copy()
errors['true_prob'] = true_probs[errors.index]
errors['confidence_gap'] = errors['pred_confidence'] - errors['true_prob']
display(errors.sort_values('pred_confidence', ascending=False).head(20)[
    ['id_str', 'true_label', 'pred_label', 'pred_confidence', 'true_prob', 'confidence_gap']
])
