In [8]:
# -*- coding: utf-8 -*-
import os, re, pickle as pkl, warnings
import numpy as np
from scipy.spatial.distance import cdist
from ot.lp import emd
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

# =========================
# Switches
# =========================
# If True: use IMP-SASI (credit pair only to the stronger-univariate feature)
# If False: fall back to MAX rule (credit max(|pair|, |uni|) to BOTH ends)
USE_SASI_FOR_PRIMARY = True   # affects pattern_gam / pattern_qlr / ebm (when (192,) or (2144,) shape)
USE_SASI_FOR_TILDE   = True   # affects SD/SDb/DISCR/PROD for NAM & QLR tilde metrics

# =========================
# FAST (interaction pairs)
# =========================
from interpret.utils import measure_interactions
def FAST(X_train, y_train, n_interactions, init_score=None, feature_names=None, feature_types=None):
    interactions = measure_interactions(
        X_train, y_train, interactions=n_interactions,
        init_score=init_score, feature_names=feature_names, feature_types=feature_types
    )
    pairs = []
    for (i, j), _ in interactions:
        pairs.append((i, j))
    return pairs, None

# =========================
# Metrics
# =========================
def importance_mass_accuracy(gt_mask, attribution):
    if not isinstance(gt_mask, np.ndarray) or not isinstance(attribution, np.ndarray):
        return np.nan
    if attribution.ndim != 1 or len(gt_mask) != len(attribution):
        return np.nan
    abs_attr = np.abs(attribution)
    total = float(np.sum(abs_attr))
    if total == 0.0:
        return 1.0 if np.sum(abs_attr[gt_mask == 1]) == 0 else 0.0
    return float(np.sum(abs_attr[gt_mask == 1]) / total)

def create_cost_matrix(edge):
    if edge <= 0: return np.zeros((0,0))
    if edge == 1: return np.array([[0.0]])
    rr, cc = np.indices((edge, edge))
    coords = np.column_stack([rr.ravel(), cc.ravel()])
    return cdist(coords, coords)

def _tight_ot_denominator(gt_dist, cost_matrix):
    exp_dists = cost_matrix @ gt_dist
    return float(np.max(exp_dists))

def calculate_emd_score_metric(gt_mask_flat, attribution_flat, grid_edge_length, base_cost_matrix, is_fni=False):
    if not (isinstance(gt_mask_flat, np.ndarray) and isinstance(attribution_flat, np.ndarray)): return np.nan
    if gt_mask_flat.ndim != 1 or attribution_flat.ndim != 1: return np.nan
    if len(gt_mask_flat) != len(attribution_flat) or len(gt_mask_flat) != grid_edge_length * grid_edge_length: return np.nan

    C = np.array(base_cost_matrix, copy=True)
    if is_fni:
        idx = np.where(gt_mask_flat == 1)[0]
        if idx.size:
            C[np.ix_(idx, idx)] = 0.0

    t = gt_mask_flat.astype(np.float64)
    s = np.abs(attribution_flat).astype(np.float64)
    sum_t, sum_s = float(t.sum()), float(s.sum())

    if sum_t < 1e-9 and sum_s < 1e-9: return 1.0
    if sum_t < 1e-9 or sum_s < 1e-9:  return 0.0

    q = t / sum_t
    p = s / sum_s

    try:
        _, log = emd(np.ascontiguousarray(p), np.ascontiguousarray(q),
                     np.ascontiguousarray(C), numItermax=200000, log=True)
        ot_cost = float(log["cost"])
    except Exception:
        return np.nan

    denom = _tight_ot_denominator(q, C)
    if denom <= 1e-12:
        d_max = np.sqrt(2 * (grid_edge_length - 1)**2) if grid_edge_length > 1 else 1.0
        denom = max(d_max, 1.0)

    score = 1.0 - (ot_cost / denom)
    return float(np.clip(score, 0.0, 1.0))

