In [None]:
# =========================================
# Updated Mutatation Modules
# =========================================

def mutate_test_set(
    X_test: np.ndarray,
    partitions=None,
    mutation_rate: float = 0.10,
    seed: int = 0,
    a4_noise_std: float = 0.02,
    ensure_all_types: bool = True,
    *,
    data_type: str = "tabular",     # {"tabular","spatial","temporal"}
    k_ctx: int = 5,                 # used for spatial A3 (kNN)
    n_clusters: int | None = None,  # used for tabular A3 (KMeans); default ~sqrt(N)
    window: int | None = None,      # used for temporal A3; default ~10% of N
    extreme_mult: float = 0.75      # how far beyond observed range to push A1/A2
) -> tuple[np.ndarray, dict[int, object]]:
    """
    Inject **synthetic anomalies** into X_test for mutation-based evaluation.

    We create a mutated copy `X_mut` and a metadata map `metas` that records the ground-truth
    *explanation target* for each mutated row:

        • **A1 (feature-level corruption)**: pick one feature f and push its value **beyond** the
          observed range (to hi + m·span or lo − m·span).  Meta → (f, f).
        • **A2 (cross-feature conflict)**: pick two features f1≠f2 and push them to **opposite**
          extremes beyond the range to maximize inconsistency.  Meta → (f1, f2).
        • **A3 (contextual inconsistency)**:
              – data_type="tabular": KMeans clusters; replace row i with a donor from the
                **farthest cluster** (cluster whose centroid is farthest from i), maximizing
                context shift.
              – data_type="spatial": kNN; replace with the **farthest** point (top 5% distances),
                not in the nearest-neighbor set.
              – data_type="temporal": windowing; replace with a row **window** steps away
                (default ≈10% of N), i.e., donor = (i + window) mod N.
            Meta → "context".
        • **A4 (benign noise)**: add small Gaussian noise per feature (no anomaly expected).
          Meta → None.

    The function balances A1–A4 across the selected rows (when feasible) and never samples more
    rows than available.  For small N/D we fall back gracefully (e.g., A2→A1 if D<2).

    Args
    ----
    X_test : (N, D) ndarray
        Test matrix to mutate (rows are records).
    partitions : ignored (kept for API compatibility).
    mutation_rate : float
        Fraction of rows to mutate (capped by N). If `ensure_all_types` and N≥4, we ensure each
        type appears at least once.
    seed : int
        RNG seed.
    a4_noise_std : float
        Std multiplier for A4 benign noise (small so A4 remains mostly non-anomalous).
    ensure_all_types : bool
        If True and N≥4, distribute mutations across A1..A4.
    data_type : {"tabular","spatial","temporal"}
        Strategy for **A3** (context) mutations.
    k_ctx : int
        #neighbors for spatial kNN.
    n_clusters : int or None
        #clusters for KMeans in tabular A3. Default ≈ sqrt(N).
    window : int or None
        Temporal offset for A3 windowing. Default ≈ 10% of N (at least 5).
    extreme_mult : float
        How far beyond the observed [lo, hi] we push A1/A2. New value = hi + m·(hi−lo) or
        lo − m·(hi−lo). Use 0.5–1.0 to make anomalies stronger.

    Returns
    -------
    X_mut : (N, D) ndarray
        Mutated copy of X_test.
    metas : dict[int, object]
        Mapping row index → expected explanation label:
          A1→(f,f), A2→(f1,f2), A3→"context", A4→None.
    """
    rng = np.random.default_rng(seed)
    X_mut = np.array(X_test, copy=True)
    N, D  = X_mut.shape

    if N == 0:
        return X_mut, {}

    # ----- how many rows to mutate (cap by N) -----
    requested = int(round(mutation_rate * N))
    if ensure_all_types and N >= 4:
        requested = max(requested, 4)     # ensure each type appears
    requested = max(1, min(requested, N)) # never exceed N

    # sample indices WITHOUT replacement
    sel = rng.choice(N, size=requested, replace=False)

    # ----- split selected rows across A1..A4 (balanced, with remainder) -----
    if ensure_all_types and N >= 4:
        base, rem = divmod(requested, 4)
        counts = [base + (i < rem) for i in range(4)]
        type_order = (["A1"] * counts[0] +
                      ["A2"] * counts[1] +
                      ["A3"] * counts[2] +
                      ["A4"] * counts[3])
    else:
        all_types = ["A1", "A2", "A3", "A4"]
        type_order = [all_types[i % 4] for i in range(requested)]

    # ---- helpers ----
    def _lo_hi_span(col: np.ndarray) -> tuple[float, float, float]:
        lo, hi = np.nanpercentile(col, [0.1, 99.9])  # tighter tails → stronger extremes
        if not np.isfinite(lo) or not np.isfinite(hi):
            lo, hi = float(np.nanmin(col)), float(np.nanmax(col))
        if lo == hi:
            hi = lo + 1e-6
        return lo, hi, (hi - lo)

    metas: dict[int, object] = {}

    # ---------- Precompute structures for A3 ----------
    mode = (data_type or "tabular").lower()
    if mode == "static":
        mode = "tabular"

    # TABULAR (KMeans across rows → far cluster donor)
    if mode == "tabular" and N >= 2:
        try:
            from sklearn.cluster import KMeans
            C = n_clusters if n_clusters is not None else int(max(2, min(N, round(np.sqrt(N)))))
            try:
                km = KMeans(n_clusters=C, random_state=seed, n_init="auto")
            except TypeError:
                km = KMeans(n_clusters=C, random_state=seed, n_init=10)
            labels = km.fit_predict(X_test)
            centroids = km.cluster_centers_   # (C,D)
            # members per cluster
            from collections import defaultdict
            cluster_members = defaultdict(list)
            for i, lab in enumerate(labels):
                cluster_members[int(lab)].append(i)
        except Exception:
            labels, centroids, cluster_members = None, None, None

    # SPATIAL (distance matrix / farthest donor)
    if mode == "spatial" and N >= 2:
        # full pairwise Euclidean distances (ok for typical N; for very large N replace with ANN)
        diffs = X_test[:, None, :] - X_test[None, :, :]
        distM = np.sqrt(np.sum(diffs * diffs, axis=2))  # (N,N)

    # TEMPORAL (window offset)
    if mode == "temporal":
        if window is None:
            window = max(5, int(round(0.10 * N)))  # ~10% of N
        window = int(max(1, min(window, max(1, N-1))))

    # ---------- apply mutations ----------
    for idx, mtype in zip(sel, type_order):
        idx = int(idx)

        if mtype == "A1":
            # push ONE feature beyond observed range by extreme_mult * span
            f = int(rng.integers(0, max(1, D)))
            lo, hi, span = _lo_hi_span(X_test[:, f])
            med = np.nanmedian(X_test[:, f])
            if X_test[idx, f] < med:
                new_val = hi + extreme_mult * span
            else:
                new_val = lo - extreme_mult * span
            jitter = float(rng.normal(0.0, 0.02 * span))
            X_mut[idx, f] = float(new_val + jitter)
            metas[idx] = (f, f)

        elif mtype == "A2":
            # pick two features; prefer different partitions if provided
            if D >= 2:
                if partitions and len(partitions) >= 2 and sum(len(p) for p in partitions) == D:
                    # sample two distinct partitions then a feature from each
                    p1, p2 = rng.choice(len(partitions), size=2, replace=False)
                    f1 = int(rng.choice(partitions[p1]))
                    f2 = int(rng.choice(partitions[p2]))
                else:
                    f1, f2 = rng.choice(D, size=2, replace=False)
                lo1, hi1, s1 = _lo_hi_span(X_test[:, f1])
                lo2, hi2, s2 = _lo_hi_span(X_test[:, f2])
                # push to opposite sides, beyond range
                X_mut[idx, f1] = hi1 + extreme_mult * s1
                X_mut[idx, f2] = lo2 - extreme_mult * s2
                metas[idx] = (int(f1), int(f2))
            else:
                # fallback to A1
                f = 0
                lo, hi, span = _lo_hi_span(X_test[:, f])
                X_mut[idx, f] = hi + extreme_mult * span
                metas[idx] = (f, f)

        elif mtype == "A3":
            if N < 2:
                # fallback to A4
                noise = rng.normal(0.0, a4_noise_std, size=D).astype(float)
                X_mut[idx] = X_mut[idx] + noise
                metas[idx] = None
                continue

            if mode == "tabular" and labels is not None and centroids is not None:
                lab_i = int(labels[idx])
                # pick the farthest centroid from x_i
                d2c = np.linalg.norm(centroids - X_test[idx], axis=1)
                far_lab = int(np.argmax(d2c))
                # choose donor from far_lab with largest distance to x_i
                cand = cluster_members[far_lab]
                if len(cand) == 0:
                    cand = [j for j in range(N) if j != idx]
                dists = np.linalg.norm(X_test[cand] - X_test[idx], axis=1)
                donor = int(cand[int(np.argmax(dists))])

            elif mode == "spatial":
                # farthest (top 5%) non-neighbor
                drow = distM[idx].copy()
                drow[idx] = -np.inf
                # pick from top-5% farthest
                k = max(1, int(round(0.05 * N)))
                farset = np.argsort(drow)[-k:]
                donor = int(rng.choice(farset))

            elif mode == "temporal":
                donor = int((idx + window) % N)

            else:
                # generic fallback: farthest by Euclidean distance
                drow = np.linalg.norm(X_test - X_test[idx], axis=1)
                drow[idx] = -np.inf
                donor = int(np.argmax(drow))

            X_mut[idx] = X_test[donor]
            metas[idx] = "context"

        elif mtype == "A4":
            # small benign noise (kept small so A4 remains mostly non-anomalous)
            noise = rng.normal(0.0, a4_noise_std, size=D).astype(float)
            X_mut[idx] = X_mut[idx] + noise
            metas[idx] = None

        else:
            # safety fallback → benign noise
            noise = rng.normal(0.0, a4_noise_std, size=D).astype(float)
            X_mut[idx] = X_mut[idx] + noise
            metas[idx] = None

    return X_mut, metas



