In [1]:
import json
from IPython.display import display
from pathlib import Path
import pandas as pd
import numpy as np
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.utils import get_column_letter


# ---------------------------------------------------------------------------
#  I/O
# ---------------------------------------------------------------------------
def load_results(path: Path) -> list[dict]:
    with path.open() as f:
        return json.load(f)


# ---------------------------------------------------------------------------
#  Per-entry helpers
# ---------------------------------------------------------------------------
def check_top_k(entry: dict, k: int) -> bool:
    gt = entry.get("gt_label")
    names = entry.get("names") or []
    accs = entry.get("accuracy") or []
    if not gt or not names:
        return False
    top_k = [names[i] for i in np.argsort(accs)[::-1][:k]]
    return gt in top_k


def top_n_names(entry: dict, n: int = 5) -> str:
    names = entry.get("names") or []
    accs = entry.get("accuracy") or []
    if not names:
        return ""
    return ", ".join(names[i] for i in np.argsort(accs)[::-1][:n])


# ---------------------------------------------------------------------------
#  Normalise raw JSON → flat DataFrame (one row per annotation)
# ---------------------------------------------------------------------------
def normalize_results(entries: list[dict], tag: str) -> pd.DataFrame:
    rows = []
    for e in entries:
        names = e.get("names") or []
        accs = e.get("accuracy") or []
        best = int(np.argmax(accs)) if accs else None
        rows.append({
            "image_id":       e.get("image_id"),
            "ann_id":         e.get("ann_id"),
            "drawn_fish_id":  e.get("drawn_fish_id"),
            "gt_label":       e.get("gt_label"),
            f"{tag}_pred":    names[best] if best is not None else None,
            f"{tag}_score":   accs[best]  if best is not None else 0,
            f"{tag}_top5":    top_n_names(e, 5),
            "top1_hit": check_top_k(e, 1),
            "top3_hit": check_top_k(e, 3),
            "top5_hit": check_top_k(e, 5),
        })
    return pd.DataFrame(rows)


# ---------------------------------------------------------------------------
#  Per-class metrics
# ---------------------------------------------------------------------------
def per_class_metrics(df: pd.DataFrame, classes: list[str]) -> pd.DataFrame:
    valid = df[df["gt_label"].isin(classes)]
    if valid.empty:
        return pd.DataFrame(
            0.0, index=classes,
            columns=["support", "t1_abs", "t1_rec", "t3_abs", "t3_rec", "t5_abs", "t5_rec"],
        )
    stats = valid.groupby("gt_label").agg(
        support=("gt_label", "count"),
        t1_abs=("top1_hit", "sum"),
        t3_abs=("top3_hit", "sum"),
        t5_abs=("top5_hit", "sum"),
    )
    for k in (1, 3, 5):
        stats[f"t{k}_rec"] = (stats[f"t{k}_abs"] / stats["support"]).fillna(0)
    return stats.reindex(classes).fillna(0)


# ---------------------------------------------------------------------------
#  Build comparison table (per-class + weighted avg + macro avg)
# ---------------------------------------------------------------------------
def compare_models(df_old, df_new, classes):
    m_old = per_class_metrics(df_old, classes)
    m_new = per_class_metrics(df_new, classes)

    c = pd.DataFrame(index=classes)
    c["Samples"] = m_new["support"].astype(int)

    for k in (1, 3, 5):
        c[f"v93 Top-{k} Hits"]   = m_old[f"t{k}_abs"].astype(int)
        c[f"v10 Top-{k} Hits"]   = m_new[f"t{k}_abs"].astype(int)
        c[f"v93 Top-{k} Recall"] = m_old[f"t{k}_rec"]
        c[f"v10 Top-{k} Recall"] = m_new[f"t{k}_rec"]
        c[f"Δ Top-{k} Recall"]   = m_new[f"t{k}_rec"] - m_old[f"t{k}_rec"]

    n = c["Samples"].sum()

    weighted = pd.Series(name="WEIGHTED AVG (Overall Accuracy)", dtype=float)
    weighted["Samples"] = n
    for k in (1, 3, 5):
        weighted[f"v93 Top-{k} Hits"]   = c[f"v93 Top-{k} Hits"].sum()
        weighted[f"v10 Top-{k} Hits"]   = c[f"v10 Top-{k} Hits"].sum()
        weighted[f"v93 Top-{k} Recall"] = c[f"v93 Top-{k} Hits"].sum() / n if n else 0
        weighted[f"v10 Top-{k} Recall"] = c[f"v10 Top-{k} Hits"].sum() / n if n else 0
        weighted[f"Δ Top-{k} Recall"]   = weighted[f"v10 Top-{k} Recall"] - weighted[f"v93 Top-{k} Recall"]

    macro = pd.Series(name="MACRO AVG (Mean Per-Class Recall)", dtype=float)
    macro["Samples"] = n
    for k in (1, 3, 5):
        macro[f"v93 Top-{k} Hits"]   = c[f"v93 Top-{k} Hits"].sum()
        macro[f"v10 Top-{k} Hits"]   = c[f"v10 Top-{k} Hits"].sum()
        macro[f"v93 Top-{k} Recall"] = c[f"v93 Top-{k} Recall"].mean()
        macro[f"v10 Top-{k} Recall"] = c[f"v10 Top-{k} Recall"].mean()
        macro[f"Δ Top-{k} Recall"]   = macro[f"v10 Top-{k} Recall"] - macro[f"v93 Top-{k} Recall"]

    body = c.sort_values("Δ Top-1 Recall", ascending=False)
    return pd.concat([body, pd.DataFrame([weighted, macro])])


