In [2]:
import os
import pickle as pkl
import warnings
import numpy as np
from scipy.spatial.distance import cdist
from ot.lp import emd # Needs POT (Python Optimal Transport) library

warnings.filterwarnings("ignore", category=UserWarning, module="torch")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

from interpret.utils import measure_interactions
def FAST(X_train, y_train, n_interactions, init_score=None, feature_names=None, feature_types=None):
    import time
    t0 = time.time()
    interactions = measure_interactions(
        X_train,
        y_train,
        interactions=n_interactions, 
        init_score=init_score,  # Can be a model or initial scores; set to None if not used
        feature_names = feature_names,
        feature_types = feature_types
    )
    
    pairs = []
    for (i, j), _ in interactions:
        pairs.append((i,j))
    return pairs, time.time() - t0

In [3]:
# SEEDS = [2025]  # Define the seed(s) to process
SEEDS = [2025, 1283123, 3043040, 8238238, 123123]
SAVE_DIR = "./models/xai_tris"  # Directory where models are saved
EXPLANATIONS_DIR = "./explanations/xai_tris" # Directory to save MLP SHAP/IG .npy files
os.makedirs(EXPLANATIONS_DIR, exist_ok=True)
# CSV for storing metadata about generated MLP explanations
DEVICE_STR = 'cpu' # Use 'cuda:0' or similar if GPU is available and desired

seed = SEEDS[0]

In [4]:
# with open('./models/old_xai_tris/explanations.pkl', "rb") as f:
#     explanations = pkl.load(f)

with open('./models/xai_tris/explanations_xor_dist_corr_std_qlr.pkl', "rb") as f:
    explanations = pkl.load(f)

In [5]:
explanations.keys()

dict_keys(['xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated_3', 'xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated_2', 'xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated_0', 'xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated_1', 'xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated_4'])

In [11]:
from interpret.utils import measure_interactions
def FAST(X_train, y_train, n_interactions, init_score=None, feature_names=None, feature_types=None):
    import time
    t0 = time.time()
    interactions = measure_interactions(
        X_train,
        y_train,
        interactions=n_interactions, 
        init_score=init_score,  # Can be a model or initial scores; set to None if not used
        feature_names = feature_names,
        feature_types = feature_types
    )
    
    pairs = []
    for (i, j), _ in interactions:
        pairs.append((i,j))
    return pairs, time.time() - t0


D_2D_EDGE = 8 
D_FLAT = D_2D_EDGE * D_2D_EDGE 

normal_t = np.array([[1,0],[1,1],[1,0]])
normal_l = np.array([[1,0],[1,0],[1,1]])
GT_MASK_2D = np.zeros((D_2D_EDGE, D_2D_EDGE), dtype=int)

GT_MASK_2D[1:4, 1:3] = normal_t
GT_MASK_2D[4:7, 5:7] = normal_l

GT_MASK_2D_FLAT = GT_MASK_2D.flatten()


# def create_gt_mask_interactions_1d(main_effect_ground_truth_flat, custom_interaction_pairs, d_flat):
#     if len(main_effect_ground_truth_flat) != d_flat:
#         raise ValueError(f"Length of main_effect_ground_truth_flat ({len(main_effect_ground_truth_flat)}) must match d_flat ({d_flat}).")
#     interaction_ground_truth = []
#     for i, j in custom_interaction_pairs:
#         if not (0 <= i < d_flat and 0 <= j < d_flat):
#             interaction_ground_truth.append(0) # Invalid pair, append 0
#             continue
#         interaction_ground_truth.append(1 if main_effect_ground_truth_flat[i] == 1 and main_effect_ground_truth_flat[j] == 1 else 0)
#     return np.concatenate((main_effect_ground_truth_flat, np.array(interaction_ground_truth, dtype=int)))

# def create_qlr_ground_truth_mask_1d(main_effect_ground_truth_flat, d_flat):
#     if not isinstance(main_effect_ground_truth_flat, np.ndarray):
#         raise TypeError("main_effect_ground_truth_flat must be a numpy array.")
#     if main_effect_ground_truth_flat.ndim != 1:
#         raise ValueError("main_effect_ground_truth_flat must be a 1D array.")
#     if len(main_effect_ground_truth_flat) != d_flat:
#         raise ValueError(f"Length of main_effect_ground_truth_flat ({len(main_effect_ground_truth_flat)}) "
#                          f"must match d_flat ({d_flat}).")
#     if not np.all(np.isin(main_effect_ground_truth_flat, [0, 1])):
#         raise ValueError("main_effect_ground_truth_flat should only contain 0s and 1s.")

#     num_main_features = d_flat

#     # Get indices for the upper triangle, including the diagonal (for x_i*x_j where i <= j)
#     # These are the same indices used in the `quadratic_features` function.
#     inds_0, inds_1 = np.triu_indices(num_main_features, 0)
    
#     num_quadratic_terms = len(inds_0)
#     quadratic_part_mask = np.zeros(num_quadratic_terms, dtype=int)

#     for i in range(num_quadratic_terms):
#         idx_feature_1 = inds_0[i]
#         idx_feature_2 = inds_1[i]
        
#         # The mask for the quadratic term is 1 if both corresponding main features are 1
#         if main_effect_ground_truth_flat[idx_feature_1] == 1 and \
#            main_effect_ground_truth_flat[idx_feature_2] == 1:
#             quadratic_part_mask[i] = 1
#         else:
#             quadratic_part_mask[i] = 0
            
#     # Concatenate the main effect mask with the quadratic effect mask
#     full_mask = np.concatenate((main_effect_ground_truth_flat, quadratic_part_mask))
    
#     return full_mask


# Helper: get flat-index sets for T and L shapes (row-major flatten)
def _xor_shape_index_sets(d_edge, normal_t, normal_l):
    T_rows = slice(1, 4)  # [1:4)
    T_cols = slice(1, 3)  # [1:3)
    L_rows = slice(4, 7)  # [4:7)
    L_cols = slice(5, 7)  # [5:7)

    # Bounds check like your original code
    if d_edge < 7:
        # Not enough room; return empty sets to avoid false positives
        return set(), set()

    mask_t = np.zeros((d_edge, d_edge), dtype=int)
    mask_l = np.zeros((d_edge, d_edge), dtype=int)
    mask_t[T_rows, T_cols] = normal_t
    mask_l[L_rows, L_cols] = normal_l

    # Flat indices where the mini-pattern has ones
    idx_t = set(np.flatnonzero(mask_t.flatten()))
    idx_l = set(np.flatnonzero(mask_l.flatten()))
    return idx_t, idx_l


def create_gt_mask_interactions_1d(main_effect_ground_truth_flat, custom_interaction_pairs, d_flat):
    """
    XOR ground truth for methods using an explicit interaction list (FAST, etc.).
    - Main effects: all zeros (length d_flat)
    - Interactions: 1 iff one index in T and the other in L; else 0.
    """
    if not isinstance(main_effect_ground_truth_flat, np.ndarray):
        raise TypeError("main_effect_ground_truth_flat must be a numpy array.")
    if main_effect_ground_truth_flat.ndim != 1:
        raise ValueError("main_effect_ground_truth_flat must be a 1D array.")
    if len(main_effect_ground_truth_flat) != d_flat:
        raise ValueError(f"Length of main_effect_ground_truth_flat ({len(main_effect_ground_truth_flat)}) must match d_flat ({d_flat}).")

    # Build T/L index sets in the 2D grid and map to flat indices
    idx_t, idx_l = _xor_shape_index_sets(D_2D_EDGE, normal_t, normal_l)

    # Main effects are all zeros for XOR
    main_part = np.zeros(d_flat, dtype=int)

    # Interactions: only cross-shape pairs are 1
    interaction_ground_truth = []
    for (i, j) in custom_interaction_pairs:
        if not (0 <= i < d_flat and 0 <= j < d_flat):
            interaction_ground_truth.append(0)
            continue
        if i == j:
            interaction_ground_truth.append(0)
            continue
        cross = ((i in idx_t and j in idx_l) or (i in idx_l and j in idx_t))
        interaction_ground_truth.append(1 if cross else 0)

    interaction_part = np.array(interaction_ground_truth, dtype=int)
    return np.concatenate((main_part, interaction_part))


def create_qlr_ground_truth_mask_1d(main_effect_ground_truth_flat, d_flat):
    """
    XOR ground truth for QLR (upper-tri ordering i <= j).
    - Main effects: all zeros (length d_flat)
    - Quadratic terms: 1 iff i != j and {i,j} is a cross-shape pair (T vs L); else 0.
      Diagonal terms i==j are set to 0.
    """
    if not isinstance(main_effect_ground_truth_flat, np.ndarray):
        raise TypeError("main_effect_ground_truth_flat must be a numpy array.")
    if main_effect_ground_truth_flat.ndim != 1:
        raise ValueError("main_effect_ground_truth_flat must be a 1D array.")
    if len(main_effect_ground_truth_flat) != d_flat:
        raise ValueError(f"Length of main_effect_ground_truth_flat ({len(main_effect_ground_truth_flat)}) must match d_flat ({d_flat}).")

    # Build T/L index sets
    idx_t, idx_l = _xor_shape_index_sets(D_2D_EDGE, normal_t, normal_l)

    # Main effects are all zeros for XOR importance
    main_part = np.zeros(d_flat, dtype=int)

    # Upper-tri including diagonal, same order as your QLR feature mapping
    inds_0, inds_1 = np.triu_indices(d_flat, 0)
    quadratic_part_mask = np.zeros(len(inds_0), dtype=int)

    for k in range(len(inds_0)):
        i = inds_0[k]; j = inds_1[k]
        if i == j:
            quadratic_part_mask[k] = 0  # no self-importance
            continue
        cross = ((i in idx_t and j in idx_l) or (i in idx_l and j in idx_t))
        quadratic_part_mask[k] = 1 if cross else 0

    return np.concatenate((main_part, quadratic_part_mask))


def importance_mass_accuracy(gt_mask, attribution):
    if not isinstance(gt_mask, np.ndarray) or not isinstance(attribution, np.ndarray):
        return np.nan
    if attribution.ndim != 1 or len(gt_mask) != len(attribution):
        # print(f"IMA shape mismatch: gt_mask {gt_mask.shape}, attribution {attribution.shape}")
        return np.nan
    
    abs_attribution = np.abs(attribution)
    mass_in_gt = np.sum(abs_attribution[gt_mask == 1])
    total_mass = np.sum(abs_attribution)
    
    if total_mass == 0:
        return 1.0 if mass_in_gt == 0 else 0.0 # Perfect score if both are zero mass
    return mass_in_gt / total_mass

def create_cost_matrix(grid_edge_length):
    # Creates a cost matrix for a grid of grid_edge_length * grid_edge_length features
    if grid_edge_length == 0:
        return np.array([]).reshape(0,0)
    
    total_features = grid_edge_length * grid_edge_length
    if total_features == 1:
        return np.array([[0.0]])
        
    indices_matrix = np.indices((grid_edge_length, grid_edge_length))
    coordinates = []
    for r in range(grid_edge_length):
        for c in range(grid_edge_length):
            coordinates.append((indices_matrix[0][r, c], indices_matrix[1][r, c]))
    coordinates = np.array(coordinates)
    return cdist(coordinates, coordinates)

