In [3]:
def compute_area_between_curves(morf_curve, lerf_curve):
    """
    Compute the area between the LeRF and MoRF curves using the trapezoidal rule. In the literature, we divide
    the area by the number of steps to normalize it.
    """
    # Ensure the curves are of the same length
    assert len(morf_curve) == len(lerf_curve), "Curves must be of the same length"

    # Compute the area using the trapezoidal rule
    area = 0.0
    for i in range(1, len(morf_curve)):
        area += 0.5 * ((lerf_curve[i] - morf_curve[i]) + (lerf_curve[i-1] - morf_curve[i-1]))
    return area / len(morf_curve)


def compute_aopc(morf_curve):
    # Compute the Area Over the Perturbation Curve (AOPC)
    reference = morf_curve[0]

    area = 0.0
    for i in range(1, len(morf_curve)):
        area += 0.5 * ((reference - morf_curve[i]) + (reference - morf_curve[i-1]))
    return area / len(morf_curve)

def normalized_abpc(morf_curve, lerf_curve):
    abpc = compute_area_between_curves(morf_curve, lerf_curve)
    range = max(lerf_curve) - min(morf_curve)
    return abpc / range

def normalized_aopc(morf_curve):
    aopc = compute_aopc(morf_curve)
    range = max(morf_curve) - min(morf_curve)
    return aopc / range


In [5]:
# patch_abpc_area_v3.py
# Recompute ABPC_area, AOPC, norm_ABPC, norm_AOPC from MoRF/LeRF only.
# Uses user-provided functions:
#   - compute_area_between_curves(morf_curve, lerf_curve)
#   - compute_aopc(morf_curve)
#   - normalized_abpc(morf_curve, lerf_curve)
#   - normalized_aopc(morf_curve)
#
# Default behavior: DRY_RUN=True, update SUMMARY ONLY (silent, no step issues).

import os
from typing import List, Optional, Dict
import numpy as np
import pandas as pd
import wandb

# ===== CONFIG =====
ENTITY = "giuliosichili"
PROJECT = "automi"
GROUP   = "ABPC"              # can live in run.group or config.group or tags

DRY_RUN = False                # <<<<<<<<<<<<<<  SAFE by default
VERBOSE = True               # minimal printouts

WRITE_HISTORY = False         # if True, try to log to history at step_ABPC
PAGE_SIZE = 512               # scan_history page size

# Case-insensitive metric fallbacks (kept minimal)
MORF_KEYS = ("MoRF", "morf")
LERF_KEYS = ("LeRF", "lerf")

# If you want to preserve previous summary values:
PRESERVE_PREV_SUMMARY = True
# ===================

def _fmt(x: Optional[float]) -> str:
    if x is None:
        return "None"
    try:
        xf = float(x)
        if np.isnan(xf):
            return "nan"
        return f"{xf:.8f}"
    except Exception:
        return str(x)

def fetch_runs_with_group(api: wandb.Api, entity: str, project: str, group: str) -> List[wandb.apis.public.Run]:
    """Try run.group then config.group; fallback to fetch-all + local filter (incl. tags)."""
    path = f"{entity}/{project}"

    runs = list(api.runs(path, filters={"group": group}))
    if runs:
        if VERBOSE: print(f"[fetch] Found {len(runs)} via run.group == '{group}'")
        return runs

    runs = list(api.runs(path, filters={"config.group": {"$eq": group}}))
    if runs:
        if VERBOSE: print(f"[fetch] Found {len(runs)} via config.group == '{group}'")
        return runs

    if VERBOSE: print("[fetch] Fallback: fetch all, filter locally…")
    all_runs = list(api.runs(path))
    def in_group(run):
        return (
            getattr(run, "group", None) == group
            or (run.config or {}).get("group") == group
            or (group in (run.tags or []))
        )
    filtered = [r for r in all_runs if in_group(r)]
    if VERBOSE: print(f"[fetch] Local matched {len(filtered)} / {len(all_runs)}")
    return filtered

