# A New Explainable Active Learning Approach for Recommender Systems (ExAL)

**Authors and Contact Information:**

** Anonymous


# Adding Libraries and Configuration

In [None]:
# Imports 
import glob, io, logging, os, random, shutil, zipfile
import numba, numpy as np, pandas as pd, requests, tqdm
from sklearn.metrics import pairwise_distances

# Logger Setup 
# This logger will log to both a file and the console
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.FileHandler("run_logs.log"),    
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(__name__)


# Global Hyperparameters for Explainable Active Learning (ExAL) 

In [None]:

num_iter      = 10     # Number of active learning iterations (outer loop)
SWITCH        = 5      # Iteration to switch from ExAL-Max to ExAL-Min (for ExAL-Min-Max)

ALPHA_INIT    = 0.01   # Learning rate for initial EMF (matrix factorization) training
ALPHA_RETRAIN = 0.001  # Learning rate for online retraining (per AL iteration)

LAMDA         = 0.1    # L1 regularization strength for EMF
NEIGHBOR      = 20     # Number of neighbors (k) for explainability matrix W
BETA          = 0.15   # L2 regularization strength for EMF

INIT_STEPS    = 300    # Number of EMF training epochs before AL begins
ONLINE_STEP   = 10     # Number of SGD steps per user per AL iteration

TopN          = 25     # Top-N cutoff for recommendation evaluation (MAP, xP, xR, etc)
K             = 10     # Dimension of latent feature vectors in MF/EMF


# Datasets

In [None]:
def load_movielens(dataset='100k'):
    """
    Load MovieLens dataset (100k or 1M).
    Returns:
      data_M: user-item rating matrix (users x items, unrated = 0)
      movies: movie info (100k only)
    """
    if dataset == '100k':
        if not os.path.exists('ml-100k'):
            # Download and extract ML-100k
            url_100k = "http://files.grouplens.org/datasets/movielens/ml-100k.zip"
            r = requests.get(url_100k)
            z = zipfile.ZipFile(io.BytesIO(r.content))
            z.extractall(path='./')

        # Load ratings and build rating matrix
        data = pd.read_table('ml-100k/u.data', names=['UserID', 'movieID', 'Rating', 'Timestamp'])
        data_M = data.pivot(index='UserID', columns='movieID', values='Rating').fillna(0)

        # Load movie genres
        genre_columns = [
            'unknown', 'Action', 'Adventure', 'Animation', 'Children', 'Comedy', 'Crime',
            'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical',
            'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'
        ]
        movies = pd.read_csv(
            'ml-100k/u.item',
            sep='|',
            header=None,
            encoding='ISO-8859-1',
            usecols=[0, 1, *range(5, 24)],
            names=['movieID', 'movie name', 'release date', 'video release date', 'IMDb URL'] + genre_columns
        )
        movies['genre'] = movies[genre_columns].dot(pd.Index(genre_columns) + ',').str.rstrip(',')
        movies = movies[['movieID', 'movie name', 'genre']]

    elif dataset == '1m':
        if not os.path.exists('ml-1m'):
            # Download and extract ML-1M
            url_1m = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
            r = requests.get(url_1m)
            z = zipfile.ZipFile(io.BytesIO(r.content))
            z.extractall(path='./')

        # Load ratings and build rating matrix
        data = pd.read_csv('ml-1m/ratings.dat', sep='::', engine='python',
                           names=['UserID', 'movieID', 'Rating', 'Timestamp'])
        data_M = data.pivot(index='UserID', columns='movieID', values='Rating').fillna(0)
        movies = None  # No movie info for 1M

    else:
        raise ValueError("Unknown dataset")
    
    # Reindex users and items to consecutive integers
    data_M = data_M.reset_index(drop=True)
    data_M.columns = range(data_M.shape[1])

    return data_M, movies


# Data splitting

In [None]:
def split(
    data_M,
    num_test_users=None,      
    num_train_ratings=3,         
    num_test_ratings=20, 
    min_pool_size=10,
    rng=None
):
    """
    Splits user-item rating matrix into train, test, and pool sets.
    Aligns with paper methodology - test set used for AL selection.
    """
    X = data_M.values
    U, I = X.shape
    if rng is None:
        rng = np.random.default_rng()
        
    logger.info("=== SPLIT FUNCTION DEBUG ===")
    logger.info(f"Dataset shape: {U} users x {I} items")
    logger.info(f"Split sizes: train={num_train_ratings}, test={num_test_ratings}, min_pool={min_pool_size}")

    # --- 1. Identify eligible test users ---
    req = num_train_ratings + num_test_ratings + min_pool_size  
    candidates = [u for u in range(U) if np.count_nonzero(X[u]) >= req]
    logger.info(f"Found {len(candidates)} users with >= {req} ratings")

    # --- 2. Select test users for AL scenario ---
    if num_test_users is None:
        num_test_users = max(1, int(0.1 * len(candidates)))
    rng.shuffle(candidates)
    test_users = np.array(candidates[:num_test_users])
    logger.info(f"Selected {len(test_users)} test users")
    test_set = set(test_users)

    # --- 3. Allocate ratings into train, test, and pool matrices ---
    train = np.zeros((U, I), float)
    test  = np.zeros((U, I), float)
    pool  = np.zeros((U, I), float)
    
    split_stats = {"train": 0, "test": 0, "pool": 0}  
    
    for u in range(U):
        items = np.flatnonzero(X[u])
        if u in test_set and len(items) >= req:
            # Select ratings for train/test
            n_select = num_train_ratings + num_test_ratings  
            sel = rng.choice(items, size=n_select, replace=False)
            
            train_end = num_train_ratings
            
            # Train set: small set for initialization
            train[u, sel[:train_end]] = X[u, sel[:train_end]]
            # Test set: used for AL query selection AND final evaluation
            test[u, sel[train_end:]] = X[u, sel[train_end:]]
            # Pool: candidate ratings for AL queries
            pool_items = [i for i in items if i not in sel]
            pool[u, pool_items] = X[u, pool_items]
            
            split_stats["train"] += train_end
            split_stats["test"] += num_test_ratings
            split_stats["pool"] += len(pool_items)
        else:
            # For non-test users, keep all available ratings for training
            train[u, items] = X[u, items]
            split_stats["train"] += len(items)
            
    logger.info(f"Split statistics: {split_stats}")
    logger.info("=========================")
    
    return train, test, pool, test_users  

# Initialize the model

In [None]:
def initialize_model(train, test, rng, lamda, steps=INIT_STEPS,
                     alpha=ALPHA_INIT, beta=BETA, K=K,
                     neighbor=NEIGHBOR, theta=0.0):
    """
    Initializes Explainable Matrix Factorization (EMF) for ExAL experiments.

    """
    U, I = train.shape
    P = rng.random((U, K))
    Q = rng.random((I, K))
    W = calc_exp(train, neighbor=neighbor, theta=theta)
    P, Q, train_mae = EMF_with_explainability(
        train, P, Q, K, W=W, lamda=lamda, steps=steps, alpha=alpha, beta=beta
    )
    pred = P.dot(Q.T)
    mask = test != 0
    test_mae = np.abs(pred[mask] - test[mask]).mean() if np.any(mask) else np.nan
    return P, Q, train_mae, test_mae, W

def calc_exp(rate, neighbor=20, theta=0.0):
    U, I = rate.shape
    k = min(neighbor, U-1)
    if k <= 0:
        return np.zeros((U, I), float)

    dist = pairwise_distances(rate, metric='cosine')
    nn = np.argsort(dist, axis=1)[:, 1:k+1]          
    expl = (rate[nn, :] > 0).sum(axis=1) / float(k)  
    if theta > 0.0:
        expl[expl < theta] = 0.0
    return expl

def calc_exp_row(rate, u, neighbor=20, theta=0.0):
    U, I = rate.shape
    k = min(neighbor, U-1)
    if k <= 0:
        return np.zeros(I, float)

    dist_u = pairwise_distances(rate[u][None, :], rate, metric='cosine')[0]
    nn_idx = np.argsort(dist_u)[1:k+1]              
    expl_u = (rate[nn_idx, :] > 0).sum(axis=0) / float(k)
    if theta > 0.0:
        expl_u[expl_u < theta] = 0.0
    return expl_u




@numba.njit
def EMF_with_explainability(R, P, Q, K, W, lamda, steps, alpha, beta):
    """
    EMF (Explainable Matrix Factorization) optimizer.

    Returns:
      - Updated P, Q, and train MAE.
    """
    Q = Q.T
    U, I = R.shape
    for _ in range(steps):
        for u in range(U):
            for i in range(I):
                if R[u, i] > 0:
                    e = R[u, i] - np.dot(P[u], Q[:, i])
                    for f in range(K):
                        diff = P[u, f] - Q[f, i]
                        grad_p = 2 * e * Q[f, i] - beta * P[u, f] - lamda * W[u, i] * diff
                        grad_q = 2 * e * P[u, f] - beta * Q[f, i] + lamda * W[u, i] * diff
                        P[u, f] += alpha * grad_p
                        Q[f, i] += alpha * grad_q
    # Compute MAE on training set
    total_error = 0.0
    count = 0
    for u in range(U):
        for i in range(I):
            if R[u, i] > 0:
                total_error += abs(R[u, i] - np.dot(P[u], Q[:, i]))
                count += 1
    train_mae = total_error / count if count > 0 else 0.0
    return P, Q.T, train_mae

@numba.njit
def retrain_online_exp(u, train, P_init, Q, W, alpha, beta, K, steps, lamda):
    """
    Per-user online update of latent vector P_u after adding new training item(s).
    Optimizes the same loss as EMF, but only for user u and their observed ratings.
    Returns the updated latent vector for user u.
    """
    P_u = P_init[u].copy()
    Q_t = Q.T
    I = train.shape[1]
    for _ in range(steps):
        for i in range(I):
            if train[u, i] != 0:
                e = train[u, i] - np.dot(P_u, Q_t[:, i])
                for f in range(K):
                    diff = P_u[f] - Q_t[f, i]
                    grad = 2.0 * e * Q_t[f, i] - beta * P_u[f] - lamda * W[u, i] * diff
                    P_u[f] += alpha * grad
    return P_u

def calc_avg(train):
    """
    Computes the per-item average rating (ignoring zeros).
    """
    sums   = np.sum(train, axis=0)
    counts = np.count_nonzero(train, axis=0)
    avg    = np.zeros_like(sums, dtype=float)
    np.divide(sums, counts, out=avg, where=counts > 0)
    return avg


# ExAL selections 

In [None]:
@numba.njit
def active_selection_exal_min(
    u: int,
    test_items: np.ndarray,   
    pool: np.ndarray,
    eR: np.ndarray,
    lR: np.ndarray,
    Q_dot: np.ndarray,
    expl: np.ndarray,
    alpha: float,
    lamda: float
) -> int:

    best_score = np.inf
    best_m = -1
    num_items = pool.shape[1]

    for m in range(num_items):
        if pool[u, m] == 0:
            continue
        Rum = eR[u, m]
        s = 0.0
        for idx in range(test_items.shape[0]):  
            j = test_items[idx]  
            dp = Q_dot[m, j]
            inside = (
                1.0 - eR[u, j]
                + 2.0 * alpha * (
                    (Rum - lR[m]) * dp
                    + lamda * expl[u, m] * (eR[u, j] - dp)
                )
            )
            s += abs(inside)
        if s < best_score:
            best_score = s
            best_m = m
    return best_m

@numba.njit
def active_selection_exal_max(
    u: int,
    test_items: np.ndarray,  
    pool: np.ndarray,
    eR: np.ndarray,
    lR: np.ndarray,
    Q_dot: np.ndarray,
    expl: np.ndarray,
    alpha: float,
    lamda: float
) -> int:

    best_score = -np.inf
    best_m = -1
    num_items = pool.shape[1]

    for m in range(num_items):
        if pool[u, m] == 0:
            continue
        Rum = eR[u, m]
        s = 0.0
        for idx in range(test_items.shape[0]):  
            j = test_items[idx]  
            dp = Q_dot[m, j]
            inside = (
                1.0 - eR[u, j]
                + 2.0 * alpha * (
                    (Rum - lR[m]) * dp
                    + lamda * expl[u, m] * (eR[u, j] - dp)
                )
            )
            s += abs(inside)
        if s > best_score:
            best_score = s
            best_m = m
    return best_m

