# ZooBot Fine-Tuning on NGVS Morphology Labels (Curated Example)

This notebook is a **curated** example of fine-tuning a pretrained ZooBot CNN on NGVS galaxy cutouts,
then generating predictions and inspecting failure cases.

**What this demonstrates**
- Building a ZooBot-compatible `catalog` (`id_str`, `file_loc`, label columns)
- Fine-tuning (head-only vs deeper-layer) via `n_layers`
- Saving predictions to CSV for reproducible downstream analysis
- Qualitative error inspection (top/bottom predictions; incorrect vs correct)

> Note: Image data and some paths are project-specific. Replace placeholders under **Configuration**.


## 0) Configuration

Set paths and choose a label column to train on (example: **E vs non-E**).


In [None]:
import os
import logging
import pandas as pd
import numpy as np
import torch

# --- Logging / device ---
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# --- Paths (EDIT THESE) ---
DATA_DIR = "/path/to/ngvs_jpegs/"        # directory containing image files
CATALOG_CSV = "/path/to/ngvs_catalog.csv"  # must include NGVS_name + morphology label columns
CHECKPOINT_LOC = "/path/to/zoobot_checkpoint.ckpt"  # pretrained ZooBot checkpoint
SAVE_DIR = "./outputs/run_001"            # where checkpoints + predictions will be written
os.makedirs(SAVE_DIR, exist_ok=True)

# --- Task definition ---
# Example binary task: Elliptical ("E") vs non-E
MORPH_COL = "Morphology_code"  # column in your catalog CSV with labels like 'E', 'ES', 'EI', ...
LABEL_COL = "is_E"


## 1) Load catalog and build ZooBot fields

ZooBot expects a pandas DataFrame (`catalog`) that includes:
- `id_str`: unique identifier per image
- `file_loc`: path to image file
- label columns (e.g., `is_E`)


In [None]:
# Load metadata
catalog = pd.read_csv(CATALOG_CSV)

# Required fields
catalog["id_str"] = catalog["NGVS_name"].astype(str)
catalog["file_loc"] = catalog["id_str"].apply(lambda x: os.path.join(DATA_DIR, f"{x}.jpeg"))

# Drop duplicates / missing images (optional but recommended)
catalog = catalog.drop_duplicates(subset=["id_str"]).reset_index(drop=True)

# Filter to rows where the image exists (avoids runtime surprises)
catalog = catalog[catalog["file_loc"].apply(os.path.exists)].reset_index(drop=True)

print("Rows after filtering:", len(catalog))
catalog.head()


## 2) Build label column(s)

Here we build a simple **binary** target: `is_E`.
You can extend this pattern to other binary tasks or to multi-class labels.


In [None]:
# Binary label example: E vs non-E
catalog[LABEL_COL] = (catalog[MORPH_COL] == "E").astype(int)

print(catalog[LABEL_COL].value_counts(dropna=False))


## 3) Create the ZooBot DataModule

We use `GalaxyDataModule` to handle image loading, resizing/cropping, and batching.


In [None]:
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

label_cols = [LABEL_COL]

datamodule = GalaxyDataModule(
    label_cols=label_cols,
    catalog=catalog,
    batch_size=32,
    resize_after_crop=224,
    num_workers=2,
)

datamodule.setup()


## 4) Fine-tune ZooBot

Key knob: `n_layers`
- `n_layers=0` → **head-only** fine-tuning
- `n_layers>0` → fine-tune deeper layers (more adaptation, more risk of overfitting)

This code uses ZooBot's Lightning trainer helpers.


In [None]:
from zoobot.pytorch.training import finetune

# Choose head-only vs deeper fine-tuning
N_LAYERS_TO_FINETUNE = 0   # 0=head-only; try 2 for deeper fine-tune

model = finetune.FinetuneableZoobotClassifier(
    checkpoint_loc=CHECKPOINT_LOC,
    num_classes=2,
    n_layers=N_LAYERS_TO_FINETUNE,
)

trainer = finetune.get_trainer(
    save_dir=SAVE_DIR,
    accelerator="auto",
    devices="auto",
    max_epochs=50,  # adjust as needed
)

trainer.fit(model, datamodule)
best_checkpoint = trainer.checkpoint_callback.best_model_path
print("Best checkpoint:", best_checkpoint)