def collect_morf_lerf(run: wandb.apis.public.Run) -> Dict[str, np.ndarray]:
    """
    Stream history once; build MoRF/LeRF dicts keyed by step.
    Returns dict with:
      - 'morf': np.ndarray shape (N,)
      - 'lerf': np.ndarray shape (N,)
      - 'steps': np.ndarray sorted common steps shape (N,)
      - 'last_step_seen': int (max step across all logged rows)
    """
    morf_map: Dict[int, float] = {}
    lerf_map: Dict[int, float] = {}
    last_step_seen = -1

    for row in run.scan_history(page_size=PAGE_SIZE):
        step = row.get("_step", row.get("step"))
        if step is None:
            continue
        # track last seen step (for safe history logging if enabled)
        if isinstance(step, (int, float)):
            try:
                last_step_seen = max(last_step_seen, int(step))
            except Exception:
                pass

        # capture MoRF/LeRF values (case-insensitive minimal)
        mkey = next((k for k in MORF_KEYS if k in row and row[k] is not None), None)
        lkey = next((k for k in LERF_KEYS if k in row and row[k] is not None), None)
        if mkey is not None:
            try: morf_map[int(step)] = float(row[mkey])
            except Exception: pass
        if lkey is not None:
            try: lerf_map[int(step)] = float(row[lkey])
            except Exception: pass

    common_steps = sorted(set(morf_map.keys()) & set(lerf_map.keys()))
    if len(common_steps) < 1:
        return {"morf": np.array([]), "lerf": np.array([]), "steps": np.array([]), "last_step_seen": last_step_seen}

    morf = np.array([morf_map[s] for s in common_steps], dtype=float)
    lerf = np.array([lerf_map[s] for s in common_steps], dtype=float)
    return {"morf": morf, "lerf": lerf, "steps": np.array(common_steps, dtype=int), "last_step_seen": last_step_seen}

def compute_metrics_from_curves(morf: np.ndarray, lerf: np.ndarray) -> Dict[str, Optional[float]]:
    """
    Compute metrics using user-provided functions available in the notebook environment.
    - ABPC_area: compute_area_between_curves(morf, lerf) / N
    - AOPC: compute_aopc(morf)
    - norm_ABPC: normalized_abpc(morf, lerf)
    - norm_AOPC: normalized_aopc(morf)
    Missing functions are handled gracefully (metric -> None).
    """
    N = int(morf.size)
    out = {"ABPC_area": None, "AOPC": None, "norm_ABPC": None, "norm_AOPC": None}

    # ABPC_area (normalized by number of points)
    try:
        area = float(compute_area_between_curves(morf, lerf))  # provided by user
        out["ABPC_area"] = area
    except Exception:
        pass

    # AOPC (MoRF only)
    try:
        out["AOPC"] = float(compute_aopc(morf))                # provided by user
    except Exception:
        pass

    # normalized ABPC
    try:
        out["norm_ABPC"] = float(normalized_abpc(morf, lerf))  # provided by user
    except Exception:
        pass

    # normalized AOPC
    try:
        out["norm_AOPC"] = float(normalized_aopc(morf))        # provided by user
    except Exception:
        pass

    return out