# =========================
# Ground truth (8x8)
# =========================
D_2D_EDGE = 8
D = D_2D_EDGE * D_2D_EDGE
normal_t = np.array([[1,0],[1,1],[1,0]])
normal_l = np.array([[1,0],[1,0],[1,1]])
GT_MASK_2D = np.zeros((D_2D_EDGE, D_2D_EDGE), dtype=int)
GT_MASK_2D[1:4, 1:3] = normal_t
GT_MASK_2D[4:7, 5:7] = normal_l
GT_MASK_2D_FLAT = GT_MASK_2D.flatten()
COST_MATRIX_MAIN_EFFECTS = create_cost_matrix(D_2D_EDGE)

# =========================
# IO
# =========================

# tilde is the old naming for \tilde{f}_i <- w_i f_i; the final version used z_i <- w_i f_i but I haven't (yet) changed it here
# There are two separate files are because the original explanations did not need to be changed
PRIMARY_PATH = './models/xai_tris/explanations_cameraready.pkl'
TILDE_PATH   = './models/xai_tris/explanations_cameraready_nam_tilde.pkl'
with open(PRIMARY_PATH, 'rb') as f:
    explanations = pkl.load(f)
with open(TILDE_PATH, 'rb') as f:
    explanations_nam = pkl.load(f)

# =========================
# Scenario helpers
# =========================
def base_scenario_name(key: str):
    m = re.match(r"^(.*)_([0-9]+)$", key)
    return m.group(1) if m else key

def load_xor_training(scenario_full_name: str):
    path = f'./data/xai_tris/{scenario_full_name}.pkl'
    if not os.path.exists(path): return None, None
    try:
        data = pkl.load(open(path, "rb"))
        return data.x_train.float(), data.y_train
    except Exception:
        return None, None

# =========================
# Generic helpers
# =========================
def reduce_to_vector_per_seed(obj):
    if obj is None: return None
    if isinstance(obj, (list, tuple)):
        if len(obj) == 0: return None
        if all(np.isscalar(x) or (isinstance(x, np.generic) and np.ndim(x) == 0) for x in obj):
            return np.asarray(obj, dtype=float)
        arrs = [np.asarray(x) for x in obj]
        arrs = [a.squeeze(-1) if (a.ndim == 2 and a.shape[-1] == 1) else a for a in arrs]
        if all(a.ndim == 1 for a in arrs):
            return np.mean(np.stack(arrs, 0), 0).astype(float)
        stacked = np.array([np.ravel(a) for a in arrs], dtype=float)
        return np.mean(stacked, 0) if stacked.ndim == 2 else np.ravel(stacked).astype(float)
    arr = np.asarray(obj)
    if arr.ndim == 1: return arr.astype(float)
    if arr.ndim == 2:
        if arr.shape[1] == 64: return np.mean(arr, 0).astype(float)
        if arr.shape[1] == 1:  return arr[:, 0].astype(float)
        if arr.shape[0] == 1:  return arr[0, :].astype(float)
        return np.ravel(arr).astype(float)
    if arr.ndim == 3 and arr.shape[1] == 64 and arr.shape[2] == 1:
        return np.mean(arr[:, :, 0], 0).astype(float)
    return np.ravel(arr).astype(float)

# =========================
# SASI reducers (IMP-SASI)
# =========================
def _imp_sasi_from_unary_and_pairs(u, pairs_idx, w_pairs, nonneg=True):
    u = np.asarray(u, dtype=float)
    if nonneg: u = np.abs(u)
    sasi = u.copy()
    for (i, j), w in zip(pairs_idx, w_pairs):
        wv = abs(float(w)) if nonneg else float(w)
        if u[i] >= u[j]:
            if wv > sasi[i]: sasi[i] = max(u[i], wv)
        else:
            if wv > sasi[j]: sasi[j] = max(u[j], wv)
    return sasi