# Cost matrix for main effects (e.g., 8x8 grid -> 64 features)
COST_MATRIX_MAIN_EFFECTS = create_cost_matrix(D_2D_EDGE)

def calculate_emd_score_metric(gt_mask_flat, attribution_flat, grid_edge_length, base_cost_matrix, is_fni=False):
    # Input checks
    if not (isinstance(gt_mask_flat, np.ndarray) and gt_mask_flat.ndim == 1 and
            isinstance(attribution_flat, np.ndarray) and attribution_flat.ndim == 1 and
            len(gt_mask_flat) == len(attribution_flat) and
            len(gt_mask_flat) == grid_edge_length * grid_edge_length):
        return np.nan

    current_cost_matrix = np.copy(base_cost_matrix)
    if is_fni:
        gt_indices = np.where(gt_mask_flat == 1)[0]
        for r_idx in gt_indices:
            for c_idx in gt_indices:
                # Ensure indices are within bounds of the cost matrix
                if r_idx < current_cost_matrix.shape[0] and c_idx < current_cost_matrix.shape[1]:
                    current_cost_matrix[r_idx, c_idx] = 0.0
    
    sum_gt = np.sum(gt_mask_flat)
    abs_attribution = np.abs(attribution_flat)
    sum_attr = np.sum(abs_attribution)

    if sum_gt < 1e-9 and sum_attr < 1e-9: # Effectively both empty
        return 1.0
    if sum_gt < 1e-9 or sum_attr < 1e-9: # One empty, other not
        return 0.0
        
    dist_gt = gt_mask_flat.astype(np.float64) / sum_gt
    dist_attr = abs_attribution.astype(np.float64) / sum_attr
    
    # Ensure distributions are C-contiguous and float64 for EMD
    dist_gt_c = np.ascontiguousarray(dist_gt, dtype=np.float64)
    dist_attr_c = np.ascontiguousarray(dist_attr, dtype=np.float64)
    current_cost_matrix_c = np.ascontiguousarray(current_cost_matrix, dtype=np.float64)

    cost_val = 0.0
    if grid_edge_length * grid_edge_length > 1 : # EMD makes sense for >1 feature
        try:
            # Note: emd returns the cost value directly, not a tuple if log=False
            _, cost_val = emd(dist_gt_c, dist_attr_c, current_cost_matrix_c, numItermax=200000, log=True)
        except Exception:
            return np.nan # EMD calculation failed
    
    # Dmax = max Euclidean distance in the grid
    d_max = np.sqrt(2 * (grid_edge_length - 1)**2) if grid_edge_length > 1 else 0.0

    if d_max == 0: # Handles grid_edge_length = 1 (single feature) or cases where d_max is ill-defined
        return 1.0 if np.isclose(cost_val, 0) else 0.0
    
    return 1 - (cost_val['cost'] / d_max)

# --- Main Processing Logic ---
def process_explanations_metrics(explanations_dict, gt_mask_main_flat, d_flat_main, d_edge_main, cost_mat_main):
    results_raw_collection = {} 

    for scenario_full_name, methods_data in explanations_dict.items():
        if 'translations' in scenario_full_name:
            continue
        scenario_parts = scenario_full_name.split('_')
        try:
            # dataset_ind = int(scenario_parts[-1]) # Not used directly in aggregation key structure here
            scenario_base_name = '_'.join(scenario_parts[:-1])
        except ValueError:
            scenario_base_name = scenario_full_name
            # print(f"Warning: Could not parse dataset index for {scenario_full_name}. Using full name as base.")

        interaction_pairs_for_scenario = []
        if 'xor' in scenario_full_name.lower():
            data_path = f'./data/xai_tris/{scenario_full_name}.pkl'
            with open(data_path, "rb") as f:
                data = pkl.load(f)
                
                X_train_tensor = data.x_train.float() 
                y_train_tensor = data.y_train

            try:
                interaction_pairs_for_scenario, _ = FAST(X_train_tensor, y_train_tensor, n_interactions=128)
            except Exception as e:
                # print(f"Warning: Placeholder FAST failed for {scenario_full_name}: {e}")
                interaction_pairs_for_scenario = []
        
        if scenario_base_name not in results_raw_collection:
            results_raw_collection[scenario_base_name] = {}

        for method_name, explanation_content in methods_data.items():
            if method_name not in results_raw_collection[scenario_base_name]:
                results_raw_collection[scenario_base_name][method_name] = {'IMA': [], 'EMD': [], 'FNI_EMD': []}

            current_explanation = explanation_content
            if isinstance(current_explanation, list): # For "ebm 192 <--- LIST"
                current_explanation = np.array(current_explanation)
            
            if not isinstance(current_explanation, np.ndarray):
                # print(f"Warning: Data for {scenario_full_name}/{method_name} is not array/list. Skipping.")
                results_raw_collection[scenario_base_name][method_name]['IMA'].append(np.nan)
                results_raw_collection[scenario_base_name][method_name]['EMD'].append(np.nan)
                results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(np.nan)
                continue

            # Process attribution: average local, flatten global
            processed_attr = np.array([])
            if current_explanation.ndim > 1 and current_explanation.shape[0] > 1 and \
               not (current_explanation.shape[0] == d_flat_main and current_explanation.ndim == 2 and current_explanation.shape[1] == 1 and d_flat_main > 1):
                # processed_attr = np.mean(current_explanation, axis=0).squeeze()
                if current_explanation[0].ndim == 3:
                    # processed_attr = np.mean(np.vstack(current_explanation.squeeze(2)), axis=0)
                    processed_attr = np.mean(np.mean(np.array(current_explanation), axis=0).squeeze(), axis=0)
                else:
                    processed_attr = np.mean(np.vstack(current_explanation), axis=0)
            else:
                processed_attr = current_explanation.flatten()
            
            # --- IMA Calculation ---
            gt_mask_for_ima = None
            attr_for_ima = None

            if method_name in ['pattern_gam', 'ebm', 'nam'] and 'xor' in scenario_full_name.lower():
                gt_mask_for_ima = create_gt_mask_interactions_1d(gt_mask_main_flat, interaction_pairs_for_scenario, d_flat_main)
                attr_for_ima = processed_attr
            elif method_name == 'pattern_qlr':
                gt_mask_for_ima = create_qlr_ground_truth_mask_1d(gt_mask_main_flat, d_flat_main)
                attr_for_ima = processed_attr
            else: 
                gt_mask_for_ima = gt_mask_main_flat
                attr_for_ima = processed_attr[:d_flat_main] 

            # Adjust attribution length for IMA if mismatch (truncate/pad)
            if gt_mask_for_ima is not None and attr_for_ima is not None and len(attr_for_ima) != len(gt_mask_for_ima):
                # print(f"Warning IMA {scenario_full_name}/{method_name}: attr len {len(attr_for_ima)}, GT len {len(gt_mask_for_ima)}. Adjusting.")
                temp_attr = np.zeros(len(gt_mask_for_ima))
                common_len = min(len(attr_for_ima), len(gt_mask_for_ima))
                temp_attr[:common_len] = attr_for_ima[:common_len]
                attr_for_ima = temp_attr
            
            ima_val = importance_mass_accuracy(gt_mask_for_ima, attr_for_ima)
            results_raw_collection[scenario_base_name][method_name]['IMA'].append(ima_val)

            # --- EMD & FNI-EMD (main effects: first d_flat_main features, d_edge_main grid) ---
            attr_main_eff = processed_attr[:d_flat_main]
            gt_main_eff = gt_mask_main_flat # GT for EMD/FNI is always main effects

            emd_val = np.nan
            fni_emd_val = np.nan
            if len(attr_main_eff) == d_flat_main and d_flat_main > 0 : # Ensure correct length for main effects
                emd_val = calculate_emd_score_metric(gt_main_eff, attr_main_eff, d_edge_main, cost_mat_main, is_fni=False)
                fni_emd_val = calculate_emd_score_metric(gt_main_eff, attr_main_eff, d_edge_main, cost_mat_main, is_fni=True)
            
            results_raw_collection[scenario_base_name][method_name]['EMD'].append(emd_val)
            results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(fni_emd_val)

    # --- Aggregation of results ---
    final_aggregated_results = {}
    for sc_base, meth_data in results_raw_collection.items():
        final_aggregated_results[sc_base] = {}
        for meth, metrics_lists in meth_data.items():
            final_aggregated_results[sc_base][meth] = {
                'IMA_mean': np.nanmean(metrics_lists['IMA']) if metrics_lists['IMA'] else np.nan,
                'IMA_std': np.nanstd(metrics_lists['IMA']) if metrics_lists['IMA'] else np.nan,
                'EMD_mean': np.nanmean(metrics_lists['EMD']) if metrics_lists['EMD'] else np.nan,
                'EMD_std': np.nanstd(metrics_lists['EMD']) if metrics_lists['EMD'] else np.nan,
                'FNI_EMD_mean': np.nanmean(metrics_lists['FNI_EMD']) if metrics_lists['FNI_EMD'] else np.nan,
                'FNI_EMD_std': np.nanstd(metrics_lists['FNI_EMD']) if metrics_lists['FNI_EMD'] else np.nan,
            }
    return final_aggregated_results

print(f"--- Configuration ---")
print(f"D_2D_EDGE (grid edge for main effects): {D_2D_EDGE}")
print(f"D_FLAT (total main effect features): {D_FLAT}")
print(f"Sum of GT_MASK_2D_FLAT: {np.sum(GT_MASK_2D_FLAT)}")
print(f"Cost matrix for main effects shape: {COST_MATRIX_MAIN_EFFECTS.shape}")
print(f"--- Starting Metric Calculation ---")

with open('./models/xai_tris/explanations_xor_dist_corr_std_qlr.pkl', "rb") as f:
    explanations = pkl.load(f)


aggregated_metrics = process_explanations_metrics(
    explanations,
    GT_MASK_2D_FLAT,
    D_FLAT,
    D_2D_EDGE,
    COST_MATRIX_MAIN_EFFECTS
)

print(f"\n--- Aggregated Results ---")
for scenario_name, method_data in aggregated_metrics.items():
    print(f"\nScenario Type: {scenario_name}")
    for method, scores in method_data.items():
        print(f"  Method: {method}")
        print(f"    IMA    : Mean = {scores['IMA_mean']:.4f}, Std = {scores['IMA_std']:.4f}")
        print(f"    EMD    : Mean = {scores['EMD_mean']:.4f}, Std = {scores['EMD_std']:.4f}")
        print(f"    FNI-EMD: Mean = {scores['FNI_EMD_mean']:.4f}, Std = {scores['FNI_EMD_std']:.4f}")

--- Configuration ---
D_2D_EDGE (grid edge for main effects): 8
D_FLAT (total main effect features): 64
Sum of GT_MASK_2D_FLAT: 8
Cost matrix for main effects shape: (64, 64)
--- Starting Metric Calculation ---

--- Aggregated Results ---