@numba.njit
def active_selection_exal_max_min(
    u: int,
    test_items: np.ndarray,   
    pool: np.ndarray,
    eR: np.ndarray,
    lR: np.ndarray,
    Q_dot: np.ndarray,
    expl: np.ndarray,
    alpha: float,
    lamda: float,
    iteration: int,
    switch_point: int = 5
) -> int:

    if iteration < switch_point:
        return active_selection_exal_max(
            u, test_items, pool, eR, lR, Q_dot, expl, alpha, lamda  
        )
    else:
        return active_selection_exal_min(
            u, test_items, pool, eR, lR, Q_dot, expl, alpha, lamda 
        )

# Active Learning Baselines

In [None]:
@numba.njit
def active_selection_karimi(
   u: int,
   test_items: np.ndarray,   
   pool: np.ndarray,
   eR: np.ndarray,
   lR: np.ndarray,
   Q_dot: np.ndarray,
   alpha: float
) -> int:

   best_m = -1
   best_score = 1e18
   
   for m in range(pool.shape[1]):
       if pool[u, m] == 0:
           continue
       Rum = eR[u, m]
       s = 0.0
       
       for idx in range(test_items.shape[0]): 
           j = test_items[idx]  
           inside = (
               1.0
               - eR[u, j]  
               + 2.0 * alpha * (Rum - lR[m]) * Q_dot[m, j]
           )
           s += abs(inside)
           
       if s < best_score:
           best_score = s
           best_m = m
   return best_m

In [None]:
@numba.njit
def select_random(u, pool, rand_val):
    """
    Random selection from user u's pool using rand_val in [0,1).
    Ensures reproducibility. Returns −1 if pool is empty.
    """
    # 1) Gather all valid item indices in the pool for user u
    valid = []
    for i in range(pool.shape[1]):
        if pool[u, i] != 0:
            valid.append(i)
    # 2) If none, bail out
    if len(valid) == 0:
        return -1
    # 3) Map rand_val to one of those indices
    idx = int(rand_val * len(valid))
    if idx >= len(valid):
        idx = len(valid) - 1
    return valid[idx]


@numba.njit
def active_selection_uncertainty(u, pool, eR, midpoint=3.0):
    """
    Uncertainty-based selection:
    Select item with predicted rating closest to midpoint (default = 3.0).
    """
    best_idx = -1
    best_dist = np.inf
    for i in range(pool.shape[1]):
        if pool[u, i] != 0:
            d = abs(eR[u, i] - midpoint)
            if d < best_dist:
                best_dist = d
                best_idx = i
    return best_idx



@numba.njit
def active_selection_highest_pred(u, pool, eR):
    best_idx = -1
    best_score = -np.inf
    I = pool.shape[1]
    for i in range(I):
        if pool[u, i] != 0:
            s = eR[u, i]
            if s > best_score:
                best_score = s
                best_idx = i
    return best_idx


@numba.njit
def active_selection_highest_variance(u, pool, eR):
    """
    Highest global variance across users for each item in u's pool.
    """
    U, I = eR.shape
    best_idx = -1
    best_var = -1.0

    # Precompute: mean and mean of squares for each item
    item_mean = np.empty(I)
    item_msq  = np.empty(I)
    for i in range(I):
        s = 0.0
        ss = 0.0
        for uu in range(U):
            x = eR[uu, i]
            s  += x
            ss += x*x
        item_mean[i] = s / U
        item_msq[i]  = ss / U

    for i in range(I):
        if pool[u, i] == 0:
            continue
        var = item_msq[i] - item_mean[i]*item_mean[i]
        if var > best_var:
            best_var = var
            best_idx = i

    return best_idx

# Evaluation metrics

In [None]:
@numba.njit
def topn(eR, n, u):
    """
    Return indices of top-n items for user u by predicted score.
    """
    idx = np.argsort(-eR[u])
    return idx[:n]


@numba.njit
def calculate_MER(eR, W, users, n):
    """
    Mean Explainable Recall @n (MER):
    Fraction of explainable items retrieved in top-n.
    """
    total = 0.0
    count_u = 0
    U = users.shape[0]
    I = W.shape[1]
    for ui in range(U):
        u = users[ui]
        top = topn(eR, n, u)

        # Count explainable items for user u
        expl_total = 0
        for j in range(I):
            if W[u, j] > 0:
                expl_total += 1
        if expl_total == 0:
            continue

        # Count explainable items in top-n
        cnt = 0
        for idx in range(n):
            k = top[idx]
            if W[u, k] > 0:
                cnt += 1

        total += cnt / expl_total
        count_u += 1

    return total / count_u if count_u > 0 else 0.0


@numba.njit
def calculate_MEP(eR, W, users, n):
    """
    Mean Explainable Precision @n (MEP):
    Fraction of top-n items that are explainable.
    """
    MEP = 0.0
    total_expl = 0
    total_n = 0
    U = users.shape[0]

    for ui in range(U):
        u = users[ui]
        top = topn(eR, n, u)

        # Count explainable items in top-n
        cnt = 0
        for idx in range(n):
            k = top[idx]
            if W[u, k] > 0:
                cnt += 1

        MEP += cnt / n
        total_expl += cnt
        total_n += n

    return (MEP / U if U > 0 else 0.0), total_expl, total_n


@numba.njit
def calculate_MAP(eR, test, users, n):
    """
    Mean Average Precision @n (MAP):
    Measures ranking quality for known test items.
    """
    total_ap = 0.0
    U = len(users)
    I = test.shape[1]

    for ui in range(U):
        u = users[ui]

        # Count relevant items for user u
        rel = 0
        for j in range(I):
            if test[u, j] != 0:
                rel += 1
        if rel == 0:
            continue

        top = topn(eR, n, u)
        hits = 0.0
        sum_prec = 0.0

        for rank in range(n):
            j = top[rank]
            if test[u, j] != 0:
                hits += 1
                sum_prec += hits / (rank + 1)

        denom = n if rel >= n else rel
        total_ap += sum_prec / denom

    return total_ap / U if U > 0 else 0.0


def calculate_ndcg(eR, test, users, n):
    """
    Normalized Discounted Cumulative Gain @n (NDCG):
    Measures rank-sensitive relevance using ideal order baseline.
    
    Reference: https://en.wikipedia.org/wiki/Discounted_cumulative_gain
    """
    ndcg_total = 0.0
    valid_users = 0

    for u in users:
        rel = test[u]
        top_n_pred = np.argsort(-eR[u])[:n]
        dcg = 0.0
        idcg = 0.0

        ranked_rels = rel[top_n_pred]
        for i in range(len(ranked_rels)):
            if ranked_rels[i] > 0:
                dcg += 1.0 / np.log2(i + 2)

        ideal_rels = np.sort(rel)[::-1][:n]
        for i in range(len(ideal_rels)):
            if ideal_rels[i] > 0:
                idcg += 1.0 / np.log2(i + 2)

        if idcg > 0:
            ndcg_total += dcg / idcg
            valid_users += 1

    return ndcg_total / valid_users if valid_users > 0 else 0.0


def calculate_novelty(topN_items_per_user, item_popularity):
    """
    Novelty: −log2(popularity) averaged over top-N items for all users.
    Lower popularity = higher novelty.
    
    This metric measures how much the recommender system promotes niche/rare items
    versus popular/mainstream items. Higher novelty indicates better diversity.
    
    Args:
        topN_items_per_user: List of arrays, each containing recommended item indices for one user
        item_popularity: Array where item_popularity[i] = number of users who rated item i
    
    Returns:
        float: Average novelty score across all recommendations (higher = more novel/diverse)
    
    Reference: https://castells.github.io/papers/recsys2011.pdf
    """
    # Step 1: Add small epsilon to avoid log(0) mathematical error
    # Items with 0 ratings would cause log(0) = undefined
    # Adding 1e-6 (0.000001) ensures all items get a valid novelty score
    popularity_safe = item_popularity + 1e-6  # Numerical stability
    
    # Step 2: Compute log2 of popularity for all items
    # Why log2? It transforms multiplicative popularity differences into additive novelty differences
    log_popularity = np.log2(popularity_safe)
    
    # Step 3: Initialize accumulators for averaging
    total_novelty = 0.0  # Sum of all novelty scores
    count = 0           # Total number of recommendations across all users

    # Step 4: Iterate through each user's top-N recommendations
    for top_items in topN_items_per_user:
        # Step 4a: Compute novelty for this user's recommendations
        # Novelty = -log2(popularity), so we negate the log values
        # Why negative? Popular items (high log) should have low novelty
        # Rare items (low log) should have high novelty
        user_novelty = -np.sum(log_popularity[top_items])
        
        # Step 4b: Add to running total
        total_novelty += user_novelty
        
        # Step 4c: Count number of items recommended to this user
        count += len(top_items)

    # Step 5: Return average novelty across all recommendations
    # Dividing by count gives us the mean novelty per recommended item
    return total_novelty / count if count > 0 else 0.0

def calculate_novelty_EFD(topN_items_per_user, item_popularity):
    """
    Expected Free Discovery (EFD): Novelty measure based on inverse collection frequency.
    Higher EFD = recommending more rare/novel items (better for discovery).
    
    This differs from your existing novelty metric by using inverse frequency
    rather than negative log popularity, making it more sensitive to rare items.
    
    Args:
        topN_items_per_user: List of arrays, each containing recommended item indices for one user
        item_popularity: Array where item_popularity[i] = number of users who rated item i
    
    Returns:
        float: Average EFD score (higher = more novel recommendations)
        
    Reference: Vargas & Castells, RecSys 2011
    """
    # Total number of users (for inverse frequency calculation)
    total_users = np.sum(item_popularity > 0).astype(float)
    if total_users == 0:
        return 0.0
    
    # Calculate inverse frequency for each item
    # IDF(i) = 1 / fraction of users who rated item i
    # Add small epsilon to avoid division by zero
    item_freq = (item_popularity + 1e-6) / total_users
    inverse_freq = 1.0 / item_freq
    
    # Calculate EFD
    total_efd = 0.0
    count = 0
    
    for top_items in topN_items_per_user:
        # Sum of inverse frequencies for this user's recommendations
        user_efd = np.sum(inverse_freq[top_items])
        total_efd += user_efd
        count += len(top_items)
    
    # Return average EFD per recommended item
    return total_efd / count if count > 0 else 0.0

def calculate_item_coverage(topN_items_all_users, num_items):
    """
    Item Coverage: Fraction of catalog items that appear in recommendations.
    Critical for ExAL - shows if explainability constraint causes filter bubbles.
    
    Args:
        topN_items_all_users: List of arrays, each with top-N items for a user
        num_items: Total number of items in catalog
    
    Returns:
        float: Coverage ratio (0-1, higher is better)
    """
    unique_items = set()
    for user_items in topN_items_all_users:
        unique_items.update(user_items)
    return len(unique_items) / num_items



def calculate_gini_index(topN_items_all_users, num_items):
    """
    Gini Index: Measures recommendation concentration (0=uniform, 1=concentrated).
    Essential for ExAL - high Gini means recommending same explainable items to everyone.
    
    Args:
        topN_items_all_users: List of recommendation arrays
        num_items: Total catalog size
        
    Returns:
        float: Gini coefficient (0-1, lower is better for diversity)
    """
    # Count how many times each item is recommended
    item_counts = np.zeros(num_items)
    for user_items in topN_items_all_users:
        for item in user_items:
            item_counts[item] += 1
    
    # Sort counts for Gini calculation
    item_counts = np.sort(item_counts)
    n = len(item_counts)
    index = np.arange(1, n + 1)
    
    # Gini coefficient formula
    gini = (2 * np.sum(index * item_counts)) / (n * np.sum(item_counts)) - (n + 1) / n
    return gini