def nam_imp_sasi_from_192(vec192, scenario_key, d=64, nonneg=True):
    v = np.asarray(vec192, dtype=float)
    if v.size != d + 128:  # fall back to univariate
        return np.abs(v[:d]) if nonneg else v[:d]
    u = v[:d]; inter = v[d:]
    pairs = []
    if 'xor' in scenario_key.lower():
        Xtr, ytr = load_xor_training(scenario_key)
        if Xtr is not None and ytr is not None:
            try: pairs, _ = FAST(Xtr, ytr, n_interactions=128)
            except Exception: pairs = []
    if len(pairs) != 128:
        return np.abs(u) if nonneg else u
    return _imp_sasi_from_unary_and_pairs(u, pairs, inter, nonneg=nonneg)

def qlr_imp_sasi_from_2144(vec2144, d=64, nonneg=True):
    v = np.asarray(vec2144, dtype=float)
    if v.size != 2144:
        return np.abs(v[:d]) if nonneg else v[:d]
    u = v[:d]; tri = v[d:]
    i0, i1 = np.triu_indices(d, 0)
    mask_cross = i0 < i1
    pairs = list(zip(i0[mask_cross].tolist(), i1[mask_cross].tolist()))
    w_pairs = tri[mask_cross]
    return _imp_sasi_from_unary_and_pairs(u, pairs, w_pairs, nonneg=nonneg)

# =========================
# MAX reducers (credit both ends)
# =========================
def nam_max_reduce_192_to_64(vec192, scenario_key, d=64):
    v = np.asarray(vec192, dtype=float)
    if v.size != d + 128:
        return np.abs(v[:d])
    u = np.abs(v[:d]); inter = np.abs(v[d:])
    pairs = []
    if 'xor' in scenario_key.lower():
        Xtr, ytr = load_xor_training(scenario_key)
        if Xtr is not None and ytr is not None:
            try: pairs, _ = FAST(Xtr, ytr, n_interactions=128)
            except Exception: pairs = []
    out = u.copy()
    if len(pairs) == 128 and inter.size == 128:
        for k,(i,j) in enumerate(pairs):
            w = inter[k]
            if w > out[i]: out[i] = w
            if w > out[j]: out[j] = w
    return out

def qlr_max_reduce_2144_to_64(vec2144, d=64):
    v = np.asarray(vec2144, dtype=float)
    if v.size != 2144:
        return np.abs(v[:d])
    linear = np.abs(v[:d]); tri = np.abs(v[d:])
    i0, i1 = np.triu_indices(d, 0)
    out = linear.copy()
    for k,(a,b) in enumerate(zip(i0,i1)):
        w = tri[k]
        if w > out[a]: out[a] = w
        if a != b and w > out[b]: out[b] = w
    return out

# =========================
# Dispatchers
# =========================
def primary_to_64(expl_raw, method: str, scenario_key: str):
    """PAT/EBM: choose SASI vs MAX based on switch; others handled elsewhere."""
    v = reduce_to_vector_per_seed(expl_raw)
    if v is None: return np.full(64, np.nan, dtype=float)
    n = v.size
    use_sasi = USE_SASI_FOR_PRIMARY

    if method == 'pattern_gam' or method == 'ebm':
        if n == 64:   return np.abs(v)
        if n == 192:  return (nam_imp_sasi_from_192(v, scenario_key) if use_sasi
                              else nam_max_reduce_192_to_64(v, scenario_key))
        return np.abs(v[:64])

    if method == 'pattern_qlr':
        if n == 64:    return np.abs(v)
        if n == 2144:  return (qlr_imp_sasi_from_2144(v) if use_sasi
                               else qlr_max_reduce_2144_to_64(v))
        return np.abs(v[:64])

    return np.full(64, np.nan, dtype=float)