Scenario Type: xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated
  Method: pattern_gam
    IMA    : Mean = 0.5688, Std = 0.0803
    EMD    : Mean = 0.8095, Std = 0.0182
    FNI-EMD: Mean = 0.8196, Std = 0.0202
  Method: pattern_qlr
    IMA    : Mean = 0.2278, Std = 0.0385
    EMD    : Mean = 0.7832, Std = 0.0221
    FNI-EMD: Mean = 0.7941, Std = 0.0250
  Method: kernel_svm
    IMA    : Mean = 0.1231, Std = 0.0520
    EMD    : Mean = 0.7781, Std = 0.0351
    FNI-EMD: Mean = 0.7869, Std = 0.0377
  Method: ebm
    IMA    : Mean = 0.0304, Std = 0.0184
    EMD    : Mean = 0.8151, Std = 0.0017
    FNI-EMD: Mean = 0.8213, Std = 0.0017
  Method: shap
    IMA    : Mean = 0.4130, Std = 0.0354
    EMD    : Mean = 0.8734, Std = 0.0237
    FNI-EMD: Mean = 0.

In [15]:
import numpy as np
import pickle as pkl
from scipy.spatial.distance import cdist
from ot import emd
from interpret.utils import measure_interactions
import os

# ----------------------
# FAST (unchanged)
# ----------------------
def FAST(X_train, y_train, n_interactions, init_score=None, feature_names=None, feature_types=None):
    import time
    t0 = time.time()
    interactions = measure_interactions(
        X_train,
        y_train,
        interactions=n_interactions, 
        init_score=init_score,
        feature_names=feature_names,
        feature_types=feature_types
    )
    pairs = []
    for (i, j), _ in interactions:
        pairs.append((i, j))
    return pairs, time.time() - t0


# ----------------------
# Config / shapes
# ----------------------
D_2D_EDGE = 8 
D_FLAT = D_2D_EDGE * D_2D_EDGE 

normal_t = np.array([[1,0],[1,1],[1,0]])
normal_l = np.array([[1,0],[1,0],[1,1]])
GT_MASK_2D = np.zeros((D_2D_EDGE, D_2D_EDGE), dtype=int)

# Build the visual pattern (kept for reference/plots if you need it)
if D_2D_EDGE >= 7: 
    GT_MASK_2D[1:4, 1:3] = normal_t
    GT_MASK_2D[4:7, 5:7] = normal_l
else:
    print(f"Warning: D_2D_EDGE ({D_2D_EDGE}) is too small for the example GT_MASK_2D pattern. GT mask will be mostly zeros.")

GT_MASK_2D_FLAT = GT_MASK_2D.flatten()  # This is the ORIGINAL main-effects mask (non-XOR meaning)


# ----------------------
# XOR helpers
# ----------------------
def _xor_shape_index_sets(d_edge, normal_t, normal_l):
    """Return sets of flat indices for the T and L mini-shapes."""
    T_rows = slice(1, 4)  # [1:4)
    T_cols = slice(1, 3)  # [1:3)
    L_rows = slice(4, 7)  # [4:7)
    L_cols = slice(5, 7)  # [5:7)

    if d_edge < 7:
        return set(), set()

    mask_t = np.zeros((d_edge, d_edge), dtype=int)
    mask_l = np.zeros((d_edge, d_edge), dtype=int)
    mask_t[T_rows, T_cols] = normal_t
    mask_l[L_rows, L_cols] = normal_l

    idx_t = set(np.flatnonzero(mask_t.flatten()))
    idx_l = set(np.flatnonzero(mask_l.flatten()))
    return idx_t, idx_l

def _is_xor_scenario(name: str) -> bool:
    return 'xor' in str(name).lower()


# ----------------------
# Interaction pair builders
# ----------------------
def qlr_pairs_upper_tri(d):
    """Pairs (i,j) with i <= j in np.triu_indices order."""
    i_idx, j_idx = np.triu_indices(d, 0)
    return list(zip(i_idx.tolist(), j_idx.tolist()))

def get_fast_pairs_for_scenario(scenario_full_name, n_interactions=128):
    """
    Attempt to load the dataset for this scenario and compute FAST pairs.
    Returns [] if unavailable.
    """
    data_path = f'./data/xai_tris/{scenario_full_name}.pkl'
    if not os.path.exists(data_path):
        return []
    try:
        with open(data_path, "rb") as f:
            data = pkl.load(f)
        X_train_tensor = data.x_train.float()
        y_train_tensor = data.y_train
        pairs, _ = FAST(X_train_tensor, y_train_tensor, n_interactions=n_interactions)
        return pairs
    except Exception:
        return []


# ----------------------
# Ground-truth constructors (64 + interactions)
# ----------------------
def create_gt_mask_interactions_generic(
    main_effect_ground_truth_flat: np.ndarray,
    custom_interaction_pairs,
    d_flat: int,
    d_edge: int,
    is_xor: bool,
    idx_t=None,
    idx_l=None
):
    """
    For methods with explicit interaction list (NAM/EBM/PatternGAM).
    XOR:
        - mains = 0
        - interactions = 1 iff one endpoint in T and the other in L; diag ignored (pairs should be i != j).
    non-XOR:
        - mains = main_effect_ground_truth_flat
        - interactions = 1 iff both endpoints are 1 in main_effect_ground_truth_flat.
    """
    if not isinstance(main_effect_ground_truth_flat, np.ndarray) or main_effect_ground_truth_flat.ndim != 1:
        raise ValueError("main_effect_ground_truth_flat must be a 1D numpy array.")
    if len(main_effect_ground_truth_flat) != d_flat:
        raise ValueError("main_effect_ground_truth_flat length must equal d_flat.")

    m = len(custom_interaction_pairs)
    if is_xor:
        if idx_t is None or idx_l is None:
            idx_t, idx_l = _xor_shape_index_sets(d_edge, normal_t, normal_l)
        main_part = np.zeros(d_flat, dtype=int)
        inter = []
        for (i, j) in custom_interaction_pairs:
            if not (0 <= i < d_flat and 0 <= j < d_flat) or i == j:
                inter.append(0); continue
            cross = ((i in idx_t and j in idx_l) or (i in idx_l and j in idx_t))
            inter.append(1 if cross else 0)
        inter_part = np.array(inter, dtype=int)
        return np.concatenate([main_part, inter_part])
    else:
        main_part = main_effect_ground_truth_flat.astype(int)
        inter = []
        for (i, j) in custom_interaction_pairs:
            good = (0 <= i < d_flat) and (0 <= j < d_flat)
            inter.append(1 if (good and main_part[i] == 1 and main_part[j] == 1) else 0)
        inter_part = np.array(inter, dtype=int)
        return np.concatenate([main_part, inter_part])


def create_qlr_ground_truth_mask_generic(
    main_effect_ground_truth_flat: np.ndarray,
    d_flat: int,
    d_edge: int,
    is_xor: bool,
    idx_t=None,
    idx_l=None
):
    """
    For PatternQLR (mains + all degree-2 terms in np.triu order, incl. diagonal).
    XOR:
        - mains = 0
        - quadratic terms: 1 iff i != j and {i,j} is a cross-shape pair (T vs L); diag = 0.
    non-XOR:
        - mains = main_effect_ground_truth_flat
        - quadratic terms: 
              diag(i,i) = 1 iff main_effect_ground_truth_flat[i] == 1
              off-diag(i<j) = 1 iff both endpoints are 1 in main GT.
    """
    if not isinstance(main_effect_ground_truth_flat, np.ndarray) or main_effect_ground_truth_flat.ndim != 1:
        raise ValueError("main_effect_ground_truth_flat must be a 1D numpy array.")
    if len(main_effect_ground_truth_flat) != d_flat:
        raise ValueError("main_effect_ground_truth_flat length must equal d_flat.")

    inds_0, inds_1 = np.triu_indices(d_flat, 0)
    quad = np.zeros(len(inds_0), dtype=int)

    if is_xor:
        if idx_t is None or idx_l is None:
            idx_t, idx_l = _xor_shape_index_sets(d_edge, normal_t, normal_l)
        main_part = np.zeros(d_flat, dtype=int)
        for k in range(len(inds_0)):
            i, j = int(inds_0[k]), int(inds_1[k])
            if i == j:
                quad[k] = 0
            else:
                cross = ((i in idx_t and j in idx_l) or (i in idx_l and j in idx_t))
                quad[k] = 1 if cross else 0
        return np.concatenate([main_part, quad])
    else:
        main_part = main_effect_ground_truth_flat.astype(int)
        for k in range(len(inds_0)):
            i, j = int(inds_0[k]), int(inds_1[k])
            if i == j:
                quad[k] = 1 if main_part[i] == 1 else 0
            else:
                quad[k] = 1 if (main_part[i] == 1 and main_part[j] == 1) else 0
        return np.concatenate([main_part, quad])


# ----------------------
# EMD / FNI-EMD
# ----------------------
def calculate_emd_score_metric(gt_mask_flat, attribution_flat, grid_edge_length, base_cost_matrix, is_fni=False):
    """
    64-D EMD/FNI-EMD (main effects grid). Kept for non-interaction methods.
    """
    if not (isinstance(gt_mask_flat, np.ndarray) and gt_mask_flat.ndim == 1 and
            isinstance(attribution_flat, np.ndarray) and attribution_flat.ndim == 1 and
            len(gt_mask_flat) == len(attribution_flat) and
            len(gt_mask_flat) == grid_edge_length * grid_edge_length):
        return np.nan

    current_cost_matrix = np.copy(base_cost_matrix)
    if is_fni:
        gt_indices = np.where(gt_mask_flat == 1)[0]
        for r_idx in gt_indices:
            for c_idx in gt_indices:
                if r_idx < current_cost_matrix.shape[0] and c_idx < current_cost_matrix.shape[1]:
                    current_cost_matrix[r_idx, c_idx] = 0.0
    
    sum_gt = np.sum(gt_mask_flat)
    abs_attribution = np.abs(attribution_flat)
    sum_attr = np.sum(abs_attribution)

    if sum_gt < 1e-9 and sum_attr < 1e-9:
        return 1.0
    if sum_gt < 1e-9 or sum_attr < 1e-9:
        return 0.0
        
    dist_gt = gt_mask_flat.astype(np.float64) / sum_gt
    dist_attr = abs_attribution.astype(np.float64) / sum_attr

    dist_gt_c = np.ascontiguousarray(dist_gt, dtype=np.float64)
    dist_attr_c = np.ascontiguousarray(dist_attr, dtype=np.float64)
    current_cost_matrix_c = np.ascontiguousarray(current_cost_matrix, dtype=np.float64)

    _, log = emd(dist_gt_c, dist_attr_c, current_cost_matrix_c, numItermax=200000, log=True)
    d_max = np.sqrt(2 * (grid_edge_length - 1)**2) if grid_edge_length > 1 else 0.0
    if d_max == 0:
        return 1.0 if np.isclose(log['cost'], 0) else 0.0
    return 1 - (log['cost'] / d_max)

def _grid_coords_for_main(d_edge):
    coords = []
    for r in range(d_edge):
        for c in range(d_edge):
            coords.append((float(r), float(c)))
    return np.array(coords, dtype=float)

def _coords_for_interactions_from_pairs(d_edge, pairs):
    main_coords = _grid_coords_for_main(d_edge)
    coords = []
    for (i, j) in pairs:
        ri, ci = main_coords[i]
        rj, cj = main_coords[j]
        coords.append(((ri + rj) / 2.0, (ci + cj) / 2.0))
    return np.array(coords, dtype=float)

