In [None]:
# patch_abpc_area.py
# Recompute ABPC_area for runs in group=="ABPC" and re-log at step=10.

import os
from typing import List
import numpy as np
import pandas as pd
import wandb


# ==== CONFIG ====
ENTITY = "giuliosichili"
PROJECT = "automi"
GROUP   = "ABPC"      # target label (may live in run.group OR config.group OR tags)
STEP_FOR_AREA = 10
TAG_AFTER_FIX = "fixed-metric"
DRY_RUN = True        # True = only print what would happen (no writes)

def fetch_runs_with_group(api: wandb.Api, entity: str, project: str, group: str) -> List[wandb.apis.public.Run]:
    """
    Try both the special W&B 'run.group' and the common 'config.group' convention.
    Falls back to fetching all runs and filtering locally (also checks tags).
    """
    path = f"{entity}/{project}"

    # (A) Try real run.group
    runs = api.runs(path, filters={"group": group})
    if runs:
        print(f"[fetch] Found {len(runs)} runs via server filter group=='{group}'")
        return list(runs)

    # (B) Try config.group
    runs = api.runs(path, filters={"config.group": {"$eq": group}})
    if runs:
        print(f"[fetch] Found {len(runs)} runs via server filter config.group=='{group}'")
        return list(runs)

    # (C) Fallback: fetch all, filter locally
    print("[fetch] Server filters returned 0; fetching all runs and filtering locally…")
    all_runs = list(api.runs(path))
    def in_group(run):
        return (
            getattr(run, "group", None) == group
            or (run.config or {}).get("group") == group
            or (group in (run.tags or []))
        )
    filtered = [r for r in all_runs if in_group(r)]
    print(f"[fetch] Local filter matched {len(filtered)} / {len(all_runs)} runs")
    return filtered

def main():
    api = wandb.Api()
    runs = fetch_runs_with_group(api, ENTITY, PROJECT, GROUP)

    print(f"\n== Sanity check of first 5 matched runs ==")
    for r in runs[:5]:
        print(f"  - id={r.id} name={r.name} | run.group={getattr(r,'group',None)} "
              f"| config.group={(r.config or {}).get('group')} | tags={r.tags}")

    print(f"\nProceeding on {len(runs)} runs…\n")
    fixed = 0
    skipped = 0

    for run in runs:
        try:
            # --- Pull MoRF/LeRF history ---
            df: pd.DataFrame = run.history(keys=["MoRF", "LeRF"], pandas=True)
            if df.empty:
                print(f"- {run.id}  SKIP: empty history")
                skipped += 1
                continue

            # Normalize step column name
            if "_step" not in df.columns:
                if "step" in df.columns:
                    df["_step"] = df["step"]
                else:
                    print(f"- {run.id}  SKIP: no _step/step column present")
                    skipped += 1
                    continue

            # Keep steps 0..9 with both values present
            sub = (
                df.loc[df["_step"].between(0, 9, inclusive="both"), ["_step", "MoRF", "LeRF"]]
                  .dropna()
                  .sort_values("_step")
            )

            # Sanity: we expect exactly steps [0..9]
            expected = np.arange(10)
            got = sub["_step"].to_numpy(dtype=int)
            if sub.shape[0] != 10 or not np.array_equal(got, expected):
                print(f"- {run.id}  SKIP: need steps 0..9 with no gaps; got {got.tolist()}")
                skipped += 1
                continue

            morf = sub["MoRF"].to_numpy(dtype=float)
            lerf = sub["LeRF"].to_numpy(dtype=float)

            # --- Use YOUR function (do not redefine here) ---
            corrected_area = float(compute_area_between_curves(morf, lerf))

            # Read old ABPC_area at step=10 (if any)
            old_hist = run.history(keys=["ABPC_area"], pandas=True)
            old_at_10 = None
            if not old_hist.empty:
                if "_step" not in old_hist.columns and "step" in old_hist.columns:
                    old_hist["_step"] = old_hist["step"]
                if "_step" in old_hist.columns:
                    row10 = old_hist.loc[old_hist["_step"] == STEP_FOR_AREA]
                    if "ABPC_area" in row10.columns:
                        vals = row10["ABPC_area"].dropna()
                        if not vals.empty:
                            old_at_10 = float(vals.values[-1])

            print(f"- {run.id}  corrected_area={corrected_area:.8f}  (old@10={old_at_10})")

            if DRY_RUN:
                continue

            # --- Re-log corrected value at the exact step ---
            os.environ["WANDB_RESUME"] = "allow"
            os.environ["WANDB_RUN_ID"] = run.id
            resumed = wandb.init(entity=run.entity, project=run.project, id=run.id, resume="allow")

            wandb.log({"ABPC_area": corrected_area}, step=STEP_FOR_AREA)

            # --- Patch summary for dashboards/filters + provenance ---
            if old_at_10 is not None:
                resumed.summary["ABPC_area_old"] = old_at_10
            resumed.summary["ABPC_area"] = corrected_area
            resumed.summary["ABPC_area_fixed"] = True

            try:
                resumed.tags = list(set((resumed.tags or []) + [TAG_AFTER_FIX]))
                resumed.notes = (resumed.notes or "") + "\nPatched ABPC_area via compute_area_between_curves()."
                resumed.update()
            except Exception:
                pass

            resumed.finish()
            fixed += 1

        except Exception as e:
            print(f"- {run.id}  ERROR: {e}")
            skipped += 1

    print(f"\nDone. Fixed: {fixed}, Skipped: {skipped}, Dry-run: {DRY_RUN}")

if __name__ == "__main__":
    main()