def calculate_ARP(topN_items_all_users, item_popularity):
    """
    Average Recommendation Popularity: Mean popularity of recommended items.
    Crucial for ExAL - explainable items tend to be popular (many neighbors rated them).
    
    Args:
        topN_items_all_users: List of recommendation arrays
        item_popularity: Array where item_popularity[i] = number of users who rated item i
        
    Returns:
        float: Average popularity (lower indicates more diverse/novel recommendations)
    """
    total_popularity = 0
    count = 0
    
    for user_items in topN_items_all_users:
        for item in user_items:
            total_popularity += item_popularity[item]
            count += 1
            
    return total_popularity / count if count > 0 else 0.0

In [None]:
def compute_item_popularity(rating_matrix):
    """
    Reference: https://dl.acm.org/doi/abs/10.1145/3109859.3109912
    
    Compute item popularity as the count of nonzero ratings per item.
    Popularity is defined as the number of users who rated each item.
    
    Args:
        rating_matrix (np.ndarray): User-item rating matrix [num_users x num_items]
        
    Returns:
        np.ndarray: 1D array of item popularity counts [num_items]
    """
    # Count, for each item (column), how many users gave a nonzero rating.
    # This results in an array where each entry is the popularity count for one item.
    return np.count_nonzero(rating_matrix, axis=0)


def assign_popularity_buckets(item_popularity):
    """
    Assign items to 'low', 'medium', and 'high' popularity buckets (equal thirds).
    This is used for popularity bias analysis.
    
    Args:
        item_popularity (np.ndarray): Array of item popularity counts [num_items]
        
    Returns:
        np.ndarray: 1D array with bucket assignments per item: 0=low, 1=medium, 2=high
    """
    # Compute the popularity values at the 33.33% and 66.66% percentiles.
    quantiles = np.percentile(item_popularity, [33.33, 66.66])
    # Create array to hold bucket assignments (default 0 = low popularity).
    buckets = np.zeros_like(item_popularity, dtype=int)
    # Assign bucket label 2 (high popularity) to items above 66.66% quantile.
    buckets[item_popularity > quantiles[1]] = 2
    # Assign bucket label 1 (medium popularity) to items between 33.33% and 66.66% quantiles.
    buckets[(item_popularity > quantiles[0]) & (item_popularity <= quantiles[1])] = 1
    # Items at or below 33.33% quantile remain bucket 0 (low popularity).
    return buckets


def fraction_by_popularity_bucket(topN_items, popularity_buckets):
    """
    Given a set of recommended item indices (topN_items) and each item's popularity bucket,
    compute the fraction of recommendations in each bucket (low/medium/high).
    
    Args:
        topN_items (np.ndarray or list): Indices of recommended items
        popularity_buckets (np.ndarray): Array of item bucket assignments (output of assign_popularity_buckets)
        
    Returns:
        np.ndarray: 1D array with fraction of items in each bucket [frac_low, frac_medium, frac_high]
    """
    # Initialize counters for each bucket (0: low, 1: medium, 2: high).
    bucket_counts = np.zeros(3)
    # Iterate through each recommended item.
    for idx in topN_items:
        # Lookup the item's bucket and increment corresponding counter.
        bucket = popularity_buckets[idx]
        bucket_counts[bucket] += 1
    # Normalize to obtain the fraction for each bucket.
    return bucket_counts / len(topN_items)


# Main experiment loop

In [None]:
# DataFrame utility: concatenate with dropping all-NaN columns
def safe_concat(df_list, ignore_index=True):
    """
    Concatenate DataFrames, dropping any columns that are all-NaN.
    """
    cleaned = [df.dropna(axis=1, how='all') for df in df_list]
    return pd.concat(cleaned, ignore_index=ignore_index)


'''def get_lambda_for_strategy(strategy, lambda_value):
    """
    Only return the real lambda_value for the three ExAL methods.
    Otherwise (all other strategies) return 0 to disable the explainability term.
    """
    return lambda_value if strategy in ('EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max') else 0.0'''


In [None]:

def main(lambda_value, strategy_input=None, return_results=False, seed=None, n_iter=num_iter, results_folder=None, dataset='100k', freeze_Q=True, theta=0.0, neighbor=20,RECOMPUTE_W_EACH_ITER=True):
    global logger

    # 1. Set output folders for results and popularity buckets
    if results_folder is None:
        results_folder = f"Results_{neighbor}/seeds_results"
    pop_folder = f"Results_{neighbor}/Popularity_Buckets"

    # 2. Initialize RNG for reproducibility
    rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()

    # 3. Ensure output directories exist
    os.makedirs(results_folder, exist_ok=True)
    os.makedirs(pop_folder, exist_ok=True)

    # 4. Load dataset and prepare rating matrix
    data_M, _ = load_movielens(dataset)
    rate = pd.DataFrame(data_M)

    logger.info("=" * 60)
    logger.info("STARTING EXAL EXPERIMENT with freeze_Q="
                f"{freeze_Q}, lambda={lambda_value}, strategy={strategy_input}, "
                f"dataset={dataset}, seed={seed}, n_iter={n_iter}, "
                f"neighbor={neighbor}, theta={theta}")
    logger.info("=" * 60)
    logger.info(
        f"freeze_Q={freeze_Q} → "
        f"{'fixed Q during AL' if freeze_Q else 'EMF updates per iteration'}"
    )

    
    fixed_test = None

    train_init, test_init, pool_init, test_user = split(
        rate,
        num_test_users=fixed_test,
        num_train_ratings=3,
        num_test_ratings=20,   
        min_pool_size=10,
        rng=rng
    )
    # Enhanced logging for first 5 test users
    logger.info("==== Data Split Verification (First 5 Test Users) ====")

    for i, u in enumerate(test_user[:5]):
        train_idx = np.where(train_init[u] != 0)[0]
        test_idx = np.where(test_init[u] != 0)[0]
        pool_idx = np.where(pool_init[u] != 0)[0]
        
        logger.info(f"User {u:3d}:")
        logger.info(f"  - Train: {len(train_idx):2d} items - {train_idx.tolist()[:5]}{'...' if len(train_idx) > 5 else ''}")
        logger.info(f"  - Test:  {len(test_idx):2d} items - {test_idx.tolist()[:5]}{'...' if len(test_idx) > 5 else ''} ")
        logger.info(f"  - Pool:  {len(pool_idx):2d} items - {pool_idx.tolist()[:5]}{'...' if len(pool_idx) > 5 else ''}")
    logger.info("="*55)
    
    
    # 6. Log split statistics for the first 5 test users
    logger.info("==== Data Split Check for First 5 Test Users ====")
    for u in test_user[:5]:
        train_idx = np.where(train_init[u] != 0)[0]
        test_idx = np.where(test_init[u] != 0)[0]
        pool_idx = np.where(pool_init[u] != 0)[0]
        logger.info(f"User {u:3d}: train={len(train_idx)}, test={len(test_idx)}, pool={len(pool_idx)}")
    logger.info("==============================================")

    # 7. Compute per-item average rating, item popularity, and assign popularity buckets
    lR = calc_avg(train_init)

    item_popularity = compute_item_popularity(train_init)
    popularity_buckets = assign_popularity_buckets(item_popularity)

    # 8. Select strategies to run (user-defined or default list)
    if strategy_input:
        strategies = [strategy_input]
    else:
        strategies = [
            'EXAL-Min', 
            'EXAL-Max',
            'EXAL-Min-Max',
            'KARIMI', 
            'Uncertainty', 
            'Random', 
            'HighestPred', 
            'HighestVar'
        ]

    results = {}

    # 9. Loop through each active learning strategy
    for strategy in strategies:
        logger.info(f"\n{'='*60}")
        logger.info(f"STRATEGY: {strategy}")
        logger.info(f"{'='*60}")

        # Reset splits and RNG for fair comparison
        rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
        train = np.copy(train_init)
        test = np.copy(test_init)
        pool = np.copy(pool_init)

        # Initialize per-iteration metrics storage
        evolution = pd.DataFrame(columns=["Iteration", "MAE", "MEP", "MER", "F-Score", "MAP"])
        
        #lamda_used = get_lambda_for_strategy(strategy, LAMDA)

        lamda_used = float(lambda_value)

        logger.info(f"Running {strategy} with lamda_used={lamda_used}")

        # 10. Initialize latent factors; compute initial explainability matrix
        P_init, Q_init, _, _, expl = initialize_model(
            train_init, test_init, rng,
            steps=INIT_STEPS, alpha=ALPHA_INIT, beta=BETA,
            K=K, neighbor=neighbor, theta=theta, lamda=lamda_used 
        )

        # Log initial explainability matrix sparsity
        logger.info("\n[Explainability Matrix Sparsity Check]")
        logger.info("Initial explainable items per user (first 10 users):")
        for u in range(min(10, expl.shape[0])):
            num_expl = int(np.sum(expl[u] > 0))
            total_items = expl.shape[1]
            perc = 100.0 * num_expl / total_items
            logger.info(f"  User {u:2d}: {num_expl:4d} / {total_items} items explainable ({perc:.2f}%)")
        logger.info("-" * 55)

        # Deep copy latent factors for use in AL loop
        P, Q = np.copy(P_init), np.copy(Q_init)

        # Initialize MAE tracking
        train_mae_list = []
        test_mae_list = []

        # 11. Active Learning loop  (single loop)
        for iteration in tqdm.tqdm(range(n_iter), desc=f"[{strategy}] AL Iter"):
            logger.info(f"\n--- Iteration {iteration} ---")

            # Update per-item averages
            lR = calc_avg(train)
            if lamda_used > 0 and RECOMPUTE_W_EACH_ITER:
                expl = calc_exp(train, neighbor=neighbor, theta=theta)


            # Predictions for selection + item–item dots
            eR = P.dot(Q.T)
            Q_dot = Q.dot(Q.T)

            # Per-user selection + online update
            for u in tqdm.tqdm(test_user, desc=" Users", leave=False):
                test_items = np.where(test[u, :] != 0)[0]

                if strategy == 'EXAL-Min':
                    j = active_selection_exal_min(u, test_items, pool, eR, lR, Q_dot, expl, ALPHA_RETRAIN, lamda_used)
                elif strategy == 'EXAL-Max':
                    j = active_selection_exal_max(u, test_items, pool, eR, lR, Q_dot, expl, ALPHA_RETRAIN, lamda_used)
                elif strategy == 'EXAL-Min-Max':
                    j = active_selection_exal_max_min(u, test_items, pool, eR, lR, Q_dot, expl,
                                                    ALPHA_RETRAIN, lamda_used, iteration, switch_point=SWITCH)
                elif strategy == 'KARIMI':
                    j = active_selection_karimi(u, test_items, pool, eR, lR, Q_dot, ALPHA_RETRAIN)
                elif strategy == 'Random':
                    j = select_random(u, pool, rng.random())
                elif strategy == 'Uncertainty':
                    j = active_selection_uncertainty(u, pool, eR, midpoint=3.0)
                elif strategy == 'HighestPred':
                    j = active_selection_highest_pred(u, pool, eR)
                elif strategy == 'HighestVar':
                    j = active_selection_highest_variance(u, pool, eR)
                else:
                    raise ValueError(f"Unknown strategy: {strategy}")

                if j >= 0:
                    # move picked rating to train
                    train[u, j] = pool[u, j]
                    pool[u, j]  = 0
                    lR = calc_avg(train)

                    # refresh explainability row for this user (fast mode only)
                    if lamda_used > 0 and not RECOMPUTE_W_EACH_ITER:
                        expl[u, :] = calc_exp_row(train, u, neighbor=neighbor, theta=theta)


                    # online update of P_u
                    P[u] = retrain_online_exp(u, train, P, Q, expl, ALPHA_RETRAIN, BETA, K, ONLINE_STEP, lamda_used)

            # optional EMF pass to update Q (and P) set False to activate online updates for both P and Q
            if not freeze_Q:
                P, Q, _ = EMF_with_explainability(train, P, Q, K,
                                                expl, lamda_used,
                                                steps=ONLINE_STEP,
                                                alpha=ALPHA_RETRAIN,
                                                beta=BETA)
                logger.info(f"[Iter {iteration}] Q updated (mean={Q.mean():.4f}, std={Q.std():.4f})")



            # === recompute predictions AFTER updates (for metrics) ===
            eR = P.dot(Q.T)

            # --- MAE ---
            mae_train = np.nanmean(np.abs(eR[train != 0] - train[train != 0]))
            mae_test  = np.nanmean(np.abs(eR[test  != 0] - test [test  != 0]))
            train_mae_list.append(mae_train)
            test_mae_list.append(mae_test)
            logger.info(f"[{strategy}] Iter {iteration}: TRAIN MAE={mae_train:.4f}, TEST MAE={mae_test:.4f}")

            # --- W stats ---
            num_expl_nonzero = np.sum(expl > 0)
            total_entries    = expl.shape[0] * expl.shape[1]
            percent_nonzero  = 100 * num_expl_nonzero / total_entries
            logger.info(f"[W Sparsity] Non-zero W entries: {num_expl_nonzero}/{total_entries} ({percent_nonzero:.4f}%)")

            percent_exp_user = np.sum(expl > 0, axis=1) / expl.shape[1]
            mean_cov = np.mean(percent_exp_user) * 100
            top5     = np.sort(percent_exp_user)[-5:] * 100
            bot5     = np.sort(percent_exp_user)[:5] * 100
            logger.info(f"[W Coverage] Mean explainable items/user: {mean_cov:.2f}%")
            logger.info(f"Top 5 users w/ most explainable items: {top5}")
            logger.info(f"Bottom 5 users w/ least explainable items: {bot5}")

            # --- Metrics ---
            mask = (train != 0) | (pool != 0)
            eR_masked = eR.copy()
            eR_masked[mask] = -np.inf

            MEP, total_expl, total_n = calculate_MEP(eR_masked, expl, test_user, TopN)
            MER = calculate_MER(eR_masked, expl, test_user, TopN)
            F   = 2*(MEP*MER)/(MEP+MER) if (MEP+MER)>0 else 0.0
            MAPv= calculate_MAP(eR_masked, test, test_user, TopN)

            topN_items, all_top = [], []
            for u in test_user:
                s = eR[u].copy()
                s[mask[u]] = -np.inf
                top = np.argsort(s)[-TopN:][::-1]
                topN_items.append(top)
                all_top.extend(top)

            coverage = calculate_item_coverage(topN_items, train.shape[1])
            gini = calculate_gini_index(topN_items, train.shape[1])
            arp  = calculate_ARP(topN_items, item_popularity)
            gini_diversity = 1 - gini
            novelty     = calculate_novelty(topN_items, item_popularity)
            novelty_efd = calculate_novelty_EFD(topN_items, item_popularity)
            ndcg        = calculate_ndcg(eR_masked, test, test_user, TopN)
            frac_pop    = fraction_by_popularity_bucket(np.array(all_top), popularity_buckets)
            frac_bias   = frac_pop[0] - frac_pop[2]

            mae_high, mae_low = [], []
            for u in test_user:
                recs = topN_items[test_user.tolist().index(u)]
                high = [i for i in recs if popularity_buckets[i]==2 and test[u,i]!=0]
                low  = [i for i in recs if popularity_buckets[i]==0 and test[u,i]!=0]
                if high: mae_high.append(np.mean(np.abs(eR[u,high] - test[u,high])))
                if low:  mae_low .append(np.mean(np.abs(eR[u,low]  - test[u,low])))
            mean_high = np.nan if not mae_high else np.mean(mae_high)
            mean_low  = np.nan if not mae_low  else np.mean(mae_low)
            mae_bias  = mean_low - mean_high

            pop_df = pd.DataFrame({
                "Iteration":[iteration],
                "Frac_Low": [frac_pop[0]],
                "Frac_Med": [frac_pop[1]],
                "Frac_High":[frac_pop[2]],
            })
            pop_key = f"{strategy}_lambda_{lambda_value}_{dataset}"
            pop_path= os.path.join(pop_folder, f"{pop_key}{'_seed_'+str(seed) if seed else ''}.csv")
            if os.path.exists(pop_path):
                old = pd.read_csv(pop_path)
                pop_df = pd.concat([old, pop_df], ignore_index=True)
            pop_df.to_csv(pop_path, index=False)

            iteration_df = pd.DataFrame({
                "Iteration":       [iteration],
                "MAE":             [mae_test],
                "Train_MAE":       [mae_train],
                "Overfit_Gap":     [mae_test-mae_train],
                "MEP":             [MEP],
                "MER":             [MER],
                "F-Score":         [F],
                "MAP":             [MAPv],
                "NDCG":            [ndcg],
                "Novelty":         [novelty],
                "Novelty_EFD":     [novelty_efd],
                "Gini":            [gini_diversity],
                "Coverage":        [coverage],
                "ARP":             [arp],
                "MAE_HighPop":     [mean_high],
                "MAE_LowPop":      [mean_low],
                "MAE_Pop_Bias":    [mae_bias],
                "Frac_Bias":       [frac_bias],
                "Total_Explained": [total_expl],
                "Total_Candidates":[total_n]
            })
            evolution = safe_concat([evolution, iteration_df], ignore_index=True)
            logger.info("Iteration stats:\n" + str(iteration_df))




        # === Save strategy results to disk ===
        key = f"{strategy}_lambda_{lambda_value}_{dataset}"
        results[key] = evolution
        if not return_results:
            path = os.path.join(results_folder, f"{key}{'_seed_'+str(seed) if seed else ''}.csv")
            evolution.to_csv(path, index=False)
            logger.info(f"Saved results to {path}")

            # Log MAE progress
            logger.info(f"Iteration {iteration} complete:")
            logger.info(f"  - Train MAE: {mae_train:.4f}")
            logger.info(f"  - Test MAE: {mae_test:.4f}")
            logger.info(f"  - MEP: {MEP:.4f}, MER: {MER:.4f}, F-Score: {F:.4f}")

    logger.info("\n" + "="*60)
    logger.info("EXPERIMENT COMPLETE")
    logger.info("="*60)

    return results if return_results else None