def _emd_with_arbitrary_cost(gt_vec, attr_vec, cost_matrix, d_edge, is_fni=False):
    """EMD/FNI-EMD for arbitrary square cost matrix; normalization keeps main-grid Dmax."""
    if gt_vec.shape != attr_vec.shape:
        return np.nan
    if cost_matrix.shape[0] != cost_matrix.shape[1] or cost_matrix.shape[0] != gt_vec.shape[0]:
        return np.nan

    C = np.array(cost_matrix, dtype=float)
    if is_fni:
        gt_idx = np.flatnonzero(gt_vec > 0)
        if gt_idx.size > 0:
            C = C.copy()
            C[np.ix_(gt_idx, gt_idx)] = 0.0

    sum_gt = float(np.sum(gt_vec))
    sum_attr = float(np.sum(attr_vec))
    if sum_gt < 1e-9 and sum_attr < 1e-9:
        return 1.0
    if sum_gt < 1e-9 or sum_attr < 1e-9:
        return 0.0

    p = (gt_vec / sum_gt).astype(np.float64)
    q = (attr_vec / sum_attr).astype(np.float64)
    _, log = emd(p, q, C, numItermax=200000, log=True)

    d_max = np.sqrt(2 * (d_edge - 1)**2) if d_edge > 1 else 0.0
    if d_max == 0:
        return 1.0 if np.isclose(log['cost'], 0) else 0.0
    return 1.0 - (log['cost'] / d_max)


# ----------------------
# IMA
# ----------------------
def importance_mass_accuracy(gt_mask, attribution):
    if not isinstance(gt_mask, np.ndarray) or not isinstance(attribution, np.ndarray):
        return np.nan
    if attribution.ndim != 1 or len(gt_mask) != len(attribution):
        return np.nan
    
    abs_attribution = np.abs(attribution)
    mass_in_gt = np.sum(abs_attribution[gt_mask == 1])
    total_mass = np.sum(abs_attribution)
    
    if total_mass == 0:
        return 1.0 if mass_in_gt == 0 else 0.0
    return mass_in_gt / total_mass


# ----------------------
# Cost matrix (64-D main grid)
# ----------------------
def create_cost_matrix(grid_edge_length):
    if grid_edge_length == 0:
        return np.array([]).reshape(0, 0)
    total = grid_edge_length * grid_edge_length
    if total == 1:
        return np.array([[0.0]])
    indices_matrix = np.indices((grid_edge_length, grid_edge_length))
    coordinates = [(indices_matrix[0][r, c], indices_matrix[1][r, c])
                   for r in range(grid_edge_length) for c in range(grid_edge_length)]
    coordinates = np.array(coordinates)
    return cdist(coordinates, coordinates)

COST_MATRIX_MAIN_EFFECTS = create_cost_matrix(D_2D_EDGE)


# ----------------------
# Main processing (original metrics now full-dim when applicable)
# ----------------------
INTERACTION_METHODS = {'pattern_gam', 'ebm', 'nam', 'discr', 'pattern_qlr'}

def process_explanations_metrics(explanations_dict, main_effect_gt_flat, d_flat_main, d_edge_main, cost_mat_main):
    """
    For interaction methods, compute ORIGINAL metrics using the full feature space:
      - NAM/EBM/PatternGAM: length = 64 + len(FAST pairs)
      - PatternQLR: length = 64 + D*(D+1)/2 (mains + all degree-2 terms in triu order)
    For other methods, keep the 64-D main-effects evaluation.
    """
    results_raw_collection = {} 

    for scenario_full_name, methods_data in explanations_dict.items():
        if 'translations' in scenario_full_name:
            continue

        scenario_parts = scenario_full_name.split('_')
        try:
            scenario_base_name = '_'.join(scenario_parts[:-1])
        except ValueError:
            scenario_base_name = scenario_full_name

        is_xor = _is_xor_scenario(scenario_full_name)
        idx_t, idx_l = _xor_shape_index_sets(d_edge_main, normal_t, normal_l) if is_xor else (None, None)

        # Try to get FAST pairs (order must match models' interaction ordering)
        pairs_for_scenario = get_fast_pairs_for_scenario(scenario_full_name, n_interactions=128)

        if scenario_base_name not in results_raw_collection:
            results_raw_collection[scenario_base_name] = {}

        for method_name, explanation_content in methods_data.items():
            if method_name not in results_raw_collection[scenario_base_name]:
                results_raw_collection[scenario_base_name][method_name] = {'IMA': [], 'EMD': [], 'FNI_EMD': []}

            # normalize explanation to a single 1D vector
            current_explanation = explanation_content
            if isinstance(current_explanation, list):
                current_explanation = np.array(current_explanation)

            if not isinstance(current_explanation, np.ndarray):
                results_raw_collection[scenario_base_name][method_name]['IMA'].append(np.nan)
                results_raw_collection[scenario_base_name][method_name]['EMD'].append(np.nan)
                results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(np.nan)
                continue

            if current_explanation.ndim > 1 and current_explanation.shape[0] > 1 and \
               not (current_explanation.shape[0] == d_flat_main and current_explanation.ndim == 2 and current_explanation.shape[1] == 1 and d_flat_main > 1):
                if np.ndim(current_explanation[0]) == 3:
                    processed_attr = np.mean(np.mean(np.array(current_explanation), axis=0).squeeze(), axis=0)
                else:
                    processed_attr = np.mean(np.vstack(current_explanation), axis=0)
            else:
                processed_attr = current_explanation.flatten()

            # -----------------------
            # Build ground truth & geometry matching the method's mapping
            # -----------------------
            use_full = (method_name in INTERACTION_METHODS) and (processed_attr.size > d_flat_main)

            if method_name in {'nam', 'ebm', 'pattern_gam', 'discr'} and use_full:
                # NAM/EBM/PatternGAM: mains + FAST pairs
                if len(pairs_for_scenario) == 0:
                    # Fall back (cannot fabricate interaction order)
                    gt_for_ima = main_effect_gt_flat
                    attr_for_ima = processed_attr[:d_flat_main]
                    # 64-D EMD/FNI only
                    emd_val = calculate_emd_score_metric(main_effect_gt_flat, attr_for_ima, d_edge_main, cost_mat_main, is_fni=False)
                    fni_val = calculate_emd_score_metric(main_effect_gt_flat, attr_for_ima, d_edge_main, cost_mat_main, is_fni=True)
                    ima_val = importance_mass_accuracy(gt_for_ima, attr_for_ima)
                    results_raw_collection[scenario_base_name][method_name]['IMA'].append(ima_val)
                    results_raw_collection[scenario_base_name][method_name]['EMD'].append(emd_val)
                    results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(fni_val)
                    continue

                # GT (64 + m)
                gt_full = create_gt_mask_interactions_generic(
                    main_effect_gt_flat, pairs_for_scenario, d_flat_main, d_edge_main, is_xor, idx_t, idx_l
                )
                m = len(pairs_for_scenario)
                # Align attribution length to 64+m
                full_len = d_flat_main + m
                attr_full = processed_attr[:full_len]
                if attr_full.shape[0] != full_len:
                    tmp = np.zeros(full_len, dtype=float)
                    n = min(len(processed_attr), full_len)
                    tmp[:n] = processed_attr[:n]
                    attr_full = tmp

                # Geometry: mains at grid, interactions at midpoints of FAST pairs
                main_coords = _grid_coords_for_main(d_edge_main)         # (64,2)
                inter_coords = _coords_for_interactions_from_pairs(d_edge_main, pairs_for_scenario)  # (m,2)
                coords = np.vstack([main_coords, inter_coords])          # (64+m,2)
                C = cdist(coords, coords)

                # IMA / EMD / FNI-EMD in full space
                ima_val = importance_mass_accuracy(gt_full, attr_full)
                emd_val = _emd_with_arbitrary_cost(gt_full.astype(float), np.abs(attr_full).astype(float), C, d_edge_main, is_fni=False)
                fni_val = _emd_with_arbitrary_cost(gt_full.astype(float), np.abs(attr_full).astype(float), C, d_edge_main, is_fni=True)

                results_raw_collection[scenario_base_name][method_name]['IMA'].append(ima_val)
                results_raw_collection[scenario_base_name][method_name]['EMD'].append(emd_val)
                results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(fni_val)

            elif method_name == 'pattern_qlr' and use_full:
                # QLR: mains + all quadratic terms in upper-tri order (incl diag)
                pairs_triu = qlr_pairs_upper_tri(d_flat_main)
                gt_full = create_qlr_ground_truth_mask_generic(
                    main_effect_gt_flat, d_flat_main, d_edge_main, is_xor, idx_t, idx_l
                )
                # Align attribution to (64 + len(triu))
                full_len = d_flat_main + len(pairs_triu)
                attr_full = processed_attr[:full_len]
                if attr_full.shape[0] != full_len:
                    tmp = np.zeros(full_len, dtype=float)
                    n = min(len(processed_attr), full_len)
                    tmp[:n] = processed_attr[:n]
                    attr_full = tmp

                # Geometry: mains on grid; interactions at midpoints for triu pairs
                main_coords = _grid_coords_for_main(d_edge_main)
                inter_coords = _coords_for_interactions_from_pairs(d_edge_main, pairs_triu)
                coords = np.vstack([main_coords, inter_coords])
                C = cdist(coords, coords)

                ima_val = importance_mass_accuracy(gt_full, attr_full)
                emd_val = _emd_with_arbitrary_cost(gt_full.astype(float), np.abs(attr_full).astype(float), C, d_edge_main, is_fni=False)
                fni_val = _emd_with_arbitrary_cost(gt_full.astype(float), np.abs(attr_full).astype(float), C, d_edge_main, is_fni=True)

                results_raw_collection[scenario_base_name][method_name]['IMA'].append(ima_val)
                results_raw_collection[scenario_base_name][method_name]['EMD'].append(emd_val)
                results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(fni_val)

            else:
                # Non-interaction methods or methods without interaction tail -> 64-D original behavior
                attr_main_eff = processed_attr[:d_flat_main]
                gt_main_eff = main_effect_gt_flat

                # IMA: 64-D
                ima_val = importance_mass_accuracy(gt_main_eff, attr_main_eff)

                # EMD & FNI-EMD: 64-D
                emd_val = calculate_emd_score_metric(gt_main_eff, attr_main_eff, d_edge_main, cost_mat_main, is_fni=False)
                fni_val = calculate_emd_score_metric(gt_main_eff, attr_main_eff, d_edge_main, cost_mat_main, is_fni=True)

                results_raw_collection[scenario_base_name][method_name]['IMA'].append(ima_val)
                results_raw_collection[scenario_base_name][method_name]['EMD'].append(emd_val)
                results_raw_collection[scenario_base_name][method_name]['FNI_EMD'].append(fni_val)

    # --- Aggregation ---
    final_aggregated_results = {}
    for sc_base, meth_data in results_raw_collection.items():
        final_aggregated_results[sc_base] = {}
        for meth, metrics_lists in meth_data.items():
            final_aggregated_results[sc_base][meth] = {
                'IMA_mean': np.nanmean(metrics_lists['IMA']) if metrics_lists['IMA'] else np.nan,
                'IMA_std':  np.nanstd(metrics_lists['IMA'])  if metrics_lists['IMA'] else np.nan,
                'EMD_mean': np.nanmean(metrics_lists['EMD']) if metrics_lists['EMD'] else np.nan,
                'EMD_std':  np.nanstd(metrics_lists['EMD'])  if metrics_lists['EMD'] else np.nan,
                'FNI_EMD_mean': np.nanmean(metrics_lists['FNI_EMD']) if metrics_lists['FNI_EMD'] else np.nan,
                'FNI_EMD_std':  np.nanstd(metrics_lists['FNI_EMD'])  if metrics_lists['FNI_EMD'] else np.nan,
            }
    return final_aggregated_results