# ---------------------------------------------------------------------------
#  Console summary
# ---------------------------------------------------------------------------
def print_summary(df_old, df_new, classes, title=""):
    m_old = per_class_metrics(df_old, classes)
    m_new = per_class_metrics(df_new, classes)
    n_cls = len(classes)
    n_sam = int(m_new["support"].sum())

    w = 70
    print(f"\n{'=' * w}")
    print(f"  {title}  |  Classes: {n_cls}  |  Samples: {n_sam:,}")
    print(f"{'=' * w}")
    print(f"  {'Metric':<30} {'v93':>10} {'v10':>10} {'Δ':>10}")
    print(f"  {'-' * 62}")

    for label, getter in [
        ("Weighted (micro)", lambda m, k: m[f"t{k}_abs"].sum() / n_sam if n_sam else 0),
        ("Macro (per-class)", lambda m, k: m[f"t{k}_rec"].mean()),
    ]:
        for k in (1, 3, 5):
            v_old = getter(m_old, k)
            v_new = getter(m_new, k)
            d = v_new - v_old
            sign = "+" if d >= 0 else ""
            print(f"  Top-{k} {label:<24} {v_old:>9.2%} {v_new:>9.2%} {sign}{d:>8.2%}")
        print()

# ---------------------------------------------------------------------------
#  Top-N errors per class for a given model
# ---------------------------------------------------------------------------
def top_errors_per_class(df: pd.DataFrame, pred_col: str, top5_col: str,
                         n: int = 5) -> pd.DataFrame:
    wrong = df[df["gt_label"] != df[pred_col]].copy()

    rows = []
    for cls in sorted(df["gt_label"].dropna().unique()):
        cls_df = df[df["gt_label"] == cls]
        cls_total = len(cls_df)
        cls_wrong = wrong[wrong["gt_label"] == cls]
        n_errors = len(cls_wrong)
        error_rate = n_errors / cls_total if cls_total else 0

        row = {
            "Class": cls,
            "Total Samples": cls_total,
            "Errors (Top-1)": n_errors,
            "Error Rate": error_rate,
        }

        if n_errors > 0:
            top_wrong = cls_wrong[pred_col].value_counts().head(n)
            for i, (pred, count) in enumerate(top_wrong.items(), 1):
                row[f"#{i} Wrong Pred"] = pred
                row[f"#{i} Count"] = count
                row[f"#{i} % of Errors"] = count / n_errors
        rows.append(row)

    result = pd.DataFrame(rows)
    result.sort_values("Errors (Top-1)", ascending=False, inplace=True)
    result.reset_index(drop=True, inplace=True)
    return result


print("Functions loaded ✓")

Functions loaded ✓


In [2]:
# ---------------------------------------------------------------------------
#  Load data & build normalised DataFrames
# ---------------------------------------------------------------------------
v93_path = Path("inference_results_v93.json")
v10_path = Path("inference_results_v10.json")
output_path = Path("validation_comparison_v93_vs_v10.xlsx")

df_v93 = normalize_results(load_results(v93_path), "v93")
df_v10 = normalize_results(load_results(v10_path), "v10")

cls_v93 = sorted(df_v93["v93_pred"].dropna().unique())
cls_v10 = sorted(df_v10["v10_pred"].dropna().unique())
common  = sorted(set(cls_v93) & set(cls_v10))

print(f"Loaded:  v93 = {len(df_v93):,} rows,  v10 = {len(df_v10):,} rows")
print(f"Predicted classes — v93: {len(cls_v93)},  v10: {len(cls_v10)},  common: {len(common)}")