def tilde_to_64(vec, family_tag, scenario_key):
    """SD/SDb/DISCR/PROD vectors for NAM/QLR (SASI vs MAX per switch)."""
    v = np.asarray(vec)
    if v.ndim != 1: v = np.ravel(v)
    n = v.size
    use_sasi = USE_SASI_FOR_TILDE

    if family_tag == 'NAM':
        if n == 64:  return np.abs(v)
        if n == 192: return nam_imp_sasi_from_192(v, scenario_key) if use_sasi else nam_max_reduce_192_to_64(v, scenario_key)
        return np.abs(v[:64])

    if family_tag == 'QLR':
        if n == 64:    return np.abs(v)
        if n == 2144:  return qlr_imp_sasi_from_2144(v) if use_sasi else qlr_max_reduce_2144_to_64(v)
        return np.abs(v[:64])

    return np.abs(v[:64])

# =========================
# Metric computation adapter
# =========================
def compute_metrics(vec64):
    if vec64 is None or vec64.size != 64:
        return {'IMA': np.nan, 'EMD': np.nan, 'FNI_EMD': np.nan}
    return {
        'IMA': importance_mass_accuracy(GT_MASK_2D_FLAT, vec64),
        'EMD': calculate_emd_score_metric(GT_MASK_2D_FLAT, vec64, D_2D_EDGE, COST_MATRIX_MAIN_EFFECTS, is_fni=False),
        'FNI_EMD': calculate_emd_score_metric(GT_MASK_2D_FLAT, vec64, D_2D_EDGE, COST_MATRIX_MAIN_EFFECTS, is_fni=True),
    }

def nanmeanstd(xs):
    if not xs: return (np.nan, np.nan)
    arr = np.array(xs, dtype=float)
    return np.nanmean(arr), np.nanstd(arr)

# =========================
# Evaluate all methods
# =========================
EVAL_METHODS = ['pattern_gam', 'pattern_qlr', 'kernel_svm', 'ebm', 'shap', 'ig', 'pattern_net', 'pattern_attribution']

results_raw_collection = {}   # {sc_base: {method: {'IMA':[], 'EMD':[], 'FNI_EMD':[]}}}
final_aggregated_results = {}

# ---- Primary explanations (PAT/EBM use SASI or MAX per switch) ----
for scenario_key, method_dict in explanations.items():
    sc_base = base_scenario_name(scenario_key)
    results_raw_collection.setdefault(sc_base, {})

    for method in EVAL_METHODS:
        if method not in method_dict: continue
        payloads = method_dict[method]  # list of seeds

        ml = results_raw_collection[sc_base].setdefault(method, {'IMA': [], 'EMD': [], 'FNI_EMD': []})
        for seed_payload in payloads:
            if method in ('pattern_gam', 'pattern_qlr', 'ebm'):
                vec64 = primary_to_64(seed_payload, method=method, scenario_key=scenario_key)
            elif method in ('shap', 'ig', 'pattern_net', 'pattern_attribution'):
                v = reduce_to_vector_per_seed(seed_payload)
                if v is None: vec64 = np.full(64, np.nan)
                else:
                    # Globalize locals by mean |·| if needed
                    if v.size == 64: vec64 = np.abs(v)
                    else:
                        # try to interpret as (B,64) or reduced already by reduce_to_vector_per_seed
                        vec64 = np.abs(v[:64]) if v.size >= 64 else np.full(64, np.nan)
            elif method == 'kernel_svm':
                v = reduce_to_vector_per_seed(seed_payload)
                vec64 = np.abs(v) if (v is not None and v.size == 64) else np.full(64, np.nan)
            else:
                vec64 = np.full(64, np.nan)

            mets = compute_metrics(vec64)
            ml['IMA'].append(mets['IMA']); ml['EMD'].append(mets['EMD']); ml['FNI_EMD'].append(mets['FNI_EMD'])

# ---- Tilde metrics (NAM & QLR): SD / SDb / DISCR / PROD; SASI or MAX per switch ----
TILDE_METHOD_LABELS = {
    'NAM': {'SD_tilde': 'SD_NAM', 'SDb_tilde': 'SDb_NAM', 'DISCR_tilde': 'DISCR_NAM'},
    'QLR': {'SD_tilde': 'SD_QLR', 'SDb_tilde': 'SDb_QLR', 'DISCR_tilde': 'DISCR_QLR'},
}