# ----------------------
# Run
# ----------------------
print(f"--- Configuration ---")
print(f"D_2D_EDGE (grid edge for main effects): {D_2D_EDGE}")
print(f"D_FLAT (total main effect features): {D_FLAT}")
print(f"Sum of GT_MASK_2D_FLAT (non-XOR main GT): {np.sum(GT_MASK_2D_FLAT)}")
print(f"Cost matrix for main effects shape: {COST_MATRIX_MAIN_EFFECTS.shape}")
print(f"--- Starting Metric Calculation ---")

with open('./models/xai_tris/explanations_xor_dist_corr_std_qlr.pkl', "rb") as f:
    explanations = pkl.load(f)

aggregated_metrics = process_explanations_metrics(
    explanations,
    GT_MASK_2D_FLAT,   # For XOR runs, mains are overridden to 0 inside the GT builders for interaction methods
    D_FLAT,
    D_2D_EDGE,
    COST_MATRIX_MAIN_EFFECTS
)

print(f"\n--- Aggregated Results ---")
for scenario_name, method_data in aggregated_metrics.items():
    print(f"\nScenario Type: {scenario_name}")
    for method, scores in method_data.items():
        print(f"  Method: {method}")
        print(f"    IMA    : Mean = {scores['IMA_mean']:.4f}, Std = {scores['IMA_std']:.4f}")
        print(f"    EMD    : Mean = {scores['EMD_mean']:.4f}, Std = {scores['EMD_std']:.4f}")
        print(f"    FNI-EMD: Mean = {scores['FNI_EMD_mean']:.4f}, Std = {scores['FNI_EMD_std']:.4f}")


--- Configuration ---
D_2D_EDGE (grid edge for main effects): 8
D_FLAT (total main effect features): 64
Sum of GT_MASK_2D_FLAT (non-XOR main GT): 8
Cost matrix for main effects shape: (64, 64)
--- Starting Metric Calculation ---

--- Aggregated Results ---

Scenario Type: xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated
  Method: pattern_gam
    IMA    : Mean = 0.5688, Std = 0.0803
    EMD    : Mean = 0.9209, Std = 0.0108
    FNI-EMD: Mean = 0.9319, Std = 0.0112
  Method: pattern_qlr
    IMA    : Mean = 0.2278, Std = 0.0385
    EMD    : Mean = 0.8809, Std = 0.0123
    FNI-EMD: Mean = 0.8876, Std = 0.0127
  Method: kernel_svm
    IMA    : Mean = 0.1231, Std = 0.0520
    EMD    : Mean = 0.7781, Std = 0.0351
    FNI-EMD: Mean = 0.7869, Std = 0.0377
  Method: ebm
    IMA    : Mean = 0.0304, Std = 0.0184
    EMD    : Mean = 0.8376, Std = 0.0549
    FNI-EMD: Mean = 0.8392, Std = 0.0541
  Method: shap
    IMA    : Mean = 0.4130, Std = 0.0354
    EMD    : Mean = 0.8734, Std = 0.0237
    

In [12]:
import numpy as np

def parse_scenario_name_for_table(scenario_base_name):
    """
    Parses a scenario base name into a formatted string for table rows.
    """
    name_lower = scenario_base_name.lower()
    
    base_type_str = "OTHER" 
    if "xor" in name_lower:
        base_type_str = "XOR"
    elif "multiplicative" in name_lower:
        base_type_str = "MULT"
    elif "linear" in name_lower: 
        base_type_str = "LIN"
    elif "additive" in name_lower: 
        base_type_str = "LIN" 

    dist_str = ""
    if "_distractor" in name_lower or "distractor_" in name_lower:
        dist_str = "DIST"
        
    bg_str = "UNK_BG" 
    if "white" in name_lower:
        bg_str = "WHITE"
    elif "correlated" in name_lower:
        bg_str = "CORR"
        
    return f"{base_type_str} {dist_str} {bg_str}".replace("  ", " ").strip()

def get_scenario_sort_key_v2(parsed_scenario_name):
    """
    Generates a sort key tuple for a parsed scenario name to achieve the desired row order:
    Primary group: Base Type (LIN, MULT, XOR, OTHER)
    Secondary group (within Base Type): Non-DIST vs DIST
    Tertiary group (within above): WHITE, CORR, UNK_BG
    """
    parts = parsed_scenario_name.split(' ') 

    base_type = parts[0] 
    base_type_key = 3 
    if base_type == "LIN": base_type_key = 0
    elif base_type == "MULT": base_type_key = 1
    elif base_type == "XOR": base_type_key = 2
    
    is_dist_key = 1 if "DIST" in parts else 0

    bg_key = 2 
    if "WHITE" in parts: bg_key = 0
    elif "CORR" in parts: bg_key = 1
        
    return (base_type_key, is_dist_key, bg_key)


def format_latex_value(mean_val, std_val, precision=2, is_bold=False):
    """Formats mean and std into LaTeX string, handling NaN and bolding."""
    if mean_val is None or std_val is None or np.isnan(mean_val) or np.isnan(std_val):
        return "-"
    
    mean_str = f"{mean_val:.{precision}f}"
    std_str = f"{std_val:.{precision}f}"
    core_expression = f"{mean_str} \\pm {std_str}"
    
    if is_bold:
        return f"$\\mathbf{{{core_expression}}}$"
    else:
        return f"${core_expression}$"

def generate_latex_tables_final_layout(aggregated_metrics): # Renamed main function
    """
    Generates three LaTeX tables (IMA, EMD, FNI-EMD) from aggregated metrics,
    with specified row and column sorting, and emboldening the best result per row.
    """
    data_for_tables = {}
    all_method_names_in_data = set()

    for scenario_base_name, methods_data in aggregated_metrics.items():
        parsed_name = parse_scenario_name_for_table(scenario_base_name)
        if parsed_name not in data_for_tables:
            data_for_tables[parsed_name] = {}
        
        for method_name, metrics_values in methods_data.items():
            all_method_names_in_data.add(method_name)
            data_for_tables[parsed_name][method_name] = metrics_values

    # Sort row labels using the custom key
    sorted_row_labels = sorted(list(data_for_tables.keys()), key=get_scenario_sort_key_v2)
    
    # Define custom column order
    custom_column_order = [
        'pattern_gam', 'pattern_qlr', 'nam', 'ebm', 'kernel_svm', 
        'pattern_net', 'pattern_attribution', 'shap', 'ig'
    ]

    # Create the final ordered list of method names for columns
    ordered_method_columns = [method for method in custom_column_order if method in all_method_names_in_data]
    
    remaining_methods_in_data = sorted(
        [method for method in all_method_names_in_data if method not in ordered_method_columns]
    )
    ordered_method_columns.extend(remaining_methods_in_data)


    latex_output_tables = {}
    metrics_config = [
        ("IMA", "IMA_mean", "IMA_std", "Importance Mass Accuracy (IMA)"),
        ("EMD", "EMD_mean", "EMD_std", "Earth Mover's Distance (EMD)"),
        ("FNI-EMD", "FNI_EMD_mean", "FNI_EMD_std", "Feature Preservation Index EMD (FNI-EMD)")
    ]

    for metric_key, mean_key, std_key, caption_title in metrics_config:
        num_columns = len(ordered_method_columns)
        table_cols_format = "l|" + "c" * num_columns 
        
        latex_str = f"% LaTeX Table for {metric_key} with custom layout and best result emboldened\n"
        latex_str += "\\begin{table}[htbp]\n"
        latex_str += "\\centering\n"
        latex_str += f"\\caption{{{caption_title}. Values are mean $\\pm$ standard deviation. Best result per row is emboldened.}}\n"
        latex_str += f"\\label{{tab:{metric_key.lower().replace('-', '')}_results_final_layout}}\n"
        latex_str += f"\\resizebox{{\\textwidth}}{{!}}{{\n" 
        latex_str += f"\\begin{{tabular}}{{{table_cols_format}}}\n"
        latex_str += "\\hline\n"
        
        header_methods = [name.replace("_", "\\_") for name in ordered_method_columns]
        latex_str += "Scenario & " + " & ".join(header_methods) + " \\\\ \\hline\\hline\n"
        
        for row_label in sorted_row_labels:
            means_for_current_row = []
            for method_name_iter in ordered_method_columns: # Use the new column order here
                val = data_for_tables.get(row_label, {}).get(method_name_iter, {}).get(mean_key)
                if val is not None and not np.isnan(val):
                    means_for_current_row.append(val)
                else:
                    means_for_current_row.append(-np.inf) 

            max_mean_in_row = -np.inf 
            if any(m != -np.inf for m in means_for_current_row):
                max_mean_in_row = np.nanmax([m if m != -np.inf else np.nan for m in means_for_current_row])

            row_values_formatted = [row_label] 
            for method_name in ordered_method_columns: # And here for data fetching
                method_metrics = data_for_tables.get(row_label, {}).get(method_name, {})
                mean_val = method_metrics.get(mean_key)
                std_val = method_metrics.get(std_key)
                
                is_best_in_row = False
                if mean_val is not None and not np.isnan(mean_val) and \
                   max_mean_in_row != -np.inf and not np.isnan(max_mean_in_row):
                    if np.isclose(mean_val, max_mean_in_row):
                        is_best_in_row = True
                
                row_values_formatted.append(format_latex_value(mean_val, std_val, is_bold=is_best_in_row))
            
            latex_str += " & ".join(row_values_formatted) + " \\\\\n"
            
        latex_str += "\\hline\n"
        latex_str += "\\end{tabular}\n"
        latex_str += "}\n" 
        latex_str += "\\end{table}\n"
        
        latex_output_tables[metric_key] = latex_str
        
    return latex_output_tables

if __name__ == '__main__':
    # Sample aggregated_metrics dictionary to test sorting and column order


    print("--- Generating LaTeX Tables with Final Custom Row and Column Layout ---")
    latex_tables_final_layout = generate_latex_tables_final_layout(aggregated_metrics)

    for metric_name, table_code in latex_tables_final_layout.items():
        print(f"\n% --- LaTeX Table for {metric_name} (final layout, bolded best) ---")
        print(table_code)
        print("% --- End of LaTeX Table ---")

--- Generating LaTeX Tables with Final Custom Row and Column Layout ---

% --- LaTeX Table for IMA (final layout, bolded best) ---
% LaTeX Table for IMA with custom layout and best result emboldened
\begin{table}[htbp]
\centering
\caption{Importance Mass Accuracy (IMA). Values are mean $\pm$ standard deviation. Best result per row is emboldened.}
\label{tab:ima_results_final_layout}
\resizebox{\textwidth}{!}{
\begin{tabular}{l|cccccccccc}
\hline
Scenario & pattern\_gam & pattern\_qlr & nam & ebm & kernel\_svm & pattern\_net & pattern\_attribution & shap & ig & discr \\ \hline\hline
XOR DIST CORR & $0.57 \pm 0.08$ & $0.23 \pm 0.04$ & $0.36 \pm 0.04$ & $0.03 \pm 0.02$ & $0.12 \pm 0.05$ & $0.61 \pm 0.19$ & $\mathbf{0.77 \pm 0.13}$ & $0.41 \pm 0.04$ & $0.73 \pm 0.09$ & $0.12 \pm 0.03$ \\
\hline
\end{tabular}
}
\end{table}