In [None]:
def multi_seed_experiment(lambda_value, strategy_input=None, seeds=None, n_iter=num_iter,
                          results_folder=None, dataset='100k', freeze_Q=True,
                          theta=0.0, neighbor=20, recompute_w_each_iter=True):    
    """
    Run active learning experiments for multiple seeds and average results.
    """

    # 1. Set default folders for results
    if results_folder is None:
        results_folder = f"Results_{neighbor}"
    seeds_folder = os.path.join(results_folder, "seeds_results")
    pop_folder = os.path.join(results_folder, "Popularity_Buckets")

    # 2. Default seeds if not provided
    if seeds is None:
        seeds = [42, 101, 202, 303, 404, 505, 606, 707, 808, 909]
        #seeds = [42]  # (for debugging)

    # 3. Ensure output directories exist
    os.makedirs(seeds_folder, exist_ok=True)
    os.makedirs(pop_folder, exist_ok=True)
    logger.info(f"Running multi-seed experiment with seeds: {seeds}")

    # 4. Prepare result storage
    results_all = {}  

    # 5. Run experiment for each seed
    for seed in seeds:
        # Ensure RNG state is controlled
        np.random.seed(seed)
        random.seed(seed)

        logger.info(f"==== Seed {seed} ====")

        # Run single-seed experiment
        results = main(
            lambda_value,
            strategy_input=strategy_input,
            return_results=True,
            seed=seed,
            n_iter=n_iter,
            results_folder=seeds_folder,
            dataset=dataset,
            theta=theta,
            neighbor=neighbor,
            freeze_Q=freeze_Q,
            RECOMPUTE_W_EACH_ITER=recompute_w_each_iter   
        )

        # Save and collect results
        for strategy_key, df in results.items():
            out_path = os.path.join(seeds_folder, f"{strategy_key}_seed_{seed}.csv")
            df.to_csv(out_path, index=False)
            logger.info(f"Saved: {out_path}")

            if strategy_key not in results_all:
                results_all[strategy_key] = []
            results_all[strategy_key].append(df.copy())

    # 6. Average results across seeds for each strategy
    for strategy_key, dfs in results_all.items():
        # Create MultiIndex DataFrame from list of runs
        concat_df = pd.concat(dfs, keys=range(len(dfs)), names=['Seed', 'Row'])

        # Group by iteration index and average across seeds
        avg_df = concat_df.groupby('Iteration').mean(numeric_only=True).reset_index()

        avg_csv = os.path.join(seeds_folder, f"AVG_{strategy_key}.csv")
        avg_df.to_csv(avg_csv, index=False)
        logger.info(f"Averaged results for {strategy_key} saved to {avg_csv}")

    logger.info(f"Multi-seed experiment completed. Averaged results in {seeds_folder}/.")

    # 7. Post-process and average popularity results for each strategy
    strategies = [
        'EXAL-Min', 
        'EXAL-Max',
        'EXAL-Min-Max',
        'KARIMI', 
        'Uncertainty', 
        'Random', 
        'HighestPred', 
        'HighestVar'
    ]
    for strat in strategies:
        # Pattern: grab all CSVs with per-seed Popularity bucket results
        pattern = os.path.join(pop_folder, f"{strat}_lambda_{lambda_value}_{dataset}_seed_*.csv")
        files = [f for f in glob.glob(pattern) if not os.path.basename(f).startswith('AVG_')]

        avg_file = os.path.join(pop_folder, f"AVG_{strat}_lambda_{lambda_value}_{dataset}.csv")

        if len(files) == 1:
            # Only one file → copy it
            shutil.copyfile(files[0], avg_file)
            logger.info(f"[INFO] Only one Popularity file for {strat}: copied {files[0]} → {avg_file}")

        elif len(files) > 1:
            # Multiple seeds → average them
            dfs = [pd.read_csv(f) for f in files]
            concat = pd.concat(dfs, keys=range(len(dfs)), names=['Seed', 'Row'])
            avg_df = concat.groupby('Iteration').mean(numeric_only=True).reset_index()
            avg_df.to_csv(avg_file, index=False)
            logger.info(f"[INFO] Averaged Popularity results for {strat} saved to {avg_file}")

        else:
            # No files found
            logger.warning(f"[WARN] No Popularity files found for {strat}.")


In [None]:
if __name__ == '__main__':
    import sys
    import argparse
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        force=True
    )
    logger = logging.getLogger(__name__)

    if 'ipykernel' in sys.argv[0]:
        sys.argv = [sys.argv[0]]

    parser = argparse.ArgumentParser(
        description='Active Learning Experiment for MovieLens'
    )
    parser.add_argument('--lambda_value', type=float, default=LAMDA,
                        help='Lambda regularization parameter')
    parser.add_argument('--strategy', type=str, default=None,
                        help='AL strategy (e.g., "EXAL-Min")')
    parser.add_argument('--dataset', type=str, default='100k',
                        choices=['100k', '1m'],
                        help='MovieLens dataset: 100k or 1m')
    parser.add_argument('--theta', type=float, default=0.0,
                        help='Explainability threshold (theta) for W_{ui}')
    parser.add_argument('--neighbor', type=int, default=NEIGHBOR,
                        help='Number of neighbors for explainability matrix W')
    parser.add_argument('--freeze_Q', dest='freeze_Q', action='store_true',
                        help='Freeze item factors Q during AL iterations (paper-faithful).')
    parser.add_argument('--no-freeze_Q', dest='freeze_Q', action='store_false',
                        help='Allow iteration-end EMF updates to Q (deviation; experimental).')
    parser.set_defaults(freeze_Q=False)
    parser.add_argument('--recompute_w_each_iter', action='store_true', default=True) # True by default if we want to recompute W each iteration False for fast mode

    args = parser.parse_args()
    lambda_v = args.lambda_value
    strategy = args.strategy
    dataset = args.dataset
    theta = args.theta
    freeze_Q = args.freeze_Q
    neighbor = args.neighbor

    multi_seed_experiment(
        lambda_v,
        strategy_input=strategy,
        seeds=None,
        n_iter=num_iter,
        results_folder=None,
        dataset=dataset,
        freeze_Q=freeze_Q,
        theta=theta,
        neighbor=neighbor,
        recompute_w_each_iter=args.recompute_w_each_iter)   
'''

    # Loop over each lambda value
    for lambda_v in LAMDAs:
        logger.info(f"\n=== Running experiments for LAMDA = {lambda_v} ===")
        multi_seed_experiment(
            lambda_v,
            strategy_input=strategy,
            seeds=None,
            n_iter=num_iter,
            results_folder=None,  
            dataset=dataset,
            freeze_Q=freeze_Q,
            theta=theta,
            neighbor=neighbor,
            recompute_w_each_iter=args.recompute_w_each_iter
        )'''


