In [3]:
def compute_area_between_curves(morf_curve, lerf_curve):
    """
    Compute the area between the LeRF and MoRF curves using the trapezoidal rule. In the literature, we divide
    the area by the number of steps to normalize it.
    """
    # Ensure the curves are of the same length
    assert len(morf_curve) == len(lerf_curve), "Curves must be of the same length"

    # Compute the area using the trapezoidal rule
    area = 0.0
    for i in range(1, len(morf_curve)):
        area += 0.5 * ((lerf_curve[i] - morf_curve[i]) + (lerf_curve[i-1] - morf_curve[i-1]))
    return area / len(morf_curve)


def compute_aopc(morf_curve):
    # Compute the Area Over the Perturbation Curve (AOPC)
    reference = morf_curve[0]

    area = 0.0
    for i in range(1, len(morf_curve)):
        area += 0.5 * ((reference - morf_curve[i]) + (reference - morf_curve[i-1]))
    return area / len(morf_curve)

def normalized_abpc(morf_curve, lerf_curve):
    abpc = compute_area_between_curves(morf_curve, lerf_curve)
    range = max(lerf_curve) - min(morf_curve)
    return abpc / range

def normalized_aopc(morf_curve):
    aopc = compute_aopc(morf_curve)
    range = max(morf_curve) - min(morf_curve)
    return aopc / range


In [10]:
# patch_switch_signed_curves.py
# Append swapped/signed curves as `signed_MoRF` / `signed_LeRF` at the same supervoxels_perturbed x,
# and update summary metrics accordingly. Verbose and dry-run by default.

import os
from typing import Dict, List, Optional, Tuple
import numpy as np
import wandb

# ===== USER CONFIG =====
ENTITY = "giuliosichili"
PROJECT = "automi"
GROUP   = "ABPC-volumes"

TARGET_SV_TYPE = "FCC-organs"
TARGET_AGG     = "false_positive_aggregation"
TARGET_VOLUME_CODES = ["00044"]

X_KEY     = "supervoxels_perturbed"   # custom x-axis used during logging
PAGE_SIZE = 512

DRY_RUN   = False   # << SAFE DEFAULT: ONLY PRINT, NO WRITES >>
VERBOSE   = True   # << PRINT per-run details >>
# New: only recompute scalar metrics, do NOT append new signed curves
RECOMPUTE_METRICS_ONLY = True
# =======================

# Externally provided functions must exist in your environment:
#   compute_area_between_curves(morf_curve, lerf_curve)
#   compute_aopc(morf_curve)
#   normalized_abpc(morf_curve, lerf_curve)
#   normalized_aopc(morf_curve)

def _fmt(x: Optional[float], fixed4: bool = False) -> str:
    if x is None:
        return "None"
    try:
        xf = float(x)
        if np.isnan(xf):
            return "nan"
        return f"{xf:.4f}" if fixed4 else f"{xf:.8f}"
    except Exception:
        return str(x)

def fetch_target_runs(api: wandb.Api) -> List[wandb.apis.public.Run]:
    """Find FCC-organs + false_positive_aggregation runs inside GROUP."""
    path = f"{ENTITY}/{PROJECT}"
    filt = {
        "config.group": {"$eq": GROUP},
        "config.supervoxel_type": {"$eq": TARGET_SV_TYPE},
        "config.aggregation_function": {"$eq": TARGET_AGG},
        "config.volume_code": {"$in": TARGET_VOLUME_CODES},
    }
    return list(api.runs(path, filters=filt))

def collect_curves_with_x_and_last_step(run: wandb.apis.public.Run) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
    """
    Build MoRF/LeRF arrays aligned on the custom x-axis (supervoxels_perturbed).
    Returns (morf, lerf, x_sorted, last_step_seen).
    """
    morf_map: Dict[int, float] = {}
    lerf_map: Dict[int, float] = {}
    last_step_seen = -1

    for row in run.scan_history(page_size=PAGE_SIZE):
        step = row.get("_step", row.get("step"))
        if step is None:
            continue
        try:
            step = int(step)
        except Exception:
            continue
        if step > last_step_seen:
            last_step_seen = step

        # x-axis (prefer supervoxels_perturbed; fallback to step)
        x = row.get(X_KEY, step)
        try:
            x = int(x)
        except Exception:
            try:
                x = int(float(x))
            except Exception:
                continue

        m = row.get("MoRF", None)
        l = row.get("LeRF", None)
        if m is not None:
            try:
                morf_map[x] = float(m)
            except Exception:
                pass
        if l is not None:
            try:
                lerf_map[x] = float(l)
            except Exception:
                pass

    common_x = sorted(set(morf_map.keys()) & set(lerf_map.keys()))
    if not common_x:
        return np.array([]), np.array([]), np.array([]), last_step_seen

    morf = np.array([morf_map[i] for i in common_x], dtype=float)
    lerf = np.array([lerf_map[i] for i in common_x], dtype=float)
    xs   = np.array(common_x, dtype=int)
    return morf, lerf, xs, last_step_seen