for scenario_key, blocks in explanations_nam.items():
    sc_base = base_scenario_name(scenario_key)
    results_raw_collection.setdefault(sc_base, {})

    # NAM tilde
    nam_list = blocks.get('nam_tilde_metrics', [])
    for entry in nam_list:
        for k_src, k_dst in TILDE_METHOD_LABELS['NAM'].items():
            if k_src not in entry: continue
            vec64 = tilde_to_64(entry[k_src], family_tag='NAM', scenario_key=scenario_key)
            mets = compute_metrics(vec64)
            ml = results_raw_collection[sc_base].setdefault(k_dst, {'IMA': [], 'EMD': [], 'FNI_EMD': []})
            ml['IMA'].append(mets['IMA']); ml['EMD'].append(mets['EMD']); ml['FNI_EMD'].append(mets['FNI_EMD'])
        sd_vec = entry.get('SD_tilde', None); sdb_vec = entry.get('SDb_tilde', None)
        if sd_vec is not None and sdb_vec is not None:
            prod_vec = np.asarray(sd_vec, float) * np.asarray(sdb_vec, float)
            vec64 = tilde_to_64(prod_vec, family_tag='NAM', scenario_key=scenario_key)
            mets = compute_metrics(vec64)
            ml = results_raw_collection[sc_base].setdefault('PROD_NAM', {'IMA': [], 'EMD': [], 'FNI_EMD': []})
            ml['IMA'].append(mets['IMA']); ml['EMD'].append(mets['EMD']); ml['FNI_EMD'].append(mets['FNI_EMD'])

    # QLR tilde
    qlr_list = blocks.get('qlr_tilde_metrics', [])
    for entry in qlr_list:
        for k_src, k_dst in TILDE_METHOD_LABELS['QLR'].items():
            if k_src not in entry: continue
            vec64 = tilde_to_64(entry[k_src], family_tag='QLR', scenario_key=scenario_key)
            mets = compute_metrics(vec64)
            ml = results_raw_collection[sc_base].setdefault(k_dst, {'IMA': [], 'EMD': [], 'FNI_EMD': []})
            ml['IMA'].append(mets['IMA']); ml['EMD'].append(mets['EMD']); ml['FNI_EMD'].append(mets['FNI_EMD'])
        sd_vec = entry.get('SD_tilde', None); sdb_vec = entry.get('SDb_tilde', None)
        if sd_vec is not None and sdb_vec is not None:
            prod_vec = np.asarray(sd_vec, float) * np.asarray(sdb_vec, float)
            vec64 = tilde_to_64(prod_vec, family_tag='QLR', scenario_key=scenario_key)
            mets = compute_metrics(vec64)
            ml = results_raw_collection[sc_base].setdefault('PROD_QLR', {'IMA': [], 'EMD': [], 'FNI_EMD': []})
            ml['IMA'].append(mets['IMA']); ml['EMD'].append(mets['EMD']); ml['FNI_EMD'].append(mets['FNI_EMD'])

# ---- Aggregate ----
final_aggregated_results = {}
for sc_base, meth_data in results_raw_collection.items():
    final_aggregated_results[sc_base] = {}
    for meth, ml in meth_data.items():
        IMA_m, IMA_s = nanmeanstd(ml['IMA'])
        EMD_m, EMD_s = nanmeanstd(ml['EMD'])
        FNI_m, FNI_s = nanmeanstd(ml['FNI_EMD'])
        final_aggregated_results[sc_base][meth] = {
            'IMA_mean': IMA_m, 'IMA_std': IMA_s,
            'EMD_mean': EMD_m, 'EMD_std': EMD_s,
            'FNI_EMD_mean': FNI_m, 'FNI_EMD_std': FNI_s,
        }