# PLOT the Baselines

### Metrics Over Iterations

In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import logging

# Setup logging for experiment traceability
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# ======= Experiment global constants (set these as needed) =======
LAMDA      = LAMDA        # Explainability regularization strength
TopN       = TopN         # Number of recommendations per user
dataset    = dataset      # Dataset name: '100k' or '1m'
STEPint    = INIT_STEPS   # Steps for MF model initialization
STEPonl    = ONLINE_STEP  # Steps for online MF updates
num_iter   = num_iter     # Active learning iterations
SWITCH     = SWITCH       # Switch point for EXAL-Min-Max hybrid
NEIGHBOR   = NEIGHBOR     # k for k-nearest neighbor explainability

# ======= Define plot styles for each strategy =======
styles = {
    'KARIMI':               {'color': 'blue',      'marker': 'o', 'linestyle': '-'},
    'Random':               {'color': 'green',     'marker': 's', 'linestyle': '-'},
    'HighestVar':           {'color': 'cyan',      'marker': '^', 'linestyle': '-'},
    'HighestPred':          {'color': 'magenta',   'marker': 'v', 'linestyle': '-'},
    'Uncertainty':          {'color': 'orange',    'marker': 'x', 'linestyle': '-'},
    'EXAL-Min':             {'color': 'gold',      'marker': 'D', 'linestyle': '-'},
    'EXAL-Max':             {'color': 'black',     'marker': '*', 'linestyle': '--'},  
    'EXAL-Min-Max':         {'color': 'red',       'marker': 'X', 'linestyle': '--'}  
}

# ======= Metric labels and directions for visualization =======
labels = {
    'MAE':          'Mean Absolute Error (MAE) ↓ better',
    'MEP':          'Mean Explainable Precision (MEP) ↑ better', 
    'MER':          'Mean Explainable Recall (MER) ↑ better',
    'F-Score':      'Explainable F-Score ↑ better',
    'MAP':          'Mean Average Precision (MAP) ↑ better'
}
direction_info = {
    'MAE':          {'better': 'lower', 'arrow': '↓', 'color': 'green'},
    'MEP':          {'better': 'higher', 'arrow': '↑', 'color': 'green'}, 
    'MER':          {'better': 'higher', 'arrow': '↑', 'color': 'green'},
    'F-Score':      {'better': 'higher', 'arrow': '↑', 'color': 'green'},
    'MAP':          {'better': 'higher', 'arrow': '↑', 'color': 'green'}
}
method_categories = {
    'Baselines': ['KARIMI', 'Random', 'HighestVar', 'HighestPred','Uncertainty'],
    'Original ExAL': ['EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max']
}

neighbor = NEIGHBOR
logger.info(f"Generating enhanced plots for NEIGHBOR = {neighbor}")

results_folder = f"Results_{neighbor}/seeds_results"
plot_folder    = f"Results_{neighbor}/Plots_Enhanced_Metrics"
os.makedirs(plot_folder, exist_ok=True)

#  Map each method to its results CSV path 
csv_files = {
    method: os.path.join(
        results_folder,
        f"AVG_{method}_lambda_{LAMDA}_{dataset}.csv"
    )
    for method in styles
}
def ensure_zero_row(df, metric):
    """Guarantee an Iteration==0 row; if missing, copy the first row and set Iteration=0."""
    if (df['Iteration'] == 0).any():
        return df
    first = df.iloc[0].copy()
    first['Iteration'] = 0
    return (pd.concat([pd.DataFrame([first]), df], ignore_index=True)
              .sort_values('Iteration'))
    
    
def create_organized_legend(ax, dfs):
    """Organize legend by strategy category for clearer comparison."""
    legend_elements = []
    for category, methods in method_categories.items():
        category_methods = [m for m in methods if m in dfs and m in styles]
        if category_methods:
            legend_elements.append(plt.Line2D([0], [0], color='none', label=f'─── {category} ───'))
            for method in category_methods:
                style = styles[method]
                legend_elements.append(
                    plt.Line2D([0], [0], 
                              color=style['color'], 
                              marker=style['marker'],
                              linestyle=style['linestyle'],
                              linewidth=2, 
                              markersize=7,
                              label=method)
                )
    return ax.legend(handles=legend_elements, fontsize=12, 
                    loc='center left', bbox_to_anchor=(1, 0.5))

def highlight_best_performers(ax, dfs, metric):
    """Highlight the top-performing method in the final iteration."""
    if not dfs:
        return
    final_values = {m: df[metric].iloc[-1] for m, df in dfs.items() if not df.empty}
    if not final_values:
        return
    is_lower_better = direction_info[metric]['better'] == 'lower'
    best_method = min(final_values, key=final_values.get) if is_lower_better else max(final_values, key=final_values.get)
    best_value = final_values[best_method]
    if best_method in dfs:
        df = dfs[best_method]
        final_iter = df['Iteration'].iloc[-1]
        ax.scatter(final_iter, best_value, 
                  s=150, 
                  facecolors='none', 
                  edgecolors='red', 
                  linewidths=3,
                  label=f'Best: {best_method}')

def add_direction_indicator(ax, metric, ylow, yhigh):
    """Add a text/arrow to indicate if higher/lower values are preferred for the metric."""
    # DISABLED: Direction indicator removed for cleaner plots
    pass

def analyze_trends(dfs, metric):
    """Compute improvement trend from first to last AL iteration for each method."""
    trends = {}
    direction = direction_info[metric]
    is_lower_better = direction['better'] == 'lower'
    for method, df in dfs.items():
        if len(df) < 2:
            continue
        first_val = df[metric].iloc[0]
        last_val = df[metric].iloc[-1]
        improvement = (first_val - last_val) if is_lower_better else (last_val - first_val)
        trend = 'improving' if improvement > 0 else 'declining'
        trends[method] = {
            'improvement': improvement,
            'trend': trend,
            'first': first_val,
            'last': last_val
        }
    return trends

#  Plot metrics across all AL strategies 
for metric, ylabel in labels.items():
    plt.figure(figsize=(16, 10))  # Large, clear plots for paper-quality figures
    dfs = {}
    min_val, max_val = np.inf, -np.inf

    # Load results from CSVs
    for method, path in csv_files.items():
        if not os.path.isfile(path):
            logger.debug(f"Missing file for {method}: {path}")
            continue
        try:
            df = pd.read_csv(path)
            if metric not in df.columns:
                logger.debug(f"Metric {metric} not in file for {method}: {path}")
                continue
            df_clean = df[df[metric].notna() & np.isfinite(df[metric])]
            if df_clean.empty:
                logger.warning(f"No valid data for {metric} in {method}")
                continue
            dfs[method] = df_clean
            min_val = min(min_val, df_clean[metric].min())
            max_val = max(max_val, df_clean[metric].max())
        except Exception as e:
            logger.warning(f"Error loading {method}: {e}")
            continue

    if not dfs:
        logger.warning(f"No data for metric {metric} at NEIGHBOR={neighbor}")
        plt.close()
        continue

    margin = 0.05 * (max_val - min_val) if max_val > min_val else 0.01
    ylow  = min_val - margin
    yhigh = max_val + margin

    # Plot each strategy's learning curve
    for method, df in dfs.items():
        style = styles[method]
        linewidth = 3.5 if 'EXAL' in method else 3.0
        markersize = 7 if 'EXAL' in method else 7
        alpha = 1.0 if 'EXAL' in method else 0.8
        plt.plot(
            df['Iteration'], df[metric],
            label=method,
            linewidth=linewidth,
            marker=style['marker'],
            color=style['color'],
            linestyle=style['linestyle'],
            markersize=markersize,
            alpha=alpha
        )

    # Highlight top performer and add better direction
    highlight_best_performers(plt.gca(), dfs, metric)
    add_direction_indicator(plt.gca(), metric, ylow, yhigh)

    # Plot formatting for clarity and publication
    plt.xlabel('Active Learning Iteration', fontsize=20, fontweight='bold')
    plt.ylabel(ylabel, fontsize=20, fontweight='bold')
    direction = direction_info[metric]
    plt.title(
        f"{ylabel.split('(')[0].strip()} Evolution Across AL Iterations\n"
        f"Dataset: MovieLens-{dataset.upper()} | λ={LAMDA} | Top-N={TopN} | "
        f"Neighbors={neighbor} |{direction['arrow']} Better",
        fontsize=17, fontweight='bold', pad=25
    )
    plt.grid(True, alpha=0.7, linestyle='-', linewidth=0.9)
    plt.ylim(ylow, yhigh)
    max_iter = max(df['Iteration'].max() for df in dfs.values())
    plt.xticks(range(0, int(max_iter) + 1), fontsize=20)
    plt.yticks(fontsize=14)

    # Legend and layout
    create_organized_legend(plt.gca(), dfs)
    plt.tight_layout()

    # ----- Save plots in all desired formats -----
    save_path_png = os.path.join(plot_folder, f"Enhanced_{metric}_Comparison_LAMDA_({LAMDA}).png")
    save_path_pdf = save_path_png.replace('.png', '.pdf')
    save_path_svg = save_path_png.replace('.png', '.svg')

    plt.savefig(save_path_png, dpi=300, bbox_inches='tight', facecolor='white')
    plt.savefig(save_path_pdf, bbox_inches='tight', facecolor='white')
    plt.savefig(save_path_svg, bbox_inches='tight', facecolor='white')
    logger.info(f"Saved PNG plot:  {save_path_png}")
    logger.info(f"Saved PDF plot:  {save_path_pdf}")
    logger.info(f"Saved SVG plot:  {save_path_svg}")

    plt.show()

    # Trend analysis for report or log
    trends = analyze_trends(dfs, metric)
    logger.info(f"\n=== TRENDS ANALYSIS FOR {metric} ===")
    for method, trend_info in sorted(trends.items(), key=lambda x: x[1]['improvement'], reverse=True):
        logger.info(f"{method:20}: {trend_info['trend']:10} "
                   f"({trend_info['first']:.4f} → {trend_info['last']:.4f}, "
                   f"Δ={trend_info['improvement']:+.4f})")


## t-test significance vs. Random baseline 

In [None]:
import os
import glob
import pandas as pd
import numpy as np
from scipy.stats import ttest_rel, shapiro, wilcoxon, normaltest
import seaborn as sns
import matplotlib.pyplot as plt
from tabulate import tabulate
from statsmodels.stats.multitest import multipletests
import warnings
from typing import Dict, List, Tuple
from scipy import stats

# === CONFIGURATION ===
class Config:
    def __init__(self):
        # These need to be set with actual values before running
        self.lambda_value = LAMDA  # Replace with actual lambda value
        self.neighbor = NEIGHBOR   # Replace with actual neighbor value  
        self.dataset = dataset     # Replace with actual dataset name
        
        self.all_methods = [
            'EXAL-Min', 'EXAL-Max', 'EXAL-Min-Max', 
            'KARIMI', 'Random', 'HighestPred', 'HighestVar', 'Uncertainty'
        ]
        self.reference_method = 'Random'
        self.metrics = ["MAE", "MEP", "MER", "F-Score", "MAP"]
        
        # Metric information with correct names
        self.metric_info = {
            "MAE": {
                "direction": False,  # Lower is better
                "full_name": "Mean Absolute Error",
                "description": "Average prediction error",
                "expected_improvement": "decrease"
            },
            "MEP": {
                "direction": True,   # Higher is better
                "full_name": "Mean Explainable Precision",
                "description": "Fraction of top-N items that are explainable",
                "expected_improvement": "increase"
            },
            "MER": {
                "direction": True,   # Higher is better
                "full_name": "Mean Explainable Recall",
                "description": "Fraction of explainable items retrieved in top-N",
                "expected_improvement": "increase"
            },
            "F-Score": {
                "direction": True,   # Higher is better
                "full_name": "Explainable F-Score",
                "description": "Harmonic mean of MEP and MER",
                "expected_improvement": "increase"
            },
            "MAP": {
                "direction": True,   # Higher is better
                "full_name": "Mean Average Precision",
                "description": "Ranking quality for known test items",
                "expected_improvement": "increase"
            }
        }
        
        self.metric_direction = {k: v["direction"] for k, v in self.metric_info.items()}
        self.alpha = 0.05  # Significance level
        self.use_two_sided = True  # Use two-sided tests for exploratory analysis
        self.min_sample_size = 3
        self.result_dir = f"Results_{self.neighbor}/seeds_results"
        self.stat_dir = "stat_results"
        os.makedirs(self.stat_dir, exist_ok=True)