def collect_signed_curves_if_present(run: wandb.apis.public.Run) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """If signed_MoRF / signed_LeRF already logged, return them aligned on X_KEY. Else empty arrays."""
    morf_map: Dict[int, float] = {}
    lerf_map: Dict[int, float] = {}
    for row in run.scan_history(page_size=PAGE_SIZE):
        x = row.get(X_KEY, row.get("_step", row.get("step")))
        try:
            x = int(x)
        except Exception:
            try:
                x = int(float(x))
            except Exception:
                continue
        sm = row.get("signed_MoRF")
        sl = row.get("signed_LeRF")
        if sm is not None:
            try:
                morf_map[x] = float(sm)
            except Exception:
                pass
        if sl is not None:
            try:
                lerf_map[x] = float(sl)
            except Exception:
                pass
    common_x = sorted(set(morf_map.keys()) & set(lerf_map.keys()))
    if not common_x:
        return np.array([]), np.array([]), np.array([])
    morf = np.array([morf_map[i] for i in common_x], dtype=float)
    lerf = np.array([lerf_map[i] for i in common_x], dtype=float)
    xs   = np.array(common_x, dtype=int)
    return morf, lerf, xs

def recompute_from_corrected(morf_fixed: np.ndarray, lerf_fixed: np.ndarray) -> Dict[str, float]:
    """Compute corrected metrics from corrected curves using user's functions.
    Adds guards in case normalization denominators become zero.
    """
    out: Dict[str, float] = {}
    out["ABPC_area"] = float(compute_area_between_curves(morf_fixed, lerf_fixed))
    out["AOPC"]      = float(compute_aopc(morf_fixed))
    try:
        out["norm_ABPC"] = float(normalized_abpc(morf_fixed, lerf_fixed))
    except ZeroDivisionError:
        out["norm_ABPC"] = float('nan')
    try:
        out["norm_AOPC"] = float(normalized_aopc(morf_fixed))
    except ZeroDivisionError:
        out["norm_AOPC"] = float('nan')
    return out

def old_summary_metrics(run: wandb.apis.public.Run) -> Dict[str, Optional[float]]:
    s = run.summary or {}
    def g(k):
        v = s.get(k)
        try:
            return float(v) if v is not None else None
        except Exception:
            return None
    return {
        "ABPC_area": g("ABPC_area"),
        "AOPC": g("AOPC"),
        "norm_ABPC": g("norm_ABPC"),
        "norm_AOPC": g("norm_AOPC"),
    }

