# COSMOS Visual Zoobot Evaluation

Use this notebook to inspect the `test_set.csv` predictions (class probabilities per galaxy) and compare them against ground-truth labels to build confusion matrices, per-class statistics, and other quick diagnostics.

## 1. Configure file locations

* `PREDICTIONS_CSV` points to the exported probability table (e.g. the `test_set.csv` downloaded from the cluster).
* Provide ground-truth labels either by (a) pointing `GROUND_TRUTH_CSV` to a saved test catalog that contains `id_str` and `label`, or (b) setting `REBUILD_GROUND_TRUTH=True` and supplying the same catalog/DB paths that were used during training so the notebook can recreate the splits locally.


In [None]:
from pathlib import Path

# --- Required ---
PREDICTIONS_CSV = Path("/Users/marchuertascompany/Documents/data/COSMOS-Web/zoobot/ilbert/test_set.csv")

# Option A: load a saved catalog with true labels (id_str, label[, label_name,...])
GROUND_TRUTH_CSV = None  # Path("/Users/.../test_catalog.csv")

# Option B: rebuild the catalog locally (set to True and fill the paths below)
REBUILD_GROUND_TRUTH = False
STAMP_DIR = Path("/n03data/huertas/COSMOS-Web/zoobot/stamps/f150w")
VISUAL_LABELS = Path("/n07data/ilbert/COSMOS-Web/photoz_MASTER_v3.1.0/MORPHO/visualmorpho_COSMOSWeb_v7.db")
SQLITE_TABLE = "morphology"
FILTER_NAME = "F150W"
FILENAME_TEMPLATE = "{filter}_{id}.jpg"
KEEP_AMBIGUOUS = False
MAX_GALAXIES = None
TEST_FRACTION = 0.2
VAL_FRACTION = 0.1
SEED = 42


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from moprhology.zoobot import train_on_cosmos_visual as cosmos
from IPython.display import display

In [None]:
def rebuild_test_catalog():
    df_visual = cosmos.load_visual_catalog(VISUAL_LABELS, SQLITE_TABLE)
    catalog = cosmos.attach_stamps(
        df_visual=df_visual,
        stamp_dir=STAMP_DIR,
        filename_template=FILENAME_TEMPLATE,
        filter_name=FILTER_NAME,
        keep_ambiguous=KEEP_AMBIGUOUS
    )
    catalog = cosmos.maybe_subsample(catalog, MAX_GALAXIES, SEED)
    train_catalog, val_catalog, test_catalog = cosmos.stratified_splits(
        catalog,
        test_fraction=TEST_FRACTION,
        val_fraction=VAL_FRACTION,
        seed=SEED
    )
    return test_catalog


In [None]:
predictions = pd.read_csv(PREDICTIONS_CSV)
prob_cols = [col for col in predictions.columns if col.startswith('p_')]
if not prob_cols:
    raise ValueError("No probability columns (prefixed with 'p_') found in predictions file.")

truth = None
if GROUND_TRUTH_CSV is not None:
    truth = pd.read_csv(GROUND_TRUTH_CSV)
elif REBUILD_GROUND_TRUTH:
    truth = rebuild_test_catalog()

if truth is not None and 'id_str' not in truth.columns:
    raise ValueError("Ground-truth catalog must contain an 'id_str' column.")

predictions.head()


## 2. Attach class predictions

Derive the most probable class for each galaxy and the associated confidence.

In [None]:
label_names = cosmos.CLASS_COLUMNS
if len(prob_cols) != len(label_names):
    print("Warning: probability column count does not match CLASS_COLUMNS length.")

prob_matrix = predictions[prob_cols].to_numpy()
pred_indices = prob_matrix.argmax(axis=1)
pred_labels = [label_names[i] for i in pred_indices]
pred_confidence = prob_matrix.max(axis=1)

predictions['pred_label'] = pred_labels
predictions['pred_idx'] = pred_indices
predictions['pred_confidence'] = pred_confidence

predictions[['id_str', 'pred_label', 'pred_confidence'] + prob_cols].head()


## 3. Merge with ground truth (if available)