config = Config()
warnings.filterwarnings('ignore')

# === STATISTICAL HELPER FUNCTIONS ===

def cohens_d_paired(x, y):
    """Calculate Cohen's d effect size for paired samples."""
    differences = np.array(x) - np.array(y)
    mean_diff = np.mean(differences)
    std_diff = np.std(differences, ddof=1)  # Sample standard deviation
    
    if std_diff == 0:
        return 0.0 if mean_diff == 0 else np.inf
    
    return mean_diff / std_diff

def interpret_effect_size(d):
    """Interpret Cohen's d effect size according to conventional guidelines."""
    abs_d = abs(d)
    if abs_d < 0.2:
        return "negligible"
    elif abs_d < 0.5:
        return "small"
    elif abs_d < 0.8:
        return "medium"
    else:
        return "large"

def check_normality(data, alpha=0.05):
    """Check normality of data using multiple tests."""
    data = np.array(data)
    n = len(data)
    
    results = {
        'sample_size': n,
        'is_normal': False,
        'shapiro_stat': np.nan,
        'shapiro_p': np.nan,
        'dagostino_stat': np.nan,
        'dagostino_p': np.nan,
        'test_used': 'None'
    }
    
    if n < 3:
        results['test_used'] = 'Insufficient data'
        return results
    
    try:
        # Shapiro-Wilk test (preferred for small samples)
        if n <= 50:
            stat, p_val = shapiro(data)
            results.update({
                'shapiro_stat': stat,
                'shapiro_p': p_val,
                'is_normal': p_val > alpha,
                'test_used': 'Shapiro-Wilk'
            })
        else:
            # D'Agostino-Pearson test for larger samples
            stat, p_val = normaltest(data)
            results.update({
                'dagostino_stat': stat,
                'dagostino_p': p_val,
                'is_normal': p_val > alpha,
                'test_used': 'D\'Agostino-Pearson'
            })
    except Exception as e:
        results['test_used'] = f'Error: {str(e)}'
    
    return results

# === DATA LOADING ===
def load_data(config, strategy='final_only', num_iterations=3, min_iteration=5):
    """Load data for all methods from individual seed files with multiple strategies."""
    print(f"\n===== LOADING DATA: NEIGHBOR = {config.neighbor} =====")
    print(f"Data extraction strategy: {strategy}")
    
    # Initialize storage with metadata
    method_results = {m: {metric: [] for metric in config.metrics} 
                     for m in config.all_methods}
    
    # Also store metadata for advanced analysis
    method_metadata = {m: {'iterations': [], 'seeds': [], 'temporal_order': []} 
                      for m in config.all_methods}
    
    for method in config.all_methods:
        # Find all seed files for this method
        pattern = os.path.join(
            config.result_dir, 
            f"{method}_lambda_{config.lambda_value}_{config.dataset}_seed_*.csv"
        )
        files = sorted(glob.glob(pattern))
        
        print(f"  Found {len(files)} files for {method}")
        
        for file_idx, file_path in enumerate(files):
            try:
                df = pd.read_csv(file_path)
                if not df.empty and 'Iteration' in df.columns:
                    
                    # Extract seed number from filename for tracking
                    seed_num = file_idx  # Fallback
                    if '_seed_' in file_path:
                        try:
                            seed_num = int(file_path.split('_seed_')[1].split('.')[0])
                        except:
                            pass
                    
                    # Select rows based on strategy
                    if strategy == 'final_only':
                        rows_to_use = [len(df)-1]  # Last row index
                        
                    elif strategy == 'last_n':
                        start_idx = max(0, len(df) - num_iterations)
                        rows_to_use = list(range(start_idx, len(df)))
                        
                    elif strategy == 'all_iterations':
                        rows_to_use = list(range(len(df)))
                        
                    elif strategy == 'convergence_period':
                        rows_to_use = [i for i in range(len(df)) if df.iloc[i]['Iteration'] >= min_iteration]
                        
                    elif strategy == 'early_late':
                        early_rows = list(range(min(3, len(df))))
                        late_rows = list(range(max(0, len(df)-3), len(df)))
                        rows_to_use = early_rows + late_rows
                    
                    else:
                        raise ValueError(f"Unknown strategy: {strategy}")
                    
                    # Store each metric value from selected rows
                    for row_idx in rows_to_use:
                        row = df.iloc[row_idx]
                        iteration = row.get('Iteration', row_idx)
                        
                        for metric in config.metrics:
                            if metric in df.columns:
                                value = row.get(metric, np.nan)
                                # Only store valid numeric values
                                if not pd.isna(value) and np.isfinite(value):
                                    method_results[method][metric].append(value)
                                    
                                    # Store metadata for advanced analysis
                                    method_metadata[method]['iterations'].append(iteration)
                                    method_metadata[method]['seeds'].append(seed_num)
                                    method_metadata[method]['temporal_order'].append(len(method_results[method][metric])-1)
                    
            except Exception as e:
                print(f"Warning: Failed to load {file_path}: {e}")
    
    # Print comprehensive sample information
    print(f"\nSample sizes per method (strategy: {strategy}):")
    for method in config.all_methods:
        sample_size = len(method_results[method]['MAE'])
        unique_seeds = len(set(method_metadata[method]['seeds'])) if method_metadata[method]['seeds'] else 0
        avg_iterations_per_seed = sample_size / unique_seeds if unique_seeds > 0 else 0
        print(f"  {method}: {sample_size} observations from {unique_seeds} seeds "
              f"({avg_iterations_per_seed:.1f} iterations/seed)")
    
    return method_results, method_metadata

# === STATISTICAL ANALYSIS ===
def perform_comprehensive_statistical_tests(config, method_results):
    """Perform comprehensive statistical tests comparing each method to the baseline."""
    print(f"\n===== COMPREHENSIVE STATISTICAL ANALYSIS vs {config.reference_method} =====")
    print(f"Test type: {'Two-sided' if config.use_two_sided else 'One-sided'}")
    print(f"Significance level: α = {config.alpha}")
    print(f"Multiple testing correction: Bonferroni")
    
    records = []
    ref_data = method_results[config.reference_method]
    
    # STEP 1: Perform all pairwise comparisons
    for method in config.all_methods:
        if method == config.reference_method:
            continue  # Skip self-comparison
            
        method_data = method_results[method]
        
        for metric in config.metrics:
            print(f"\nAnalyzing {method} vs {config.reference_method} on {metric}...")
            
            # Extract paired values (same seeds for both methods)
            ref_vals = np.array(ref_data[metric])
            method_vals = np.array(method_data[metric])
            
            # STEP 2: Data validation and cleaning
            valid_idx = ~(pd.isna(ref_vals) | pd.isna(method_vals) | 
                         np.isinf(ref_vals) | np.isinf(method_vals))
            ref_vals_clean = ref_vals[valid_idx]
            method_vals_clean = method_vals[valid_idx]
            
            # Check minimum sample size
            n_pairs = len(ref_vals_clean)
            if n_pairs < config.min_sample_size:
                print(f"  Skipped: Insufficient data (n={n_pairs} < {config.min_sample_size})")
                continue
            
            # STEP 3: Descriptive statistics
            ref_mean = np.mean(ref_vals_clean)
            ref_std = np.std(ref_vals_clean, ddof=1)
            method_mean = np.mean(method_vals_clean)
            method_std = np.std(method_vals_clean, ddof=1)
            
            # Calculate differences for paired analysis
            differences = method_vals_clean - ref_vals_clean
            mean_diff = np.mean(differences)
            std_diff = np.std(differences, ddof=1)
            
            # Calculate descriptive statistics for differences
            diff_min = np.min(differences)
            diff_max = np.max(differences)
            diff_median = np.median(differences)
            
            # Relative change calculation
            if abs(ref_mean) > 1e-10:  # Avoid division by zero
                percent_change = (mean_diff / abs(ref_mean)) * 100
            else:
                percent_change = np.nan
            
            # STEP 4: Check normality of differences
            normality_results = check_normality(differences, config.alpha)
            is_normal = normality_results['is_normal']
            
            print(f"  Sample size: n={n_pairs}")
            print(f"  Mean difference: {mean_diff:.6f} ({percent_change:+.2f}%)")
            print(f"  Normality test: {normality_results['test_used']}, p={normality_results.get('shapiro_p', normality_results.get('dagostino_p', 'N/A')):.4f}")
            print(f"  Distribution: {'Normal' if is_normal else 'Non-normal'}")
            
            # STEP 5: Determine test direction
            if config.use_two_sided:
                alternative = 'two-sided'
                print(f"  Test direction: Two-sided (exploratory)")
            else:
                # One-sided test based on expected improvement direction
                is_higher_better = config.metric_direction[metric]
                if is_higher_better:
                    alternative = 'greater'  # H1: method > baseline
                    print(f"  Test direction: One-sided (expect method > baseline)")
                else:
                    alternative = 'less'     # H1: method < baseline  
                    print(f"  Test direction: One-sided (expect method < baseline)")
            
            # STEP 6: Choose and perform appropriate statistical test
            try:
                if std_diff == 0:
                    # No variance in differences
                    p_value = 1.0 if mean_diff == 0 else 0.0
                    test_used = 'No variance in differences'
                    test_statistic = np.nan
                    print(f"  Test: {test_used}")
                    
                elif is_normal and n_pairs >= 5:
                    # PAIRED T-TEST (parametric)
                    t_stat, p_value = ttest_rel(method_vals_clean, ref_vals_clean, 
                                               alternative=alternative)
                    test_used = 'Paired t-test'
                    test_statistic = t_stat
                    degrees_freedom = n_pairs - 1
                    print(f"  Test: Paired t-test, t({degrees_freedom})={t_stat:.4f}, p={p_value:.6f}")
                    
                else:
                    # WILCOXON SIGNED-RANK TEST (non-parametric)
                    non_zero_diff = differences[differences != 0]
                    if len(non_zero_diff) == 0:
                        p_value = 1.0
                        test_used = 'All differences are zero'
                        test_statistic = np.nan
                        print(f"  Test: All differences are zero")
                    else:
                        try:
                            w_stat, p_value = wilcoxon(differences, 
                                                     alternative=alternative,
                                                     zero_method='wilcox',
                                                     mode='auto')
                            test_used = 'Wilcoxon signed-rank test'
                            test_statistic = w_stat
                            print(f"  Test: Wilcoxon signed-rank, W={w_stat:.4f}, p={p_value:.6f}")
                        except Exception as e:
                            # Fallback for edge cases
                            p_value = np.nan
                            test_used = f'Wilcoxon failed: {str(e)}'
                            test_statistic = np.nan
                            print(f"  Test: Wilcoxon failed - {str(e)}")
                        
            except Exception as e:
                p_value = np.nan
                test_used = f'Test failed: {str(e)}'
                test_statistic = np.nan
                print(f"  Test: Failed - {str(e)}")
            
            # STEP 7: Calculate effect size
            effect_size = cohens_d_paired(method_vals_clean, ref_vals_clean)
            effect_interpretation = interpret_effect_size(effect_size)
            print(f"  Effect size: Cohen's d = {effect_size:.4f} ({effect_interpretation})")
            
            # STEP 8: Determine significance and performance
            is_significant = not pd.isna(p_value) and p_value < config.alpha
            
            # Determine performance direction
            if pd.isna(p_value):
                performance = 'Test Failed'
            elif not is_significant:
                performance = 'No Difference'
            else:
                # Significant - determine if better or worse
                is_higher_better = config.metric_direction[metric]
                method_is_better = (mean_diff > 0) if is_higher_better else (mean_diff < 0)
                performance = 'Better' if method_is_better else 'Worse'
            
            print(f"  Result: {performance} (p={p_value:.6f}, {'significant' if is_significant else 'not significant'})")
            
            # STEP 9: Store comprehensive results
            record = {
                # Method and metric info
                'Method': method,
                'Metric': metric,
                'Metric_Direction': 'Higher better' if config.metric_direction[metric] else 'Lower better',
                
                # Sample information
                'Sample_Size': n_pairs,
                'Valid_Pairs': n_pairs,
                
                # Descriptive statistics
                'Method_Mean': method_mean,
                'Method_Std': method_std,
                'Ref_Mean': ref_mean,
                'Ref_Std': ref_std,
                
                # Difference statistics
                'Mean_Diff': mean_diff,
                'Std_Diff': std_diff,
                'Diff_Min': diff_min,
                'Diff_Max': diff_max,
                'Diff_Median': diff_median,
                'Percent_Change': percent_change,
                
                # Normality testing
                'Normality_Test': normality_results['test_used'],
                'Normality_Statistic': normality_results.get('shapiro_stat', 
                                                           normality_results.get('dagostino_stat', np.nan)),
                'Normality_P': normality_results.get('shapiro_p', 
                                                    normality_results.get('dagostino_p', np.nan)),
                'Is_Normal': is_normal,
                
                # Statistical test results
                'Test_Used': test_used,
                'Test_Statistic': test_statistic,
                'Alternative_Hypothesis': alternative,
                'P_Value': p_value,
                'Is_Significant': is_significant,
                
                # Effect size
                'Effect_Size': effect_size,
                'Effect_Interpretation': effect_interpretation,
                
                # Performance assessment
                'Performance': performance
            }
            
            records.append(record)
    
    # Convert to DataFrame
    results_df = pd.DataFrame(records)
    
    # STEP 10: Apply multiple testing correction
    if len(results_df) > 0:
        print(f"\n===== MULTIPLE TESTING CORRECTION =====")
        
        # Extract valid p-values for correction
        p_values = results_df['P_Value'].values
        valid_p_mask = ~pd.isna(p_values)
        
        if np.sum(valid_p_mask) > 0:
            # Apply Bonferroni correction
            valid_p = p_values[valid_p_mask]
            n_tests = len(valid_p)
            
            print(f"Number of tests performed: {n_tests}")
            print(f"Bonferroni correction: α_corrected = {config.alpha}/{n_tests} = {config.alpha/n_tests:.6f}")
            
            # Perform correction
            reject, corrected_p, alpha_sidak, alpha_bonf = multipletests(
                valid_p, 
                alpha=config.alpha, 
                method='bonferroni'
            )
            
            # Initialize corrected arrays
            corrected_p_values = np.full_like(p_values, np.nan)
            significant_after_correction = np.zeros(len(p_values), dtype=bool)
            
            # Store corrected values
            corrected_p_values[valid_p_mask] = corrected_p
            significant_after_correction[valid_p_mask] = reject
            
            # Add to results DataFrame
            results_df['P_Value_Corrected'] = corrected_p_values
            results_df['Significant_After_Correction'] = significant_after_correction
            
            # Update performance based on corrected significance
            def update_performance(row):
                if pd.isna(row['P_Value']):
                    return 'Test Failed'
                elif not row['Significant_After_Correction']:
                    return 'No Difference'
                else:
                    return row['Performance']  # Keep original Better/Worse designation
            
            results_df['Performance_Corrected'] = results_df.apply(update_performance, axis=1)
            
            # Summary of correction impact
            n_significant_before = results_df['Is_Significant'].sum()
            n_significant_after = results_df['Significant_After_Correction'].sum()
            print(f"Significant results before correction: {n_significant_before}")
            print(f"Significant results after correction: {n_significant_after}")
            print(f"Results lost to multiple testing: {n_significant_before - n_significant_after}")
    
    return results_df

