In [34]:
# =============================================================================
# Significance Tests for Experimental Results
# =============================================================================

import os
import pandas as pd
import numpy as np
from scipy.stats import wilcoxon

# ------------------------------
# Paths and dataset configuration
# ------------------------------

# Base directory containing results CSVs
base_dir = os.path.join("..", "results")

# List of dataset types to load (adjust to your project)
from definitions import DEFAULT_DATA_GEN_TYPES

# Metrics to test
metrics_to_plot = ["mse_test", "wasser_test"]

# Methods configuration (adjust according to your METHODS_CONFIG)
from experiment.plot_config import METHODS_CONFIG
METHODS_CONFIG = [m for m in METHODS_CONFIG if m[0] != '__true__']
method_display = {m[0]: m[3] for m in METHODS_CONFIG if m[0] is not True}

# Define "our methods" and baselines
our_methods = ["ours-linear_u_diag", "ours-lnl_u_diag"]
all_methods = list(method_display.keys())
baselines = [m for m in all_methods if m not in our_methods]

# Significance level
alpha = 0.05
# Practical improvement level
beta = 0.05

# ------------------------------
# Load CSVs
# ------------------------------
all_data = {}
for data_type in DEFAULT_DATA_GEN_TYPES:
    csv_path = os.path.join(base_dir, f"{data_type}/summary_00_01_00/mean",
                            f"df-{data_type}-summary_00_01_00.csv")
    df = pd.read_csv(csv_path)
    all_data[data_type] = df

def median_log_ratio_ci(our, base, n_boot=10000, alpha=0.05):
    our = np.maximum(our, 1e-12)
    base = np.maximum(base, 1e-12)
    log_ratio = np.log(our) - np.log(base)
    boot_stats = []
    n = len(log_ratio)
    for _ in range(n_boot):
        idx = np.random.randint(0, n, size=n)
        boot_stats.append(np.median(log_ratio[idx]))
    lo = np.percentile(boot_stats, 100 * (alpha/2))
    hi = np.percentile(boot_stats, 100 * (1 - alpha/2))
    return lo, hi  # in log-space; exp gives multiplicative CI
# ------------------------------
# Initialize results dictionary
# ------------------------------
# Structure: results[our_method][baseline_method][data_type] = outcome string
results = {our: {base: {} for base in baselines} for our in our_methods}

# ------------------------------
# Run paired Wilcoxon tests and record detailed outcomes
# ------------------------------
for data_type, df in all_data.items():
    for metric in metrics_to_plot:
        df_metric = df[df["metric"] == metric]

        for our_method in our_methods:
            our_df = df_metric[df_metric["method"] == our_method]

            for baseline in baselines:
                baseline_df = df_metric[df_metric["method"] == baseline]

                # Merge to ensure paired comparison on data_idx and env_idx
                paired = pd.merge(
                    our_df, baseline_df,
                    on=["data_idx", "env_idx"],
                    suffixes=("_our", "_base")
                )

                # Include NaNs as worst-case
                our_vals = paired["val_our"].fillna(np.inf).values
                baseline_vals = paired["val_base"].fillna(np.inf).values

                # Skip if no pairs
                if len(our_vals) == 0:
                    continue

                # assume paired arrays our_vals and baseline_vals (positive)
                eps = 1e-12  # to guard against zeros
                our = np.maximum(our_vals, eps)
                base = np.maximum(baseline_vals, eps)
                
                log_ratio = np.log(our) - np.log(base)   # r_i
                threshold = np.log(1 - beta)                 # log(1 - beta)
                
                # test H1: median(log_ratio) < threshold  <=> median(log_ratio - threshold) < 0
                d_1 = log_ratio - threshold
                d_2 = - log_ratio - threshold
                
                # wilcoxon of d against 0: call wilcoxon(d, np.zeros_like(d), alternative="less")
                stat_1, p_value_1 = wilcoxon(d_1, np.zeros_like(d_1), alternative="less")
                stat_2, p_value_2 = wilcoxon(d_2, np.zeros_like(d_2), alternative="less")

                # Initialize entry if not exists
                if data_type not in results[our_method][baseline]:
                    results[our_method][baseline][data_type] = [
                        {'ours_better': False, 'baseline_better': False},
                        {'ours_better': False, 'baseline_better': False}
                    ]  # [mse, wasser]

                # Assign outcome based on metric
                metric_idx = 0 if metric == "mse_test" else 1
                if p_value_1 < alpha:
                    results[our_method][baseline][data_type][metric_idx]['ours_better'] = True
                if p_value_2 < alpha:
                    results[our_method][baseline][data_type][metric_idx]['baseline_better'] = True


                # lo, hi = median_log_ratio_ci(our_vals, baseline_vals)
                # print("median ratio CI:", np.exp(lo), np.exp(hi))



