# Explainability-Aware Active Learning for Recommender Systems

**Authors and Contact Information:**

** Anonymous


# Adding Libraries and Configuration

In [None]:
# === Standard & third-party imports ==========================================
from __future__ import annotations

import io
import os
import glob
import zipfile
import shutil
import logging
import tempfile
from typing import Iterable, Tuple

import requests
import numpy as np
import pandas as pd
import numba
from sklearn.metrics import pairwise_distances
import tqdm


# === Logging setup ============================================================
# Log to both console and a file; concise, timestamped format.
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        logging.FileHandler("run_logs.log"),
        logging.StreamHandler(),
    ],
)
logger = logging.getLogger(__name__)

# Keep noisy Numba/CUDA internals quiet unless explicitly needed.
for name in (
    "numba",
    "numba.cuda",
    "numba.cuda.cudadrv",
    "numba.cuda.cudadrv.driver",
    "numba.cuda.cudadrv.memory",
):
    _lg = logging.getLogger(name)
    _lg.setLevel(logging.WARNING)
    _lg.propagate = False


# === CUDA detection ===========================================================
# Try importing Numba's CUDA; fall back to CPU-only if unavailable.
try:
    from numba import cuda
    USE_CUDA = cuda.is_available()
except Exception:
    USE_CUDA = False

logger.info(
    "[GPU] CUDA detected — GPU path will be used where available."
    if USE_CUDA else
    "[CPU] No CUDA detected — running CPU-only path."
)

# === Hardware info (robust) ===================================================
try:
    import platform
    np_ver = np.__version__
    nb_ver = numba.__version__

    if USE_CUDA:
        # Device identity
        dev = cuda.get_current_device()
        name = getattr(dev, "name", "Unknown GPU")
        cc   = getattr(dev, "compute_capability", (None, None))
        try:
            driver_ver  = ".".join(map(str, cuda.runtime.get_driver_version()))
        except Exception:
            driver_ver = "unknown"
        try:
            runtime_ver = ".".join(map(str, cuda.runtime.get_version()))
        except Exception:
            runtime_ver = "unknown"

        # Memory via context (portable across Numba versions)
        try:
            free_b, total_b = cuda.current_context().get_memory_info()
            total_gib = total_b / (1 << 30)
            free_gib  = free_b  / (1 << 30)
            mem_str   = f"VRAM total={total_gib:.2f} GiB, free={free_gib:.2f} GiB"
        except Exception:
            mem_str = "VRAM: (unavailable)"

        logger.info(
            "[GPU] Device=%s | CC=%s.%s | %s | Driver=%s | Runtime=%s | NumPy=%s | Numba=%s",
            name, *(cc if isinstance(cc, tuple) else (cc, "")), mem_str, driver_ver, runtime_ver, np_ver, nb_ver
        )
    else:
        # CPU identity (best-effort across OSes)
        cpu_name = (platform.processor()
                    or getattr(platform.uname(), "processor", "")
                    or getattr(platform.uname(), "machine", "")
                    or "unknown CPU")
        logger.info(
            "[CPU] Processor=%s | Cores=%s | OS=%s %s | Python=%s | NumPy=%s | Numba=%s",
            cpu_name, os.cpu_count(),
            platform.system(), platform.release(),
            platform.python_version(), np_ver, nb_ver
        )
except Exception as e:
    logger.warning("Could not query hardware info: %s", e)


# Global Hyperparameters for Explainable Active Learning (ExAL) 

In [None]:
# =========================
# Active Learning settings
# =========================
num_iter   = 10   # AL iterations
SWITCH     = 5    # EXAL Max→Min switch iteration (inclusive logic handled in code)

# =========================
# Optimization (SGD)
# =========================
ALPHA_INIT    = 0.01   # LR for initial EMF pretrain
ALPHA_RETRAIN = 0.001  # LR for per-iteration online updates

# =========================
# Explainability (λ, W)
# =========================
# Decouple training vs selection effects:
# - LAMBDA_TRAIN: applies W in EMF training
# - LAMBDA_SELECT: applies W in EXAL selection (Min/Max/Min–Max)
# This enables the 2×2: (train λ ∈ {0,>0}) × (select λ ∈ {0,>0})
LAMBDA_TRAIN  = 0.0  # λ used in EMF training (0 , 0.005  are good candidates  for study)
LAMBDA_SELECT = 0.5    # λ used in EXAL selection

# W construction
theta   = 0.0   # threshold on W_uj (0 → keep all)
NEIGHBOR= 20    # k-NN used to compute W

# =========================
# Training schedule
# =========================
INIT_STEPS  = 1000  # EMF pretrain epochs before AL
ONLINE_STEP = 5     # SGD steps/user per AL iteration
BETA        = 0.15  # L2 weight (β)

# =========================
# Model size
# =========================
K = 10  # latent dimension

# =========================
# Evaluation
# =========================
TopN = 10  # cutoff for MAP/NDCG/xP/xR

# =========================
# CUDA safety
# =========================
MAX_K = 128  # must be ≥ K (CUDA kernel local buffer)
if USE_CUDA and K > MAX_K:
    raise ValueError(f"K={K} exceeds MAX_K={MAX_K}; increase MAX_K or lower K.")


# Datasets

In [None]:
def _download_and_extract(
    url: str,
    dest_dir: str,
    expect_files: Iterable[str] = (),
    timeout: Tuple[float, float] = (15.0, 60.0),
) -> None:
    """
    Stream-download a ZIP file from `url`, verify HTTP status, extract safely
    into the current working directory (preserving the archive's structure),
    and assert that `dest_dir/expect_files` exist afterwards.
    """
    os.makedirs(dest_dir, exist_ok=True)

    with requests.get(url, stream=True, timeout=timeout) as resp:
        resp.raise_for_status()
        with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:
            for chunk in resp.iter_content(chunk_size=1 << 20):  # 1 MiB
                if chunk:
                    tmp.write(chunk)
            tmp_path = tmp.name

    try:
        with zipfile.ZipFile(tmp_path) as zf:
            # Safe extraction rooted at CWD, with Zip-Slip protection.
            cwd_abs = os.path.abspath(".")
            for member in zf.infolist():
                # Normalize and ensure the path stays under CWD
                out_path = os.path.abspath(os.path.join(cwd_abs, member.filename))
                if not out_path.startswith(cwd_abs + os.sep) and out_path != cwd_abs:
                    raise RuntimeError(f"Unsafe zip path: {member.filename}")
                if member.is_dir():
                    os.makedirs(out_path, exist_ok=True)
                else:
                    os.makedirs(os.path.dirname(out_path), exist_ok=True)
                    with zf.open(member) as src, open(out_path, "wb") as dst:
                        shutil.copyfileobj(src, dst)
    except zipfile.BadZipFile as e:
        raise RuntimeError(f"Corrupt zip from {url}") from e
    finally:
        try:
            os.remove(tmp_path)
        except OSError:
            pass

    # Verify the expected files are present under dest_dir
    for relpath in expect_files:
        full = os.path.join(dest_dir, relpath)
        if not os.path.exists(full):
            raise RuntimeError(f"Expected file not found after extract: {full}")



def load_movielens(dataset: str = "100k") -> Tuple[pd.DataFrame, pd.DataFrame | None]:
    """
    Load MovieLens (100k or 1M) into a dense user×item ratings DataFrame (zeros = unrated).

    Returns
    -------
    data_M : DataFrame [n_users × n_items]
        Ratings matrix with unrated cells filled with 0. dtype=float32 on GPU, float64 on CPU.
    movies : DataFrame | None
        For 100k: (movieID, movie name, genre) with one multi-genre string per row.
        For 1M: None (no item metadata packaged like 100k’s u.item).
    """
    DT = np.float32 if USE_CUDA else np.float64

    if dataset == "100k":
        if not os.path.exists("ml-100k"):
            _download_and_extract(
                url="https://files.grouplens.org/datasets/movielens/ml-100k.zip",
                dest_dir="ml-100k",
                expect_files=("u.data", "u.item"),
                timeout=(15, 60),
            )

        # Ratings (tab-separated): UserID, movieID, Rating, Timestamp
        data = pd.read_csv(
            "ml-100k/u.data",
            sep="\t",
            header=None,
            names=["UserID", "movieID", "Rating", "Timestamp"],
            dtype={"UserID": np.int32, "movieID": np.int32, "Rating": DT, "Timestamp": np.int64},
        )
        data_M = (
            data.pivot(index="UserID", columns="movieID", values="Rating")
                .fillna(0.0)
                .astype(DT, copy=False)
        )

        # Item metadata: build a concise genre string per movie
        genre_cols = [
            "unknown", "Action", "Adventure", "Animation", "Children", "Comedy",
            "Crime", "Documentary", "Drama", "Fantasy", "Film-Noir", "Horror",
            "Musical", "Mystery", "Romance", "Sci-Fi", "Thriller", "War", "Western",
        ]
        movies = pd.read_csv(
            "ml-100k/u.item",
            sep="|",
            header=None,
            encoding="ISO-8859-1",
            usecols=[0, 1, *range(5, 24)],
            names=["movieID", "movie name"] + genre_cols,
        )
        movies["genre"] = movies[genre_cols].dot(pd.Index(genre_cols) + ",").str.rstrip(",")
        movies = movies[["movieID", "movie name", "genre"]]

    elif dataset == "1m":
        if not os.path.exists("ml-1m"):
            _download_and_extract(
                url="https://files.grouplens.org/datasets/movielens/ml-1m.zip",
                dest_dir="ml-1m",
                expect_files=("ratings.dat",),
                timeout=(15, 60),
            )

        # Ratings ('::'-separated): UserID::movieID::Rating::Timestamp
        data = pd.read_csv(
            "ml-1m/ratings.dat",
            sep="::",
            engine="python",   # needed for '::' separator
            header=None,
            names=["UserID", "movieID", "Rating", "Timestamp"],
            dtype={"UserID": np.int32, "movieID": np.int32, "Rating": DT, "Timestamp": np.int64},
        )
        data_M = (
            data.pivot(index="UserID", columns="movieID", values="Rating")
                .fillna(0.0)
                .astype(DT, copy=False)
        )
        movies = None  # 1M doesn’t ship an easy-to-join u.item equivalent

    else:
        raise ValueError("Unknown dataset. Use '100k' or '1m'.")

    # Reindex to consecutive integers (0..U-1, 0..I-1)
    data_M = data_M.reset_index(drop=True)
    data_M.columns = range(data_M.shape[1])

    logger.info(
        f"[Data] Loaded MovieLens-{dataset}: {data_M.shape[0]} users × {data_M.shape[1]} items | "
        f"dtype={data_M.values.dtype} (GPU={USE_CUDA})"
    )
    if movies is not None:
        logger.info(f"[Data] Movies metadata rows: {len(movies)}")

    return data_M, movies


# Data splitting