def aggregate_by_seed(values, seeds):
    """Aggregate values by seed (helper function for data alignment)."""
    d = {}
    for v, s in zip(values, seeds):
        d.setdefault(s, []).append(float(v))
    return {s: np.mean(vs) for s, vs in d.items()}

def align_by_seed(method_vals, method_seeds, ref_vals, ref_seeds):
    """Align method and reference values by matching seeds."""
    m = aggregate_by_seed(method_vals, method_seeds)
    r = aggregate_by_seed(ref_vals, ref_seeds)
    common = sorted(set(m) & set(r))
    return np.array([m[s] for s in common]), np.array([r[s] for s in common])

# === CLEAN VISUALIZATION ===
def create_clean_heatmap(config, df, strategy='unknown'):
    """Create a clean, professional heatmap for statistical comparisons."""
    if df.empty:
        print("No data available for heatmap")
        return
    
    # Set style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # Create figure with better dimensions
    fig, ax = plt.subplots(figsize=(14, 9))
    
    # Prepare data - exclude reference method from display
    methods = [m for m in config.all_methods if m != config.reference_method]
    metrics = config.metrics
    
    # Create data matrix
    heat_data = np.zeros((len(metrics), len(methods)))
    annotations = np.full((len(metrics), len(methods)), "", dtype=object)
    
    # Fill data
    for _, row in df.iterrows():
        if row['Method'] not in methods:
            continue
            
        method_idx = methods.index(row['Method'])
        metric_idx = metrics.index(row['Metric'])
        
        p_val = row['P_Value']
        performance = row['Performance']
        percent_change = row['Percent_Change']
        effect_size = row['Effect_Size']
        
        # Color coding: -1 to 1 scale
        if pd.isna(p_val) or performance == 'No Difference':
            heat_data[metric_idx, method_idx] = 0
            annotations[metric_idx, method_idx] = f"ns\n{percent_change:+.1f}%"
        elif performance == 'Better':
            if p_val < 0.001:
                heat_data[metric_idx, method_idx] = 1.0
                annotations[metric_idx, method_idx] = f"***\n{percent_change:+.1f}%"
            elif p_val < 0.01:
                heat_data[metric_idx, method_idx] = 0.75
                annotations[metric_idx, method_idx] = f"**\n{percent_change:+.1f}%"
            else:
                heat_data[metric_idx, method_idx] = 0.5
                annotations[metric_idx, method_idx] = f"*\n{percent_change:+.1f}%"
        else:  # Worse
            if p_val < 0.001:
                heat_data[metric_idx, method_idx] = -1.0
                annotations[metric_idx, method_idx] = f"***\n{percent_change:+.1f}%"
            elif p_val < 0.01:
                heat_data[metric_idx, method_idx] = -0.75
                annotations[metric_idx, method_idx] = f"**\n{percent_change:+.1f}%"
            else:
                heat_data[metric_idx, method_idx] = -0.5
                annotations[metric_idx, method_idx] = f"*\n{percent_change:+.1f}%"
    
    # Create enhanced metric labels
    metric_labels = []
    for metric in metrics:
        direction = "↓" if not config.metric_info[metric]["direction"] else "↑"
        metric_labels.append(f"{metric} {direction}")
    
    # Create heatmap
    im = ax.imshow(heat_data, cmap='RdYlGn', aspect='auto', vmin=-1, vmax=1)
    
    # Add annotations with proper sizing
    for i in range(len(metrics)):
        for j in range(len(methods)):
            text = annotations[i, j]
            ax.text(j, i, text, ha="center", va="center", fontsize=17, fontweight='bold')
    
    # Set ticks and labels with proper spacing
    ax.set_xticks(np.arange(len(methods)))
    ax.set_yticks(np.arange(len(metrics)))
    
    # Fix overlapping method names
    ax.set_xticklabels(methods, rotation=0, ha='center', fontsize=13)
    ax.set_yticklabels(metric_labels, fontsize=12)
    
    # Add clean grid
    ax.set_xticks(np.arange(len(methods)+1)-0.5, minor=True)
    ax.set_yticks(np.arange(len(metrics)+1)-0.5, minor=True)
    ax.grid(which="minor", color="white", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", size=0)
    
    # Remove major ticks
    ax.tick_params(which="major", length=0)
    
    # Strategy display mapping
    strategy_display = {
        'final_only': 'Final Iteration Only',
        'all_iterations': 'All Iterations', 
        'convergence_period': 'Convergence Period (Iter 5+)',
        'last_n': 'Last N Iterations',
        'early_late': 'Early + Late Iterations'
    }.get(strategy, strategy)
    
    # Summarize which statistical test was used
    if 'Test_Used' in df.columns:
        test_counts = df['Test_Used'].value_counts()
        if len(test_counts) == 1:
            test_summary = test_counts.index[0]
        else:
            test_summary = f"Mixed ({test_counts.index[0]}: {test_counts.iloc[0]})"
    else:
        test_summary = "Unknown"
    
    # Alternative hypothesis summary
    if 'Alternative_Hypothesis' in df.columns:
        alt_counts = df['Alternative_Hypothesis'].value_counts()
        alt_summary = alt_counts.index[0] if len(alt_counts) == 1 else "Mixed"
    else:
        alt_summary = "Unknown"
    
    # Titles and labels
    ax.set_title(
        f"Statistical Comparison vs {config.reference_method} Baseline "
        f"Strategy: {strategy_display} | λ={config.lambda_value} | Neighbors={config.neighbor}\n"
        f"Test: {test_summary} | Alternative: {alt_summary}",
        fontsize=15, fontweight='bold', pad=25
    )

    ax.set_xlabel('Active Learning Methods', fontsize=15, labelpad=15)
    ax.set_ylabel('Evaluation Metrics', fontsize=15, labelpad=15)
    
    # Add colorbar with better positioning
    cbar = plt.colorbar(im, ax=ax, shrink=0.6, aspect=20, pad=0.02)
    cbar.set_label('Performance vs Random', fontsize=15, labelpad=15)
    cbar.set_ticks([-1, -0.5, 0, 0.5, 1])
    cbar.set_ticklabels(['Much Worse', 'Worse', 'No Diff', 'Better', 'Much Better'], fontsize=12)
    
    # Calculate Bonferroni-corrected alpha
    num_tests = len(df)
    alpha_corrected = config.alpha / num_tests if num_tests > 0 else np.nan

    legend_text = (
        f"GREEN = Better than Random | RED = Worse than Random | YELLOW = No significant difference\n"
        f"*** p<0.001 | ** p<0.01 | * p<0.05 | ns = not significant | ↑ = Higher is better | ↓ = Lower is better\n"
        f"Bonferroni-corrected α = {alpha_corrected:.6f} (Original α = {config.alpha}) for {num_tests} tests"
    )
    
    plt.figtext(0.5, 0.08, legend_text, ha='center', fontsize=10, 
                bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8))
    
    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.18, top=0.90, left=0.12, right=0.95)
    
    # Save with strategy in filename
    base_path = os.path.join(config.stat_dir, f"heatmap_{strategy}_neighbor_{config.neighbor}_lambda_{config.lambda_value}")
    for ext in ['png', 'pdf', 'svg']:
        plt.savefig(f"{base_path}.{ext}", dpi=300, bbox_inches='tight')
    
    print(f"\nSaved heatmap for {strategy_display}: {base_path}.[png|pdf|svg]")
    plt.show()