# ==================== Evaluate Explanations on Mutations ====================
def evaluate_explanations_on_mutations(
    explanations: Dict[int, Dict],
    metas: Dict[int, Optional[Tuple[int,int]]],
    partitions: List[List[int]],
    topM: int = 3,                 # how many top disagreement pairs to consider
    use_violated_deps: bool = True,# also accept any pair in ex["violated_dependencies"]
    y_true: Optional[np.ndarray] = None,  # optional: if provided, p90 is computed on clean rows
    ref_max_dis_clean: Optional[np.ndarray] = None,
) -> Dict[str, float]:
    """
    Mutation-based explanation evaluation (A1–A4), improved:

    • A1 (feature corruption): success if ANY of the top-M disagreement pairs
      includes the mutated partition k, OR if any violated dependency involves k.
    • A2 (cross-feature inconsistency): success if the unordered expected pair {k1,k2}
      is in the top-M disagreement pairs OR in violated dependencies.
    • A3 (context): success if ex['context_flag'] is True (or similarity < 0.9 fallback).
    • A4 (noise): success if no violated deps AND max_disagreement <= p90 of clean refs.

    Returns per-type accuracy and overall.
    """
    import numpy as np

    def _top_pairs_from_D(D: np.ndarray, M: int = 3) -> set:
        """Return set of up to M pairs (k,l) with largest D (k<l)."""
        if D is None:
            return set()
        K = D.shape[0]
        tri = np.triu(D, 1)
        # Handle pathological all-NaN/const
        if not np.isfinite(tri).any():
            return set()
        order = np.argsort(tri, axis=None)[::-1]  # descending by disagreement
        pairs = set()
        for idx in order:
            k, l = np.unravel_index(idx, (K, K))
            if k < l:
                pairs.add((int(k), int(l)))
                if len(pairs) >= M:
                    break
        return pairs

    f2p = _feature_to_partition_map(partitions)
    total = {"A1":0, "A2":0, "A3":0, "A4":0}
    correct = {"A1":0, "A2":0, "A3":0, "A4":0}

    # --- Reference p90 for A4 (prefer clean rows if y_true provided) ---
    if y_true is not None:
        clean_idxs = [i for i in explanations.keys() if y_true[i] == 0]
        ref_max_dis = [explanations[i]["max_disagreement"] for i in clean_idxs]
    else:
        # fallback: approximate clean as those with no dep violations & no context flag
        ref_max_dis = [
            ex["max_disagreement"]
            for ex in explanations.values()
            if not ex.get("violated_dependencies", []) and not ex.get("context_flag", False)
        ]


    # --- Reference p90 for A4, prefer truly clean distribution ---
    if ref_max_dis_clean is not None and len(ref_max_dis_clean) >= 20:
        p90 = float(np.quantile(ref_max_dis_clean, 0.90))
    elif y_true is not None:
        clean_idxs = [i for i in explanations.keys() if y_true[i] == 0]
        ref_max_dis = [explanations[i]["max_disagreement"] for i in clean_idxs]
        p90 = float(np.quantile(ref_max_dis, 0.90)) if len(ref_max_dis) >= 20 else 0.2
    else:
        ref_max_dis = [
            ex["max_disagreement"]
            for ex in explanations.values()
            if not ex.get("violated_dependencies", []) and not ex.get("context_flag", False)
        ]
        p90 = float(np.quantile(ref_max_dis, 0.90)) if len(ref_max_dis) >= 20 else 0.2

    for idx, expected in metas.items():
        ex = explanations.get(idx)
        if ex is None:
            continue

        # Build candidate pairs: top-M by D plus (optionally) violated deps
        D = ex.get("disagreement_matrix", None)
        cand_pairs = _top_pairs_from_D(D, M=topM)
        if use_violated_deps:
            viol = [tuple(sorted(p)) for p in ex.get("violated_dependencies", [])]
            cand_pairs |= set(viol)

        if expected == "context":
            total["A3"] += 1
            ok = bool(ex.get("context_flag", False))
            if not ok and ex.get("context_similarity") is not None:
                ok = float(ex["context_similarity"]) < 0.9
            correct["A3"] += int(ok)

        elif expected is None:
            total["A4"] += 1
            ok = (len(ex.get("violated_dependencies", [])) == 0) and (ex["max_disagreement"] <= p90)
            correct["A4"] += int(ok)

        else:
            # expected is (f1, f2)
            f1, f2 = expected
            k1, k2 = f2p[int(f1)], f2p[int(f2)]
            exp_pair = tuple(sorted((k1, k2)))

            if k1 == k2:
                # A1: success if any candidate pair mentions k1
                total["A1"] += 1
                ok = any((k1 == a or k1 == b) for (a, b) in cand_pairs)
                correct["A1"] += int(ok)
            else:
                # A2: success if expected pair is among candidates
                total["A2"] += 1
                ok = (exp_pair in cand_pairs)
                correct["A2"] += int(ok)

    # Accuracies
    accs = {}
    for t in ["A1","A2","A3","A4"]:
        accs[t] = (correct[t] / total[t]) if total[t] > 0 else np.nan
    tot = sum(total.values())
    cor = sum(correct.values())
    accs["overall"] = (cor / tot) if tot > 0 else np.nan
    return accs