# ------------------------------
# Generate LaTeX tables: one per "our method" with red/black stars
# ------------------------------
from IPython.display import display, Latex
from sympy import preview

for our_method in our_methods:
    latex_lines = [
        "\\begin{table}[h]",
        "\\centering",
        f"\\begin{{tabular}}{{l{'c' * len(DEFAULT_DATA_GEN_TYPES)}}}",
        "\\hline",
        "Baseline / Dataset & " + " & ".join(DEFAULT_DATA_GEN_TYPES) + " \\\\",
        "\\hline"
    ]

    for baseline in baselines:
        row = [method_display.get(baseline, baseline)]
        for data_type in DEFAULT_DATA_GEN_TYPES:
            cell = ""
            if data_type in results[our_method][baseline]:
                for metric_idx in range(2):  # 0 = MSE, 1 = Wasser
                    entry = results[our_method][baseline][data_type][metric_idx]
                    # entry is now a dict with keys 'ours_better' and 'baseline_better'
                    symbol = ""
                    if entry.get('ours_better', False):
                        symbol += "$\\ast$"  # black star
                    if entry.get('baseline_better', False):
                        symbol += "\\textcolor{red}{$\\ast$}"  # red star
                    cell += symbol.strip()
                    if metric_idx == 0:
                        cell += " / "  # separate metrics in the cell visually
            else:
                cell = " / "  # no test run
            row.append(cell)
        latex_lines.append(" & ".join(row) + " \\\\")

    latex_lines.extend([
        "\\hline",
        "\\end{tabular}",
        f"\\caption{{Significance tests for {method_display.get(our_method, our_method)}. "
        "Black star indicates our method significantly outperforms the baseline; "
        "red star indicates baseline significantly outperforms ours. "
        "Each cell shows MSE / Wasserstein results.}",
        "\\end{table}"
    ])

    latex_table = "\n".join(latex_lines)

    # Save table to file
    filename = f"significance_{our_method}.tex"
    with open(filename, "w") as f:
        f.write(latex_table)

    print(latex_table)


\begin{table}[h]
\centering
\begin{tabular}{lcccccc}
\hline
Baseline / Dataset & linear-er & linear-sf & scm-er & scm-sf & sergio-er & sergio-sf \\
\hline
GIES &  / $\ast$ &  / $\ast$ &  / $\ast$ &  /  & $\ast$ / $\ast$ & $\ast$ / $\ast$ \\
IGSP & $\ast$ / $\ast$ &  / $\ast$ & $\ast$ / $\ast$ & \textcolor{red}{$\ast$} / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ \\
DCDI & $\ast$ / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ & \textcolor{red}{$\ast$} /  &  /  \\
LLC & $\ast$ / $\ast$ &  / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ \\
NODAGS & $\ast$ / $\ast$ &  / $\ast$ & $\ast$ / $\ast$ &  / $\ast$ & \textcolor{red}{$\ast$} / $\ast$ &  /  \\
KDS (Linear) &  /  &  /  & $\ast$ / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ & $\ast$ / $\ast$ \\
KDS (MLP) & \textcolor{red}{$\ast$} /  & $\ast$ / $\ast$ & $\ast$ / $\ast$ &  /  &  /  &  /  \\
\hline
\end{tabular}
\caption{Significance tests for \textbf{SKDS (Linear)}. Black star indicates our