# === REPORTING ===
def generate_summary_report(config, df, strategy='unknown'):
    """Generate a comprehensive summary report of the statistical analysis."""
    if df.empty:
        print("No data available for reporting")
        return
    
    print("\n" + "="*80)
    print("STATISTICAL ANALYSIS SUMMARY")
    print("="*80)
    
    # Overall performance summary
    print("\n=== PERFORMANCE SUMMARY vs RANDOM ===")
    summary_data = []
    for method in df['Method'].unique():
        method_df = df[df['Method'] == method]
        better_count = sum(method_df['Performance'] == 'Better')
        worse_count = sum(method_df['Performance'] == 'Worse')
        no_diff_count = sum(method_df['Performance'] == 'No Difference')
        total = len(method_df)
        
        # Calculate average effect size
        avg_effect = method_df['Effect_Size'].mean()
        
        summary_data.append({
            'Method': method,
            'Better': f"{better_count}/{total} ({better_count/total*100:.1f}%)",
            'Worse': f"{worse_count}/{total} ({worse_count/total*100:.1f}%)",
            'No_Diff': f"{no_diff_count}/{total} ({no_diff_count/total*100:.1f}%)",
            'Avg_Effect_Size': f"{avg_effect:.3f}"
        })
    
    print(tabulate(summary_data, headers='keys', tablefmt='grid'))
    
    # Best performers per metric
    print("\n=== BEST PERFORMERS PER METRIC ===")
    for metric in config.metrics:
        metric_df = df[df['Metric'] == metric]
        if not metric_df.empty:
            # Find best performer
            if config.metric_direction[metric]:  # Higher is better
                best_row = metric_df.loc[metric_df['Method_Mean'].idxmax()]
            else:  # Lower is better
                best_row = metric_df.loc[metric_df['Method_Mean'].idxmin()]
            
            print(f"\n{metric} ({config.metric_info[metric]['full_name']}):")
            print(f"  Best: {best_row['Method']} (Mean: {best_row['Method_Mean']:.4f})")
            print(f"  vs Random: {best_row['Percent_Change']:+.2f}% change, p={best_row['P_Value']:.4f}")
            print(f"  Effect size: {best_row['Effect_Size']:.3f} ({best_row['Effect_Interpretation']})")
    
    # Save detailed results
    output_path = os.path.join(config.stat_dir, 
                              f"detailed_results_{strategy}_neighbor_{config.neighbor}_lambda_{config.lambda_value}.csv")
    df.to_csv(output_path, index=False)
    print(f"\nDetailed results saved to: {output_path}")
    
    # Save summary
    summary_df = pd.DataFrame(summary_data)
    summary_path = os.path.join(config.stat_dir, 
                                f"summary_{strategy}_neighbor_{config.neighbor}_lambda_{config.lambda_value}.csv")
    summary_df.to_csv(summary_path, index=False)
    print(f"Summary saved to: {summary_path}")

def create_strategy_comparison_plot(all_results, config):
    """Create visualization comparing different data extraction strategies."""
    if len(all_results) < 2:
        print("Need at least 2 strategies for comparison")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'Strategy Comparison: Statistical Power Analysis\nλ={config.lambda_value}, Neighbors={config.neighbor}', 
                 fontsize=16, fontweight='bold')
    
    # 1. Sample sizes comparison
    ax1 = axes[0, 0]
    strategies = list(all_results.keys())
    sample_sizes = []
    for strategy in strategies:
        if len(all_results[strategy]) > 0:
            sample_sizes.append(all_results[strategy]['Sample_Size'].mean())
        else:
            sample_sizes.append(0)
    
    bars1 = ax1.bar(strategies, sample_sizes, color='skyblue', alpha=0.7)
    ax1.set_title('Average Sample Size per Strategy')
    ax1.set_ylabel('Sample Size (n)')
    ax1.tick_params(axis='x', rotation=0)
    for bar, size in zip(bars1, sample_sizes):
        if size > 0:
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                    f'{size:.1f}', ha='center', va='bottom')
    
    # 2. Significance rates comparison
    ax2 = axes[0, 1]
    sig_rates = []
    for strategy in strategies:
        if len(all_results[strategy]) > 0:
            sig_rate = all_results[strategy]['Is_Significant'].mean() * 100
        else:
            sig_rate = 0
        sig_rates.append(sig_rate)
    
    bars2 = ax2.bar(strategies, sig_rates, color='lightcoral', alpha=0.7)
    ax2.set_title('Statistical Significance Rate')
    ax2.set_ylabel('% Tests Significant')
    ax2.tick_params(axis='x', rotation=0)
    ax2.set_ylim(0, 100)
    for bar, rate in zip(bars2, sig_rates):
        if rate > 0:
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                    f'{rate:.1f}%', ha='center', va='bottom')
    
    # 3. Effect sizes comparison
    ax3 = axes[1, 0]
    effect_sizes = []
    for strategy in strategies:
        if len(all_results[strategy]) > 0:
            avg_effect = all_results[strategy]['Effect_Size'].abs().mean()
        else:
            avg_effect = 0
        effect_sizes.append(avg_effect)
    
    bars3 = ax3.bar(strategies, effect_sizes, color='lightgreen', alpha=0.7)
    ax3.set_title('Average |Effect Size| (Cohen\'s d)')
    ax3.set_ylabel('|Cohen\'s d|')
    ax3.tick_params(axis='x', rotation=0)
    for bar, effect in zip(bars3, effect_sizes):
        if effect > 0:
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                    f'{effect:.3f}', ha='center', va='bottom')
    
    # 4. Performance outcomes
    ax4 = axes[1, 1]
    performance_data = []
    for strategy in strategies:
        if len(all_results[strategy]) > 0:
            df = all_results[strategy]
            better = (df['Performance'] == 'Better').sum()
            worse = (df['Performance'] == 'Worse').sum()
            no_diff = (df['Performance'] == 'No Difference').sum()
            total = len(df)
            
            performance_data.append({
                'Strategy': strategy,
                'Better': better/total*100 if total > 0 else 0,
                'Worse': worse/total*100 if total > 0 else 0,
                'No_Diff': no_diff/total*100 if total > 0 else 0
            })
        else:
            performance_data.append({
                'Strategy': strategy,
                'Better': 0,
                'Worse': 0,
                'No_Diff': 0
            })
    
    if performance_data:
        perf_df = pd.DataFrame(performance_data)
        x = np.arange(len(strategies))
        width = 0.25
        
        ax4.bar(x - width, perf_df['Better'], width, label='Better', color='green', alpha=0.7)
        ax4.bar(x, perf_df['No_Diff'], width, label='No Difference', color='yellow', alpha=0.7)
        ax4.bar(x + width, perf_df['Worse'], width, label='Worse', color='red', alpha=0.7)
        
        ax4.set_title('Performance Outcomes Distribution')
        ax4.set_ylabel('% of Tests')
        ax4.set_xlabel('Strategy')
        ax4.set_xticks(x)
        ax4.set_xticklabels(strategies)
        ax4.legend()
        ax4.set_ylim(0, 100)
    
    plt.tight_layout()
    
    # Save the comparison plot
    comparison_path = os.path.join(config.stat_dir, 
                                  f"strategy_comparison_plot_neighbor_{config.neighbor}_lambda_{config.lambda_value}")
    for ext in ['png', 'pdf']:
        plt.savefig(f"{comparison_path}.{ext}", dpi=300, bbox_inches='tight')
    
    print(f"Strategy comparison plot saved: {comparison_path}.[png|pdf]")
    plt.show()

def main():
    """Main execution function with enhanced statistical rigor and full plotting."""
    print("="*100)
    print("COMPREHENSIVE STATISTICAL ANALYSIS FOR EXAL - ENHANCED VERSION WITH PLOTTING")
    print("="*100)
    print(f"Configuration:")
    print(f"  - Lambda: {config.lambda_value}")
    print(f"  - Neighbors: {config.neighbor}")
    print(f"  - Dataset: {config.dataset}")
    print(f"  - Reference Method: {config.reference_method}")
    print(f"  - Significance Level: {config.alpha}")
    print(f"  - Test Type: {'Two-sided' if config.use_two_sided else 'One-sided'}")
    
    # Check configuration
    if config.lambda_value == "LAMDA" or config.neighbor == "NEIGHBOR" or config.dataset == "dataset":
        print("\nWARNING: Please set actual values for lambda_value, neighbor, and dataset in the Config class!")
        print("Current values are placeholders and will cause file loading to fail.")
        return
    
    # CHOOSE DATA EXTRACTION STRATEGY:
    strategies_to_compare = [
        ('final_only', 'Traditional: Final iteration only', {}),
        ('all_iterations', 'Enhanced: All iterations for maximum power', {}),
        ('convergence_period', 'Stable: Iterations 5+ (convergence period)', {'min_iteration': 5})
    ]
    
    all_results = {}
    
    for strategy, description, kwargs in strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"ANALYSIS WITH STRATEGY: {strategy.upper()}")
        print(f"Description: {description}")
        print(f"{'='*60}")
        
        try:
            # Load data with current strategy
            method_results, method_metadata = load_data(config, strategy=strategy, **kwargs)
            
            # Check if we have sufficient data
            total_samples = sum(len(method_results[m]['MAE']) for m in config.all_methods)
            if total_samples == 0:
                print(f"Error: No data loaded for strategy {strategy}. Check file paths and configuration.")
                continue
            
            # Perform statistical analysis
            results_df = perform_comprehensive_statistical_tests(config, method_results)
            
            if not results_df.empty:
                # Store results
                all_results[strategy] = results_df
                
                # Create visualizations for this strategy
                print(f"\n=== CREATING VISUALIZATIONS FOR {strategy.upper()} ===")
                create_clean_heatmap(config, results_df, strategy)
                
                # Generate detailed report
                generate_summary_report(config, results_df, strategy)
                
                # Save strategy-specific results
                output_path = os.path.join(config.stat_dir, 
                                          f"analysis_{strategy}_neighbor_{config.neighbor}_lambda_{config.lambda_value}.csv")
                results_df.to_csv(output_path, index=False)
                print(f"Results saved to: {output_path}")
                
                # Quick summary for this strategy
                print(f"\n=== QUICK SUMMARY FOR {strategy.upper()} ===")
                significant_count = results_df['Is_Significant'].sum()
                total_tests = len(results_df)
                print(f"Significant results: {significant_count}/{total_tests} ({significant_count/total_tests*100:.1f}%)")
                
                # Show sample sizes achieved
                sample_sizes = sorted(results_df['Sample_Size'].unique())
                print(f"Sample sizes achieved: {sample_sizes}")
                
                # Show effect sizes for key comparisons
                exal_methods = results_df[results_df['Method'].str.contains('EXAL')]
                if not exal_methods.empty:
                    avg_effect_size = exal_methods['Effect_Size'].abs().mean()
                    print(f"Average |effect size| for EXAL methods: {avg_effect_size:.3f}")
            else:
                print(f"No statistical results generated for strategy {strategy}")
        
        except Exception as e:
            print(f"Error processing strategy {strategy}: {str(e)}")
            continue
    
    # COMPARISON OF STRATEGIES
    if len(all_results) > 1:
        print(f"\n{'='*80}")
        print("STRATEGY COMPARISON SUMMARY")
        print(f"{'='*80}")
        
        # Create strategy comparison visualization
        create_strategy_comparison_plot(all_results, config)
        
        comparison_data = []
        for strategy, results_df in all_results.items():
            significant_count = results_df['Is_Significant'].sum()
            total_tests = len(results_df)
            avg_sample_size = results_df['Sample_Size'].mean()
            avg_effect_size = results_df['Effect_Size'].abs().mean()
            
            # Count Better vs Worse performance
            better_count = (results_df['Performance'] == 'Better').sum()
            worse_count = (results_df['Performance'] == 'Worse').sum()
            
            comparison_data.append({
                'Strategy': strategy,
                'Avg_Sample_Size': f"{avg_sample_size:.1f}",
                'Significant_Tests': f"{significant_count}/{total_tests}",
                'Sig_Percentage': f"{significant_count/total_tests*100:.1f}%",
                'Better_Performance': f"{better_count}",
                'Worse_Performance': f"{worse_count}",
                'Avg_Effect_Size': f"{avg_effect_size:.3f}"
            })  
        
        comparison_df = pd.DataFrame(comparison_data)
        print(tabulate(comparison_df, headers='keys', tablefmt='grid'))
        
        # Save comparison
        comparison_path = os.path.join(config.stat_dir, 
                                      f"strategy_comparison_neighbor_{config.neighbor}_lambda_{config.lambda_value}.csv")
        comparison_df.to_csv(comparison_path, index=False)
        print(f"\nStrategy comparison saved to: {comparison_path}")
    
    elif len(all_results) == 1:
        print(f"\nOnly one strategy completed successfully: {list(all_results.keys())[0]}")
    else:
        print(f"\nNo strategies completed successfully. Check your configuration and data files.")
    
    print(f"\n{'='*100}")
    print("COMPREHENSIVE STATISTICAL ANALYSIS WITH FULL PLOTTING COMPLETE!")
    print(f"All results, visualizations, and reports saved in: {config.stat_dir}/")
    print(f"{'='*100}")

if __name__ == "__main__":
    main()