# =========================
#      Print summary 
# =========================
def brief_summary(d):
    lines = []
    for sc in sorted(d.keys()):
        lines.append(f"[{sc}]")
        for meth, vals in sorted(d[sc].items()):
            lines.append(
                f"  {meth}: IMA={vals['IMA_mean']:.3f}±{vals['IMA_std']:.3f}, "
                f"EMD={vals['EMD_mean']:.3f}±{vals['EMD_std']:.3f}, "
                f"FNI-EMD={vals['FNI_EMD_mean']:.3f}±{vals['FNI_EMD_std']:.3f}"
            )
        lines.append("")
    return "\n".join(lines)

print("=== Aggregated results (means ± stds) ===")
print(brief_summary(final_aggregated_results))

=== Aggregated results (means ± stds) ===
[linear_additive_1d1p_0.10_correlated]
  DISCR_NAM: IMA=0.797±0.067, EMD=0.905±0.026, FNI-EMD=0.932±0.023
  DISCR_QLR: IMA=0.665±0.054, EMD=0.872±0.021, FNI-EMD=0.886±0.019
  PROD_NAM: IMA=0.955±0.026, EMD=0.835±0.072, FNI-EMD=0.991±0.006
  PROD_QLR: IMA=0.961±0.022, EMD=0.975±0.015, FNI-EMD=0.990±0.005
  SD_NAM: IMA=0.483±0.051, EMD=0.788±0.066, FNI-EMD=0.889±0.022
  SD_QLR: IMA=0.487±0.002, EMD=0.862±0.001, FNI-EMD=0.869±0.001
  SDb_NAM: IMA=0.850±0.065, EMD=0.926±0.027, FNI-EMD=0.956±0.020
  SDb_QLR: IMA=0.804±0.098, EMD=0.924±0.036, FNI-EMD=0.936±0.032
  ebm: IMA=0.566±0.003, EMD=0.838±0.011, FNI-EMD=0.904±0.001
  ig: IMA=0.857±0.036, EMD=0.936±0.024, FNI-EMD=0.960±0.009
  kernel_svm: IMA=0.797±0.075, EMD=0.917±0.027, FNI-EMD=0.927±0.023
  pattern_attribution: IMA=0.418±0.171, EMD=0.749±0.063, FNI-EMD=0.805±0.071
  pattern_gam: IMA=0.395±0.095, EMD=0.775±0.035, FNI-EMD=0.800±0.035
  pattern_net: IMA=0.397±0.191, EMD=0.776±0.068, FNI-EMD=0.7

In [9]:
# =========================
# Transposed LaTeX tables
# =========================
def parse_scenario_name_for_table(scenario_base_name):
    name_lower = scenario_base_name.lower()
    base_type_str = "OTHER"
    if "xor" in name_lower: base_type_str = "XOR"
    elif "multiplicative" in name_lower: base_type_str = "MULT"
    elif "linear" in name_lower or "additive" in name_lower: base_type_str = "LIN"
    dist_str = "DIST" if ("_distractor" in name_lower or "distractor_" in name_lower) else ""
    bg_str = "UNK_BG"
    if "white" in name_lower: bg_str = "WHITE"
    elif "correlated" in name_lower: bg_str = "CORR"
    return f"{base_type_str} {dist_str} {bg_str}".replace("  ", " ").strip()

def _parse_label_parts(parsed_label):
    parts = parsed_label.split()
    base = parts[0]
    if len(parts) == 2: return base, parts[1]
    if len(parts) >= 3: return base, " ".join(parts[1:])
    return base, "UNK_BG"

BASE_TYPES = ["LIN", "MULT", "XOR"]
BG_ORDER = ["WHITE", "CORR", "DIST WHITE", "DIST CORR", "UNK_BG"]