# ---------------------------------------------------------------------------
#  Console summaries for each class set
# ---------------------------------------------------------------------------
print_summary(df_v93, df_v10, common,  "COMMON CLASSES (intersection)")
print_summary(df_v93, df_v10, cls_v93, "v93 LABEL SET")
print_summary(df_v93, df_v10, cls_v10, "v10 LABEL SET")

Loaded:  v93 = 168,612 rows,  v10 = 168,612 rows
Predicted classes — v93: 639,  v10: 775,  common: 638

  COMMON CLASSES (intersection)  |  Classes: 638  |  Samples: 149,304
  Metric                                v93        v10          Δ
  --------------------------------------------------------------
  Top-1 Weighted (micro)            96.16%    95.90%   -0.26%
  Top-3 Weighted (micro)            98.04%    98.69% +   0.64%
  Top-5 Weighted (micro)            98.19%    98.97% +   0.78%

  Top-1 Macro (per-class)           95.26%    96.69% +   1.44%
  Top-3 Macro (per-class)           97.53%    98.96% +   1.43%
  Top-5 Macro (per-class)           97.75%    99.18% +   1.43%


  v93 LABEL SET  |  Classes: 639  |  Samples: 149,304
  Metric                                v93        v10          Δ
  --------------------------------------------------------------
  Top-1 Weighted (micro)            96.16%    95.90%   -0.26%
  Top-3 Weighted (micro)            98.04%    98.69% +   0.64%
  Top

In [3]:
# ---------------------------------------------------------------------------
#  Excel export with formatting
# ---------------------------------------------------------------------------
SETS = {
    "Common_Classes": common,
    "v93_Set":        cls_v93,
    "v10_Set":        cls_v10,
}

# ---- styles ----
HEADER_FILL = PatternFill("solid", fgColor="4472C4")
HEADER_FONT = Font(bold=True, size=11, color="FFFFFF")
SUMMARY_FILL = PatternFill("solid", fgColor="FFF2CC")
SUMMARY_FONT = Font(bold=True, size=11)
GREEN_FILL = PatternFill("solid", fgColor="C6EFCE")
GREEN_FONT = Font(color="006100")
RED_FILL   = PatternFill("solid", fgColor="FFC7CE")
RED_FONT   = Font(color="9C0006")
_s = Side(style="thin")
THIN = Border(left=_s, right=_s, top=_s, bottom=_s)