def main():
    api = wandb.Api()
    runs = fetch_runs_with_group(api, ENTITY, PROJECT, GROUP)

    patched = 0
    skipped = 0

    for run in runs:
        try:
            data = collect_morf_lerf(run)
            morf, lerf, steps = data["morf"], data["lerf"], data["steps"]
            last_step_seen = data["last_step_seen"]

            if morf.size < 1 or lerf.size < 1:
                if VERBOSE: print(f"- {run.id} SKIP: no common MoRF/LeRF steps")
                skipped += 1
                continue

            step_abpc = int(len(steps))  # per your rule: ABPC step = number of points
            metrics = compute_metrics_from_curves(morf, lerf)

            if VERBOSE:
                print(
                    f"- {run.id} N={len(steps)} step_ABPC={step_abpc} "
                    f"| ABPC={_fmt(metrics['ABPC_area'])} "
                    f"| AOPC={_fmt(metrics['AOPC'])} "
                    f"| norm_ABPC={_fmt(metrics['norm_ABPC'])} "
                    f"| norm_AOPC={_fmt(metrics['norm_AOPC'])}"
                )

            if DRY_RUN:
                patched += 1
                continue

            # --------- WRITE (silent) ----------
            # Prefer summary updates (safe & monotonic-step-proof)
            prev_abpc = (run.summary or {}).get("ABPC_area")
            prev_aopc = (run.summary or {}).get("AOPC")
            prev_nabp = (run.summary or {}).get("norm_ABPC")
            prev_naop = (run.summary or {}).get("norm_AOPC")

            # Optionally preserve previous values
            if PRESERVE_PREV_SUMMARY:
                updates = {
                    "ABPC_area_prev": prev_abpc,
                    "AOPC_prev": prev_aopc,
                    "norm_ABPC_prev": prev_nabp,
                    "norm_AOPC_prev": prev_naop,
                    "ABPC_step": step_abpc,
                }
            else:
                updates = {"ABPC_step": step_abpc}

            updates.update({k: v for k, v in metrics.items() if v is not None})

            # Summary-only patch (no resume/init needed)
            run.summary.update(updates, overwrite=True)

            # Optional: also append to history at the correct step, but only if step increases
            if WRITE_HISTORY and step_abpc > last_step_seen:
                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = run.id
                session = wandb.init(entity=run.entity, project=run.project, id=run.id, resume="allow")
                wandb.log(updates, step=step_abpc)
                session.finish()

            patched += 1

        except Exception as e:
            if VERBOSE:
                print(f"- {run.id} ERROR: {e}")
            skipped += 1

    print(f"Done. Candidate runs: {len(runs)} | Patched (or would patch in DRY_RUN): {patched} | Skipped: {skipped} | DRY_RUN={DRY_RUN}")

if __name__ == "__main__":
    main()

[fetch] Found 96 via config.group == 'ABPC'
- 16iodovw N=10 step_ABPC=10 | ABPC=0.00199486 | AOPC=0.00220119 | norm_ABPC=0.67137021 | norm_AOPC=0.76069163
- xs5huj3i N=10 step_ABPC=10 | ABPC=-0.00004985 | AOPC=-0.00012146 | norm_ABPC=-0.20336164 | norm_AOPC=-0.49549179
- ljaplvnu N=10 step_ABPC=10 | ABPC=0.11041753 | AOPC=0.13775711 | norm_ABPC=0.60037063 | norm_AOPC=0.74902345
- 7tgx01wa N=10 step_ABPC=10 | ABPC=0.00205025 | AOPC=0.00232290 | norm_ABPC=0.66252818 | norm_AOPC=0.76335680
- mai6yw32 N=10 step_ABPC=10 | ABPC=0.00079952 | AOPC=0.00062514 | norm_ABPC=0.68075597 | norm_AOPC=0.77197001
- 62hhyfhu N=10 step_ABPC=10 | ABPC=0.00016520 | AOPC=-0.00010564 | norm_ABPC=0.33646310 | norm_AOPC=-0.41138230
- 990p57nj N=10 step_ABPC=10 | ABPC=0.11649967 | AOPC=0.24538905 | norm_ABPC=0.35908484 | norm_AOPC=0.75635829
- n5lg3g0d N=10 step_ABPC=10 | ABPC=0.00063590 | AOPC=0.00073101 | norm_ABPC=0.56716979 | norm_AOPC=0.75956969
- sd3uk39v N=10 step_ABPC=10 | ABPC=0.00088221 | AOPC=0.000907