Metrics are computed only when true labels are provided. Set `GROUND_TRUTH_CSV` or enable the rebuild block above.

In [None]:
if truth is None:
    print("Ground-truth labels not available. Provide GROUND_TRUTH_CSV or enable REBUILD_GROUND_TRUTH to continue with metrics.")
else:
    truth_cols = ['id_str']
    if 'label' in truth.columns:
        truth_cols.append('label')
    if 'true_label' in truth.columns and 'true_label' not in truth_cols:
        truth_cols.append('true_label')
    merged = predictions.merge(truth[truth_cols], on='id_str', how='inner', suffixes=('', '_true'))
    label_to_index = {name: idx for idx, name in enumerate(cosmos.CLASS_COLUMNS)}
    if 'label' in merged.columns:
        if merged['label'].dtype == object and merged['label'].isin(label_to_index.keys()).all():
            merged['true_label'] = merged['label']
            merged['true_idx'] = merged['true_label'].map(label_to_index)
        else:
            merged['true_idx'] = merged['label'].astype(int)
            merged['true_label'] = merged['true_idx'].map(lambda idx: cosmos.CLASS_COLUMNS[idx])
    elif 'true_label' in merged.columns:
        merged['true_idx'] = merged['true_label'].map(label_to_index)
    else:
        raise ValueError("Ground-truth data must contain a 'label' (int) or 'true_label' (string) column.")

    print(f"Merged {len(merged)} rows with ground truth out of {len(predictions)} predictions.")
    merged.head()


## 4. Classification metrics

In [None]:
if truth is None:
    print("Skipping metrics because ground truth is missing.")
else:
    acc = accuracy_score(merged['true_idx'], merged['pred_idx'])
    print(f"Overall accuracy: {acc:.4f}")
    report = classification_report(
        merged['true_idx'],
        merged['pred_idx'],
        target_names=cosmos.CLASS_COLUMNS,
        zero_division=0
    )
    print(report)


## 5. Confusion matrix

In [None]:
if truth is None:
    print("Confusion matrix unavailable without ground truth.")
else:
    cm = confusion_matrix(
        merged['true_idx'],
        merged['pred_idx'],
        labels=list(range(len(cosmos.CLASS_COLUMNS)))
    )
    cm_norm = cm / cm.sum(axis=1, keepdims=True)

    fig, ax = plt.subplots(1, 2, figsize=(16, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax[0], xticklabels=cosmos.CLASS_COLUMNS, yticklabels=cosmos.CLASS_COLUMNS)
    ax[0].set_title('Confusion matrix (counts)')
    ax[0].set_xlabel('Predicted')
    ax[0].set_ylabel('True')

    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', ax=ax[1], xticklabels=cosmos.CLASS_COLUMNS, yticklabels=cosmos.CLASS_COLUMNS)
    ax[1].set_title('Confusion matrix (row-normalized)')
    ax[1].set_xlabel('Predicted')
    ax[1].set_ylabel('True')
    plt.tight_layout()


## 6. Per-class confidence summary

In [None]:
if truth is None:
    print("Provide ground-truth labels to compute per-class summaries.")
else:
    per_class = (
        merged.groupby('true_label')
        .agg(
            support=('true_label', 'size'),
            accuracy=('pred_idx', lambda idx: np.mean(idx == merged.loc[idx.index, 'true_idx'])),
            mean_confidence=('pred_confidence', 'mean'),
            median_confidence=('pred_confidence', 'median')
        )
        .sort_values('support', ascending=False)
    )
    display(per_class)


## 7. Inspect largest errors

In [None]:
if truth is None:
    print("Ground truth required to list misclassifications.")
else:
    prob_array = merged[prob_cols].to_numpy()
    true_probs = prob_array[np.arange(len(merged)), merged['true_idx']]
    errors = merged[merged['pred_label'] != merged['true_label']].copy()
    errors['true_prob'] = true_probs[errors.index]
    errors['confidence_gap'] = errors['pred_confidence'] - errors['true_prob']
    display(errors.sort_values('pred_confidence', ascending=False).head(20)[['id_str', 'true_label', 'pred_label', 'pred_confidence', 'true_prob', 'confidence_gap']])