% --- End of LaTeX Table ---

% --- LaTeX Table for EMD (final layout, bolded best) ---
% LaTeX Table for EMD with custom layout and best result emboldened
\begin{table}

In [8]:
# ==== ADDITIONS / REPLACEMENTS START HERE ====

import numpy as np
import pickle as pkl
from scipy.spatial.distance import cdist
from ot import emd

# -----------------------------
# (1) Redistribution functions
# -----------------------------

# --- original sum version (absolute attributions; each pair split 0.5/0.5) ---
def transform_interaction_sum(pattern, interaction_pairs=(), d=64):
    pattern = np.asarray(pattern)
    feat_imp = np.zeros(d, dtype=float)
    feat_imp += np.abs(pattern[:d])
    inter_vals = np.abs(pattern[d:])
    for k, (i, j) in enumerate(interaction_pairs):
        v = inter_vals[k]
        feat_imp[i] += 0.5 * v
        feat_imp[j] += 0.5 * v
    return feat_imp

# --- weighted sum version (divide pairwise mass by node degree) ---
def transform_interaction_weighted_sum(pattern, interaction_pairs=(), d=64):
    pattern = np.asarray(pattern)
    main_abs = np.abs(pattern[:d])
    inter_vals = np.abs(pattern[d:])

    inter_sum = np.zeros(d, dtype=float)
    degree = np.zeros(d, dtype=float)

    for k, (i, j) in enumerate(interaction_pairs):
        v = inter_vals[k]
        inter_sum[i] += v
        inter_sum[j] += v
        degree[i] += 1.0
        degree[j] += 1.0

    # divide per-feature interaction mass by its incident count when > 0
    out = main_abs.copy()
    nonzero = degree > 0
    out[nonzero] += inter_sum[nonzero] / degree[nonzero]
    out[~nonzero] += 0.0
    return out

# --- max redistribution, signed (metrics later take abs anyway) ---
def transform_interaction_max_signed(pattern, interaction_pairs=(), d=None):
    pattern = np.asarray(pattern)
    m = len(interaction_pairs)
    if d is None:
        d = pattern.size - m
    if d + m != pattern.size or d <= 0:
        raise ValueError("Length mismatch: expected len(pattern) == d + len(interaction_pairs).")
    main = pattern[:d]
    inter = pattern[d:] if m > 0 else np.empty(0, dtype=pattern.dtype)
    adj = [[] for _ in range(d)]
    for k, (i, j) in enumerate(interaction_pairs):
        if not (0 <= i < d and 0 <= j < d):
            raise ValueError(f"interaction_pairs[{k}] = ({i},{j}) out of range [0,{d-1}]")
        adj[i].append(k); adj[j].append(k)
    out = np.empty(d, dtype=float)
    for i in range(d):
        best_val = float(main[i]); best_abs = abs(best_val)
        for k in adj[i]:
            v = float(inter[k]); av = abs(v)
            if av > best_abs:
                best_val = v; best_abs = av
        out[i] = best_val
    return out

# helper: unsigned version of max (for symmetry with the other two)
def transform_interaction_max_abs(pattern, interaction_pairs=(), d=64):
    return np.abs(transform_interaction_max_signed(pattern, interaction_pairs, d=d))


# ---------------------------------------
# (2) QLR interaction-pair construction
# ---------------------------------------
def qlr_pairs_upper_tri(d):
    """Pairs (i,j) with i <= j, in the same order as np.triu_indices(d, 0)."""
    i_idx, j_idx = np.triu_indices(d, 0)
    return list(zip(i_idx.tolist(), j_idx.tolist()))


# -------------------------------------------------
# (3) Metric functions (FNI -> FNI rename applied)
# -------------------------------------------------
def importance_mass_accuracy(gt_mask, attribution):
    if not isinstance(gt_mask, np.ndarray) or not isinstance(attribution, np.ndarray):
        return np.nan
    if attribution.ndim != 1 or len(gt_mask) != len(attribution):
        return np.nan
    abs_attr = np.abs(attribution)
    mass_in_gt = np.sum(abs_attr[gt_mask == 1])
    total_mass = np.sum(abs_attr)
    if total_mass == 0:
        return 1.0 if mass_in_gt == 0 else 0.0
    return mass_in_gt / total_mass

def create_cost_matrix(grid_edge_length):
    if grid_edge_length == 0:
        return np.array([]).reshape(0,0)
    total = grid_edge_length * grid_edge_length
    if total == 1:
        return np.array([[0.0]])
    indices_matrix = np.indices((grid_edge_length, grid_edge_length))
    coordinates = [(indices_matrix[0][r, c], indices_matrix[1][r, c])
                   for r in range(grid_edge_length) for c in range(grid_edge_length)]
    coordinates = np.array(coordinates)
    return cdist(coordinates, coordinates)

def calculate_emd_score_metric(gt_mask_flat, attribution_flat, grid_edge_length, base_cost_matrix, is_fni=False):
    # Same implementation as your original, only the flag/name changed to is_fni.
    if not (isinstance(gt_mask_flat, np.ndarray) and gt_mask_flat.ndim == 1 and
            isinstance(attribution_flat, np.ndarray) and attribution_flat.ndim == 1 and
            len(gt_mask_flat) == len(attribution_flat) and
            len(gt_mask_flat) == grid_edge_length * grid_edge_length):
        return np.nan

    current_cost_matrix = np.copy(base_cost_matrix)
    if is_fni:
        gt_indices = np.where(gt_mask_flat == 1)[0]
        for r_idx in gt_indices:
            for c_idx in gt_indices:
                if r_idx < current_cost_matrix.shape[0] and c_idx < current_cost_matrix.shape[1]:
                    current_cost_matrix[r_idx, c_idx] = 0.0

    sum_gt = np.sum(gt_mask_flat)
    abs_attribution = np.abs(attribution_flat)
    sum_attr = np.sum(abs_attribution)

    if sum_gt < 1e-9 and sum_attr < 1e-9:
        return 1.0
    if sum_gt < 1e-9 or sum_attr < 1e-9:
        return 0.0

    dist_gt = gt_mask_flat.astype(np.float64) / sum_gt
    dist_attr = abs_attribution.astype(np.float64) / sum_attr

    dist_gt_c = np.ascontiguousarray(dist_gt, dtype=np.float64)
    dist_attr_c = np.ascontiguousarray(dist_attr, dtype=np.float64)
    current_cost_matrix_c = np.ascontiguousarray(current_cost_matrix, dtype=np.float64)

    cost_val = 0.0
    if grid_edge_length * grid_edge_length > 1:
        try:
            _, cost_val = emd(dist_gt_c, dist_attr_c, current_cost_matrix_c, numItermax=200000, log=True)
        except Exception:
            return np.nan

    d_max = np.sqrt(2 * (grid_edge_length - 1)**2) if grid_edge_length > 1 else 0.0
    if d_max == 0:
        return 1.0 if np.isclose(cost_val, 0) else 0.0

    return 1 - (cost_val['cost'] / d_max)