def main():
    api = wandb.Api()
    runs = fetch_target_runs(api)
    print(f"[target] group='{GROUP}', sv_type='{TARGET_SV_TYPE}', agg='{TARGET_AGG}': {len(runs)} runs")

    patched, skipped = 0, 0
    for run in runs:
        try:
            xs = None  # ensure defined
            last_step = -1
            # Prefer already signed curves if present when metrics-only
            signed_morf, signed_lerf, signed_x = collect_signed_curves_if_present(run) if RECOMPUTE_METRICS_ONLY else (np.array([]), np.array([]), np.array([]))
            if RECOMPUTE_METRICS_ONLY and signed_morf.size > 1 and signed_lerf.size > 1:
                source = "signed_existing"
                morf_fixed, lerf_fixed = signed_morf, signed_lerf
                xs = signed_x
            else:
                morf, lerf, xs_local, last_step_local = collect_curves_with_x_and_last_step(run)
                xs = xs_local
                last_step = last_step_local
                if morf.size == 0 or lerf.size == 0:
                    if VERBOSE:
                        print(f"- {run.id} SKIP: no aligned MoRF/LeRF on '{X_KEY}'")
                    skipped += 1
                    continue
                # swap to get corrected
                morf_fixed, lerf_fixed = lerf.copy(), morf.copy()
                source = "swapped_from_raw"

            new_metrics = recompute_from_corrected(morf_fixed, lerf_fixed)
            old_metrics = old_summary_metrics(run)

            if VERBOSE:
                print(f"\nRun {run.id} | points={len(morf_fixed)} | source={source}")
                print("  ABPC_area : old =", _fmt(old_metrics["ABPC_area"], fixed4=False), " -> new =", _fmt(new_metrics["ABPC_area"], fixed4=False))
                print("  AOPC      : old =", _fmt(old_metrics["AOPC"],      fixed4=False), " -> new =", _fmt(new_metrics["AOPC"],      fixed4=False))
                print("  norm_ABPC : old =", _fmt(old_metrics["norm_ABPC"], fixed4=True),  " -> new =", _fmt(new_metrics["norm_ABPC"], fixed4=True))
                print("  norm_AOPC : old =", _fmt(old_metrics["norm_AOPC"], fixed4=True),  " -> new =", _fmt(new_metrics["norm_AOPC"], fixed4=True))

            if DRY_RUN:
                patched += 1
                continue

            os.environ["WANDB_RESUME"] = "allow"
            os.environ["WANDB_RUN_ID"] = run.id
            session = wandb.init(entity=run.entity, project=run.project, id=run.id, resume="allow")

            if not RECOMPUTE_METRICS_ONLY:
                # we must have last_step; if not, collect again
                if last_step < 0:
                    _, _, _, last_step = collect_curves_with_x_and_last_step(run)
                start_step = int(last_step) + 1
                for i in range(len(morf_fixed)):
                    payload = {
                        "signed_MoRF": float(morf_fixed[i]),
                        "signed_LeRF": float(lerf_fixed[i]),
                        X_KEY: int(xs[i]) if xs is not None and i < len(xs) else i,
                    }
                    wandb.log(payload, step=start_step + i)

            # Update summary metrics
            session.summary.update({
                "ABPC_area_prev": old_metrics["ABPC_area"],
                "AOPC_prev": old_metrics["AOPC"],
                "norm_ABPC_prev": old_metrics["norm_ABPC"],
                "norm_AOPC_prev": old_metrics["norm_AOPC"],
                **new_metrics,
                "signed_curves": True,
                "signed_metrics_refreshed": RECOMPUTE_METRICS_ONLY,
                "signed_metrics_source": source,
            })

            session.finish()
            patched += 1

        except Exception as e:
            if VERBOSE:
                print(f"- {run.id} ERROR: {e}")
            skipped += 1

    print(f"\nDone. Candidate runs: {len(runs)} | Updated (or would update in DRY_RUN): {patched} | Skipped: {skipped} | DRY_RUN={DRY_RUN} | METRICS_ONLY={RECOMPUTE_METRICS_ONLY} | VERBOSE={VERBOSE}")

if __name__ == "__main__":
    main()

[target] group='ABPC-volumes', sv_type='FCC-organs', agg='false_positive_aggregation': 1 runs

Run o01qjx5z | points=400 | source=signed_existing
  ABPC_area : old = -0.00016442  -> new = 0.00016442
  AOPC      : old = 0.00007131  -> new = 0.00023573
  norm_ABPC : old = -1.4364  -> new = 0.5849
  norm_AOPC : old = 0.6230  -> new = 0.8385

Run o01qjx5z | points=400 | source=signed_existing
  ABPC_area : old = -0.00016442  -> new = 0.00016442
  AOPC      : old = 0.00007131  -> new = 0.00023573
  norm_ABPC : old = -1.4364  -> new = 0.5849
  norm_AOPC : old = 0.6230  -> new = 0.8385


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
ABPC_area,0.00016
ABPC_area_prev,-0.00016
AOPC,0.00024
AOPC_prev,7e-05
LeRF,-0.0001
LeRF_cache_hit_ratio_percent,21
LeRF_inference_time_sec,4.324
LeRF_volume_removed_mm3,5206010.77755
LeRF_volume_removed_pct,2.25082
LeRF_volume_removed_voxels,758939



Done. Candidate runs: 1 | Updated (or would update in DRY_RUN): 1 | Skipped: 0 | DRY_RUN=False | METRICS_ONLY=True | VERBOSE=True


In [9]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
ABPC_area,-0.00016
AOPC,7e-05
LeRF,-0.0001
LeRF_cache_hit_ratio_percent,21.0
LeRF_inference_time_sec,4.324
LeRF_volume_removed_mm3,5206010.77755
LeRF_volume_removed_pct,2.25082
LeRF_volume_removed_voxels,758939.0
MoRF,-0.0001
MoRF_cache_hit_ratio_percent,21.0