METHOD_DISPLAY = {
    'pattern_gam': r'$\text{PAT}^{\text{GAM}}$',
    'pattern_qlr': r'$\text{PAT}^{\text{QLR}}$',
    'SD_NAM': r'$\text{SD}(f^{\text{GAM}})$',
    'SDb_NAM': r'$\text{SD}(f^{\text{PGAM}})$',
    'DISCR_NAM': r'$\text{DISCR}(f^{\text{GAM}})$',
    'PROD_NAM': r'$\text{PROD}(f^{\text{GAM}})$',
    'SD_QLR': r'$\text{SD}(f^{\text{QLR}})$',
    'SDb_QLR': r'$\text{SD}(f^{\text{PQLR}})$',
    'DISCR_QLR': r'$\text{DISCR}(f^{\text{QLR}})$',
    'PROD_QLR': r'$\text{PROD}(f^{\text{QLR}})$',
    'ebm': r'EBM',
    'kernel_svm': r'Kernel Pattern',
    'pattern_net': r'PatternNet',
    'pattern_attribution': r'PatternAttribution',
    'shap': r'SHAP',
    'ig': r'Int. Grads.',
}
METHOD_ORDER = [
    'pattern_gam', 'pattern_qlr',
    'SD_NAM', 'SDb_NAM', 'PROD_NAM', 'DISCR_NAM',
    'SD_QLR', 'SDb_QLR', 'PROD_QLR', 'DISCR_QLR',
    'ebm', 'kernel_svm', 'pattern_net', 'pattern_attribution', 'shap', 'ig'
]

def format_latex_value(mean_val, std_val, precision=2, is_bold=False):
    if mean_val is None or std_val is None or np.isnan(mean_val) or np.isnan(std_val):
        return "-"
    mean_str = f"{mean_val:.{precision}f}"
    std_str  = f"{std_val:.{precision}f}"
    core = f"{mean_str} \\pm {std_str}"
    return f"$\\mathbf{{{core}}}$" if is_bold else f"${core}$"

def generate_latex_tables_transposed(aggregated_metrics, precision=2, shade_mult=True):
    data_for_tables = {}
    present_methods, present_labels = set(), set()
    for scenario_base_name, methods_data in aggregated_metrics.items():
        parsed_label = parse_scenario_name_for_table(scenario_base_name)
        present_labels.add(parsed_label)
        data_for_tables.setdefault(parsed_label, {})
        for method_name, metrics_values in methods_data.items():
            present_methods.add(method_name)
            data_for_tables[parsed_label][method_name] = metrics_values

    columns = []
    group_spans = []
    for bt in BASE_TYPES:
        subcols = []
        for bg in BG_ORDER:
            label = f"{bt} {bg}".replace("  "," ").strip()
            if label in present_labels:
                subcols.append(label)
        if subcols:
            columns.extend(subcols)
            group_spans.append((bt, len(subcols)))

    ordered_methods = [m for m in METHOD_ORDER if m in present_methods]
    extras = sorted([m for m in present_methods if m not in ordered_methods])
    ordered_methods.extend(extras)

    metrics_config = [
        ("IMA", "IMA_mean", "IMA_std", "Importance Mass Accuracy (IMA)"),
        ("EMD", "EMD_mean", "EMD_std", "Earth Mover's Distance (EMD)"),
        ("FNI-EMD", "FNI_EMD_mean", "FNI_EMD_std", "False-Negative Invariant EMD (FNI-EMD)")
    ]

    latex_output_tables = {}

    for metric_key, mean_key, std_key, caption_title in metrics_config:
        colspec = "l|"
        for col_label in columns:
            base_type, _ = _parse_label_parts(col_label)
            colspec += "M" if (shade_mult and base_type == "MULT") else "c"
        table_cols_format = colspec

        top_header_cells = [""]
        for group_name, span in group_spans:
            if shade_mult and group_name == "MULT":
                top_header_cells.append(rf"\multicolumn{{{span}}}{{c}}{{\cellcolor{{colMULT}} {group_name}}}")
            else:
                top_header_cells.append(rf"\multicolumn{{{span}}}{{c}}{{{group_name}}}")
        top_header_line = " & ".join(top_header_cells) + r" \\"

        sub_header_cells = ["Method"]
        for col_label in columns:
            _, bg_group = _parse_label_parts(col_label)
            sub_header_cells.append(bg_group)
        sub_header_line = " & ".join(sub_header_cells) + r" \\ \hline\hline"

        col_best_rounded = []
        for col_label in columns:
            vals = []
            for m in ordered_methods:
                mean_val = data_for_tables.get(col_label, {}).get(m, {}).get(mean_key)
                if mean_val is not None and not np.isnan(mean_val):
                    vals.append(np.around(mean_val, precision))
            col_best_rounded.append(np.max(vals) if vals else None)

        s = []
        s.append(f"% Transposed LaTeX Table for {metric_key} (MULT banded)")
        s.append(r"\begin{table}[htbp]")
        s.append(r"\centering")
        s.append(rf"\caption{{{caption_title}. Values are mean $\pm$ standard deviation. Best result per column (ties at {precision} dp) is emboldened.}}")
        s.append(rf"\label{{tab:{metric_key.lower().replace('-', '')}_results_transposed}}")
        s.append(r"\resizebox{\textwidth}{!}{")
        s.append(rf"\begin{{tabular}}{{{table_cols_format}}}")
        s.append(r"\hline")
        s.append(top_header_line)
        s.append(sub_header_line)

        for m in ordered_methods:
            disp = METHOD_DISPLAY.get(m, m.replace("_", r"\_"))
            row_cells = [disp]
            for j, col_label in enumerate(columns):
                md = data_for_tables.get(col_label, {}).get(m, {})
                mean_val = md.get(mean_key); std_val = md.get(std_key)
                is_best = (
                    mean_val is not None and not np.isnan(mean_val) and
                    col_best_rounded[j] is not None and
                    np.around(mean_val, precision) == col_best_rounded[j]
                )
                row_cells.append(format_latex_value(mean_val, std_val, precision=precision, is_bold=is_best))
            s.append(" & ".join(row_cells) + r" \\")
        s.append(r"\hline")
        s.append(r"\end{tabular}")
        s.append(r"}")
        s.append(r"\end{table}")

        latex_output_tables[metric_key] = "\n".join(s)

    return latex_output_tables