In [None]:
def split(
    data_M: pd.DataFrame,
    num_test_users: int | None = None,
    num_train_ratings: int = 3,
    num_test_ratings: int = 20,
    min_pool_size: int = 10,
    rng: np.random.Generator | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Create cold-start splits for AL:
      - train: tiny warm-start for test users; full data for non-test users
      - test:  held-out items for test users
      - pool:  remaining rated items for test users (candidates)
    Returns float32 on GPU, float64 on CPU; arrays are C-contiguous.
    """
    X = data_M.values
    U, I = X.shape
    if rng is None:
        rng = np.random.default_rng()

    # --- dtypes: float32 on GPU, else keep/upgrade to float64 ---
    if np.issubdtype(X.dtype, np.floating):
        out_dtype = X.dtype
    else:
        out_dtype = np.float32 if USE_CUDA else np.float64

    # --- Sanity checks ---
    if any(v < 0 for v in (num_train_ratings, num_test_ratings, min_pool_size)):
        raise ValueError("Split counts must be non-negative.")
    req = num_train_ratings + num_test_ratings + min_pool_size
    if req == 0:
        logger.warning("Split sizes sum to 0; eligible users get empty pool.")

    # --- Eligible users (enough ratings for train+test+pool) ---
    user_nz = np.count_nonzero(X, axis=1)  # ratings per user
    candidates = np.where(user_nz >= req)[0]
    n_cand = candidates.size

    # #test users (default: 10% of eligible, at least 1)
    if num_test_users is None:
        num_test_users = max(1, int(0.1 * n_cand))
    num_test_users = min(num_test_users, n_cand)

    # Pick test-user cohort (rng-stable)
    rng.shuffle(candidates)
    test_users = np.sort(candidates[:num_test_users])  # sort for reproducibility
    test_set = set(test_users.tolist())

    logger.info("=== SPLIT ===")
    logger.info(f"Dataset: {U} users × {I} items | dtype={out_dtype} | GPU={USE_CUDA}")
    logger.info(f"Per test user -> train={num_train_ratings}, test={num_test_ratings}, min_pool={min_pool_size}")
    logger.info(f"Eligible users (≥{req} ratings): {n_cand}")
    logger.info(f"Selected test users: {len(test_users)}")

    # --- Allocate outputs (C-contiguous) ---
    train = np.zeros((U, I), dtype=out_dtype)
    test  = np.zeros((U, I), dtype=out_dtype)
    pool  = np.zeros((U, I), dtype=out_dtype)

    split_stats = {"train": 0, "test": 0, "pool": 0}

    # --- Per-user assignment ---
    for u in range(U):
        items = np.flatnonzero(X[u])  # rated item ids
        n_items = items.size

        if u in test_set and n_items >= req:
            # Choose items for cold-start train+test
            n_select = num_train_ratings + num_test_ratings
            if n_select > n_items:                 # defensive clip
                n_select = n_items
                logger.debug(f"User {u}: n_select clipped to {n_select} (had {n_items}).")

            sel = rng.choice(items, size=n_select, replace=False)

            # First part -> train
            if num_train_ratings:
                train_idx = sel[:num_train_ratings]
                train[u, train_idx] = X[u, train_idx]
                split_stats["train"] += train_idx.size

            # Rest -> test
            if num_test_ratings:
                test_idx = sel[num_train_ratings:]
                test[u, test_idx] = X[u, test_idx]
                split_stats["test"] += test_idx.size

            # Remaining rated items -> pool
            if n_items > n_select:
                sel_set = set(sel.tolist())
                pool_items = [i for i in items if i not in sel_set]
                if pool_items:
                    pool[u, pool_items] = X[u, pool_items]
                    split_stats["pool"] += len(pool_items)

        else:
            # Non-AL users: all ratings stay in train
            if n_items:
                train[u, items] = X[u, items]
                split_stats["train"] += n_items

    logger.info(f"Split stats: {split_stats}")
    logger.info("=============")

    # Ensure contiguous arrays for Numba/CUDA safety
    train = np.ascontiguousarray(train)
    test  = np.ascontiguousarray(test)
    pool  = np.ascontiguousarray(pool)

    return train, test, pool, test_users


# Initialize the model

In [None]:
def initialize_model(
    train: np.ndarray,
    test:  np.ndarray,
    rng: np.random.Generator,
    lamda: float,
    steps: int = INIT_STEPS,
    alpha: float = ALPHA_INIT,
    beta:  float = BETA,
    K: int = K,
    neighbor: int = NEIGHBOR,
    theta: float = 0.0,
    precomputed_W: np.ndarray | None = None):
    """
    Initialize Explainable MF (EMF) for ExAL.

    Steps
    -----
    1) Build W where W[u,i] = fraction of u’s k-NN (binary space) that rated i
       (optionally thresholded by `theta`).
    2) Randomly init P, Q.
    3) Pretrain with EMF_with_explainability.
    4) Compute initial test MAE (diagnostic).

    Notes
    -----
    • W is built regardless of λ so MEP/MER can be computed later.
    • dtypes: float32 on GPU, float64 on CPU.
    """
    U, I = train.shape
    DT = np.float32 if USE_CUDA else np.float64

    # Ensure consistent dtype and contiguity for Numba/CUDA
    train_dt = np.ascontiguousarray(train.astype(DT, copy=False))
    test_dt  = np.ascontiguousarray(test.astype(DT,  copy=False))

    P = rng.random((U, K), dtype=DT)
    Q = rng.random((I, K), dtype=DT)
    P = np.ascontiguousarray(P)
    Q = np.ascontiguousarray(Q)

    if precomputed_W is None:
        W = calc_exp(train_dt, neighbor=neighbor, theta=theta)
    else:
        W = precomputed_W
    # Match W dtype to compute path (important for GPU to avoid implicit casts)
    W = np.ascontiguousarray(W.astype(DT, copy=False))

    # Pretrain EMF
    P, Q, train_mae = EMF_with_explainability(train_dt, P, Q, K, W, lamda, steps, alpha, beta)

    # Initial cold-start test MAE (info only)
    pred = P.dot(Q.T)
    mask = (test_dt != 0)
    test_mae = np.abs(pred[mask] - test_dt[mask]).mean() if np.any(mask) else np.nan

    return P, Q, train_mae, test_mae, W


def _rated_mask(rate: np.ndarray) -> np.ndarray:
    """Users with ≥1 rating."""
    return np.count_nonzero(rate, axis=1) > 0


def calc_exp(rate: np.ndarray, neighbor: int = 50, theta: float = 0.0) -> np.ndarray:
    """
    Build full explainability matrix W (U×I).

    Definition
    ----------
    W[u,i] = (# of u’s k nearest neighbors who rated i) / k,
    neighbors by cosine distance on binary user–item space.

    Edge cases
    ----------
    • If <2 rated users or k==0 → zeros.
    • If theta>0 → hard-threshold values below theta to 0.

    Cost
    ----
    O(M^2 I) for pairwise distances among M rated users.
    """
    U, I = rate.shape
    DT = np.float32 if USE_CUDA else np.float64
    W = np.zeros((U, I), dtype=DT)

    mask_users = _rated_mask(rate)
    M = int(mask_users.sum())
    if M <= 1:
        return W

    k = min(neighbor, M - 1)
    if k <= 0:
        return W

    # Binary view for distances/counts; keep float for sklearn
    bin_rate = (rate > 0).astype(np.float64, copy=False)

    # Distances among rated users only
    sub = bin_rate[mask_users]                 # [M, I]
    dist = pairwise_distances(sub, metric='cosine')  # [M, M]

    # k-NN indices (skip self)
    nn = np.argsort(dist, axis=1)[:, 1:k+1]    # [M, k]

    # Count neighbor ratings per item, normalize by k (in float64 to avoid tiny drift)
    expl_sub = sub[nn, :].sum(axis=1) / float(k)   # [M, I]

    # Scatter back to full W
    idx_users = np.flatnonzero(mask_users)
    W[idx_users, :] = expl_sub.astype(DT, copy=False)

    if theta > 0.0:
        W[W < theta] = DT(0.0)
    return np.ascontiguousarray(W)


def calc_exp_row(rate: np.ndarray, u: int, neighbor: int = 50, theta: float = 0.0) -> np.ndarray:
    """
    One user’s explainability row W[u,:].

    Build neighbors for user u among users with ≥1 rating and
    return fraction of those neighbors who rated each item.

    Edge cases
    ----------
    • If u has no ratings or k==0 → zeros.
    """
    U, I = rate.shape
    DT = np.float32 if USE_CUDA else np.float64
    mask_users = _rated_mask(rate)
    if not mask_users[u]:
        return np.zeros(I, dtype=DT)

    idx_users = np.flatnonzero(mask_users)
    M = int(mask_users.sum())
    k = min(neighbor, M - 1)
    if k <= 0:
        return np.zeros(I, dtype=DT)

    # Position of u in compacted matrix
    pos = int(np.where(idx_users == u)[0][0])

    bin_rate = (rate > 0).astype(np.float64, copy=False)
    sub = bin_rate[mask_users]  # [M, I]

    # Distances from u to others
    dist_u = pairwise_distances(sub[pos][None, :], sub, metric='cosine')[0]
    nn_idx = np.argsort(dist_u)[1:k+1]  # skip self

    expl_u = sub[nn_idx, :].sum(axis=0) / float(k)
    if theta > 0.0:
        expl_u[expl_u < theta] = 0.0
    return np.ascontiguousarray(expl_u.astype(DT, copy=False))


@numba.njit
def EMF_with_explainability(
    R, P, Q, K, W, lamda, steps, alpha, beta
):
    """
    Explainable MF via SGD (Numba).

    Objective
    ---------
    Squared error + β||·||² + λ W[u,i] ||P_u − Q_i||² over observed entries.

    Impl notes
    ----------
    • Use Q^T internally for cache-friendly access.
    • Iterate only over nonzeros.
    • Return Q transposed back.
    """
    # Work with column-major access for items
    Q = Q.T
    U, I = R.shape

    # Gather coordinates of observed ratings
    nz = 0
    for u in range(U):
        for i in range(I):
            if R[u, i] != 0:
                nz += 1

    nnz_u = np.empty(nz, dtype=np.int64)
    nnz_i = np.empty(nz, dtype=np.int64)

    idx = 0
    for u in range(U):
        for i in range(I):
            if R[u, i] != 0:
                nnz_u[idx] = u
                nnz_i[idx] = i
                idx += 1

    # SGD
    for _ in range(steps):
        for t in range(nz):
            u = nnz_u[t]
            i = nnz_i[t]
            r = R[u, i]

            # prediction and error
            s = 0.0
            for f in range(K):
                s += P[u, f] * Q[f, i]
            e = r - s

            Wi = W[u, i]
            for f in range(K):
                diff = P[u, f] - Q[f, i]
                grad_p = 2.0 * e * Q[f, i] - beta * P[u, f] - lamda * Wi * diff
                grad_q = 2.0 * e * P[u, f] - beta * Q[f, i] + lamda * Wi * diff
                P[u, f] += alpha * grad_p
                Q[f, i]  += alpha * grad_q

    # MAE over observed R (diagnostic)
    if nz == 0:
        train_mae = 0.0
    else:
        total_err = 0.0
        for t in range(nz):
            u = nnz_u[t]
            i = nnz_i[t]
            r = R[u, i]
            s = 0.0
            for f in range(K):
                s += P[u, f] * Q[f, i]
            total_err += abs(r - s)
        train_mae = total_err / nz

    return P, Q.T, train_mae


@numba.njit
def retrain_online_exp(u, train, P_init, Q, W, alpha, beta, K, steps, lamda):
    """
    Online update for a single user u.

    Same objective as EMF_with_explainability, but only updates P[u]
    with Q fixed after adding new ratings to `train`.
    """
    P_u = P_init[u].copy()
    Q_t = Q.T
    I = train.shape[1]

    for _ in range(steps):
        for i in range(I):
            r = train[u, i]
            if r != 0:
                e = r - np.dot(P_u, Q_t[:, i])
                Wi = W[u, i]
                for f in range(K):
                    diff = P_u[f] - Q_t[f, i]
                    grad = 2.0 * e * Q_t[f, i] - beta * P_u[f] - lamda * Wi * diff
                    P_u[f] += alpha * grad
    return P_u


def calc_avg(train: np.ndarray) -> np.ndarray:
    """
    Per-item mean rating (ignore zeros).
    Cold items fall back to global mean.
    Returns 1D array of length I.
    """
    sums   = train.sum(axis=0)
    counts = (train != 0).sum(axis=0)
    global_mean = float(sums.sum()) / max(int(counts.sum()), 1)
    avg = np.full_like(sums, fill_value=global_mean, dtype=float)
    np.divide(sums, counts, out=avg, where=counts > 0)
    return avg


# --- CUDA path: row-wise transfer variant (keep Q persistent per iteration) ---
if 'cuda' in globals():
    @cuda.jit
    def _retrain_one_user_row_kernel(train_u, P_u, Q, W_u, alpha, beta, K, steps, lamda):
        """
        Single-thread kernel for updating one user vector P_u.
        Reads only this user’s train_u and W_u; Q is shared (I×K).
        Launch with [1,1]. Requires K ≤ MAX_K.
        """
        if cuda.threadIdx.x != 0 or cuda.blockIdx.x != 0:
            return

        I = train_u.shape[0]

        # Local copy of P_u (registers/local mem)
        P_loc = cuda.local.array(MAX_K, numba.float32)
        for f in range(K):
            P_loc[f] = P_u[f]

        # SGD over rated items in this row
        for _ in range(steps):
            for i in range(I):
                r = train_u[i]
                if r != 0.0:
                    # dot(P_loc, Q[i])
                    s = 0.0
                    for f in range(K):
                        s += P_loc[f] * Q[i, f]
                    e = r - s
                    Wi = W_u[i]
                    for f in range(K):
                        diff = P_loc[f] - Q[i, f]
                        grad = 2.0 * e * Q[i, f] - beta * P_loc[f] - lamda * Wi * diff
                        P_loc[f] += alpha * grad

        # Write back
        for f in range(K):
            P_u[f] = P_loc[f]


def retrain_online_exp_gpu(
    u: int,
    train: np.ndarray,
    P_init: np.ndarray,
    Q: np.ndarray,
    W: np.ndarray,
    alpha: float,
    beta: float,
    K: int,
    steps: int,
    lamda: float,
) -> np.ndarray:
    """
    GPU drop-in for `retrain_online_exp` using **row-wise transfers**.

    Behavior
    --------
    • Copy only train[u,:] and W[u,:] per call.
    • Keep d_Q persistent; refresh if Q object/shape changes.
    • Return updated P[u] (original dtype).
    """
    if not USE_CUDA:
        raise RuntimeError("CUDA not available. Use CPU retrain_online_exp instead.")
    if K > MAX_K:
        raise ValueError(f"K={K} exceeds MAX_K={MAX_K}; increase MAX_K or lower K.")

    # Persistent device buffers
    cache = retrain_online_exp_gpu.__dict__.setdefault("_cache", {})
    I = int(Q.shape[0])

    DT32 = np.float32

    # Refresh d_Q when needed
    q_obj_id = id(Q)
    shapes_changed = (cache.get("I") != I) or (cache.get("K") != K)

    if ("d_Q" not in cache) or shapes_changed or (cache.get("q_obj_id") != q_obj_id):
        cache["I"] = I
        cache["K"] = K
        cache["q_obj_id"] = q_obj_id
        cache["d_Q"] = cuda.to_device(np.asarray(Q, dtype=DT32))  # (I,K)

    # Ensure row buffers exist
    if ("d_train_u" not in cache) or shapes_changed:
        cache["d_train_u"] = cuda.device_array((I,), dtype=DT32)
        cache["d_W_u"]     = cuda.device_array((I,), dtype=DT32)
        cache["d_P_u"]     = cuda.device_array((K,), dtype=DT32)

    d_Q       = cache["d_Q"]
    d_train_u = cache["d_train_u"]
    d_W_u     = cache["d_W_u"]
    d_P_u     = cache["d_P_u"]

    # Host→device for this user only
    train_u32 = np.asarray(train[u], dtype=DT32)
    W_u32     = np.asarray(W[u],     dtype=DT32)
    P_u32     = np.asarray(P_init[u], dtype=DT32)

    d_train_u.copy_to_device(train_u32)
    d_W_u.copy_to_device(W_u32)
    d_P_u.copy_to_device(P_u32)

    # Launch kernel (single thread; suppress perf warnings)
    import warnings
    from numba.core.errors import NumbaPerformanceWarning
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", NumbaPerformanceWarning)
        _retrain_one_user_row_kernel[1, 1](
            d_train_u, d_P_u, d_Q, d_W_u,
            DT32(alpha), DT32(beta), np.int32(K), np.int32(steps), DT32(lamda)
        )

    # Device→host for P[u]
    P_u_out = d_P_u.copy_to_host()
    return P_u_out.astype(P_init.dtype, copy=False)


def retrain_online_exp_gpu_clear_cache() -> None:
    """
    Free persistent device buffers used by retrain_online_exp_gpu.
    """
    cache = retrain_online_exp_gpu.__dict__.get("_cache")
    if cache:
        retrain_online_exp_gpu.__dict__.pop("_cache", None)


# ExAL selections 

In [None]:
@numba.njit
def active_selection_exal_min(
    u: int,
    test_items: np.ndarray,   # indices of u’s held-out items
    pool: np.ndarray,         # nonzero ⇒ candidate in u’s pool
    eR: np.ndarray,           # predictions = P @ Q^T
    lR: np.ndarray,           # per-item running mean
    Q_dot: np.ndarray,        # optional item–item dots (I×I) or empty (0×0)
    Q: np.ndarray,            # item factors (I×K)
    expl: np.ndarray,         # W explainability weights (U×I)
    alpha: float,             # online LR used in the bound
    lamda: float,             # explainability λ (selection term)
    K: int,                   # latent dim
) -> int:
    """
    EXAL-Min: for each candidate m in u’s pool, sum the absolute EXAL bound
    over u’s test items and pick the *smallest* sum.

    Inside term (aligned with SGD):
      Δr_uj ≈ α * [ 2(R_um − lR[m])(Q_m·Q_j) + λ W_{u,m}(r̂_uj − Q_m·Q_j) ]
    """
    if test_items.shape[0] == 0:
        return -1

    use_Qdot = Q_dot.size != 0
    best_score = np.inf
    best_m = -1
    I = pool.shape[1]

    for m in range(I):
        if pool[u, m] == 0:
            continue

        Rum = eR[u, m]
        Qm = Q[m]
        s = 0.0

        for t in range(test_items.shape[0]):
            j = test_items[t]

            if use_Qdot:
                dp = Q_dot[m, j]
            else:
                sdp = 0.0
                Qj = Q[j]
                for f in range(K):
                    sdp += Qm[f] * Qj[f]
                dp = sdp

            ruj_pred = eR[u, j]
            inside = 1.0 - ruj_pred + 2.0 * alpha * (
                (Rum - lR[m]) * dp
                + lamda * expl[u, m] * (ruj_pred - dp)
            )
            s += abs(inside)

        if s < best_score:
            best_score = s
            best_m = m

    return best_m


@numba.njit
def active_selection_exal_max(
    u: int,
    test_items: np.ndarray,
    pool: np.ndarray,
    eR: np.ndarray,
    lR: np.ndarray,
    Q_dot: np.ndarray,
    Q: np.ndarray,
    expl: np.ndarray,
    alpha: float,
    lamda: float,
    K: int,
) -> int:
    """
    EXAL-Max: identical score as EXAL-Min but choose the *largest* sum.
    Uses the same corrected inside term.
    """
    if test_items.shape[0] == 0:
        return -1

    use_Qdot = Q_dot.size != 0
    best_score = -np.inf
    best_m = -1
    I = pool.shape[1]

    for m in range(I):
        if pool[u, m] == 0:
            continue

        Rum = eR[u, m]
        Qm = Q[m]
        s = 0.0

        for t in range(test_items.shape[0]):
            j = test_items[t]

            if use_Qdot:
                dp = Q_dot[m, j]
            else:
                sdp = 0.0
                Qj = Q[j]
                for f in range(K):
                    sdp += Qm[f] * Qj[f]
                dp = sdp

            ruj_pred = eR[u, j]
            inside = 1.0 - ruj_pred + 2.0 * alpha * (
                (Rum - lR[m]) * dp
                + lamda * expl[u, m] * (ruj_pred - dp)
            )
            s += abs(inside)

        if s > best_score:
            best_score = s
            best_m = m

    return best_m


@numba.njit
def active_selection_exal_max_min(
    u: int,
    test_items: np.ndarray,
    pool: np.ndarray,
    eR: np.ndarray,
    lR: np.ndarray,
    Q_dot: np.ndarray,
    Q: np.ndarray,
    expl: np.ndarray,
    alpha: float,
    lamda: float,
    iteration: int,
    switch_point: int,
    K: int,
) -> int:
    """
    EXAL Max→Min: use EXAL-Max while iteration < (switch_point−1),
    then EXAL-Min. 
    """
    if iteration < switch_point:
        return active_selection_exal_max(u, test_items, pool, eR, lR, Q_dot, Q, expl, alpha, lamda, K)
    else:
        return active_selection_exal_min(u, test_items, pool, eR, lR, Q_dot, Q, expl, alpha, lamda, K)




@numba.njit
def active_selection_karimi(
   u: int,
   test_items: np.ndarray,
   pool: np.ndarray,
   eR: np.ndarray,
   lR: np.ndarray,
   Q_dot: np.ndarray,
   Q: np.ndarray,
   alpha: float,
   K: int,
) -> int:
    """
    Karimi baseline (λ=0 analogue):
    EXAL without the explainability term.

    Inside term:
      1 − r̂_uj + 2α (R_um − lR[m]) (Q_m·Q_j)
    Choose candidate with *smallest* summed absolute term.
    """
    if test_items.shape[0] == 0:
        return -1

    use_Qdot = Q_dot.size != 0
    best_m = -1
    best_score = np.inf
    I = pool.shape[1]

    for m in range(I):
        if pool[u, m] == 0:
            continue

        Rum = eR[u, m]
        Qm = Q[m]
        s = 0.0

        for t in range(test_items.shape[0]):
            j = test_items[t]

            if use_Qdot:
                dp = Q_dot[m, j]
            else:
                sdp = 0.0
                Qj = Q[j]
                for f in range(K):
                    sdp += Qm[f] * Qj[f]
                dp = sdp

            ruj_pred = eR[u, j]
            inside = 1.0 - ruj_pred + 2.0 * alpha * ((Rum - lR[m]) * dp)
            s += abs(inside)

        if s < best_score:
            best_score = s
            best_m = m

    return best_m


# Active Learning Baselines

In [None]:
@numba.njit
def select_random(u, pool, rand_val):
    """
    Random pick from user u’s pool (nonzero entries).
    Deterministic for a given rand_val∈[0,1):
      - count candidates
      - pick k-th valid where k=floor(rand_val * count)
    Returns -1 if u has no candidates.
    """
    pool_u = pool[u]
    I = pool_u.shape[0]

    # count candidates
    cnt = 0
    for i in range(I):
        if pool_u[i] != 0:
            cnt += 1
    if cnt == 0:
        return -1

    # map rand_val → index [0, cnt-1]
    k = int(rand_val * cnt)
    if k >= cnt:
        k = cnt - 1  # guard rand_val==1.0

    # return k-th valid index
    seen = 0
    for i in range(I):
        if pool_u[i] != 0:
            if seen == k:
                return i
            seen += 1

    return -1  # unreachable fallback


@numba.njit
def active_selection_uncertainty(u, pool, eR, midpoint=3.0):
    """
    Uncertainty sampling:
    choose the pool item whose prediction for u is closest to midpoint (≈3.0).
    Returns -1 if no candidates.
    """
    pool_u = pool[u]
    eR_u = eR[u]
    I = pool_u.shape[0]

    best_idx = -1
    best_dist = np.inf

    for i in range(I):
        if pool_u[i] != 0:
            d = abs(eR_u[i] - midpoint)
            if d < best_dist:
                best_dist = d
                best_idx = i

    return best_idx


@numba.njit
def active_selection_highest_pred(u, pool, eR):
    """
    Greedy exploitation:
    pick the candidate with the highest predicted rating for user u.
    """
    best_idx = -1
    best_score = -np.inf
    I = pool.shape[1]
    for i in range(I):
        if pool[u, i] != 0:
            s = eR[u, i]
            if s > best_score:
                best_score = s
                best_idx = i
    return best_idx


@numba.njit
def active_selection_highest_confidence(u, pool, eR, midpoint=3.0):
    """
    Most confident prediction:
    pick the candidate furthest from midpoint (absolute distance).
    """
    best_idx = -1
    best_confidence = -1.0
    
    for i in range(pool.shape[1]):
        if pool[u, i] != 0:
            conf = abs(eR[u, i] - midpoint)
            if conf > best_confidence:
                best_confidence = conf
                best_idx = i
    return best_idx


@numba.njit
def active_selection_highest_variance(u, pool, eR):
    """
    Highest global variance:
      1) compute per-item variance across users
      2) among u’s pool, pick item with max variance
    Returns -1 if no candidates.
    """
    U, I = eR.shape

    # per-item mean and mean-of-squares
    item_mean = np.empty(I)
    item_msq  = np.empty(I)

    for i in range(I):
        s = 0.0
        ss = 0.0
        for uu in range(U):
            x = eR[uu, i]
            s  += x
            ss += x * x
        invU = 1.0 / U
        item_mean[i] = s * invU
        item_msq[i]  = ss * invU

    # scan u’s pool
    pool_u = pool[u]
    best_idx = -1
    best_var = -1.0

    for i in range(I):
        if pool_u[i] == 0:
            continue
        mu = item_mean[i]
        var = item_msq[i] - mu * mu  # Var = E[X^2] − (E[X])^2
        if var > best_var:
            best_var = var
            best_idx = i

    return best_idx


# Evaluation metrics

In [None]:

# -----------------------------
# Top-N helper (Numba-safe)
# -----------------------------
@numba.njit
def topn(eR, n, u):
    row = eR[u]
    finite = np.isfinite(row)
    idx = np.arange(row.shape[0])[finite]
    if idx.size == 0:
        return idx
    scores = row[idx]
    order = np.argsort(scores)[::-1]
    k = n if n < order.size else order.size
    return idx[order[:k]]




# -----------------------------
# Explainability @N
# -----------------------------
@numba.njit
def calculate_MER(eR, W, users, n):
    """
    MER@N (Mean Explainable Recall):
      For each user, recall = (# explainable *candidates* appearing in top-N)
                              / (# explainable *candidates* overall).
      A candidate is an unseen item with a finite prediction after masking.
      Users with no explainable candidates OR with no finite candidates are skipped.
    Returns:
      scalar MER in [0,1].
    """
    total = 0.0
    counted = 0

    I = W.shape[1]
    Uq = users.shape[0]

    for ui in range(Uq):
        u = users[ui]
        row = eR[u]

        # skip if user has no finite (unseen) candidates
        has_finite = False
        for j in range(row.shape[0]):
            if np.isfinite(row[j]):
                has_finite = True
                break
        if not has_finite:
            continue

        # total explainable *candidates* for user u
        expl_total = 0
        for j in range(I):
            if (W[u, j] > 0.0) and np.isfinite(row[j]):
                expl_total += 1
        if expl_total == 0:
            continue

        # explainable candidates in top-N
        top = topn(eR, n, u)
        cnt = 0
        for k_i in range(top.shape[0]):
            k = top[k_i]
            if (W[u, k] > 0.0) and np.isfinite(row[k]):
                cnt += 1

        total += (cnt / float(expl_total))
        counted += 1

    return total / counted if counted > 0 else 0.0




@numba.njit
def calculate_MEP(eR, W, users, n):
    """
    MEP@N (Mean Explainable Precision):
      For each user, precision = (# explainable in top-L) / L,
      where L = min(N, #finite candidates for that user).
      Users with L==0 are skipped.
    Returns:
      (MEP, total_explainable_found, total_positions_L)
    """
    MEP_sum = 0.0
    total_expl = 0
    total_L = 0
    counted = 0

    Uq = users.shape[0]
    J = eR.shape[1]

    for ui in range(Uq):
        u = users[ui]
        row = eR[u]

        # count finite candidates
        finite_cnt = 0
        for j in range(J):
            if np.isfinite(row[j]):
                finite_cnt += 1
        if finite_cnt == 0:
            continue

        top = topn(eR, n, u)
        top_len = n if finite_cnt >= n else finite_cnt

        cnt = 0
        for r in range(top_len):
            k = top[r]
            # be explicit: only count explainable *candidates*
            if (W[u, k] > 0.0) and np.isfinite(row[k]):
                cnt += 1

        MEP_sum += cnt / float(top_len)
        total_expl += cnt
        total_L += top_len
        counted += 1

    return (MEP_sum / counted if counted > 0 else 0.0), total_expl, total_L



# -----------------------------
# Ranking metrics
# -----------------------------
@numba.njit
def calculate_MAP(eR, test, users, n):
    """
    MAP@N:
      For each user:
        - Let R = #relevant in test[u].
        - Let L = min(N, #finite candidates).
        - AP = sum_{rank<=L} (precision@rank when item is relevant) / min(R, L).
      Users with R==0 or L==0 are skipped.
    Returns:
      mean(AP) over valid users.
    """
    total_ap = 0.0
    valid = 0

    Uq = users.shape[0]
    I = test.shape[1]

    for ui in range(Uq):
        u = users[ui]

        # count relevant
        rel = 0
        for j in range(I):
            if test[u, j] != 0:
                rel += 1
        if rel == 0:
            continue

        # count finite candidates
        row = eR[u]
        finite_cnt = 0
        for j in range(row.shape[0]):
            if np.isfinite(row[j]):
                finite_cnt += 1
        if finite_cnt == 0:
            continue

        top = topn(eR, n, u)
        top_len = n if finite_cnt >= n else finite_cnt

        hits = 0.0
        sum_prec = 0.0
        for rank in range(top_len):
            j = top[rank]
            if test[u, j] != 0:
                hits += 1.0
                sum_prec += hits / float(rank + 1)

        denom = rel if rel < top_len else top_len
        if denom > 0:
            total_ap += (sum_prec / float(denom))
            valid += 1

    return total_ap / valid if valid > 0 else 0.0





def calculate_ndcg(eR, test, users, n, graded=False):
    ndcg_total = 0.0
    valid_users = 0

    for k in range(users.shape[0]):
        u = users[k]
        row = eR[u]
        rel = test[u]

        # keep only finite candidates
        finite = np.isfinite(row)
        cand = np.where(finite)[0]
        if cand.size == 0:
            continue

        scores = row[cand]
        order = np.argsort(scores)[::-1]
        top_idx = cand[order[:n]]

        # DCG
        dcg = 0.0
        for i, j in enumerate(top_idx):
            g = rel[j] if graded else (1.0 if rel[j] > 0 else 0.0)
            if g > 0.0:
                dcg += g / np.log2(i + 2.0)

        # IDCG (best possible among the same candidate set)
        rel_cand = rel[cand]
        ideal = np.sort(rel_cand)[::-1][:n]
        idcg = 0.0
        for i in range(ideal.shape[0]):
            g = ideal[i] if graded else (1.0 if ideal[i] > 0 else 0.0)
            if g > 0.0:
                idcg += g / np.log2(i + 2.0)

        if idcg > 0.0:
            ndcg_total += dcg / idcg
            valid_users += 1

    return (ndcg_total / valid_users) if valid_users > 0 else 0.0


# -----------------------------
# Popularity / Novelty / Diversity (CPU-side)
# -----------------------------
def calculate_item_coverage(topN_items_all_users, num_items):
    """
    Item Coverage (IC):
      Fraction of catalog items that appear at least once in any user's top-N list.
    """
    if num_items <= 0:
        return 0.0
    unique_items = set()
    for user_items in topN_items_all_users:
        unique_items.update(user_items)
    return len(unique_items) / float(num_items)


def calculate_gini_index(topN_items_all_users, num_items):
    """
    Gini index of recommendation exposure across items (0=uniform, 1=concentrated).
    Includes items with zero exposure.
    """
    if num_items <= 0:
        return 0.0

    counts = np.zeros(num_items, dtype=np.float64)
    for user_items in topN_items_all_users:
        for it in user_items:
            if 0 <= it < num_items:
                counts[it] += 1.0

    total = counts.sum()
    if total == 0.0:
        return 0.0

    counts.sort()
    n = float(num_items)
    index = np.arange(1, num_items + 1, dtype=np.float64)
    gini = (2.0 * (index * counts).sum()) / (n * total) - (n + 1.0) / n
    # Numerical safety
    if gini < 0.0:
        gini = 0.0
    elif gini > 1.0:
        gini = 1.0
    return float(gini)


def calculate_ARP(topN_items_all_users, item_popularity):
    """
    ARP (Average Recommendation Popularity):
      Mean popularity (count of raters) of the items that appear in top-N lists.
      Lower is better (more long-tail exposure).
    """
    total = 0.0
    cnt = 0
    for user_items in topN_items_all_users:
        for it in user_items:
            total += float(item_popularity[it])
            cnt += 1
    return total / cnt if cnt > 0 else 0.0


def calculate_novelty_log2(topN_items_per_user, item_popularity, eps=1e-6):
    """
    Novelty (dataset-specific):
      Mean of -log2(popularity + eps) over all recommended items.
    """
    logp = np.log2(item_popularity + eps)
    s = 0.0
    c = 0
    for items in topN_items_per_user:
        s += float(-np.sum(logp[items]))
        c += len(items)
    return s / c if c else 0.0


def calculate_novelty_IDF(topN_items_per_user, item_popularity, num_users, eps=1e-6):
    """
    IDF-style Novelty (dataset-comparable):
      novelty(i) = -log2( (pop(i)+eps) / num_users )
    """
    if num_users <= 0:
        return 0.0
    p = (item_popularity + eps) / float(num_users)
    inv_info = -np.log2(p)
    s = 0.0
    c = 0
    for items in topN_items_per_user:
        s += float(np.sum(inv_info[items]))
        c += len(items)
    return s / c if c else 0.0


def calculate_novelty_EFD(topN_items_per_user, item_popularity, num_users, eps=1e-6):
    """
    EFD (Expected Free Discovery):
      Mean of 1 / freq(i), where freq(i) = (pop(i)+eps)/num_users.
      Larger values suggest rarer items overall.
    """
    if num_users <= 0:
        return 0.0
    freq = (item_popularity + eps) / float(num_users)  # (0, 1]
    inv = 1.0 / freq
    s = 0.0
    c = 0
    for items in topN_items_per_user:
        s += float(np.sum(inv[items]))
        c += len(items)
    return s / c if c else 0.0


def calculate_novelty_EPC(topN_items_per_user, item_popularity):
    """
    EPC (Expected Popularity Complement) in [0,1]:
      Mean of (1 - pop(i)/max_pop) across all recommended items.
    """
    max_pop = max(1, int(item_popularity.max()))
    s = 0.0
    c = 0
    for items in topN_items_per_user:
        s += float(np.sum(1.0 - (item_popularity[items] / max_pop)))
        c += len(items)
    return s / c if c else 0.0


In [None]:
# -----------------------------
# Popularity exposure / buckets
# -----------------------------
def compute_item_popularity(rating_matrix: np.ndarray) -> np.ndarray:
    """
    Popularity per item = #users with a nonzero rating.
    Args:
        rating_matrix: shape (num_users, num_items), zeros = unrated.
    Returns:
        1D int64 array of length num_items.
    """
    # np.count_nonzero along axis=0 returns int64 on NumPy; be explicit.
    return np.count_nonzero(rating_matrix, axis=0).astype(np.int64)


def assign_popularity_buckets(item_popularity: np.ndarray) -> np.ndarray:
    """
    Assign each item to a popularity bucket: 0=low, 1=medium, 2=high using tertiles.
    Thresholds are the 33.33% and 66.66% percentiles of item_popularity.

    Notes:
      - Items == q33 → low; q33 < pop ≤ q66 → medium; pop > q66 → high.
      - If all items have identical popularity, everything ends up in bucket 0 (low),
        which is fine and makes the split explicit.

    Args:
        item_popularity: 1D array length num_items.
    Returns:
        1D int array length num_items with values in {0,1,2}.
    """
    if item_popularity.size == 0:
        return item_popularity.astype(int)

    q33, q66 = np.percentile(item_popularity, [33.33, 66.66])
    buckets = np.zeros_like(item_popularity, dtype=int)

    # (q33, q66] -> 1 ; (> q66) -> 2 ; else -> 0
    mid_mask  = (item_popularity > q33) & (item_popularity <= q66)
    high_mask = (item_popularity > q66)

    buckets[mid_mask]  = 1
    buckets[high_mask] = 2
    return buckets


def fraction_by_popularity_bucket(topN_items: np.ndarray,
                                  popularity_buckets: np.ndarray) -> np.ndarray:
    """
    Compute the fraction of (low, medium, high) items within a single top-N list.

    Args:
        topN_items: 1D array of item indices for one user (length N).
        popularity_buckets: 1D array (len=num_items) with values {0,1,2}.
    Returns:
        1D float array of length 3 with fractions summing to 1 (or zeros if empty).
    """
    n = int(len(topN_items))
    if n == 0:
        return np.zeros(3, dtype=float)

    # Map items -> buckets, then bincount over {0,1,2}
    b = popularity_buckets[topN_items]
    counts = np.bincount(b, minlength=3).astype(float)
    return counts / float(n)


def popularity_exposure_gap(topN_items_all_users, popularity_buckets: np.ndarray) -> float:
    """
    Exposure gap across all users’ recommendations:
        gap = share(low) - share(high)
      where share(x) = (#items in bucket x) / (total #recommended items)

    Positive gap => more long-tail exposure (more low-pop items recommended overall).

    Args:
        topN_items_all_users: iterable of 1D arrays (one per user) of item indices.
        popularity_buckets: 1D array (len=num_items) with values {0,1,2}.
    Returns:
        float in [-1, 1]. Returns 0.0 if there are no recommended items at all.
    """
    low = 0.0
    high = 0.0
    total = 0.0

    for items in topN_items_all_users:
        if len(items) == 0:
            continue
        b = popularity_buckets[items]
        # Count low (0) and high (2); medium (1) is ignored for the gap.
        low  += float(np.sum(b == 0))
        high += float(np.sum(b == 2))
        total += float(len(items))

    if total == 0.0:
        return 0.0
    return (low / total) - (high / total)


# -----------------------------
# Q·Q^T precomputation with memory-awareness and caching
# -----------------------------
def maybe_precompute_Q_dot(
    Q: np.ndarray,
    max_frac_avail: float = 0.25,   # use up to 25% of currently available RAM
    abs_cap_gib: float = 2.0        # but never exceed this hard cap (GiB)
) -> np.ndarray:
    """
    Return Q_dot = Q @ Q.T if affordable under memory limits; otherwise, return
    an empty (0,0) sentinel. Caches the result per-`id(Q)` to avoid recomputing
    across iterations when Q hasn't changed.

    Rules:
      • bytes_needed = I*I*Q.dtype.itemsize
      • allowed = min(abs_cap_gib, max_frac_avail * psutil.virtual_memory().available)
      • If bytes_needed <= allowed -> compute; else -> return empty.
    """
    I = int(Q.shape[0])
    if I == 0:
        return np.empty((0, 0), dtype=Q.dtype)

    # --- simple cache keyed by object identity + shape + dtype ---
    cache = maybe_precompute_Q_dot.__dict__.setdefault("_cache", {})
    key = (id(Q), I, Q.dtype.str)
    hit = cache.get("key") == key
    if hit and "Q_dot" in cache:
        return cache["Q_dot"]

    # --- how much can we use? ---
    GIB = float(1 << 30)
    allowed = abs_cap_gib * GIB
    try:
        import psutil  # optional
        avail = float(psutil.virtual_memory().available)
        allowed = min(allowed, max_frac_avail * avail)
    except Exception:
        # psutil not available; fall back to hard cap only
        pass

    bytes_needed = float(I) * float(I) * float(Q.dtype.itemsize)

    if bytes_needed <= allowed:
        Q_dot = Q.dot(Q.T)   # dtype preserved
        cache["key"] = key
        cache["Q_dot"] = Q_dot
        try:
            import math
            logger.info(f"[Q_dot] precomputed {I}×{I} ({bytes_needed/ (1<<20):.1f} MiB).")
        except Exception:
            pass
        return Q_dot
    else:
        logger.info(
            f"[Q_dot] skipped: need {bytes_needed/(1<<20):.1f} MiB, "
            f"allowing ≤ {allowed/(1<<20):.1f} MiB."
        )
        cache["key"] = key
        cache["Q_dot"] = np.empty((0, 0), dtype=Q.dtype)  # sentinel
        return cache["Q_dot"]

def clear_qdot_cache():
    maybe_precompute_Q_dot.__dict__.pop("_cache", None)


# -----------------------------
# DataFrame utility
# -----------------------------
def safe_concat(df_list, ignore_index: bool = True) -> pd.DataFrame:
    """
    Concatenate DataFrames while dropping columns that are all-NaN in each input
    (helps when some per-iteration frames have sparse diagnostics).

    Args:
        df_list: iterable of DataFrames.
        ignore_index: passed to pd.concat.
    Returns:
        Concatenated DataFrame. If df_list is empty, returns an empty DataFrame.
    """
    if not df_list:
        return pd.DataFrame()
    cleaned = [df.dropna(axis=1, how='all') for df in df_list]
    return pd.concat(cleaned, ignore_index=ignore_index)



# Ensure LAMBDA_SELECT is defined before use
try:
    LAMBDA_SELECT
except NameError:
    LAMBDA_SELECT = None  # means “use lambda_value”


# Main experiment loop

In [None]:
def main(lambda_value, strategy_input=None, return_results=False, seed=None, n_iter=num_iter,
         results_folder=None, dataset='100k', freeze_Q=False, theta=0.0, neighbor=NEIGHBOR,
         RECOMPUTE_W_EACH_ITER=True):
    global logger
    # Selection λ source: prefer global LAMBDA_SELECT; else fall back to training λ
    sel_lambda_source = (LAMBDA_SELECT if ('LAMBDA_SELECT' in globals() and LAMBDA_SELECT is not None)
                         else lambda_value)

    # 1) Output dirs
    if results_folder is None:
        results_folder = f"Results_{neighbor}/seeds_results"
    pop_folder = f"Results_{neighbor}/Popularity_Buckets"

    # 2) RNG
    rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
    logger.info(f"[seed={seed}] first rng draw = {rng.random():.6f}")

    # 3) FS setup
    os.makedirs(results_folder, exist_ok=True)
    os.makedirs(pop_folder, exist_ok=True)

    # 4) Data
    data_M, _ = load_movielens(dataset)
    rate = data_M  # already dense DataFrame

    logger.info("=" * 60)
    logger.info("STARTING EXAL EXPERIMENT with freeze_Q="
                f"{freeze_Q}, lambda_train={lambda_value}, lambda_select={sel_lambda_source}, "
                f"strategy={strategy_input}, dataset={dataset}, seed={seed}, n_iter={n_iter}, "
                f"neighbor={neighbor}, theta={theta}")
    logger.info("=" * 60)
    logger.info(f"{'fixed Q during AL (ablation)' if freeze_Q else 'EMF updates per iteration (paper-faithful)'}")

    fixed_test = None

    # Split for cold-start AL: tiny train, held-out test, remaining pool
    train_init, test_init, pool_init, test_user = split(
        rate,
        num_test_users=fixed_test,
        num_train_ratings=3,
        num_test_ratings=20,
        min_pool_size=10,
        rng=rng
    )

    # Quick sanity for first 5 users
    logger.info("==== Data Split Check for First 5 Test Users ====")
    for u in test_user[:5]:
        train_idx = np.where(train_init[u] != 0)[0]
        test_idx = np.where(test_init[u] != 0)[0]
        pool_idx = np.where(pool_init[u] != 0)[0]
        logger.info(f"User {u:3d}: train={len(train_idx)}, test={len(test_idx)}, pool={len(pool_idx)}")
    logger.info("==============================================")

    # Item stats for selection/metrics
    lR = calc_avg(train_init)
    item_popularity = compute_item_popularity(train_init)
    popularity_buckets = assign_popularity_buckets(item_popularity)

    # Strategy list
    if strategy_input:
        strategies = [strategy_input]
    else:
        strategies = [
            'EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max', 'KARIMI',
            'Uncertainty', 'Random', 'HighestPred', 'HighestConfidence', 'HighestVar'
        ]

    results = {}

    # 9) Run each strategy
    for strategy in strategies:
        logger.info(f"\n{'='*60}")
        logger.info(f"STRATEGY: {strategy}")
        logger.info(f"{'='*60}")

        # Reset state per strategy for fair comparison
        rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
        train = np.copy(train_init)
        test = np.copy(test_init)
        pool = np.copy(pool_init)

        # Per-iteration metrics table
        evolution = pd.DataFrame(columns=["Iteration", "MAE", "MEP", "MER", "F-Score", "MAP"])

        # Decoupled λ: training vs. selection
        lamda_train  = float(lambda_value)
        lamda_select = float(sel_lambda_source)
        logger.info(f"Running {strategy} with lamda_train={lamda_train}, lamda_select={lamda_select}")

        # Init EMF + explainability
        W0 = calc_exp(train_init, neighbor=neighbor, theta=theta)
        P_init, Q_init, _, _, expl = initialize_model(
            train_init, test_init, rng,
            steps=INIT_STEPS, alpha=ALPHA_INIT, beta=BETA,
            K=K, neighbor=neighbor, theta=theta, lamda=lamda_train,
            precomputed_W=W0
        )

        # W sparsity snapshot
        if (lamda_train > 0) or (lamda_select > 0):
            logger.info("\n[Explainability Matrix Sparsity Check]")
            logger.info("Initial explainable items per user (first 10 users):")
            for u in range(min(10, expl.shape[0])):
                num_expl = int(np.sum(expl[u] > 0))
                total_items = expl.shape[1]
                perc = 100.0 * num_expl / total_items
                logger.info(f"  User {u:2d}: {num_expl:4d} / {total_items} items explainable ({perc:.2f}%)")
            logger.info("-" * 55)
        else:
            logger.info("[Explainability] λ_train==0 and λ_select==0 → W only used for metrics.")

        # Working copies
        P, Q = np.copy(P_init), np.copy(Q_init)

        if USE_CUDA:
            train = train.astype(np.float32, copy=False)
            P     = P.astype(np.float32, copy=False)
            Q     = Q.astype(np.float32, copy=False)
            expl  = expl.astype(np.float32, copy=False)

        train_mae_list, test_mae_list = [], []

        # 11) AL loop
        for iteration in tqdm.tqdm(range(n_iter), desc=f"[{strategy}] AL Iter"):
            logger.info(f"\n--- Iteration {iteration} ---")

            # Exact lR this iter
            item_sums   = train.sum(axis=0)
            item_counts = (train != 0).sum(axis=0)
            global_sum  = float(item_sums.sum())
            global_cnt  = int(item_counts.sum())
            global_mean = global_sum / max(global_cnt, 1)

            lR = np.full(train.shape[1], global_mean, dtype=train.dtype)
            mask_pos = item_counts > 0
            lR[mask_pos] = item_sums[mask_pos] / item_counts[mask_pos]

            # Recompute W if needed
            if RECOMPUTE_W_EACH_ITER:
                expl = calc_exp(train, neighbor=neighbor, theta=theta)

            # Predictions (+ optional Q·Qᵀ cache)
            eR = P.dot(Q.T)
            Q_dot = maybe_precompute_Q_dot(Q)

            # Selection + online user update
            for u in tqdm.tqdm(test_user, desc=" Users", leave=False):
                test_items = np.where(test[u, :] != 0)[0]

                if strategy == 'EXAL-Min':
                    j = active_selection_exal_min(
                        u, test_items, pool, eR, lR, Q_dot, Q, expl, ALPHA_RETRAIN, lamda_select, K
                    )
                elif strategy == 'EXAL-Max':
                    j = active_selection_exal_max(
                        u, test_items, pool, eR, lR, Q_dot, Q, expl, ALPHA_RETRAIN, lamda_select, K
                    )
                elif strategy == 'EXAL-Min-Max':
                    j = active_selection_exal_max_min(
                        u, test_items, pool, eR, lR, Q_dot, Q, expl,
                        ALPHA_RETRAIN, lamda_select, iteration, switch_point=SWITCH, K=K
                    )

                elif strategy == 'KARIMI':
                    j = active_selection_karimi(
                        u, test_items, pool, eR, lR, Q_dot, Q, ALPHA_RETRAIN, K
                    )
                elif strategy == 'Random':
                    j = select_random(u, pool, rng.random())
                elif strategy == 'Uncertainty':
                    j = active_selection_uncertainty(u, pool, eR, midpoint=3.0)
                elif strategy == 'HighestPred':
                    j = active_selection_highest_pred(u, pool, eR)
                elif strategy == 'HighestConfidence':
                    j = active_selection_highest_confidence(u, pool, eR, midpoint=3.0)
                elif strategy == 'HighestVar':
                    j = active_selection_highest_variance(u, pool, eR)
                else:
                    raise ValueError(f"Unknown strategy: {strategy}")

                if j >= 0:
                    # Move pick from pool→train
                    r = pool[u, j]
                    train[u, j] = r
                    pool[u, j]  = 0

                    # O(1) updates for lR/global stats
                    item_sums[j]   += r
                    item_counts[j] += 1
                    global_sum     += r
                    global_cnt     += 1
                    lR[j] = item_sums[j] / item_counts[j]

                    # Fast per-user W refresh if not rebuilding full W
                    if (lamda_train > 0 or lamda_select > 0) and not RECOMPUTE_W_EACH_ITER:
                        expl[u, :] = calc_exp_row(train, u, neighbor=neighbor, theta=theta)

                    # Online update for user u
                    P[u] = (
                        retrain_online_exp_gpu(u, train, P, Q, expl, ALPHA_RETRAIN, BETA, K, ONLINE_STEP, lamda_train)
                        if USE_CUDA else
                        retrain_online_exp(u, train, P, Q, expl, ALPHA_RETRAIN, BETA, K, ONLINE_STEP, lamda_train)
                    )

            # Optional EMF pass (updates Q and P)
            if not freeze_Q:
                P, Q, _ = EMF_with_explainability(train, P, Q, K, expl, lamda_train, ONLINE_STEP, ALPHA_RETRAIN, BETA)

            # Metrics (after updates)
            eR = P.dot(Q.T)

            # MAE
            mae_train = np.nanmean(np.abs(eR[train != 0] - train[train != 0]))
            mae_test  = np.nanmean(np.abs(eR[test  != 0] - test [test  != 0]))
            train_mae_list.append(mae_train)
            test_mae_list.append(mae_test)
            logger.info(f"[{strategy}] Iter {iteration}: TRAIN MAE={mae_train:.4f}, TEST MAE={mae_test:.4f}")

            # W diagnostics
            num_expl_nonzero = np.sum(expl > 0)
            total_entries    = expl.shape[0] * expl.shape[1]
            percent_nonzero  = 100 * num_expl_nonzero / total_entries
            logger.info(f"[W Sparsity] Non-zero W entries: {num_expl_nonzero}/{total_entries} ({percent_nonzero:.4f}%)")

            percent_exp_user = np.sum(expl > 0, axis=1) / expl.shape[1]
            mean_cov = np.mean(percent_exp_user) * 100
            top5     = np.sort(percent_exp_user)[-5:] * 100
            bot5     = np.sort(percent_exp_user)[:5] * 100
            logger.info(f"[W Coverage] Mean explainable items/user: {mean_cov:.2f}%")
            logger.info(f"Top 5 users w/ most explainable items: {top5}")
            logger.info(f"Bottom 5 users w/ least explainable items: {bot5}")

            # Mask seen items for ranking metrics
            mask = (train != 0) | (pool != 0)
            eR_masked = eR.copy()
            eR_masked[mask] = np.float32(-np.inf) if eR_masked.dtype == np.float32 else -np.inf

            # Explainability@N
            MEP, total_expl, total_n = calculate_MEP(eR_masked, expl, test_user, TopN)
            MER = calculate_MER(eR_masked, expl, test_user, TopN)
            F   = 2*(MEP*MER)/(MEP+MER) if (MEP+MER) > 0 else 0.0

            # Ranking
            MAPv = calculate_MAP(eR_masked, test, test_user, TopN)
            ndcg = calculate_ndcg(eR_masked, test, test_user, TopN, graded=True)

            # Build top-N per user for exposure/novelty metrics
            topN_items, all_top = [], []
            for u in test_user:
                s = eR[u].copy()
                s[mask[u]] = -np.inf
                valid = ~np.isneginf(s)
                if np.any(valid):
                    v_idx = np.where(valid)[0]
                    top_local = v_idx[np.argsort(s[v_idx])[-TopN:]][::-1]
                else:
                    top_local = np.empty(0, dtype=np.int64)
                topN_items.append(top_local)
                all_top.extend(top_local)

            # Beyond-accuracy
            num_items      = train.shape[1]
            ic             = calculate_item_coverage(topN_items, num_items)
            gini_conc      = calculate_gini_index(topN_items, num_items)
            diversity_1mG  = 1.0 - gini_conc
            arp            = calculate_ARP(topN_items, item_popularity)
            novelty_log2   = float(calculate_novelty_log2(topN_items, item_popularity))
            novelty_efd    = calculate_novelty_EFD(topN_items, item_popularity, num_users=train.shape[0])
            novelty_epc    = calculate_novelty_EPC(topN_items, item_popularity)

            # Popularity mix / exposure
            frac_pop       = fraction_by_popularity_bucket(np.array(all_top), popularity_buckets)
            frac_bias      = float(frac_pop[0] - frac_pop[2])
            exposure_gap   = popularity_exposure_gap(topN_items, popularity_buckets)

            # MAE by popularity bucket (diagnostic)
            mae_high, mae_low = [], []
            for u, recs in zip(test_user, topN_items):
                high = [i for i in recs if popularity_buckets[i] == 2 and test[u, i] != 0]
                low  = [i for i in recs if popularity_buckets[i] == 0 and test[u, i] != 0]
                if high:
                    mae_high.append(np.mean(np.abs(eR[u, high] - test[u, high])))
                if low:
                    mae_low.append(np.mean(np.abs(eR[u, low] - test[u, low])))
            mean_high = np.nan if not mae_high else float(np.mean(mae_high))
            mean_low  = np.nan if not mae_low  else float(np.mean(mae_low))
            mae_bias  = mean_low - mean_high

            # Persist popularity fractions (per strategy/λs)
            pop_df = pd.DataFrame({
                "Iteration":[iteration],
                "Frac_Low": [frac_pop[0]],
                "Frac_Med": [frac_pop[1]],
                "Frac_High":[frac_pop[2]],
            })
            pop_key = f"{strategy}_lambdaTrain_{lambda_value}_lambdaSel_{sel_lambda_source}_{dataset}"
            pop_path= os.path.join(pop_folder, f"{pop_key}{'_seed_'+str(seed) if seed else ''}.csv")
            if os.path.exists(pop_path):
                old = pd.read_csv(pop_path)
                pop_df = pd.concat([old, pop_df], ignore_index=True)
            pop_df.to_csv(pop_path, index=False)

            # Row for main CSV
            iteration_df = pd.DataFrame({
                "Iteration":          [iteration],
                "MAE":                [mae_test],
                "Train_MAE":          [mae_train],
                "Overfit_Gap":        [mae_test - mae_train],
                "MEP":                [MEP],
                "MER":                [MER],
                "F-Score":            [F],
                "MAP":                [MAPv],
                "NDCG":               [ndcg],
                "IC_ItemCoverage":    [ic],
                "Gini_Concentration": [gini_conc],
                "Diversity_1mGini":   [diversity_1mG],
                "ARP":                [arp],
                "Novelty_Log2":       [novelty_log2],
                "Novelty_EFD":        [novelty_efd],
                "Novelty_EPC":        [novelty_epc],
                "ExposureGap_LminusH":[exposure_gap],
                "Frac_Bias":          [frac_bias],
                "MAE_HighPop":        [mean_high],
                "MAE_LowPop":         [mean_low],
                "MAE_Pop_Bias":       [mae_bias],
                "Total_Explained":    [total_expl],
                "Total_Candidates":   [total_n]
            })
            evolution = safe_concat([evolution, iteration_df], ignore_index=True)
            logger.info("Iteration stats:\n" + str(iteration_df))

        # Save per-strategy results
        key = f"{strategy}_lambdaTrain_{lambda_value}_lambdaSel_{sel_lambda_source}_{dataset}"
        results[key] = evolution
        if not return_results:
            path = os.path.join(results_folder, f"{key}{'_seed_'+str(seed) if seed else ''}.csv")
            evolution.to_csv(path, index=False)
            logger.info(f"Saved results to {path}")

            # Final iter log snapshot
            logger.info(f"Iteration {iteration} complete:")
            logger.info(f"  - Train MAE: {mae_train:.4f}")
            logger.info(f"  - Test MAE: {mae_test:.4f}")
            logger.info(f"  - MEP: {MEP:.4f}, MER: {MER:.4f}, F-Score: {F:.4f}")

    logger.info("\n" + "="*60)
    logger.info("EXPERIMENT COMPLETE")
    logger.info("="*60)

    return results if return_results else None


In [None]:
def multi_seed_experiment(lambda_value, strategy_input=None, seeds=None, n_iter=num_iter,
                          results_folder=None, dataset='100k', freeze_Q=False,
                          theta=0.0, neighbor=NEIGHBOR, recompute_w_each_iter=True):    
    """
    Run active learning experiments for multiple seeds and average results.
    """
    # NEW: selection λ source (global override falls back to lambda_value)
    sel_lambda_source = (LAMBDA_SELECT if ('LAMBDA_SELECT' in globals() and LAMBDA_SELECT is not None)
                         else lambda_value)

    # 1. Set default folders for results
    if results_folder is None:
        results_folder = f"Results_{neighbor}"
    seeds_folder = os.path.join(results_folder, "seeds_results")
    pop_folder = os.path.join(results_folder, "Popularity_Buckets")

    # 2. Default seeds if not provided
    if seeds is None:
        seeds = [42, 101, 202, 303, 404, 505, 606, 707, 808, 909]

    # 3. Ensure output directories exist
    os.makedirs(seeds_folder, exist_ok=True)
    os.makedirs(pop_folder, exist_ok=True)
    logger.info(f"Running multi-seed experiment with seeds: {seeds}")

    # 4. Prepare result storage
    results_all = {}  

    # 5. Run experiment for each seed
    for seed in seeds:
        logger.info(f"Seed {seed}: first rng draw = {np.random.default_rng(seed).random():.6f}")

        # ---- GPU/Cache hygiene between seeds ----
        try:
            retrain_online_exp_gpu_clear_cache()
        except Exception:
            pass
        try:
            clear_qdot_cache()
        except Exception:
            pass

        # Run single-seed experiment
        results = main(
            lambda_value,
            strategy_input=strategy_input,
            return_results=True,
            seed=seed,
            n_iter=n_iter,
            results_folder=seeds_folder,
            dataset=dataset,
            theta=theta,
            neighbor=neighbor,
            freeze_Q=freeze_Q,
            RECOMPUTE_W_EACH_ITER=recompute_w_each_iter   
        )

        # Save and collect results
        for strategy_key, df in results.items():
            out_path = os.path.join(seeds_folder, f"{strategy_key}_seed_{seed}.csv")
            df.to_csv(out_path, index=False)
            logger.info(f"Saved: {out_path}")

            if strategy_key not in results_all:
                results_all[strategy_key] = []
            results_all[strategy_key].append(df.copy())

    # 6. Average results across seeds for each strategy
    for strategy_key, dfs in results_all.items():
        concat_df = pd.concat(dfs, keys=range(len(dfs)), names=['Seed', 'Row'])
        avg_df = concat_df.groupby('Iteration').mean(numeric_only=True).reset_index()
        avg_csv = os.path.join(seeds_folder, f"AVG_{strategy_key}.csv")
        avg_df.to_csv(avg_csv, index=False)
        logger.info(f"Averaged results for {strategy_key} saved to {avg_csv}")

    logger.info(f"Multi-seed experiment completed. Averaged results in {seeds_folder}/.")

    # 7. Post-process and average popularity results for each strategy
    strategies = [
        'EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max',  'KARIMI',
        'Uncertainty', 'Random', 'HighestPred', 'HighestConfidence', 'HighestVar'
    ]

    for strat in strategies:
        # UPDATED pattern to include both lambdas (train/select)
        pattern = os.path.join(
            pop_folder,
            f"{strat}_lambdaTrain_{lambda_value}_lambdaSel_{sel_lambda_source}_{dataset}_seed_*.csv"
        )
        files = [f for f in glob.glob(pattern) if not os.path.basename(f).startswith('AVG_')]

        avg_file = os.path.join(
            pop_folder,
            f"AVG_{strat}_lambdaTrain_{lambda_value}_lambdaSel_{sel_lambda_source}_{dataset}.csv"
        )

        if len(files) == 1:
            shutil.copyfile(files[0], avg_file)
            logger.info(f"[INFO] Only one Popularity file for {strat}: copied {files[0]} → {avg_file}")
        elif len(files) > 1:
            dfs = [pd.read_csv(f) for f in files]
            concat = pd.concat(dfs, keys=range(len(dfs)), names=['Seed', 'Row'])
            avg_df = concat.groupby('Iteration').mean(numeric_only=True).reset_index()
            avg_df.to_csv(avg_file, index=False)
            logger.info(f"[INFO] Averaged Popularity results for {strat} saved to {avg_file}")
        else:
            logger.warning(f"[WARN] No Popularity files found for {strat}.")


In [None]:
if __name__ == '__main__':
    import sys
    import argparse

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        force=True
    )
    logger = logging.getLogger(__name__)

    # Jupyter safety: drop argv noise
    if 'ipykernel' in sys.argv[0]:
        sys.argv = [sys.argv[0]]

    parser = argparse.ArgumentParser(
        description='Active Learning Experiment for MovieLens'
    )

    # λ for EMF training/updates (default = LAMBDA_TRAIN)
    parser.add_argument('--lambda_value', type=float, default=LAMBDA_TRAIN,
                        help='Lambda regularization parameter (λ) for EMF training/updates.')
    # Optional λ for selection; falls back to --lambda_value
    parser.add_argument('--lambda_select', type=float, default=None,
                        help='Lambda used by EXAL selection rules (defaults to --lambda_value if not set).')

    # Strategy choice (or "all")
    _valid_strategies = [
        'EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max', 'KARIMI',
        'Uncertainty', 'Random', 'HighestPred', 'HighestConfidence', 'HighestVar',
        'all'
    ]
    parser.add_argument('--strategy', type=str, default=None, choices=_valid_strategies,
                        help='AL strategy (e.g., "EXAL-Min"). Use "all" to run every strategy.')

    # Dataset + W parameters
    parser.add_argument('--dataset', type=str, default='100k',
                        choices=['100k', '1m'],
                        help='MovieLens dataset.')
    parser.add_argument('--theta', type=float, default=theta,
                        help='Explainability threshold θ for W_{ui}.')
    parser.add_argument('--neighbor', type=int, default=NEIGHBOR,
                        help='Number of neighbors (k) for explainability matrix W.')

    # Freeze Q during AL (ablation)
    parser.add_argument('--freeze_Q', dest='freeze_Q', action='store_true',
                        help='Freeze item factors Q during AL iterations.')
    parser.add_argument('--no-freeze_Q', dest='freeze_Q', action='store_false',
                        help='Allow iteration-end EMF updates to Q (default).')
    parser.set_defaults(freeze_Q=False)

    # Recompute W each iteration? (tri-state: auto / true / false)
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--recompute_w_each_iter', dest='recompute_w_each_iter',
                       action='store_true',
                       help='Recompute W every iteration.')
    group.add_argument('--no-recompute_w_each_iter', dest='recompute_w_each_iter',
                       action='store_false',
                       help='Do NOT recompute W every iteration.')
    parser.set_defaults(recompute_w_each_iter=None)

    # Parse CLI
    args = parser.parse_args()

    lambda_v  = args.lambda_value
    strategy  = args.strategy
    dataset   = args.dataset
    theta     = args.theta
    freeze_Q  = args.freeze_Q
    neighbor  = args.neighbor
    
    # Stash selection λ globally (keeps function signatures unchanged)
    if args.lambda_select is not None:
        LAMBDA_SELECT = args.lambda_select

    # Normalize "all"/"*" -> None (run all strategies)
    if strategy is not None and strategy.lower() in ('all', '*'):
        strategy = None

    # Auto default for recompute_w_each_iter: true if selection λ>0 (else false)
    _sel_src = (LAMBDA_SELECT if LAMBDA_SELECT is not None else lambda_v)
    if args.recompute_w_each_iter is None:
        recompute_w_each_iter = (_sel_src > 0)
        logger.info(f"[Explainability] RECOMPUTE_W_EACH_ITER auto-set to {recompute_w_each_iter} "
                    f"(λ_train={lambda_v}, λ_select={_sel_src}, θ={theta})")
    else:
        recompute_w_each_iter = args.recompute_w_each_iter
        logger.info(f"[Explainability] RECOMPUTE_W_EACH_ITER explicitly set to {recompute_w_each_iter} "
                    f"(λ_train={lambda_v}, λ_select={_sel_src}, θ={theta})")

    # Run
    multi_seed_experiment(
        lambda_v,
        strategy_input=strategy,  
        seeds=None,
        n_iter=num_iter,
        results_folder=None,
        dataset=dataset,
        freeze_Q=freeze_Q,
        theta=theta,
        neighbor=neighbor,
        recompute_w_each_iter=recompute_w_each_iter
    )

    # Final cache cleanup
    try:
        retrain_online_exp_gpu_clear_cache()
    except Exception:
        pass
    try:
        clear_qdot_cache()
    except Exception:
        pass


# END

# PLOT the Baselines

### Metrics Over Iterations

In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import logging

# ---------- Logging ----------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# ---------- Experiment selectors (EDIT THESE TO MATCH THE RUN YOU WANT TO PLOT) ----------
LAMBDA_TRAIN  = 0.005     # 0.0 for MF; 0.005 for EMF
LAMBDA_SELECT = 0.5
dataset       = '100k'
NEIGHBOR      = 20
TopN          = 10       # your experiment used 10; change if needed
SWITCH        = 5        # Min→Max switch guide if you want to show it

neighbor = NEIGHBOR
logger.info(f"Generating enhanced plots for NEIGHBOR = {neighbor} | "
            f"λ_train={LAMBDA_TRAIN}, λ_select={LAMBDA_SELECT}, dataset={dataset}, TopN={TopN}")

# ---------- Styles ----------
styles = {
    'KARIMI':               {'color': 'blue',      'marker': 'o', 'linestyle': '-'},
    'Random':               {'color': 'green',     'marker': 's', 'linestyle': '-'},
    'HighestVar':           {'color': 'cyan',      'marker': '^', 'linestyle': '-'},
    'HighestPred':          {'color': 'magenta',   'marker': 'v', 'linestyle': '-'},
    'HighestConfidence':    {'color': 'purple',    'marker': '>', 'linestyle': '-'},
    'Uncertainty':          {'color': 'orange',    'marker': 'x', 'linestyle': '-'},
    'EXAL-Min':             {'color': 'gold',      'marker': 'D', 'linestyle': '-'},
    'EXAL-Max':             {'color': 'black',     'marker': '*', 'linestyle': '--'},
    'EXAL-Min-Max':         {'color': 'red',       'marker': 'X', 'linestyle': '--'}
}

method_categories = {
    'Baselines': ['KARIMI', 'Random', 'HighestVar', 'HighestPred','HighestConfidence','Uncertainty'],
    'Original ExAL': ['EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max']
}

# ---------- Metrics to plot (must match iteration_df columns) ----------
labels = {
    'MAP':                 'Mean Average Precision (MAP) ↑ better',
    'MEP':                 'Explainable Precision (MEP) ↑ better',
    'MER':                 'Explainable Recall (MER) ↑ better',
    'F-Score':             'Explainable F1 ↑ better',
    'MAE':                 'Mean Absolute Error (MAE) ↓ better',
}

direction_info = {
    'MAP':                 {'better': 'higher', 'arrow': '↑'},
    'MEP':                 {'better': 'higher', 'arrow': '↑'},
    'MER':                 {'better': 'higher', 'arrow': '↑'},
    'F-Score':             {'better': 'higher', 'arrow': '↑'},
    'MAE':                 {'better': 'lower',  'arrow': '↓'},
}

# ---------- Folders / files ----------
results_folder = f"Results_{neighbor}/seeds_results"
plot_folder    = f"Results_{neighbor}/Plots_Enhanced_Metrics"
summary_folder = f"Results_{neighbor}/summaries"
os.makedirs(plot_folder, exist_ok=True)
os.makedirs(summary_folder, exist_ok=True)

csv_files = {
    method: os.path.join(
        results_folder,
        f"AVG_{method}_lambdaTrain_{LAMBDA_TRAIN}_lambdaSel_{LAMBDA_SELECT}_{dataset}.csv"
    )
    for method in styles
}

# ---------- Helpers ----------
def ensure_zero_row_generic(df, metric):
    """Guarantee Iteration==0 exists by duplicating first row if absent."""
    if df.empty:
        return df
    df = df.sort_values('Iteration').reset_index(drop=True)
    if (df['Iteration'] == 0).any():
        return df
    first = df.iloc[0].copy()
    first['Iteration'] = 0
    return (pd.DataFrame([first]).append(df, ignore_index=True)
            .sort_values('Iteration').reset_index(drop=True))

def ensure_zero_row(df, metric):
    """Disable zero-row synthesis for MAP/NDCG to avoid skewing start."""
    if metric in ('MAP', 'NDCG'):
        return df.sort_values('Iteration').reset_index(drop=True)
    return ensure_zero_row_generic(df, metric)

def create_organized_legend(ax, dfs):
    legend_elements = []
    for category, methods in method_categories.items():
        category_methods = [m for m in methods if m in dfs and m in styles]
        if category_methods:
            legend_elements.append(plt.Line2D([0], [0], color='none', label=f'─── {category} ───'))
            for method in category_methods:
                style = styles[method]
                legend_elements.append(
                    plt.Line2D([0], [0],
                               color=style['color'],
                               marker=style['marker'],
                               linestyle=style['linestyle'],
                               linewidth=2,
                               markersize=7,
                               label=method)
                )
    return ax.legend(handles=legend_elements, fontsize=22,
                     loc='center left', bbox_to_anchor=(1, 0.5))

def highlight_best_performers(ax, dfs, metric):
    if not dfs:
        return None, None
    finals = {m: df[metric].iloc[-1] for m, df in dfs.items() if not df.empty}
    if not finals:
        return None, None
    is_lower_better = direction_info[metric]['better'] == 'lower'
    best_m = min(finals, key=finals.get) if is_lower_better else max(finals, key=finals.get)
    best_v = finals[best_m]
    if best_m in dfs:
        df = dfs[best_m]
        ax.scatter(df['Iteration'].iloc[-1], best_v,
                   s=150, facecolors='none', edgecolors='red', linewidths=4,
                   label=f'Best: {best_m}')
    return best_m, best_v

def annotate_final_points(ax, dfs, metric, fontsize=22):
    for m, df in dfs.items():
        x = df['Iteration'].iloc[-1]
        y = df[metric].iloc[-1]
        ax.annotate(f"{y:.4f}", (x, y), xytext=(5, 0), textcoords='offset points', fontsize=fontsize)

def analyze_trends(dfs, metric):
    trends = {}
    is_lower_better = direction_info[metric]['better'] == 'lower'
    for method, df in dfs.items():
        if len(df) < 2:
            continue
        df = df.sort_values('Iteration')
        first_val = df[metric].iloc[0]
        last_val  = df[metric].iloc[-1]
        improvement = (first_val - last_val) if is_lower_better else (last_val - first_val)
        trend = 'improving' if improvement > 0 else 'declining'
        trends[method] = {'improvement': improvement, 'trend': trend,
                          'first': first_val, 'last': last_val}
    return trends

def sanity_check_metric(metric='MAP'):
    found = []
    for method, path in csv_files.items():
        if os.path.isfile(path):
            try:
                df = pd.read_csv(path)
                if metric in df.columns and 'Iteration' in df.columns and not df.empty:
                    df = df[['Iteration', metric]].dropna()
                    if not df.empty:
                        v = float(df[metric].iloc[-1])
                        found.append((method, v, path))
            except Exception as e:
                logger.warning(f"Failed reading {method} from {path}: {e}")
    if not found:
        logger.warning(f"No files with metric '{metric}' found for current selectors.")
        return
    hb = direction_info[metric]['better'] == 'higher'
    found.sort(key=lambda x: x[1], reverse=hb)
    logger.info(f"=== sanity_check_metric({metric}) ===")
    for m, v, p in found:
        logger.info(f"{m:20s}  final={v:.5f}   {p}")

# ---------- Plot all metrics ----------
best_rows = []

for metric, ylabel in labels.items():
    plt.figure(figsize=(16, 10))
    dfs = {}
    min_val, max_val = np.inf, -np.inf

    # Load
    for method, path in csv_files.items():
        if not os.path.isfile(path):
            logger.debug(f"Missing file for {method}: {path}")
            continue
        try:
            df = pd.read_csv(path)
            if metric not in df.columns or 'Iteration' not in df.columns:
                logger.debug(f"Metric '{metric}' not in file for {method}: {path}")
                continue
            df = df[['Iteration', metric]].dropna()
            df = df[np.isfinite(df[metric])]
            if df.empty:
                continue
            df = ensure_zero_row(df, metric)
            dfs[method] = df
            min_val = min(min_val, df[metric].min())
            max_val = max(max_val, df[metric].max())
        except Exception as e:
            logger.warning(f"Error loading {method} from {path}: {e}")

    # Log what actually loaded + finals
    loaded = list(dfs.keys())
    logger.info(f"[{metric}] loaded methods: {loaded}")
    if loaded:
        finals = {m: float(dfs[m][metric].iloc[-1]) for m in loaded}
        higher_better = (direction_info[metric]['better'] == 'higher')
        finals_sorted = sorted(finals.items(), key=lambda kv: kv[1], reverse=higher_better)
        logger.info(f"[{metric}] final values:")
        for m, v in finals_sorted:
            logger.info(f"  {m:20s} {v:.5f}")
    else:
        logger.warning(f"No data for metric '{metric}' at NEIGHBOR={neighbor}")
        plt.close()
        continue

    # Axis range with a small margin
    margin = 0.05 * (max_val - min_val) if max_val > min_val else 0.01
    ylow, yhigh = min_val - margin, max_val + margin

    ax = plt.gca()

    # Optional vertical guide at the Min-Max switch
    if SWITCH is not None:
        try:
            ax.axvline(SWITCH, linestyle=':', linewidth=2, alpha=0.6)
            ax.text(SWITCH, yhigh, f"  switch={SWITCH}", va='top', ha='left', fontsize=10)
        except Exception:
            pass

    # Draw curves
    for method, df in dfs.items():
        style = styles[method]
        lw = 3.5 if 'EXAL' in method else 3.0
        ms = 7
        alpha = 1.0 if 'EXAL' in method else 0.9
        ax.plot(df['Iteration'], df[metric],
                label=method,
                linewidth=lw,
                marker=style['marker'],
                color=style['color'],
                linestyle=style['linestyle'],
                markersize=ms,
                alpha=alpha)

    best_m, best_v = highlight_best_performers(ax, dfs, metric)
    annotate_final_points(ax, dfs, metric, fontsize=10)

    # Labels / title
    plt.xlabel('Active Learning Iteration', fontsize=20, fontweight='bold')
    plt.ylabel(ylabel, fontsize=20, fontweight='bold')

    title_metric = ylabel  # already has (abbr) in labels dict
    plt.title(
        f"{title_metric}\n"
        f"Dataset: MovieLens-{dataset.upper()} | λ_train={LAMBDA_TRAIN}, λ_select={LAMBDA_SELECT} "
        f"| Top-N={TopN} | Neighbors={neighbor}",
        fontsize=17, fontweight='bold', pad=25
    )

    plt.grid(True, alpha=0.7, linestyle='-', linewidth=0.9)
    plt.ylim(ylow, yhigh)
    max_iter = max(int(df['Iteration'].max()) for df in dfs.values())
    plt.xticks(range(0, max_iter + 1), fontsize=16)
    plt.yticks(fontsize=14)

    create_organized_legend(ax, dfs)
    plt.tight_layout()

    # Save
    base = os.path.join(
        plot_folder,
        f"Enhanced_{metric}_lambdaTrain_{LAMBDA_TRAIN}_lambdaSelect_{LAMBDA_SELECT}_{dataset}_k{neighbor}"
    )
    plt.savefig(base + ".png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.savefig(base + ".pdf", bbox_inches='tight', facecolor='white')
    plt.savefig(base + ".svg", bbox_inches='tight', facecolor='white')
    logger.info(f"Saved plots: {base}.(png|pdf|svg)")
    plt.show()

    # Trend logging
    trends = analyze_trends(dfs, metric)
    logger.info(f"\n=== TRENDS: {metric} ===")
    for method, t in sorted(trends.items(), key=lambda x: x[1]['improvement'], reverse=True):
        logger.info(f"{method:20} {t['trend']:10} "
                    f"({t['first']:.4f} → {t['last']:.4f}, Δ={t['improvement']:+.4f})")

    # Record best-per-metric row
    if best_m is not None and best_v is not None:
        best_rows.append({
            'Metric': metric,
            'Better': direction_info[metric]['better'],
            'Best_Method': best_m,
            'Best_FinalValue': float(best_v)
        })

# ---------- Save a compact “best methods” summary ----------
if best_rows:
    best_df = pd.DataFrame(best_rows)
    out_csv = os.path.join(
        summary_folder,
        f"best_methods_lambdaTrain_{LAMBDA_TRAIN}_lambdaSelect_{LAMBDA_SELECT}_{dataset}_k{neighbor}.csv"
    )
    best_df.to_csv(out_csv, index=False)
    logger.info(f"Saved best-per-metric summary: {out_csv}")
else:
    logger.info("No best-per-metric summary generated (no data).")

# ---------- Quick one-shot check (run once if you like) ----------
sanity_check_metric('MAP')


In [None]:
# =========================
# Pareto plots (final iteration means): MAP vs MEP
# =========================
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Where to save figures
out_dir = Path("stat_results/dual")
out_dir.mkdir(parents=True, exist_ok=True)

# Compute per-(Method,Condition) means from df_last
def _final_means(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return pd.DataFrame(columns=["Condition","Method","MAP","MEP"])
    # drop inf/nan and group
    d = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["MAP","MEP"]).copy()
    g = d.groupby(["Condition","Method"], as_index=False).agg(MAP=("MAP","mean"), MEP=("MEP","mean"))
    # keep only methods in DISPLAY_ORDER and preserve that order
    g["Method"] = pd.Categorical(g["Method"], categories=DISPLAY_ORDER, ordered=True)
    g = g.sort_values(["Condition","Method"]).reset_index(drop=True)
    return g

means = _final_means(df_last)

# Pareto-efficiency (maximize both MAP and MEP)
def pareto_mask(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    """
    Return boolean mask of Pareto-efficient points for 2 objectives (maximize X and Y).
    A point is efficient if no other point has X>= and Y>=, with at least one strict.
    """
    n = len(X)
    keep = np.ones(n, dtype=bool)
    for i in range(n):
        if not keep[i]:
            continue
        dominated = ( (X >= X[i]) & (Y >= Y[i]) & ((X > X[i]) | (Y > Y[i])) )
        dominated[i] = False
        if dominated.any():
            keep[i] = False
    return keep

def _plot_one_pareto(df_cond: pd.DataFrame, cond_label: str, filename: Path):
    if df_cond.empty:
        print(f"[PARETO] No data for {cond_label}; skipped.")
        return

    X = df_cond["MAP"].to_numpy()
    Y = df_cond["MEP"].to_numpy()
    mask = pareto_mask(X, Y)
    frontier = df_cond[mask].sort_values(["MAP","MEP"])  # tidy line

    # Aesthetics: highlight ExALs
    exal_set = {"EXAL-Min","EXAL-Max","EXAL-Min-Max"}

    plt.figure(figsize=(7.6, 6.6))
    for map_v, mep_v, method in zip(df_cond["MAP"], df_cond["MEP"], df_cond["Method"]):
        c  = "#1f77b4" if method in exal_set else "#7f7f7f"
        mk = "s"       if method in exal_set else "o"
        plt.scatter(map_v, mep_v, s=90, c=c, marker=mk, edgecolor="black", linewidth=1.0, zorder=3, label=method)

    # draw Pareto frontier
    plt.plot(frontier["MAP"], frontier["MEP"], linestyle="-", linewidth=2.2,
             color="#d62728", zorder=2, label="Pareto frontier")

    # labels near points (no duplicate legend clutter)
    for map_v, mep_v, method in zip(df_cond["MAP"], df_cond["MEP"], df_cond["Method"]):
        plt.annotate(str(method), (map_v, mep_v), xytext=(5, 5), textcoords="offset points", fontsize=11)

    plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6)
    plt.xlabel("MAP (final-iteration mean)")
    plt.ylabel("MEP (final-iteration mean)")
    plt.title(f"Pareto: MAP vs MEP  —  {cond_label}")

    # Legend UNDER the plot
    handles, labels = plt.gca().get_legend_handles_labels()
    # deduplicate labels (scatter added many)
    uniq = dict(zip(labels, handles))
    plt.legend(uniq.values(), uniq.keys(),
               loc="upper center", bbox_to_anchor=(0.5, -0.18),
               ncol=4, frameon=True)
    plt.tight_layout()
    plt.gcf().subplots_adjust(bottom=0.25)

    filename.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"[PARETO] Saved: {filename}")

# Split and plot MF / EMF
mf_means  = means[means["Condition"] == "MF"].copy()
emf_means = means[means["Condition"] == "EMF"].copy()

_plot_one_pareto(mf_means,  "MF (λ_train=0, λ_select=0.5)", out_dir / "pareto_MF.png")
_plot_one_pareto(emf_means, "EMF (λ_train=0.005, λ_select=0.5)", out_dir / "pareto_EMF.png")

# Optional: both side-by-side with legend below the whole figure
def _side_by_side(mf_df, emf_df, filename: Path):
    fig, axes = plt.subplots(1, 2, figsize=(14.2, 6.6), sharex=False, sharey=False)
    exal_set = {"EXAL-Min","EXAL-Max","EXAL-Min-Max"}

    for ax, dfc, title in zip(
        axes,
        [mf_df, emf_df],
        ["MF (λ_train=0, λ_select=0.5)", "EMF (λ_train=0.005, λ_select=0.5)"]
    ):
        if dfc.empty:
            ax.axis("off"); continue

        X = dfc["MAP"].to_numpy()
        Y = dfc["MEP"].to_numpy()
        mask = pareto_mask(X, Y)
        front = dfc[mask].sort_values(["MAP","MEP"])

        handles = []
        labels  = []

        for map_v, mep_v, method in zip(dfc["MAP"], dfc["MEP"], dfc["Method"]):
            c  = "#1f77b4" if method in exal_set else "#7f7f7f"
            mk = "s"       if method in exal_set else "o"
            h = ax.scatter(map_v, mep_v, s=90, c=c, marker=mk,
                           edgecolor="black", linewidth=1.0, zorder=3, label=method)
            handles.append(h); labels.append(method)
            ax.annotate(str(method), (map_v, mep_v), xytext=(5, 5),
                        textcoords="offset points", fontsize=11)

        ax.plot(front["MAP"], front["MEP"], linestyle="-", linewidth=2.2,
                color="#d62728", zorder=2, label="Pareto frontier")
        ax.grid(True, linestyle="--", linewidth=0.6, alpha=0.6)
        ax.set_xlabel("MAP (final mean)")
        ax.set_ylabel("MEP (final mean)")
        ax.set_title(title)

    # One combined legend under the figure (dedup)
    h1, l1 = axes[0].get_legend_handles_labels()
    h2, l2 = axes[1].get_legend_handles_labels()
    uniq = dict(zip(l1 + l2, h1 + h2))
    fig.legend(uniq.values(), uniq.keys(),
               loc="upper center", bbox_to_anchor=(0.5, -0.06),
               ncol=5, frameon=True)

    fig.tight_layout()
    fig.subplots_adjust(bottom=0.18, wspace=0.10)
    filename.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"[PARETO] Saved: {filename}")

_side_by_side(mf_means, emf_means, out_dir / "pareto_MF_EMF_side_by_side.png")


## t-test significance vs. Random baseline 

In [None]:
# =========================
# Dual-condition stats (MF vs EMF) across ALL active-selection baselines
# =========================
import os, re, glob, warnings
from typing import Dict, List, Optional, Tuple, Iterable
import numpy as np
import pandas as pd
from scipy.stats import ttest_rel
from scipy.stats import t as student_t  # for one-/two-sided p and CI
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from matplotlib.patches import Rectangle

warnings.filterwarnings("ignore")

# =========================
# Configuration (edit here)
# =========================
MF_EXPECTED  = (0.0,   0.5)   # (lambda_train, lambda_select)
EMF_EXPECTED = (0.005, 0.5)
LAMBDA_EPS   = 1e-9

DEFAULT_DATASET = "100k"
NEIGHBOR        = 20
ALPHA           = 0.05
MIN_SAMPLE_SIZE = 3

# test tails and CI level
# TAIL options: "two-sided" | "less" | "greater"
# - For vs Random, keep two-sided (neutral)
# - For DiD, set to "less" to test EMF > MF on higher-is-better metrics (MEP/MAP)
TAIL_VSR  = "two-sided"
TAIL_DID  = "greater"
CI_LEVEL  = 0.95   # used for reporting two-sided 95% CI around the mean difference

# Reduce serial-correlation inflation in multi-iteration windows by collapsing
# within each seed first (avg of (method - random) within the window), then t-test across seeds.
CLUSTER_BY_SEED = True

METHODS_UNIVERSE = [
    "EXAL-Min", "EXAL-Max", "EXAL-Min-Max",
    "KARIMI",
    "Uncertainty",
    "HighestPred", "HighestConfidence", "HighestVar",
    "Random",  # kept for discovery/alignments; plotted as last column
]
REFERENCE_METHOD = "Random"

ACTIVE_METRICS   = ["MEP", "MAP"]  # higher is better

# Heatmap color mode: "cohen_d" | "percent" | "p"
HEATMAP_MODE = "cohen_d"
D_CAP        = 1.5   # cap |d| for coloring
PCT_CAP      = 25.0  # cap |%Δ| for coloring

# Multiple-comparison correction
#   FAMILY_SCOPE: "by_row" (each heatmap row is its family) or "global"
#   METHOD: "holm" (Holm step-down) or "bonferroni"
FAMILY_SCOPE  = "by_row"
ADJUST_METHOD = "bonferroni"

OUT_DIR = os.path.join("stat_results", "dual")

STRATEGY_LABELS = {
    "final_only": "Final iteration",
    "last_n":     "Last 5 iterations >=5 (near-convergence stability)",
    "first_only": "First iteration",
    # we'll create a custom label for iteration 5 at call time
    "iter5":      "Iteration 5",
}

# =========================
# Filename parsing helpers
# =========================
def _float_equal(a: float, b: float, eps=LAMBDA_EPS) -> bool:
    return abs(a - b) <= max(eps, eps * max(1.0, abs(a), abs(b)))

SEED_RX = re.compile(
    r"^(?P<method>[^_]+)"
    r"_lambdaTrain_(?P<lt>[-+]?(?:\d+\.?\d*|\.\d+)(?:e[-+]?\d+)?)"
    r"_lambdaSel_(?P<ls>[-+]?(?:\d+\.?\d*|\.\d+)(?:e[-+]?\d+)?)"
    r"_(?P<dataset>[^_]+)"
    r"_seed_(?P<seed>\d+)\.csv$", re.IGNORECASE
)
AVG_RX = re.compile(
    r"^AVG_(?P<method>[^_]+)"
    r"_lambdaTrain_(?P<lt>[-+]?(?:\d+\.?\d*|\.\d+)(?:e[-+]?\d+)?)"
    r"_lambdaSel_(?P<ls>[-+]?(?:\d+\.?\d*|\.\d+)(?:e[-+]?\d+)?)"
    r"_(?P<dataset>[^_]+)\.csv$", re.IGNORECASE
)

def parse_meta(basename: str):
    m = SEED_RX.match(basename)
    if m:
        return dict(method=m["method"], lt=float(m["lt"]), ls=float(m["ls"]),
                    dataset=m["dataset"], seed=int(m["seed"]), is_avg=False)
    m = AVG_RX.match(basename)
    if m:
        return dict(method=m["method"], lt=float(m["lt"]), ls=float(m["ls"]),
                    dataset=m["dataset"], seed=None, is_avg=True)
    return None

def list_seeds_dirs(base_dirs: Iterable[str]) -> List[str]:
    """Find .../Results_*/seeds_results directories under provided roots."""
    out = []
    for base in base_dirs:
        base = os.path.abspath(base)
        if not os.path.isdir(base):
            continue
        for root, _, _ in os.walk(base):
            parts = os.path.normpath(root).split(os.sep)
            if len(parts) >= 2 and parts[-1] == "seeds_results" and parts[-2].startswith("Results_"):
                out.append(root)
    return sorted(set(out))

def discover_files(roots: Iterable[str]) -> pd.DataFrame:
    rows = []
    for r in roots:
        for f in glob.glob(os.path.join(r, "*.csv")):
            meta = parse_meta(os.path.basename(f))
            if not meta:
                rows.append(dict(root=r, file=f, parsed=False))
                continue
            rows.append(dict(root=r, file=f, parsed=True, **meta))
    return pd.DataFrame(rows)

def classify(df: pd.DataFrame, dataset_filter: Optional[str]) -> Tuple[List[str], List[str]]:
    """Return lists of seeded CSV paths for MF and EMF (dataset matched)."""
    mf, emf = [], []
    if df.empty:
        return mf, emf
    use = df[df["parsed"] == True].copy()
    if dataset_filter is not None:
        use = use[use["dataset"].astype(str) == str(dataset_filter)]
    for _, r in use.iterrows():
        if _float_equal(r["lt"], MF_EXPECTED[0]) and _float_equal(r["ls"], MF_EXPECTED[1]):
            mf.append(r["file"])
        elif _float_equal(r["lt"], EMF_EXPECTED[0]) and _float_equal(r["ls"], EMF_EXPECTED[1]):
            emf.append(r["file"])
    return sorted(mf), sorted(emf)

def load_seed_csvs(files: List[str]) -> Dict[str, pd.DataFrame]:
    """Load seeded CSVs (skip AVG_) and attach Method, Seed, DatasetTag. Returns method -> DF."""
    out: Dict[str, List[pd.DataFrame]] = {}
    for f in files:
        meta = parse_meta(os.path.basename(f))
        if not meta or meta["is_avg"]:
            continue
        try:
            df = pd.read_csv(f)
            if "Iteration" not in df.columns:
                continue
            df = df.copy()
            df.insert(0, "Method", meta["method"])
            df.insert(1, "Seed", meta["seed"])
            df.insert(0, "DatasetTag", meta["dataset"])
            out.setdefault(meta["method"], []).append(df)
        except Exception:
            continue
    out2 = {}
    for m, lst in out.items():
        if m in METHODS_UNIVERSE:
            out2[m] = pd.concat(lst, ignore_index=True)
    return out2

def to_long(per_method: Dict[str, pd.DataFrame], cond: str) -> pd.DataFrame:
    frames = []
    for _, df in per_method.items():
        d = df.copy()
        d.insert(0, "Condition", cond)
        frames.append(d)
    return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()

# =========================
# Row selection (time windows)
# =========================
def select_rows(df: pd.DataFrame, strategy: str, num_iterations=5, at_iter: Optional[int]=None) -> pd.DataFrame:
    """Return rows per (Method, Seed) matching the requested window."""
    if df.empty:
        return df
    d = df.sort_values(["Method", "Seed", "Iteration"]).copy()

    if strategy == "final_only":
        idx = d.groupby(["Method","Seed"])["Iteration"].idxmax()
        return d.loc[idx].copy()
    if strategy == "last_n":
        d["r"] = d.groupby(["Method","Seed"])["Iteration"].rank(ascending=False, method="first")
        return d[d["r"] <= num_iterations].drop(columns="r")
    if strategy == "first_only":
        idx = d.groupby(["Method","Seed"])["Iteration"].idxmin()
        return d.loc[idx].copy()
    if strategy == "first_n":
        d["r"] = d.groupby(["Method","Seed"])["Iteration"].rank(ascending=True, method="first")
        return d[d["r"] <= num_iterations].drop(columns="r")
    if strategy == "exact_iter":
        if at_iter is None:
            raise ValueError("exact_iter requires at_iter=<int>")
        return d[d["Iteration"] == at_iter].copy()
    raise ValueError(f"Unknown strategy: {strategy}")

# =========================
# Stats helpers (tails + CI)
# =========================
def _t_and_ci(sample: np.ndarray, mu0: float = 0.0,
              tail: str = "two-sided",
              ci_level: float = 0.95) -> Tuple[float, float, Tuple[float,float]]:
    """
    Compute t-stat, p-value (one- or two-sided), and a two-sided CI around mean(sample)-mu0.
    sample: vector of paired differences (already aligned), tested vs mu0.
    tail: "two-sided" | "less" | "greater"
    Returns: (t_stat, p_value, (ci_lo, ci_hi)) where CI is two-sided at ci_level.
    """
    x = np.asarray(sample, float)
    x = x[np.isfinite(x)]
    n = x.size
    if n < 2:
        return np.nan, 1.0, (np.nan, np.nan)
    mean = float(np.mean(x))
    sd   = float(np.std(x, ddof=1))
    if sd == 0:
        # Degenerate: all diffs identical
        t_stat = np.inf if mean - mu0 > 0 else (-np.inf if mean - mu0 < 0 else 0.0)
        # p-value under degeneracy
        if tail == "two-sided":
            p_val = 0.0 if (mean - mu0) != 0 else 1.0
        elif tail == "less":
            p_val = 0.0 if (mean - mu0) < 0 else 1.0
        else:  # "greater"
            p_val = 0.0 if (mean - mu0) > 0 else 1.0
        return t_stat, p_val, (mean - mu0, mean - mu0)

    se = sd / np.sqrt(n)
    t_stat = (mean - mu0) / se
    df = n - 1

    # p-value by tail
    if tail == "two-sided":
        p_val = 2.0 * min(student_t.cdf(t_stat, df), student_t.sf(t_stat, df))
    elif tail == "less":
        p_val = student_t.cdf(t_stat, df)
    elif tail == "greater":
        p_val = student_t.sf(t_stat, df)
    else:
        raise ValueError("tail must be one of {'two-sided','less','greater'}")

    # two-sided CI at ci_level
    alpha = 1.0 - ci_level
    t_crit = student_t.ppf(1 - alpha/2, df)
    ci_lo = (mean - mu0) - t_crit * se
    ci_hi = (mean - mu0) + t_crit * se
    return float(t_stat), float(p_val), (float(ci_lo), float(ci_hi))

def _keyed_df(vals, seeds, iters):
    return pd.DataFrame({'seed': np.asarray(seeds, int),
                         'iter': np.asarray(iters, int),
                         'val':  np.asarray(vals,  float)})

def _delta_table(df_slice: pd.DataFrame, method: str, metric: str) -> pd.DataFrame:
    """Build per-(Seed, Iteration) deltas and reference means."""
    dm = df_slice[df_slice["Method"] == method]
    dr = df_slice[df_slice["Method"] == REFERENCE_METHOD]
    if metric not in dm.columns or metric not in dr.columns:
        return pd.DataFrame(columns=["Seed","Iteration","Delta","RefMean"])
    m = _keyed_df(dm[metric].astype(float).to_numpy(), dm["Seed"], dm["Iteration"])
    r = _keyed_df(dr[metric].astype(float).to_numpy(), dr["Seed"], dr["Iteration"])
    pr = pd.merge(m, r, on=["seed","iter"], how="inner", suffixes=("_m","_r"))
    pr = pr.replace([np.inf,-np.inf], np.nan).dropna(subset=["val_m","val_r"])
    pr.rename(columns={"seed":"Seed","iter":"Iteration"}, inplace=True)
    pr["Delta"] = pr["val_m"] - pr["val_r"]
    pr["RefMean"] = pr["val_r"]
    return pr[["Seed","Iteration","Delta","RefMean"]]

def _collapse_by_seed(deltas: pd.DataFrame) -> pd.DataFrame:
    """Average Delta and RefMean within each seed (if multiple iterations exist)."""
    if deltas.empty:
        return deltas
    g = deltas.groupby("Seed", as_index=False).agg(Delta=("Delta","mean"),
                                                  RefMean=("RefMean","mean"))
    return g

def paired_vs_random(df_slice: pd.DataFrame, method: str, metric: str):
    """
    Paired test of (method - Random) = 0.
    If CLUSTER_BY_SEED is True and the window has multiple iterations per seed,
    first average within seed, then test across seeds.
    Returns: n_pairs, mean_diff, t_stat, p_raw, cohen_d (dz), percent_change_vs_random, ci_lo, ci_hi.
    """
    deltas = _delta_table(df_slice, method, metric)
    if deltas.empty:
        return 0, np.nan, np.nan, 1.0, np.nan, np.nan, np.nan, np.nan

    if CLUSTER_BY_SEED:
        use = _collapse_by_seed(deltas)
        d = use["Delta"].to_numpy()
        rbar = use["RefMean"].to_numpy()
        n = len(use)
    else:
        d = deltas["Delta"].to_numpy()
        rbar = deltas["RefMean"].to_numpy()
        n = len(deltas)

    if n < MIN_SAMPLE_SIZE:
        return n, np.nan, np.nan, 1.0, np.nan, np.nan, np.nan, np.nan

    mean_diff = float(np.mean(d))
    sd_diff   = float(np.std(d, ddof=1))
    r_mean    = float(np.mean(rbar))
    t_stat, p_raw, (ci_lo, ci_hi) = _t_and_ci(d, mu0=0.0, tail=TAIL_VSR, ci_level=CI_LEVEL)
    cohend = (mean_diff / sd_diff) if sd_diff != 0 else np.nan
    pct_change = (mean_diff / abs(r_mean) * 100) if abs(r_mean) > 1e-12 else np.nan
    return n, mean_diff, float(t_stat), float(p_raw), float(cohend), float(pct_change), float(ci_lo), float(ci_hi)

def diff_in_diff(df_mf: pd.DataFrame, df_emf: pd.DataFrame, method: str, metric: str):
    """
    Test (MF−Random) − (EMF−Random) = 0.
    If CLUSTER_BY_SEED is True, collapse within seed first, then test across seeds.
    Returns: n_pairs, mean_diff, t_stat, p_raw, cohen_d, ci_lo, ci_hi.
    """
    mf = _delta_table(df_mf, method, metric)
    emf = _delta_table(df_emf, method, metric)
    if mf.empty or emf.empty:
        return 0, np.nan, np.nan, 1.0, np.nan, np.nan, np.nan

    paired = pd.merge(mf[["Seed","Iteration","Delta"]],
                      emf[["Seed","Iteration","Delta"]],
                      on=["Seed","Iteration"], suffixes=("_mf","_emf"))
    if paired.empty:
        return 0, np.nan, np.nan, 1.0, np.nan, np.nan, np.nan

    if CLUSTER_BY_SEED:
        mf_s = paired.groupby("Seed", as_index=False)["Delta_mf"].mean()
        emf_s= paired.groupby("Seed", as_index=False)["Delta_emf"].mean()
        dtab = pd.merge(mf_s, emf_s, on="Seed")
        diff = dtab["Delta_mf"].to_numpy() - dtab["Delta_emf"].to_numpy()
        n = len(dtab)
    else:
        diff = paired["Delta_mf"].to_numpy() - paired["Delta_emf"].to_numpy()
        n = diff.size

    if n < MIN_SAMPLE_SIZE:
        return n, np.nan, np.nan, 1.0, np.nan, np.nan, np.nan

    mean_diff = float(np.mean(diff))
    sd_diff   = float(np.std(diff, ddof=1))
    t_stat, p_raw, (ci_lo, ci_hi) = _t_and_ci(diff, mu0=0.0, tail=TAIL_DID, ci_level=CI_LEVEL)
    cohend = (mean_diff / sd_diff) if sd_diff != 0 else np.nan
    return n, mean_diff, float(t_stat), float(p_raw), float(cohend), float(ci_lo), float(ci_hi)

# =========================
# Multiple-comparison adjustment
# =========================
def _adjust_series(p: pd.Series, method: str) -> pd.Series:
    """Adjust a vector of p-values using 'bonferroni' or 'holm'. Returns aligned Series."""
    p = pd.Series(p, dtype=float)
    m = len(p)
    if m == 0:
        return p

    mth = method.lower()
    if mth == "bonferroni":
        return np.minimum(p * m, 1.0)

    if mth == "holm":
        order = np.argsort(p.values)            # ascending p
        adj   = np.empty_like(p.values, float)
        running_max = 0.0
        for rank, idx in enumerate(order):
            factor = m - rank                   # m, m-1, ..., 1
            val = min(p.values[idx] * factor, 1.0)
            running_max = max(running_max, val)
            adj[idx] = running_max
        return pd.Series(adj, index=p.index)

    return p  # no adjustment

def adjust_table(df: pd.DataFrame, alpha: float, family: str, method: str) -> pd.DataFrame:
    """Apply multiple-comparison correction to MF/EMF vs Random rows."""
    if df.empty:
        return df
    out = df.copy()

    if family == "by_row":
        out["Metric_Tag"] = out["Metric"].astype(str) + "@(" + out["Condition"].astype(str) + ")"
        out["P_Adj"] = np.nan
        for _, idx in out.groupby("Metric_Tag", sort=False).groups.items():
            out.loc[idx, "P_Adj"] = _adjust_series(out.loc[idx, "P_Raw"], method).values
    else:
        out["P_Adj"] = _adjust_series(out["P_Raw"], method).values

    out["Is_Significant"] = out["P_Adj"] < alpha
    out["Direction"] = np.where(out["Mean_Diff"] > 0, "Better",
                         np.where(out["Mean_Diff"] < 0, "Worse", "No Change"))
    out["Performance"] = np.where(out["Is_Significant"], out["Direction"], "No Difference")
    return out

def adjust_did(df: pd.DataFrame, alpha: float, method: str) -> pd.DataFrame:
    """Adjust DiD p-values across all DiD rows as a single family (Holm or Bonferroni)."""
    if df.empty:
        return df.assign(P_Adj=[], Is_Significant=[])
    out = df.copy()
    out["P_Adj"] = _adjust_series(out["P_Raw"], method).values
    out["Is_Significant"] = out["P_Adj"] < alpha
    return out

# =========================
# Heatmap (fixed 4 rows) with REF
# =========================
def star_for(p_adj: float) -> str:
    if not np.isfinite(p_adj): return "ns"
    if p_adj < 1e-3: return "***"
    if p_adj < 1e-2: return "**"
    if p_adj < 5e-2: return "*"
    return "ns"

def p_to_strength(p_adj: float, cap: float = 6.0) -> float:
    if not np.isfinite(p_adj) or p_adj <= 0:
        return 1.0
    return min(-np.log10(p_adj), cap)/cap

def color_value(row, mode: str) -> float:
    sign = 0.0
    if row["Mean_Diff"] > 0: sign = 1.0
    elif row["Mean_Diff"] < 0: sign = -1.0

    if mode == "cohen_d":
        d = row["Cohen_d"]
        if not np.isfinite(d): return 0.0
        d = max(min(d, D_CAP), -D_CAP)  # clamp
        return d / D_CAP
    if mode == "percent":
        pc = row["Percent_Change"]
        if not np.isfinite(pc): return 0.0
        pc = max(min(pc, PCT_CAP), -PCT_CAP)
        return pc / PCT_CAP
    return sign * p_to_strength(row["P_Adj"])  # "p" mode

def make_heatmap(
    df_dual: pd.DataFrame,
    strategy: str,
    out_dir: str,
    dataset_tag: str,
    methods_order: List[str],
    show: bool = True,
    title_override: Optional[str] = None
):
    os.makedirs(out_dir, exist_ok=True)

    ordered_rows = ["MEP@(MF)","MEP@(EMF)","MAP@(MF)","MAP@(EMF)"]
    methods_present = [m for m in methods_order if m in df_dual["Method"].unique()]
    non_ref = [m for m in methods_present if m != REFERENCE_METHOD]
    methods = non_ref + [REFERENCE_METHOD]

    df = df_dual.copy()
    df["Metric_Tag"] = df["Metric"].astype(str) + "@(" + df["Condition"].astype(str) + ")"
    # NEW: choose one n to display in title (mode → fallback to min)
    n_vals = df["Sample_Size"] if "Sample_Size" in df.columns else None
    if n_vals is not None and np.isfinite(n_vals).any():
        n_for_title = int(np.median(n_vals))
    else:
        n_for_title = None
        
    H = np.full((len(ordered_rows), len(methods)), np.nan)
    ANNO = np.full((len(ordered_rows), len(methods)), "", dtype=object)

    for _, r in df.iterrows():
        tag = f"{r['Metric']}@({r['Condition']})"
        if tag not in ordered_rows or r["Method"] == REFERENCE_METHOD:
            continue
        i = ordered_rows.index(tag)
        j = methods.index(r["Method"])
        H[i, j] = color_value(r, HEATMAP_MODE)
        pct = r.get("Percent_Change", np.nan)
        pct_text = "NA" if not np.isfinite(pct) else f"{pct:+.1f}%"
        ANNO[i, j] = f"{star_for(r['P_Adj'])}\n{pct_text}"
        # annotate with star + % and n
        #n_pairs = int(r.get("Sample_Size", np.nan)) if np.isfinite(r.get("Sample_Size", np.nan)) else None
        #n_text = f"\n(n={n_pairs})" if n_pairs else ""
        #ANNO[i, j] = f"{star_for(r['P_Adj'])}\n{pct_text}{n_text}"

    # REF column (last)
    j_ref = methods.index(REFERENCE_METHOD)
    for i in range(len(ordered_rows)):
        H[i, j_ref] = 0.0
        ANNO[i, j_ref] = "REF"

    # Plot
    plt.style.use("default")
    fig, ax = plt.subplots(figsize=(13.5, 7.6))
    cmap = plt.get_cmap("RdYlGn")
    im = ax.imshow(H, cmap=cmap, norm=TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1))

    # annotations
    for i in range(len(ordered_rows)):
        for j in range(len(methods)):
            txt = ANNO[i, j]
            val = H[i, j]
            color = "white" if np.isfinite(val) and abs(val) > 0.5 else "black"
            ax.text(j, i, txt, ha="center", va="center", fontsize=18, fontweight="bold", color=color)

    ax.set_xticks(np.arange(len(methods)))
    ax.set_xticklabels(methods, rotation=30, ha="right",fontsize=12,fontweight="bold")
    ax.set_yticks(np.arange(len(ordered_rows)))
    ax.set_yticklabels(ordered_rows, fontsize=12,fontweight="bold")
    ax.set_xlabel("Active Learning Methods (Reference = Random)",fontsize=14)
    ax.set_ylabel("Metric@(Condition)")

    # Blue border around REF
    ax.add_patch(Rectangle((j_ref-0.5, -0.5), 1, len(ordered_rows), fill=False, edgecolor="blue", linewidth=3))

    mode_detail_map = {
        "p":       rf"(color = signed $-\log_{{10}}$ adjusted p; correction = {ADJUST_METHOD.title()})",
        "cohen_d": "(color = signed paired Cohen’s d)",
        "percent": "(color = signed %Δ vs Random; capped at ±25%)",
    }
    mode_label = mode_detail_map.get(HEATMAP_MODE, "")
    lt_mf, ls_mf   = MF_EXPECTED
    lt_emf, ls_emf = EMF_EXPECTED
    if _float_equal(ls_mf, ls_emf):
        cond_label = (
            rf"$\lambda_{{select}}={ls_mf}$"
            r"  |  "
            rf"MF: $\lambda_{{train}}={lt_mf}$"
            r"  |  "
            rf"EMF: $\lambda_{{train}}={lt_emf}$"
        )
    else:
        cond_label = (
            rf"MF: $\lambda_{{train}}={lt_mf},\,\lambda_{{select}}={ls_mf}$"
            r"  |  "
            rf"EMF: $\lambda_{{train}}={lt_emf},\,\lambda_{{select}}={ls_emf}$"
        )
    title_main = title_override or STRATEGY_LABELS.get(strategy, strategy)
    n_suffix = f"  (paired samples: n={n_for_title})"

    ax.set_title(title_main + n_suffix + "\n" + mode_label + "\n" + cond_label,
                 fontsize=16, fontweight="bold")

    cbar = plt.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
    cbar.set_label("Heatmap scale (−1 worse · 0 no diff · +1 better)")

    legend_text = (
        r"$\mathbf{GREEN}$ = Better than Random | "
        r"$\mathbf{RED}$ = Worse than Random | "
        r"$\mathbf{YELLOW}$ is No difference" "\n"
        rf"Stars use $\mathbf{{{ADJUST_METHOD.title()}}}$-adjusted p-values: "  
        r"$\mathbf{(***)}$<0.001, $\mathbf{(**)}$<0.01, $\mathbf{(*)}$<0.05, and $\mathbf{ns}$ = not significant."
    )
    plt.figtext(0.5, -0.02, legend_text, ha="center", fontsize=15.5,
                bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.85))
    plt.tight_layout(rect=[0,0.05,1,0.98])

    base = os.path.join(out_dir, f"heatmap_dual_FIXED4_{HEATMAP_MODE}_{strategy}_neighbor_{NEIGHBOR}_dataset_{dataset_tag}")
    for ext in ["png","pdf","svg"]:
        plt.savefig(f"{base}.{ext}", dpi=300, bbox_inches="tight")
    print(f"[PLOT] Saved heatmap: {base}.[png|pdf|svg]")

    if show:
        plt.show()
    plt.close(fig)

# =========================
# Reporting
# =========================
def write_reports(df: pd.DataFrame, df_did: pd.DataFrame, strategy: str, out_dir: str):
    """Save detailed per-test rows, DiD table, per-condition summaries, and a compact pivot."""
    os.makedirs(out_dir, exist_ok=True)

    det_path = os.path.join(out_dir, f"detailed_dual_{strategy}.csv")
    df.to_csv(det_path, index=False)
    print(f"[CSV] Detailed (MF & EMF vs Random): {det_path}")

    did_path = os.path.join(out_dir, f"did_dual_{strategy}.csv")
    df_did.to_csv(did_path, index=False)
    print(f"[CSV] Difference-in-Differences: {did_path}")

    for cond in ["MF","EMF"]:
        sub = df[df["Condition"] == cond]
        rows = []
        for m in [mm for mm in METHODS_UNIVERSE if mm in sub["Method"].unique() and mm != REFERENCE_METHOD]:
            mdf = sub[sub["Method"] == m]
            if mdf.empty: continue
            better = int((mdf["Performance"] == "Better").sum())
            worse  = int((mdf["Performance"] == "Worse").sum())
            nodif  = int((mdf["Performance"] == "No Difference").sum())
            tot = len(mdf)
            avg_abs_pct = float(mdf["Percent_Change"].abs().replace([np.inf,-np.inf], np.nan).dropna().mean()) if tot else np.nan
            rows.append(dict(
                Method=m,
                Better=f"{better}/{tot} ({(100*better/tot if tot else 0):.1f}%)",
                Worse=f"{worse}/{tot} ({(100*worse/tot if tot else 0):.1f}%)",
                No_Diff=f"{nodif}/{tot} ({(100*nodif/tot if tot else 0):.1f}%)",
                Avg_abs_pct=f"{0 if np.isnan(avg_abs_pct) else avg_abs_pct:.2f}%"
            ))
        path = os.path.join(out_dir, f"summary_{cond}_{strategy}.csv")
        pd.DataFrame(rows).to_csv(path, index=False)
        print(f"[CSV] {cond} summary: {path}")

    show = df.copy()
    show["Metric_Tag"] = show["Metric"].astype(str) + "@(" + show["Condition"].astype(str) + ")"

    def fmt_cell(r):
        s = star_for(r["P_Adj"])
        pct = "NA" if not np.isfinite(r["Percent_Change"]) else f"{r['Percent_Change']:+.1f}%"
        d   = "NA" if not np.isfinite(r["Cohen_d"]) else f"{r['Cohen_d']:.2f}"
        return f"{r['Performance']} | {pct} | {s} | d={d}"

    show["Cell"] = show.apply(fmt_cell, axis=1)
    pivot = show.pivot_table(index="Method", columns="Metric_Tag", values="Cell",
                             aggfunc=lambda x: x.iloc[0])
    pivot = pivot.loc[:, ~pivot.columns.duplicated()]  # safety
    want_cols = ["MEP@(MF)","MEP@(EMF)","MAP@(MF)","MAP@(EMF)"]
    pivot = pivot.reindex(columns=[c for c in want_cols if c in pivot.columns])
    p_path = os.path.join(out_dir, f"comparison_pivot_{strategy}.csv")
    pivot.to_csv(p_path)
    print(f"[CSV] Pivot: {p_path}")

# =========================
# Console analysis helpers
# =========================
def _pretty(sig: bool) -> str:
    return "✓" if sig else "–"

def _stars(p_adj: float) -> str:
    if not np.isfinite(p_adj): return "ns"
    if p_adj < 1e-3: return "***"
    if p_adj < 1e-2: return "**"
    if p_adj < 5e-2: return "*"
    return "ns"

def analyze_outputs(out_dir: str, strategy: str):
    """
    Reload CSVs we just wrote and print:
      (1) For MF and EMF, per metric: a table comparing each baseline vs Random.
      (2) A concise DiD (MF−R)−(EMF−R) summary.
    """
    det = os.path.join(out_dir, f"detailed_dual_{strategy}.csv")
    did = os.path.join(out_dir, f"did_dual_{strategy}.csv")
    print("\n[ANALYZE]", STRATEGY_LABELS.get(strategy, strategy))
    if not os.path.exists(det):
        print("  (No detailed CSV found.)")
        return

    df = pd.read_csv(det)

    # ----- A) Per-condition baseline vs Random tables -----
    if df.empty:
        print("  (Detailed table is empty.)")
    else:
        df = df.sort_values(["Condition", "Metric", "Method"]).reset_index(drop=True)
        for cond in ["MF", "EMF"]:
            sub = df[(df["Condition"] == cond)].copy()
            if sub.empty: continue
            print(f"\n  -- {cond} vs Random --")
            for metric in ACTIVE_METRICS:
                msub = sub[sub["Metric"] == metric].copy()
                if msub.empty: continue
                msub = msub[msub["Method"] != REFERENCE_METHOD]
                def _fmt_num(x, fmt):
                    return "" if not np.isfinite(x) else format(x, fmt)
                out_rows = []
                for _, r in msub.iterrows():
                    out_rows.append(dict(
                        Method   = r["Method"],
                        n        = int(r["Sample_Size"]),
                        MeanDiff = _fmt_num(r["Mean_Diff"], "+.4f"),
                        dz       = _fmt_num(r["Cohen_d"], ".2f"),
                        pctDelta = (_fmt_num(r["Percent_Change"], "+.1f") + "%") if np.isfinite(r["Percent_Change"]) else "",
                        p_adj    = _fmt_num(r.get("P_Adj", np.nan), ".3g"),
                        Sig      = _stars(r.get("P_Adj", np.nan)),
                        Perf     = r.get("Performance", "")
                    ))
                tab = pd.DataFrame(out_rows, columns=["Method","n","MeanDiff","dz","pctDelta","p_adj","Sig","Perf"])
                if tab.empty: continue
                def _rank(row):
                    if row["Perf"] == "Better": return (0, -abs(float(row["dz"] or 0)))
                    if row["Perf"] == "Worse":  return (1, -abs(float(row["dz"] or 0)))
                    return (2, 0)
                tab["_order"] = tab.apply(_rank, axis=1)
                tab = tab.sort_values("_order").drop(columns="_order")
                print(f"     {metric}:")
                print(tab.to_string(index=False))

    # ----- B) DiD summary -----
    if os.path.exists(did):
        df_did = pd.read_csv(did)
        print("\n  -- DiD: (MF−R) − (EMF−R) --")
        if df_did.empty:
            print("     No DiD rows.")
        else:
            sig = df_did[df_did["Is_Significant"] == True].copy()
            if sig.empty:
                print("     No significant DiD.")
            else:
                for metric in ACTIVE_METRICS:
                    ss = sig[sig["Metric"] == metric]
                    if ss.empty: continue
                    pos = ss[ss["Mean_Diff"] > 0.0].sort_values("Cohen_d", ascending=False)
                    neg = ss[ss["Mean_Diff"] < 0.0].sort_values("Cohen_d")
                    if not pos.empty:
                        names = ", ".join(f"{r.Method} (dz={r.Cohen_d:.2f}, p_adj={r.P_Adj:.3g})" for _, r in pos.iterrows())
                        print(f"     {metric}: MF > EMF for {names}")
                    if not neg.empty:
                        names = ", ".join(f"{r.Method} (dz={r.Cohen_d:.2f}, p_adj={r.P_Adj:.3g})" for _, r in neg.iterrows())
                        print(f"     {metric}: EMF > MF for {names}")
    else:
        print("\n  -- DiD: (MF−R) − (EMF−R) --")
        print("     (No DiD CSV found.)")

# =========================
# Overview table helpers (new)
# =========================
def _count_raw_pairs_for_window(mf_slice: pd.DataFrame, emf_slice: pd.DataFrame) -> int:
    """
    Raw n for the window = number of unique (Seed, Iteration) pairs present in Random.
    We take the min across MF and EMF slices to reflect aligned availability.
    """
    def _count(df):
        if df is None or df.empty: return 0
        r = df[df["Method"] == REFERENCE_METHOD]
        if r.empty: return 0
        return r.drop_duplicates(["Seed","Iteration"]).shape[0]
    n_mf  = _count(mf_slice)
    n_emf = _count(emf_slice)
    if n_mf and n_emf:
        return min(n_mf, n_emf)
    return n_mf or n_emf or 0

def _overview_row(strategy_key: str, paired_df: pd.DataFrame, n_raw: int) -> dict:
    """
    Build one summary row for the overview table from a window's paired_df.
    """
    if paired_df is None or paired_df.empty:
        return dict(Strategy=strategy_key, n=0, SigTests="0/0", SigPct="0.0%", Better=0, Worse=0, AvgAbsPct="0.00")

    total_tests = int(paired_df.shape[0])
    sig_tests   = int((paired_df["Is_Significant"] == True).sum())
    better      = int((paired_df["Performance"] == "Better").sum())
    worse       = int((paired_df["Performance"] == "Worse").sum())
    avg_abs_pct = paired_df["Percent_Change"].abs().replace([np.inf, -np.inf], np.nan).dropna()
    avg_abs_pct = 0.0 if avg_abs_pct.empty else float(avg_abs_pct.mean())

    return dict(
        Strategy   = STRATEGY_LABELS.get(strategy_key, strategy_key),
        n          = n_raw,
        SigTests   = f"{sig_tests}/{total_tests}",
        SigPct     = f"{(100.0*sig_tests/total_tests if total_tests else 0.0):.1f}%",
        Better     = better,
        Worse      = worse,
        AvgAbsPct  = f"{avg_abs_pct:.2f}"
    )

def _write_overview_table(rows: list, out_dir: str, filename_base: str = "overview_dual_windows"):
    """
    Save overview as CSV and a LaTeX tabular snippet that matches paper style.
    """
    os.makedirs(out_dir, exist_ok=True)
    df = pd.DataFrame(rows, columns=["Strategy","n","SigTests","SigPct","Better","Worse","AvgAbsPct"])
    csv_path = os.path.join(out_dir, f"{filename_base}.csv")
    df.to_csv(csv_path, index=False)

    tex_lines = [
        r"\begin{tabular}{l r r r r r r}",
        r"\toprule",
        r"Strategy & $n$ & Sig.\ Tests & Sig.\ (\%) & Better & Worse & Avg.\ $|\%\Delta|$ \\",
        r"\midrule",
    ]
    for _, r in df.iterrows():
        tex_lines.append(
            f"{r['Strategy']} & {r['n']} & {r['SigTests']} & {r['SigPct']} & "
            f"{int(r['Better'])} & {int(r['Worse'])} & {float(r['AvgAbsPct']):.2f} \\\\"
        )
    tex_lines += [r"\bottomrule", r"\end{tabular}"]
    tex_str = "\n".join(tex_lines)
    tex_path = os.path.join(out_dir, f"{filename_base}.tex")
    with open(tex_path, "w") as f:
        f.write(tex_str)

    print(f"[CSV] Overview table: {csv_path}")
    print(f"[TEX] LaTeX table:    {tex_path}")

# =========================
# Orchestrator
# =========================
def run_dual_analysis(
    base_dirs: Iterable[str],
    dataset_filter: Optional[str] = DEFAULT_DATASET,
    strategies = None,
    out_dir: str = OUT_DIR,
    show_plots: bool = True
):
    """
    Main entry: discover files, run tests, write CSVs, plot heatmaps, print analysis,
    and produce an overview summary table across windows.
    """
    # Default windows: first, iter5, final, last5>=5
    if strategies is None:
        strategies = [
            ("first_only", {}),
            ("exact_iter", {"at_iter": 5}),
            ("final_only", {}),
            ("last_n",     {"num_iterations": 5}),
        ]

    os.makedirs(out_dir, exist_ok=True)

    roots = list_seeds_dirs(base_dirs)
    if not roots:
        print(f"[AUTO] No 'Results_*/seeds_results' folders under: {list(base_dirs)}")
        return

    ledger = discover_files(roots)
    if ledger.empty:
        print("[AUTO] No CSV files discovered.")
        return
    led_path = os.path.join(out_dir, "discovery_ledger.csv")
    ledger.to_csv(led_path, index=False)
    print(f"[DISCOVERY] Ledger saved: {led_path}")

    mf_files, emf_files = classify(ledger, dataset_filter)
    if not mf_files and not emf_files:
        print("[INFO] No MF/EMF matches for given dataset filter.")
        return

    mf_long  = to_long(load_seed_csvs(mf_files),  "MF")  if mf_files  else pd.DataFrame()
    emf_long = to_long(load_seed_csvs(emf_files), "EMF") if emf_files else pd.DataFrame()

    if mf_long.empty:  print("[WARN] No usable seeded MF files (missing Iteration or no matches).")
    if emf_long.empty: print("[WARN] No usable seeded EMF files (missing Iteration or no matches).")

    if not mf_long.empty:
        mf_long = mf_long[mf_long["Method"].isin(METHODS_UNIVERSE)]
    if not emf_long.empty:
        emf_long = emf_long[emf_long["Method"].isin(METHODS_UNIVERSE)]

    # accumulate overview rows
    overview_rows = []

    for strategy, kwargs in strategies:
        # Pretty name (special-case exact_iter->iter5 label)
        title_override = None
        strategy_key_for_table = strategy
        if strategy == "exact_iter" and kwargs.get("at_iter", None) == 5:
            strategy_key_for_table = "iter5"
            title_override = STRATEGY_LABELS["iter5"]

        print(f"\n===== {STRATEGY_LABELS.get(strategy_key_for_table, strategy)} =====")
        mf_slice  = select_rows(mf_long,  strategy=strategy, **kwargs) if not mf_long.empty else pd.DataFrame()
        emf_slice = select_rows(emf_long, strategy=strategy, **kwargs) if not emf_long.empty else pd.DataFrame()

        # A) MF/EMF vs Random
        rows = []
        for cond, dfc in (("MF", mf_slice), ("EMF", emf_slice)):
            if dfc.empty:
                continue
            for method in [m for m in METHODS_UNIVERSE if m != REFERENCE_METHOD and m in dfc["Method"].unique()]:
                for metric in ACTIVE_METRICS:
                    n, mean_diff, t_stat, p_raw, d, pct, ci_lo, ci_hi = paired_vs_random(dfc, method, metric)
                    if n < MIN_SAMPLE_SIZE:
                        continue
                    rows.append(dict(
                        Condition=cond, Method=method, Metric=metric,
                        Sample_Size=n, Mean_Diff=mean_diff, T_Stat=t_stat,
                        P_Raw=p_raw, Cohen_d=d, Percent_Change=pct,
                        CI_Lo=ci_lo, CI_Hi=ci_hi, Tail=TAIL_VSR
                    ))
        paired_df = pd.DataFrame(rows)
        if paired_df.empty:
            print("[INFO] No aligned pairs under this window (check Random presence & seed overlap).")
            continue
        paired_df = adjust_table(paired_df, alpha=ALPHA, family=FAMILY_SCOPE, method=ADJUST_METHOD)

        # B) DiD: (MF−R) − (EMF−R)
        did_rows = []
        if not mf_slice.empty and not emf_slice.empty:
            common_methods = sorted(set(METHODS_UNIVERSE) & set(mf_slice["Method"].unique()) & set(emf_slice["Method"].unique()))
            common_methods = [m for m in common_methods if m != REFERENCE_METHOD]
            for method in common_methods:
                for metric in ACTIVE_METRICS:
                    n, mean_diff, t_stat, p_raw, d, ci_lo, ci_hi = diff_in_diff(mf_slice, emf_slice, method, metric)
                    if n < MIN_SAMPLE_SIZE:
                        continue
                    did_rows.append(dict(
                        Method=method, Metric=metric, Sample_Size=n,
                        Mean_Diff=mean_diff, T_Stat=t_stat, P_Raw=p_raw, Cohen_d=d,
                        CI_Lo=ci_lo, CI_Hi=ci_hi, Tail=TAIL_DID
                    ))
        did_df = pd.DataFrame(did_rows)
        did_df = adjust_did(did_df, alpha=ALPHA, method=ADJUST_METHOD)

        # Write + Plot + Analyze
        write_reports(paired_df, did_df, strategy_key_for_table, out_dir)
        make_heatmap(
            paired_df,
            strategy=strategy_key_for_table,
            out_dir=out_dir,
            dataset_tag=(dataset_filter or "mixed"),
            methods_order=METHODS_UNIVERSE,
            show=show_plots,
            title_override=title_override
        )
        analyze_outputs(out_dir, strategy_key_for_table)

        # Overview row
        n_raw = _count_raw_pairs_for_window(mf_slice, emf_slice)
        overview_rows.append(_overview_row(strategy_key_for_table, paired_df, n_raw))

    # Save a single overview across all processed windows
    _write_overview_table(overview_rows, out_dir, filename_base="overview_dual_windows")

# =========================
# Example call
# =========================
if __name__ == "__main__":
    run_dual_analysis(
        base_dirs=[
            ".",  # current tree
            #"LINK_TO_ANOTHER_ROOT",
        ],
        dataset_filter="100k",
        show_plots=True,
    )