# --------------------------------------------------------------------
# (4) Main processing extended with redistribution + FNI terminology
# --------------------------------------------------------------------
def process_explanations_metrics_with_redistribution(
    explanations_dict, gt_mask_main_flat, d_flat_main, d_edge_main, cost_mat_main,
    fast_func_for_pairs  # function FAST to recover interaction pairs when needed
):
    """
    Computes metrics in four variants:
      - original (your current behavior)
      - sum redistribution (64-D)
      - weighted sum redistribution (64-D)
      - max redistribution (64-D)

    Returns nested dict: results[scenario_base][method][metric_variant]['IMA'|'EMD'|'FNI-EMD'] -> list of values,
    and aggregated means/stds.
    """
    # raw collections
    results_raw = {}  # results_raw[scenario_base][method][variant]['IMA'|'EMD'|'FNI-EMD'] -> list

    for scenario_full_name, methods_data in explanations_dict.items():
        if 'translations' in scenario_full_name:
            continue

        parts = scenario_full_name.split('_')
        try:
            scenario_base_name = '_'.join(parts[:-1])
        except ValueError:
            scenario_base_name = scenario_full_name

        # default: no interaction pairs
        interaction_pairs_for_scenario = []
        if 'xor' in scenario_full_name.lower():
            # load training data to re-run FAST as you do
            data_path = f'./data/xai_tris/{scenario_full_name}.pkl'
            try:
                with open(data_path, "rb") as f:
                    data = pkl.load(f)
                X_train_tensor = data.x_train.float()
                y_train_tensor = data.y_train
                interaction_pairs_for_scenario, _ = fast_func_for_pairs(
                    X_train_tensor, y_train_tensor, n_interactions=128
                )
            except Exception:
                interaction_pairs_for_scenario = []

        if scenario_base_name not in results_raw:
            results_raw[scenario_base_name] = {}

        for method_name, explanation_content in methods_data.items():
            if method_name not in results_raw[scenario_base_name]:
                results_raw[scenario_base_name][method_name] = {
                    'original': {'IMA': [], 'EMD': [], 'FNI-EMD': []},
                    'sum':      {'IMA': [], 'EMD': [], 'FNI-EMD': []},
                    'weighted': {'IMA': [], 'EMD': [], 'FNI-EMD': []},
                    'max':      {'IMA': [], 'EMD': [], 'FNI-EMD': []},
                }

            # normalize explanation to 1D vector (global) or mean over samples
            current_explanation = explanation_content
            if isinstance(current_explanation, list):
                current_explanation = np.array(current_explanation)

            if not isinstance(current_explanation, np.ndarray):
                # store NaNs for all four variants
                for v in ['original','sum','weighted','max']:
                    for m in ['IMA','EMD','FNI-EMD']:
                        results_raw[scenario_base_name][method_name][v][m].append(np.nan)
                continue

            if current_explanation.ndim > 1 and current_explanation.shape[0] > 1 and \
               not (current_explanation.shape[0] == d_flat_main and current_explanation.ndim == 2 and current_explanation.shape[1] == 1 and d_flat_main > 1):
                if np.ndim(current_explanation[0]) == 3:
                    processed_attr = np.mean(np.mean(np.array(current_explanation), axis=0).squeeze(), axis=0)
                else:
                    processed_attr = np.mean(np.vstack(current_explanation), axis=0)
            else:
                processed_attr = current_explanation.flatten()

            # -----------------------
            # (4a) ORIGINAL metrics
            # -----------------------
            # IMA: replicate your prior logic for GT construction per method
            XOR_GT_MASK_MAIN_FLAT = np.zeros(D_FLAT, dtype=int)
            
            if method_name in ['pattern_gam', 'ebm', 'nam'] and 'xor' in scenario_full_name.lower():
                gt_mask_for_ima = create_gt_mask_interactions_1d(XOR_GT_MASK_MAIN_FLAT, interaction_pairs_for_scenario, d_flat_main)
                attr_for_ima = processed_attr
            elif method_name == 'pattern_qlr':
                gt_mask_for_ima = create_qlr_ground_truth_mask_1d(XOR_GT_MASK_MAIN_FLAT, d_flat_main)
                attr_for_ima = processed_attr
            else:
                gt_mask_for_ima = gt_mask_main_flat
                attr_for_ima = processed_attr[:d_flat_main]

            if len(attr_for_ima) != len(gt_mask_for_ima):
                tmp = np.zeros(len(gt_mask_for_ima))
                n = min(len(attr_for_ima), len(gt_mask_for_ima))
                tmp[:n] = attr_for_ima[:n]
                attr_for_ima = tmp

            ima_val = importance_mass_accuracy(gt_mask_for_ima, attr_for_ima)

            # EMD/FNI-EMD on main effects only
            attr_main_eff = processed_attr[:d_flat_main]
            emd_val = np.nan
            fni_emd_val = np.nan
            if len(attr_main_eff) == d_flat_main and d_flat_main > 0:
                emd_val = calculate_emd_score_metric(gt_mask_main_flat, attr_main_eff, d_edge_main, cost_mat_main, is_fni=False)
                fni_emd_val = calculate_emd_score_metric(gt_mask_main_flat, attr_main_eff, d_edge_main, cost_mat_main, is_fni=True)

            results_raw[scenario_base_name][method_name]['original']['IMA'].append(ima_val)
            results_raw[scenario_base_name][method_name]['original']['EMD'].append(emd_val)
            results_raw[scenario_base_name][method_name]['original']['FNI-EMD'].append(fni_emd_val)

            # ---------------------------------------------------
            # (4b) REDISTRIBUTIONS back to 64-D and 64-D metrics
            # ---------------------------------------------------
            # Decide interaction pairs for redistribution
            # For QLR: pairs are upper-tri on 64-D (in the same order used to build z^QLR)
            # For others: use FAST pairs when available.
            pairs = []
            d_expected = d_flat_main

            if method_name == 'pattern_qlr':
                pairs = qlr_pairs_upper_tri(d_expected)
            else:
                pairs = interaction_pairs_for_scenario

            # make the pair list fit the tail length if mismatch
            m_expected = len(pairs)
            m_from_vec = max(0, len(processed_attr) - d_expected)
            if m_from_vec <= 0:
                # nothing to redistribute; fall back to plain 64-D main effects
                red_sum = np.abs(processed_attr[:d_expected])
                red_weighted = red_sum.copy()
                red_max = red_sum.copy()
            else:
                if m_expected != m_from_vec:
                    # trim or pad pairs (padding cannot invent structure -> trim to fit)
                    pairs = pairs[:m_from_vec]
                    m_expected = len(pairs)
                    if d_expected + m_expected != len(processed_attr):
                        # as a last resort, skip redistribution (use main effects only)
                        red_sum = np.abs(processed_attr[:d_expected])
                        red_weighted = red_sum.copy()
                        red_max = red_sum.copy()
                    else:
                        red_sum = transform_interaction_sum(processed_attr, pairs, d=d_expected)
                        red_weighted = transform_interaction_weighted_sum(processed_attr, pairs, d=d_expected)
                        # max uses signed; convert to abs for metric parity
                        red_max = transform_interaction_max_abs(processed_attr, pairs, d=d_expected)
                else:
                    red_sum = transform_interaction_sum(processed_attr, pairs, d=d_expected)
                    red_weighted = transform_interaction_weighted_sum(processed_attr, pairs, d=d_expected)
                    red_max = transform_interaction_max_abs(processed_attr, pairs, d=d_expected)

            # Metrics on redistributed 64-D vectors
            for variant_name, vattr in [('sum', red_sum), ('weighted', red_weighted), ('max', red_max)]:
                ima_r = importance_mass_accuracy(gt_mask_main_flat, vattr)
                emd_r = calculate_emd_score_metric(gt_mask_main_flat, vattr, d_edge_main, cost_mat_main, is_fni=False)
                fni_r = calculate_emd_score_metric(gt_mask_main_flat, vattr, d_edge_main, cost_mat_main, is_fni=True)
                results_raw[scenario_base_name][method_name][variant_name]['IMA'].append(ima_r)
                results_raw[scenario_base_name][method_name][variant_name]['EMD'].append(emd_r)
                results_raw[scenario_base_name][method_name][variant_name]['FNI-EMD'].append(fni_r)

    # ----------------------
    # aggregate mean / std
    # ----------------------
    aggregated = {}
    for sc_base, meth_data in results_raw.items():
        aggregated[sc_base] = {}
        for meth, variants in meth_data.items():
            aggregated[sc_base][meth] = {}
            for variant, metric_lists in variants.items():
                aggregated[sc_base][meth][variant] = {
                    'IMA_mean': np.nanmean(metric_lists['IMA']) if metric_lists['IMA'] else np.nan,
                    'IMA_std':  np.nanstd(metric_lists['IMA'])  if metric_lists['IMA'] else np.nan,
                    'EMD_mean': np.nanmean(metric_lists['EMD']) if metric_lists['EMD'] else np.nan,
                    'EMD_std':  np.nanstd(metric_lists['EMD'])  if metric_lists['EMD'] else np.nan,
                    'FNI-EMD_mean': np.nanmean(metric_lists['FNI-EMD']) if metric_lists['FNI-EMD'] else np.nan,
                    'FNI-EMD_std':  np.nanstd(metric_lists['FNI-EMD'])  if metric_lists['FNI-EMD'] else np.nan,
                }
    return results_raw, aggregated


# --------------------------------------------------------
# (5) Pretty printing: compact tables per metric
# --------------------------------------------------------
def print_metric_tables(aggregated_results):
    """
    For each scenario, prints three tables (IMA / EMD / FNI-EMD).
    Rows: methods; Columns: original / sum / weighted / max as Mean ± Std (4 decimals).
    """
    import pandas as pd

    for scenario_name, method_data in aggregated_results.items():
        methods = sorted(method_data.keys())

        # build per-metric dataframes
        cols = ['original','sum','weighted','max']
        def fmt(m, s):
            if np.isnan(m) or np.isnan(s):
                return "nan"
            return f"{m:.4f} ± {s:.4f}"

        tables = {'IMA': [], 'EMD': [], 'FNI-EMD': []}
        for meth in methods:
            row_ima, row_emd, row_fni = [], [], []
            for v in cols:
                stats = method_data[meth][v]
                row_ima.append(fmt(stats['IMA_mean'], stats['IMA_std']))
                row_emd.append(fmt(stats['EMD_mean'], stats['EMD_std']))
                row_fni.append(fmt(stats['FNI-EMD_mean'], stats['FNI-EMD_std']))
            tables['IMA'].append(row_ima)
            tables['EMD'].append(row_emd)
            tables['FNI-EMD'].append(row_fni)

        print(f"\n=== Scenario: {scenario_name} ===")
        df_ima = pd.DataFrame(tables['IMA'], index=methods, columns=cols)
        df_emd = pd.DataFrame(tables['EMD'], index=methods, columns=cols)
        df_fni = pd.DataFrame(tables['FNI-EMD'], index=methods, columns=cols)

        print("\nIMA (Importance Mass Accuracy):")
        print(df_ima.to_string())

        print("\nEMD:")
        print(df_emd.to_string())

        print("\nFNI-EMD:")
        print(df_fni.to_string())


# ======================
# ===== USAGE EXAMPLE ==
# ======================
# (Keep your existing configuration / GT / COST matrix creation as-is.)
# Ensure the following are already defined in your script before this section:
# - D_2D_EDGE, D_FLAT, GT_MASK_2D_FLAT, COST_MATRIX_MAIN_EFFECTS
# - FAST (as provided), create_gt_mask_interactions_1d, create_qlr_ground_truth_mask_1d
# - calculate_emd_score_metric (replaced above to use is_fni and 'FNI-EMD' naming)

print(f"--- Configuration ---")
print(f"D_2D_EDGE (grid edge for main effects): {D_2D_EDGE}")
print(f"D_FLAT (total main effect features): {D_FLAT}")
print(f"Sum of GT_MASK_2D_FLAT: {np.sum(GT_MASK_2D_FLAT)}")
print(f"Cost matrix for main effects shape: {COST_MATRIX_MAIN_EFFECTS.shape}")
print(f"--- Starting Metric Calculation (with redistribution) ---")

with open('./models/xai_tris/explanations_xor_dist_corr_std_qlr.pkl', "rb") as f:
    explanations = pkl.load(f)

# Compute
results_raw_all, aggregated_all = process_explanations_metrics_with_redistribution(
    explanations,
    GT_MASK_2D_FLAT,
    D_FLAT,
    D_2D_EDGE,
    COST_MATRIX_MAIN_EFFECTS,
    FAST  # pass your FAST function for pair discovery
)

# Print compact tables per scenario
print("\n--- Aggregated Results (Original, Sum, Weighted, Max) ---")
print_metric_tables(aggregated_all)

# ==== END ADDITIONS / REPLACEMENTS ====


--- Configuration ---
D_2D_EDGE (grid edge for main effects): 8
D_FLAT (total main effect features): 64
Sum of GT_MASK_2D_FLAT: 8
Cost matrix for main effects shape: (64, 64)
--- Starting Metric Calculation (with redistribution) ---

--- Aggregated Results (Original, Sum, Weighted, Max) ---

=== Scenario: xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated ===

IMA (Importance Mass Accuracy):
                            original              sum         weighted              max
discr                0.1178 ± 0.0273  0.8692 ± 0.0238  0.7617 ± 0.0510  0.8956 ± 0.0193
ebm                  0.0304 ± 0.0184  0.1441 ± 0.0599  0.0784 ± 0.0310  0.1187 ± 0.0659
ig                   0.7278 ± 0.0947  0.7278 ± 0.0947  0.7278 ± 0.0947  0.7278 ± 0.0947
kernel_svm           0.1231 ± 0.0520  0.1231 ± 0.0520  0.1231 ± 0.0520  0.1231 ± 0.0520
nam                  0.3641 ± 0.0440  0.5350 ± 0.0364  0.3663 ± 0.0808  0.4708 ± 0.0606
pattern_attribution  0.7687 ± 0.1277  0.7687 ± 0.1277  0.7687 ± 0.1277  0

In [9]:
# ======= Pretty tables with ordering, renaming, bolding, and exports =======
import os
import numpy as np
import pandas as pd



# Methods: order + pretty names 
METHOD_ORDER = [
    "discr", "pattern_gam", "pattern_qlr", "kernel_svm", 
    "pattern_attribution", "pattern_net", "nam", "ebm", "shap", "ig"
]

PRETTY_NAME = {
    "discr": "DISCR",
    "kernel_svm": "KernelPattern",
    "shap": "SHAP",
    "ig": "IG",
    "pattern_attribution": "PatternAttr",
    "pattern_net": "PatternNet",
    "nam": "NAM",
    "ebm": "EBM",
    "pattern_gam": "PatternGAM",
    "pattern_qlr": "PatternQLR",
}

COL_SPECS = [
    ("original", "Original (feature space)"),
    ("sum",      "Sum (redistributed)"),
    ("weighted", "Weighted sum (redistributed)"),
    ("max",      "Max (redistributed)"),
]
# If you want a different order or subset, just reorder/remove tuples above.
# Example:
# COL_SPECS = [("sum","Sum"), ("weighted","Weighted"), ("max","Max"), ("original","Original")]

COL_ORDER = ["original", "sum", "weighted", "max"] 


METRICS = ["IMA", "EMD", "FNI-EMD"]    

def _fmt_cell(mean, std):
    if np.isnan(mean) or np.isnan(std):
        return "nan"
    return f"{mean:.4f} ± {std:.4f}"