## 5) Predict on the full catalog and save to CSV

Saving predictions makes your experiment reproducible and enables analysis without re-running inference.


In [None]:
from zoobot.pytorch.predictions import predict_on_catalog

finetuned_model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(best_checkpoint)

PRED_CSV = os.path.join(SAVE_DIR, "finetuned_predictions.csv")

predict_on_catalog.predict(
    catalog=catalog,
    model=finetuned_model,
    n_samples=1,
    label_cols=label_cols,
    save_loc=PRED_CSV,
    trainer_kwargs={"accelerator": "auto", "devices": "auto"},
    datamodule_kwargs={"num_workers": 2},
)

print("Wrote:", PRED_CSV)


## 6) Merge predictions with labels + quick diagnostics

At minimum, compute accuracy and inspect confidence extremes.
If you later add calibration plots (reliability diagram / ECE), this is where it fits.


In [None]:
pred = pd.read_csv(PRED_CSV)

# ZooBot prediction column names depend on label name; in many setups you will see something like f"{LABEL_COL}_pred"
# We'll infer it robustly:
pred_col = None
for c in pred.columns:
    if c.endswith("_pred") and LABEL_COL in c:
        pred_col = c
        break
if pred_col is None:
    # fallback: take the first *_pred column
    pred_col = [c for c in pred.columns if c.endswith("_pred")][0]

df = pred.merge(
    catalog[["id_str", "file_loc", LABEL_COL, MORPH_COL]],
    on="id_str",
    how="left",
)

df["pred_label"] = (df[pred_col] >= 0.5).astype(int)
acc = (df["pred_label"] == df[LABEL_COL]).mean()
print("Accuracy:", round(acc, 4))

df[[ "id_str", MORPH_COL, LABEL_COL, pred_col]].head()


## 7) Qualitative failure inspection (high-signal)

This is the most reviewer-friendly artifact:
- show the **highest-confidence positives** and **lowest-confidence positives**
- highlight incorrect predictions


In [None]:
from PIL import Image
import matplotlib.pyplot as plt

def show_grid(subdf, pred_col, n=6, ncols=3, title=None):
    n = min(n, len(subdf))
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4*ncols, 3.5*nrows))
    axes = np.array(axes).reshape(-1)

    for i in range(n):
        row = subdf.iloc[i]
        im = Image.open(row["file_loc"])
        axes[i].imshow(im)

        correct = int(row["pred_label"]) == int(row[LABEL_COL])
        color = "white" if correct else "red"

        axes[i].text(0.02, 0.92, f"True: {row[LABEL_COL]}", color=color, transform=axes[i].transAxes, fontsize=11)
        axes[i].text(0.02, 0.82, f"Pred: {row[pred_col]:.2f}", color=color, transform=axes[i].transAxes, fontsize=11)
        axes[i].text(0.02, 0.08, f"{row[MORPH_COL]}", color=color, transform=axes[i].transAxes, fontsize=10)

        axes[i].axis("off")

    for j in range(n, len(axes)):
        axes[j].axis("off")

    if title:
        fig.suptitle(title, fontsize=14)
    plt.tight_layout()
    return fig

# Top confidence positives / bottom confidence positives
top = df.sort_values(pred_col, ascending=False).head(9)
bot = df.sort_values(pred_col, ascending=True).head(9)

fig1 = show_grid(top, pred_col, n=9, title="Highest-confidence predictions")
fig2 = show_grid(bot, pred_col, n=9, title="Lowest-confidence predictions")

# Save curated figures
FIG_DIR = os.path.join(SAVE_DIR, "figures")
os.makedirs(FIG_DIR, exist_ok=True)
fig1.savefig(os.path.join(FIG_DIR, "top_confidence.png"), dpi=200)
fig2.savefig(os.path.join(FIG_DIR, "bottom_confidence.png"), dpi=200)
print("Saved figures to:", FIG_DIR)


## 8) Next steps (optional)

To make this repo application-ready, consider adding:
- a reliability diagram / expected calibration error (ECE)
- stratified error analysis (e.g., performance vs image quality / magnitude / subtype)
- a small `scripts/` entry point that runs sections 0–6 without the notebook