latex_tables = generate_latex_tables_transposed(final_aggregated_results, precision=2, shade_mult=True)
for metric_name, table_code in latex_tables.items():
    print(f"\n% --- LaTeX Table for {metric_name} (transposed, MULT banded) ---")
    print(table_code)
    print("% --- End of LaTeX Table ---")


% --- LaTeX Table for IMA (transposed, MULT banded) ---
% Transposed LaTeX Table for IMA (MULT banded)
\begin{table}[htbp]
\centering
\caption{Importance Mass Accuracy (IMA). Values are mean $\pm$ standard deviation. Best result per column (ties at 2 dp) is emboldened.}
\label{tab:ima_results_transposed}
\resizebox{\textwidth}{!}{
\begin{tabular}{l|ccccMMMMcccc}
\hline
 & \multicolumn{4}{c}{LIN} & \multicolumn{4}{c}{\cellcolor{colMULT} MULT} & \multicolumn{4}{c}{XOR} \\
Method & WHITE & CORR & DIST WHITE & DIST CORR & WHITE & CORR & DIST WHITE & DIST CORR & WHITE & CORR & DIST WHITE & DIST CORR \\ \hline\hline
$\text{PAT}^{\text{GAM}}$ & $0.86 \pm 0.02$ & $0.40 \pm 0.10$ & $0.89 \pm 0.01$ & $0.36 \pm 0.11$ & $0.81 \pm 0.04$ & $0.36 \pm 0.11$ & $0.82 \pm 0.03$ & $0.38 \pm 0.11$ & $0.19 \pm 0.05$ & $0.26 \pm 0.09$ & $0.30 \pm 0.09$ & $0.24 \pm 0.10$ \\
$\text{PAT}^{\text{QLR}}$ & $0.70 \pm 0.01$ & $0.52 \pm 0.07$ & $0.74 \pm 0.00$ & $0.37 \pm 0.04$ & $0.16 \pm 0.02$ & $0.21 \pm 0.02$ & 