def _bold_winners_as_markdown(df_means, df_strings):
    """
    Bold the max value per column in df_strings (string cells), using df_means (numeric) to decide winners.
    Ties: bold all tied maxima.
    """
    out = df_strings.copy()
    for col in df_means.columns:
        col_vals = df_means[col]
        if col_vals.isna().all():
            continue
        max_val = np.nanmax(col_vals.values.astype(float))
        winners = (np.abs(col_vals - max_val) < 1e-12)  # tie-safe
        for idx in df_means.index[winners]:
            out.loc[idx, col] = f"**{out.loc[idx, col]}**"
    return out

def _styler_from_strings_and_means(df_strings, df_means, caption=None):
    """
    Create a pandas Styler that shows df_strings but bolds the winners from df_means.
    """
    # build boolean mask for winners
    mask = pd.DataFrame(False, index=df_means.index, columns=df_means.columns)
    for col in df_means.columns:
        col_vals = df_means[col]
        if col_vals.isna().all():
            continue
        m = np.nanmax(col_vals.values.astype(float))
        mask[col] = (np.abs(col_vals - m) < 1e-12)

    def bold_mask(s):
        return ['font-weight: bold' if m else '' for m in mask[s.name]]

    sty = (df_strings.style
           .apply(bold_mask, axis=0)
           .set_properties(**{"white-space": "nowrap"})
           .set_table_styles([
               {"selector": "th", "props": [("text-align", "center")]},
               {"selector": "td", "props": [("text-align", "center"), ("padding", "6px 10px")]},
               {"selector": "caption", "props": [("caption-side", "top"), ("font-weight", "bold"), ("margin-bottom", "8px")]}
           ])
          )
    if caption:
        sty = sty.set_caption(caption)
    return sty

def build_tables_for_scenario(method_data, method_order, pretty_map, col_specs):
    """
    col_specs: list of (internal_key, pretty_label)
    """
    methods_present = [m for m in method_order if m in method_data]
    pretty_index = [pretty_map.get(m, m) for m in methods_present]

    # Split specs into internal keys (used to fetch stats) and labels (shown in table)
    col_keys   = [k for k, _ in col_specs]
    col_labels = [l for _, l in col_specs]

    per_metric = {}
    for metric in METRICS:
        means = []
        stds  = []
        for m in methods_present:
            row_means, row_stds = [], []
            for ckey in col_keys:
                stats = method_data[m][ckey]
                mean_key = f"{metric}_mean" if metric != "FNI-EMD" else "FNI-EMD_mean"
                std_key  = f"{metric}_std"  if metric != "FNI-EMD" else "FNI-EMD_std"
                row_means.append(stats.get(mean_key, np.nan))
                row_stds.append(stats.get(std_key, np.nan))
            means.append(row_means)
            stds.append(row_stds)

        # Use pretty labels as the displayed column headers
        df_means = pd.DataFrame(means, index=pretty_index, columns=col_labels, dtype=float)
        df_stds  = pd.DataFrame(stds,  index=pretty_index, columns=col_labels, dtype=float)

        # Format
        df_fmt = pd.DataFrame(index=pretty_index, columns=col_labels, dtype=object)
        for i in df_fmt.index:
            for j in df_fmt.columns:
                df_fmt.loc[i, j] = _fmt_cell(df_means.loc[i, j], df_stds.loc[i, j])

        # Bold winners by comparing numeric df_means; labels are fine here
        df_fmt_bold_md = _bold_winners_as_markdown(df_means, df_fmt)
        styler = _styler_from_strings_and_means(df_fmt, df_means, caption=f"{metric}")

        per_metric[metric] = {
            "df_means": df_means, "df_stds": df_stds,
            "df_fmt": df_fmt, "df_fmt_bold_md": df_fmt_bold_md,
            "styler": styler
        }
    return per_metric


def export_tables(per_metric, out_dir=None, fname_prefix="", to_png=False):
    """
    Save HTML (and optional PNG if dataframe_image is available) for each metric table.
    Returns dict of saved paths.
    """
    saved = {}
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        try:
            import dataframe_image as dfi
            have_dfi = True
        except Exception:
            have_dfi = False

        for metric, bundle in per_metric.items():
            sty = bundle["styler"]
            html_path = os.path.join(out_dir, f"{fname_prefix}{metric}.html")
            with open(html_path, "w", encoding="utf-8") as f:
                f.write(sty.to_html())
            saved[metric] = {"html": html_path}

            if to_png:
                if have_dfi:
                    png_path = os.path.join(out_dir, f"{fname_prefix}{metric}.png")
                    dfi.export(sty, png_path)   # requires chrome or wkhtmltopdf depending on backend
                    saved[metric]["png"] = png_path
                else:
                    print("Note: dataframe_image not installed; skipping PNG export.")
    return saved

# -------------------------
# EXAMPLE USAGE:
# -------------------------
# aggregated_all is what you already computed earlier.

for scenario_name, method_data in aggregated_all.items():
    per_metric = build_tables_for_scenario(
        method_data,
        method_order=METHOD_ORDER,
        pretty_map=PRETTY_NAME,
        col_specs=COL_SPECS,     # <— here
    )

    print(f"\n=== {scenario_name} ===")
    for metric in METRICS:
        print(f"\n{metric}:")
        print(per_metric[metric]["df_fmt_bold_md"].to_markdown(tablefmt="github"))

    ### Optionally export nice HTML/PNG for email or slides
    saved_paths = export_tables(per_metric, out_dir="./table_exports",
                                fname_prefix=f"{scenario_name}_", to_png=False)
    print("Saved:", saved_paths)



=== xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated ===

IMA:
|               | Original (feature space)   | Sum (redistributed)   | Weighted sum (redistributed)   | Max (redistributed)   |
|---------------|----------------------------|-----------------------|--------------------------------|-----------------------|
| DISCR         | 0.1178 ± 0.0273            | **0.8692 ± 0.0238**   | 0.7617 ± 0.0510                | **0.8956 ± 0.0193**   |
| PatternGAM    | 0.5688 ± 0.0803            | 0.6537 ± 0.0521       | 0.4889 ± 0.1072                | 0.6180 ± 0.0679       |
| PatternQLR    | 0.2278 ± 0.0385            | 0.4286 ± 0.0132       | 0.3414 ± 0.0063                | 0.5864 ± 0.0835       |
| KernelPattern | 0.1231 ± 0.0520            | 0.1231 ± 0.0520       | 0.1231 ± 0.0520                | 0.1231 ± 0.0520       |
| PatternAttr   | **0.7687 ± 0.1277**        | 0.7687 ± 0.1277       | **0.7687 ± 0.1277**            | 0.7687 ± 0.1277       |
| PatternNet    | 0.6141 ± 0.1915 

In [10]:
--- Aggregated Results (Original, Sum, Weighted, Max) ---

=== Scenario: xor_distractor_additive_1d1p_0.20_0.40_0.40_correlated ===

IMA (Importance Mass Accuracy):
                            original              sum         weighted              max
discr                0.1178 ± 0.0273  0.8692 ± 0.0238  0.7617 ± 0.0510  0.8956 ± 0.0193
ebm                  0.0336 ± 0.0179  0.1441 ± 0.0599  0.0784 ± 0.0310  0.1187 ± 0.0659
ig                   0.7278 ± 0.0947  0.7278 ± 0.0947  0.7278 ± 0.0947  0.7278 ± 0.0947
kernel_svm           0.1231 ± 0.0520  0.1231 ± 0.0520  0.1231 ± 0.0520  0.1231 ± 0.0520
nam                  0.4062 ± 0.0549  0.5350 ± 0.0364  0.3663 ± 0.0808  0.4708 ± 0.0606
pattern_attribution  0.7687 ± 0.1277  0.7687 ± 0.1277  0.7687 ± 0.1277  0.7687 ± 0.1277
pattern_gam          0.5973 ± 0.0788  0.6537 ± 0.0521  0.4889 ± 0.1072  0.6180 ± 0.0679
pattern_net          0.6141 ± 0.1915  0.6141 ± 0.1915  0.6141 ± 0.1915  0.6141 ± 0.1915
pattern_qlr          0.2422 ± 0.0334  0.4286 ± 0.0132  0.3414 ± 0.0063  0.5864 ± 0.0835
shap                 0.4130 ± 0.0354  0.4130 ± 0.0354  0.4130 ± 0.0354  0.4130 ± 0.0354

EMD:
                            original              sum         weighted              max
discr                0.7986 ± 0.0207  0.9457 ± 0.0056  0.9176 ± 0.0182  0.9430 ± 0.0052
ebm                  0.8151 ± 0.0017  0.7888 ± 0.0304  0.7700 ± 0.0186  0.7796 ± 0.0325
ig                   0.8995 ± 0.0460  0.8995 ± 0.0460  0.8995 ± 0.0460  0.8995 ± 0.0460
kernel_svm           0.7781 ± 0.0351  0.7781 ± 0.0351  0.7781 ± 0.0351  0.7781 ± 0.0351
nam                  0.8256 ± 0.0312  0.8918 ± 0.0115  0.8687 ± 0.0137  0.8779 ± 0.0144
pattern_attribution  0.9095 ± 0.0230  0.9095 ± 0.0230  0.9095 ± 0.0230  0.9095 ± 0.0230
pattern_gam          0.8095 ± 0.0182  0.9203 ± 0.0112  0.8816 ± 0.0176  0.9144 ± 0.0123
pattern_net          0.8891 ± 0.0687  0.8891 ± 0.0687  0.8891 ± 0.0687  0.8891 ± 0.0687
pattern_qlr          0.7832 ± 0.0221  0.8818 ± 0.0092  0.8564 ± 0.0100  0.9194 ± 0.0219
shap                 0.8734 ± 0.0237  0.8734 ± 0.0237  0.8734 ± 0.0237  0.8734 ± 0.0237

FNI-EMD:
                            original              sum         weighted              max
discr                0.8052 ± 0.0216  0.9729 ± 0.0065  0.9516 ± 0.0112  0.9790 ± 0.0046
ebm                  0.8213 ± 0.0017  0.8129 ± 0.0270  0.7850 ± 0.0229  0.8009 ± 0.0317
ig                   0.9580 ± 0.0128  0.9580 ± 0.0128  0.9580 ± 0.0128  0.9580 ± 0.0128
kernel_svm           0.7869 ± 0.0377  0.7869 ± 0.0377  0.7869 ± 0.0377  0.7869 ± 0.0377
nam                  0.8414 ± 0.0304  0.9150 ± 0.0127  0.8853 ± 0.0125  0.9058 ± 0.0167
pattern_attribution  0.9603 ± 0.0289  0.9603 ± 0.0289  0.9603 ± 0.0289  0.9603 ± 0.0289
pattern_gam          0.8196 ± 0.0202  0.9308 ± 0.0096  0.8970 ± 0.0196  0.9250 ± 0.0133
pattern_net          0.9235 ± 0.0400  0.9235 ± 0.0400  0.9235 ± 0.0400  0.9235 ± 0.0400
pattern_qlr          0.7941 ± 0.0250  0.8908 ± 0.0052  0.8679 ± 0.0087  0.9260 ± 0.0210
shap                 0.9112 ± 0.0105  0.9112 ± 0.0105  0.9112 ± 0.0105  0.9112 ± 0.0105

SyntaxError: invalid syntax (1941484341.py, line 1)