with pd.ExcelWriter(output_path, engine="openpyxl") as writer:

    # ── 1. Summary sheet ─────────────────────────────────────────────────
    summary_rows = []
    for sname, classes in SETS.items():
        mo = per_class_metrics(df_v93, classes)
        mn = per_class_metrics(df_v10, classes)
        n = int(mn["support"].sum())
        row = {"Set": sname, "Classes": len(classes), "Samples": n}
        for k in (1, 3, 5):
            oh, nh = int(mo[f"t{k}_abs"].sum()), int(mn[f"t{k}_abs"].sum())
            row[f"v93 Top-{k} Acc"]      = oh / n if n else 0
            row[f"v10 Top-{k} Acc"]      = nh / n if n else 0
            row[f"Δ Top-{k} Acc"]        = row[f"v10 Top-{k} Acc"] - row[f"v93 Top-{k} Acc"]
            row[f"v93 Top-{k} Macro"]    = mo[f"t{k}_rec"].mean()
            row[f"v10 Top-{k} Macro"]    = mn[f"t{k}_rec"].mean()
            row[f"Δ Top-{k} Macro"]      = row[f"v10 Top-{k} Macro"] - row[f"v93 Top-{k} Macro"]
        summary_rows.append(row)
    pd.DataFrame(summary_rows).to_excel(writer, sheet_name="Summary", index=False)

    # ── 2. Per-class comparison sheets ────────────────────────────────────
    for sname, classes in SETS.items():
        compare_models(df_v93, df_v10, classes).to_excel(writer, sheet_name=sname)

    # ── 3. Raw data (merged) ─────────────────────────────────────────────
    raw = df_v93.merge(
        df_v10,
        on=["image_id", "ann_id", "drawn_fish_id", "gt_label"],
        suffixes=("_v93", "_v10"),
        how="outer",
    )
    raw.to_excel(writer, sheet_name="Raw_Data", index=False)

    # ── 4. Top-5 errors per class for each model ────────────────────────
    err_v93 = top_errors_per_class(df_v93, "v93_pred", "v93_top5")
    err_v10 = top_errors_per_class(df_v10, "v10_pred", "v10_top5")
    err_v93.to_excel(writer, sheet_name="v93_Top_Errors", index=False)
    err_v10.to_excel(writer, sheet_name="v10_Top_Errors", index=False)

    # ── 5. Class dictionaries ─────────────────────────────────────────────
    pd.DataFrame({
        "v93 classes":    pd.Series(cls_v93),
        "v10 classes":    pd.Series(cls_v10),
        "common classes": pd.Series(common),
    }).to_excel(writer, sheet_name="Class_Lists", index=False)

    # ==================================================================
    #  Formatting pass
    # ==================================================================
    wb = writer.book
    for sname in wb.sheetnames:
        ws = wb[sname]

        # --- header row ---
        for cell in ws[1]:
            cell.font = HEADER_FONT
            cell.fill = HEADER_FILL
            cell.alignment = Alignment(horizontal="center", wrap_text=True)
            cell.border = THIN

        # --- freeze ---
        ws.freeze_panes = "B2" if sname in SETS else "A2"

        # --- auto-width ---
        for ci, col in enumerate(ws.columns, 1):
            width = max((len(str(c.value or "")) for c in col), default=8)
            ws.column_dimensions[get_column_letter(ci)].width = min(width + 3, 42)

        # --- comparison sheets: recall %, delta coloring, summary rows ---
        if sname in SETS:
            hdrs = [c.value for c in ws[1]]
            delta_ci = {i + 1 for i, h in enumerate(hdrs) if h and "Δ" in str(h)}
            recall_ci = {i + 1 for i, h in enumerate(hdrs) if h and "Recall" in str(h)}
            pct_ci = delta_ci | recall_ci

            for row in ws.iter_rows(min_row=2, max_row=ws.max_row):
                is_summary = row[0].value and any(
                    tag in str(row[0].value) for tag in ("AVG", "WEIGHTED", "MACRO")
                )
                for cell in row:
                    cell.border = THIN
                    if cell.column in pct_ci and isinstance(cell.value, (int, float)):
                        cell.number_format = "0.00%"
                    if cell.column in delta_ci and isinstance(cell.value, (int, float)):
                        if cell.value > 0.001:
                            cell.fill, cell.font = GREEN_FILL, GREEN_FONT
                        elif cell.value < -0.001:
                            cell.fill, cell.font = RED_FILL, RED_FONT
                    if is_summary and cell.column not in delta_ci:
                        cell.fill = SUMMARY_FILL
                        cell.font = SUMMARY_FONT

        # --- Summary sheet: % format + delta coloring ---
        if sname == "Summary":
            hdrs = [c.value for c in ws[1]]
            pct_ci = {i + 1 for i, h in enumerate(hdrs)
                      if h and any(t in str(h) for t in ("Acc", "Macro", "Δ"))}
            delta_ci = {i + 1 for i, h in enumerate(hdrs) if h and "Δ" in str(h)}

            for row in ws.iter_rows(min_row=2, max_row=ws.max_row):
                for cell in row:
                    cell.border = THIN
                    if cell.column in pct_ci and isinstance(cell.value, (int, float)):
                        cell.number_format = "0.00%"
                    if cell.column in delta_ci and isinstance(cell.value, (int, float)):
                        if cell.value > 0.001:
                            cell.fill, cell.font = GREEN_FILL, GREEN_FONT
                        elif cell.value < -0.001:
                            cell.fill, cell.font = RED_FILL, RED_FONT

        # --- Top-Errors sheets: % format + error-rate coloring ---
        if sname in ("v93_Top_Errors", "v10_Top_Errors"):
            hdrs = [c.value for c in ws[1]]
            pct_ci = {i + 1 for i, h in enumerate(hdrs)
                      if h and ("Error Rate" in str(h) or "% of Errors" in str(h))}
            err_rate_ci = {i + 1 for i, h in enumerate(hdrs)
                          if h and "Error Rate" == str(h)}

            for row in ws.iter_rows(min_row=2, max_row=ws.max_row):
                for cell in row:
                    cell.border = THIN
                    if cell.column in pct_ci and isinstance(cell.value, (int, float)):
                        cell.number_format = "0.00%"
                    if cell.column in err_rate_ci and isinstance(cell.value, (int, float)):
                        if cell.value > 0.3:
                            cell.fill, cell.font = RED_FILL, RED_FONT
                        elif cell.value < 0.05:
                            cell.fill, cell.font = GREEN_FILL, GREEN_FONT

print(f"Excel saved → {output_path}")

Excel saved → validation_comparison_v93_vs_v10.xlsx
