# Bayesian Stats

This notebook can be used to conduct bayesian ttests and ANOVAS with pymc

In [None]:
# imports

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns
import pymc as pm
import arviz as az
from scipy.stats import gaussian_kde, norm
import ipywidgets as widgets
from IPython.display import display, clear_output

print("Running with PyMC version:", pm.__version__)


In [None]:
# paths 
# here we have paths to folders where the logging histories of individual training runs are stored.

# for Unet and Swin Transformer models trained on MACS, we have two sets of logs:
unet_macs_loss_dir = r'N:\isipd\projects\p_planetdw\data\methods_test\logs\unet_ae_samples'
swin_macs_loss_dir = r'N:\isipd\projects\p_planetdw\data\methods_test\logs\swin_ae_samples'

# for Unet and Swin Transformer models trained on PS, we have two sets of logs:
unet_ps_loss_dir = r'N:\isipd\projects\p_planetdw\data\methods_test\logs\unet_ps_samples'
swin_ps_loss_dir = r'N:\isipd\projects\p_planetdw\data\methods_test\logs\swin_ps_samples'

# these are the metrics that we want to analyse plus the ones where higher is better.
metrics = ['loss', 'specificity', 'sensitivity', 'IoU', 'f1_score', 'Hausdorff_distance']
maximize_metrics = {'specificity', 'sensitivity', 'IoU', 'f1_score'}

# and some output directories where we will save the results of our analysis.
unet_macs_output_dir = r'N:\isipd\projects\p_planetdw\data\methods_test\logs\unet_ae_samples'
swin_macs_output_dir = r'N:\isipd\projects\p_planetdw\data\methods_test\logs\swin_ae_samples'

## Read and plot the data

In [None]:
# this function reads the metrics from CSV files in a given directory and returns them as a 3D numpy array, with
# the shape (number of files, number of epochs, number of metrics). It also returns a lookup dictionary for metric names

def read_metrics_as_array(directory, metrics):
    
    """
    Reads CSV files from a directory and extracts specified metrics into a 3D numpy array.
    Args:
        directory (str): Path to the directory containing CSV files.
        metrics (list): List of metric names to extract from the CSV files.
    Returns:
        tuple: A tuple containing:
            - data_array (np.ndarray): A 3D numpy array of shape (num_files, num_epochs, num_metrics).
            - lookup (dict): A dictionary mapping metric names to their indices in the data array.
            - files (list): List of file names processed.
    """

    files = sorted([f for f in os.listdir(directory) if f.endswith('.csv')])
    data_list = []
    
    metric_names = []
    for metric in metrics:
        metric_names.append(metric)
        metric_names.append('val_' + metric)

    for file in files:
        file_path = os.path.join(directory, file)
        df = pd.read_csv(file_path)
        file_data = []

        for name in metric_names:
            if name in df.columns:
                file_data.append(df[name].values)
            else:
                # Fill with NaNs if column is missing
                file_data.append(np.full(len(df), np.nan))
        
        # Transpose so shape is (epochs, metrics)
        file_data = np.stack(file_data, axis=1)  # shape: (epochs, num_metrics)
        data_list.append(file_data)

    # Convert to a 3D array: (files, epochs, metrics)
    data_array = np.stack(data_list, axis=0)

    # Build lookup dict
    lookup = {name: idx for idx, name in enumerate(metric_names)}

    return data_array, lookup, files


In [None]:
# this function plots the val and train metric from the array of losses.

def plot_losses(loss_array, metrics, metric_lookup, output_dir, show_plot=False):
    """
    Plots training and validation metrics from a 3D numpy array of losses.
    Args:
        loss_array (np.ndarray): 3D numpy array of shape (num_files, num_epochs, num_metrics).
        metrics (list): List of metric names to plot.
        metric_lookup (dict): Dictionary mapping metric names to their indices in the loss_array.
        output_dir (str): Directory to save the plot.
        show_plot (bool): Whether to display the plot interactively.
    """

    epochs = loss_array.shape[1]
    num_metrics = len(metrics)

    plt.figure(figsize=(20, 2.5))  # Square layout
    
    for i, metric in enumerate(metrics):
        plt.subplot(1, num_metrics, i + 1)

        # Training metric
        for j in range(loss_array.shape[0]):
            plt.plot(range(epochs), loss_array[j, :, metric_lookup[metric]], color='lightblue', linewidth=1)

        train_mean = np.nanmean(loss_array[:, :, metric_lookup[metric]], axis=0)
        
        # Validation metric
        val_metric = 'val_' + metric
        if val_metric in metric_lookup:
            for j in range(loss_array.shape[0]):
                plt.plot(range(epochs), loss_array[j, :, metric_lookup[val_metric]], color='peachpuff', linewidth=1)

            val_mean = np.nanmean(loss_array[:, :, metric_lookup[val_metric]], axis=0)
            plt.plot(range(epochs), train_mean, color='tab:blue', label=f'train', linewidth=2)
            plt.plot(range(epochs), val_mean, color='tab:orange', label=f'val', linewidth=2)

        plt.title(metric)
        plt.ylim(0,1)
        plt.xlabel('Epochs')
        plt.ylabel(metric)
        #plt.legend(loc='upper right')
        plt.grid(True)
        plt.gca().set_aspect('auto')  # Square plot per metric (approx)
    
    plt.tight_layout()
    output_path = os.path.join(output_dir, 'losses_plot.png')
    plt.savefig(output_path)
    if show_plot:
        plt.show()
    plt.close()
    print(f"Plot saved to {output_path}")

In [None]:
# lets import and plot the losses for Unet and Swin Transformer models trained on MACS and PS datasets.

unet_macs, unet_macs_metric_lookup, unet_macs_file_names = read_metrics_as_array(unet_macs_loss_dir, metrics)
swin_macs, swin_macs_metric_lookup, swin_macs_file_names = read_metrics_as_array(swin_macs_loss_dir, metrics)
unet_ps, unet_ps_metric_lookup, unet_ps_file_names = read_metrics_as_array(unet_ps_loss_dir, metrics)
swin_ps, swin_ps_metric_lookup, swin_ps_file_names = read_metrics_as_array(swin_ps_loss_dir, metrics)

plot_losses(unet_macs, metrics, unet_macs_metric_lookup, unet_macs_output_dir, show_plot=True)
plot_losses(swin_macs, metrics, swin_macs_metric_lookup, swin_macs_output_dir, show_plot=True)
plot_losses(unet_ps, metrics, unet_ps_metric_lookup, unet_ps_loss_dir, show_plot=True)
plot_losses(swin_ps, metrics, swin_ps_metric_lookup, swin_ps_loss_dir, show_plot=True)

In [None]:
# next we will define a function to get the metric values across epochs at the point of the lowest val_loss for each file in the data array.
# we also define a function to help interpret the Bayes factor values.

def get_best_metric(data_array, metric_lookup, metric):
    """
    Get the best metric values across epochs for each file in the data array.
    Args:
        data_array (np.ndarray): 3D array of shape (files, epochs, metrics).
        metric_lookup (dict): Dictionary mapping metric names to their indices.
        metric (str): The metric to evaluate.
        maximize_metrics (set): Set of metrics that should be maximized.
    Returns:
        np.ndarray: Array of best metric values for each file.
    """

    best_values = []


    for i in range(data_array.shape[0]):

        losses = data_array[i, :, metric_lookup['loss']]
        values = data_array[i, :, metric_lookup[metric]]


        best_epoch = np.nanargmin(losses)
        #print(f'best epoch; {best_epoch}')

        best_value = values[best_epoch]
        #print('best value:', best_value)

        best_values.append(best_value)

    return np.array(best_values)

def interpret_bayes_factor(bf):
    """Return Jeffreys-style verbal label for a Bayes factor > 1."""
    if bf < 3:
        return "anecdotal"
    elif bf < 10:
        return "moderate"
    elif bf < 30:
        return "strong"
    elif bf < 100:
        return "very strong"
    else:
        return "extreme"

## BEST Bayesian T-test

The **BEST (Bayesian Estimation Supersedes the t-Test)** approach offers a modern, more informative alternative to the traditional **t-test**. While the t-test provides a p-value to assess whether two group means are significantly different, it relies heavily on assumptions like:

- Normality of data
- Equal variances
- A fixed significance threshold (e.g., *p* < 0.05)

Moreover, the t-test doesn't convey the size or uncertainty of the effect in an intuitive way.

In contrast, **BEST** uses Bayesian methods to estimate the full **posterior distribution** of the group means and their difference. This approach provides a richer understanding, including:

- How large the difference might be
- How uncertain we are about that difference
- The probability that one group is greater than the other

BEST is also more robust to common issues like unequal variances and outliers.

In short, **BEST supersedes the t-test** by delivering **more nuanced, probabilistic insights** into group comparisons, rather than a single binary decision.

In [None]:
def BEST(combined_array, group_one, group_two, metric, group_one_label='unet', group_two_label='swin', plot=True):

    """
    Perform Bayesian estimation of the difference in means and standard deviations
    between two groups using Student's t-distribution (Kruschke 2005).
    Args:
        combined_array (pd.DataFrame): DataFrame containing the metric values and group labels.
        group_one (np.ndarray): values of group 1.
        group_two (np.ndarray): values of group 2.
    Returns:
        None
    """
    minimal_metrcis = ['val_loss', 'loss', 'val_Hausdorff_distance', 'Hausdorff_distance']

    metric_values = combined_array[metric].values

    mu_m = metric_values.mean()
    mu_s = metric_values.std()*2

    sigma_low = 10**-1
    sigma_high = 10

    # define the model. change the priors as needed. 

    with pm.Model() as model:

        # Priors for group means (share common hyperpriors mu_m, mu_s)
        # We place Normal priors on each group’s mean, centered at a common location (mu_m) with shared uncertainty (mu_s) comming from the data (weakly informed).
        group1_mean = pm.Normal(f'{group_one_label}_mean', mu=mu_m, sigma=mu_s)
        group2_mean = pm.Normal(f'{group_two_label}_mean', mu=mu_m, sigma=mu_s)

        # Priors for group standard deviations (allowing unequal variances)
        # Each group’s scale is given an independent Uniform prior between sigma_low and sigma_high, allowing heterogeneity (previously defined --> uninformend as not from the data).
        group1_std = pm.Uniform(f'{group_one_label}_std', lower=sigma_low, upper=sigma_high)
        group2_std = pm.Uniform(f'{group_two_label}_std', lower=sigma_low, upper=sigma_high)

        # Degrees of freedom for Student-t (robust to outliers)
        #To model heavy tails and robustness to outliers, we use a Student-t likelihood. We put an Exponential prior on nu − 1, shift it to nu, and also expose its log₁₀ for diagnostics.
        nu_minus_one = pm.Exponential('nu_minus_one', lam=1/29)
        nu = pm.Deterministic('nu', nu_minus_one + 1)
        nu_log10 = pm.Deterministic('nu_log10', np.log10(nu))

        # Convert std to precision for the Student-t likelihood
        # Student-t in PyMC uses precision (λ = σ⁻²), so we square-invert the std priors.
        lambda_group1 = group1_std**-2
        lambda_group2 = group2_std**-2

        # Observation models for each group
        # Each group’s data are modeled as Student-t draws with their own mean, precision, and shared ν.
        group_one_obs = pm.StudentT(f'{group_one_label}_obs', mu=group1_mean, lam=lambda_group1, nu=nu, observed=group_one)
        group_two_obs = pm.StudentT(f'{group_two_label}_obs', mu=group2_mean, lam=lambda_group2, nu=nu, observed=group_two)

        # Deterministic comparisons for interpretation
        # We compute the difference of means, difference of stds, and Cohen’s d–style effect size for easy interpretation of the posterior.
        diff_of_means = pm.Deterministic('diff_of_means', group1_mean - group2_mean)
        diff_of_stds = pm.Deterministic('diff_of_stds', group1_std - group2_std)
        effect_size = pm.Deterministic('effect_size', diff_of_means / np.sqrt((group1_std**2 + group2_std**2) / 2))

        # Posterior sampling
        idata = pm.sample(tune=1000, draws=2000, chains=4, target_accept=0.95, return_inferencedata=True)

    if plot:
        # Plotting the posterior distributions and summaries
        print('\n---- Posterior for the means and stds ----')

        az.plot_posterior(idata, var_names=[f'{group_one_label}_mean', f'{group_two_label}_mean', f'{group_one_label}_std', f'{group_two_label}_std', 'nu_log10', 'nu'])
        plt.show()

        print('\n---- Posterior for the differences and effect size ----')

        az.plot_posterior(idata, var_names=['diff_of_means', 'diff_of_stds', 'effect_size'], ref_val=0)
        plt.show()

        print('\n---- Forests for means, stds, and nu ----')
        
        az.plot_forest(idata, var_names=[f'{group_one_label}_mean', f'{group_two_label}_mean'])
        plt.show()

        print('\n---- Forests for stds and nu ----')

        az.plot_forest(idata, var_names=[f'{group_one_label}_std', f'{group_two_label}_std', 'nu'])
        plt.show()

    print('\n---- Model summary ----')

    summary = az.summary(idata, var_names=[f'{group_one_label}_mean', f'{group_two_label}_mean','diff_of_means', 'diff_of_stds', 'effect_size'])
    print(summary)

    print('\n---- Savage-Dickey Bayes Factor ----\n')

    
    # Posterior density at δ = 0  (KDE is still appropriate here)
    diff_samples = idata.posterior['diff_of_means'].values.flatten()
    posterior_kde            = gaussian_kde(diff_samples)
    posterior_density_at_zero = posterior_kde.evaluate(0)[0]

    # Analytical prior density at δ = 0  (δ ~ Normal(0, √2·mu_s))
    prior_sd_diff            = np.sqrt(2) * mu_s
    prior_density_at_zero    = norm.pdf(0, loc=0, scale=prior_sd_diff)

    # Bayes factors
    BF_01 = posterior_density_at_zero / prior_density_at_zero   # H₀ over H₁
    BF_10 = 1 / BF_01                                           # H₁ over H₀

    metric_is_lower_better = metric  in minimal_metrcis
    mean_diff = diff_samples.mean()   # μ_unet − μ_swin

    if BF_10 > 1:            # data support the alternative
        # who wins, given the metric direction?
        if (not metric_is_lower_better and mean_diff > 0) or \
        (    metric_is_lower_better and mean_diff < 0):
            winner, loser = group_one_label, group_two_label
        else:
            winner, loser = group_two_label, group_one_label

        label = interpret_bayes_factor(BF_10)
        print(f"p(δ=0)        : {prior_density_at_zero:.4g}")
        print(f"p(δ=0 | data) : {posterior_density_at_zero:.4g}")
        print(f"Evidence for {winner} outperforming {loser}: "
            f"BF₁₀ = {BF_10:.4g}  ({label})\n")

    else:                     # data support the null
        label = interpret_bayes_factor(BF_01)
        print(f"p(δ=0)        : {prior_density_at_zero:.4g}")
        print(f"p(δ=0 | data) : {posterior_density_at_zero:.4g}")
        print(f"Evidence for no difference (H₀): "
            f"BF₀₁ = {BF_01:.4g}  ({label})")

In [None]:
# we call the BEST function to compare Unet and Swin Transformer models trained on MACS and PS datasets and visualise them as tabs

tabs     = []      # the Output widgets (one per metric)
tab_titles = []    # used to label the tab headers

for metric in metrics:

    full_metric = 'val_' + metric if not metric.startswith('val_') else metric
    out = widgets.Output()          # each metric gets its own Output “sandbox”

    with out:                       # everything inside goes only to this tab
        clear_output(wait=True)     # keeps the tab clean on reruns

        print(f"### Processing metric: {full_metric}\n")

        unet_macs_best = get_best_metric(unet_macs, unet_macs_metric_lookup, full_metric)
        swin_macs_best = get_best_metric(swin_macs, swin_macs_metric_lookup, full_metric)

        #plt.boxplot([unet_best, swin_best], labels=['UNet', 'Swin Transformer'], showfliers=True)



        BEST(
            pd.concat([
                pd.DataFrame({full_metric: unet_macs_best, 'group': 'U-Net | Aerial'}),
                pd.DataFrame({full_metric: swin_macs_best, 'group': 'Swin U-Net | Aerial'}),
            ]).reset_index(drop=True),
            unet_macs_best, swin_macs_best, full_metric,
            plot=True                                    
        )

    # keep references so we can build the Tab afterwards
    tabs.append(out)
    tab_titles.append(full_metric)

tab_widget = widgets.Tab(children=tabs)

for i, title in enumerate(tab_titles):
    tab_widget.set_title(i, title)   # label each tab

display(tab_widget)


## Bayesian ANOVA

**Bayesian ANOVA** offers a more flexible and informative alternative to the traditional **frequentist ANOVA**. While classical ANOVA tests whether there are any statistically significant differences between group means, it comes with several limitations:

- It provides only a *p-value*, not the size or uncertainty of effects.
- It assumes normality, homogeneity of variance, and fixed effects.
- It gives no direct probability for hypotheses—just whether the null is rejected or not.

In contrast, **Bayesian ANOVA** uses probability distributions to directly model uncertainty and effect sizes. Instead of asking whether group means are different *in general*, it estimates:

- The **posterior distribution** of each group mean
- The **probability** of differences between groups
- The **credible intervals** that reflect uncertainty in estimates

Bayesian ANOVA can also naturally handle more complex models (e.g., hierarchical structures, unequal variances) and allows for **model comparison using Bayes Factors**, offering a principled way to weigh evidence for competing hypotheses.

In short, **Bayesian ANOVA** goes beyond just testing for significance—it provides a **deeper, probabilistic understanding** of group differences, effect sizes, and model credibility.

In [None]:
def fit_bayesian_anova(df, metric, lower_is_better=False, robust=True, hierarchical=True, tune=1000, draws=2000, chains=4, target_accept=0.95):
    """
    Fit a Bayesian ANOVA model with options for partial pooling and heavy tails.

    In this function, we:
      1. Extract the outcome and group labels.
      2. Choose weakly informative priors based on the data.
      3. Optionally apply partial pooling across groups.
      4. Optionally use a Student-T likelihood to guard against outliers.
      5. Return posterior samples in an ArviZ InferenceData, including log_likelihood.
    """
    # 1) Prepare data
    y = df[metric].values
    groups = df["group"].astype("category")
    g_idx = groups.cat.codes.values  # convert categories into 0…K-1
    K = int(g_idx.max() + 1)

    # 2) Set priors based on observed data scale
    mu_m, mu_s = y.mean(), y.std() * 2
    sigma_low, sigma_high = 1e-1, 10  # bound sigma to [0.1, 10]

    with pm.Model() as model:
        # --- Priors on group means ---------------------------------------
        if hierarchical:
            # we assume group means share a global distribution
            mu_grand = pm.Normal("mu_grand", mu=mu_m, sigma=mu_s)
            tau = pm.HalfNormal("tau", sigma=mu_s)
            mu = pm.Normal("mu", mu=mu_grand, sigma=tau, shape=K)
        else:
            # independent priors for each group mean
            mu = pm.Normal("mu", mu=mu_m, sigma=mu_s, shape=K)

        # --- Priors on group standard deviations -------------------------
        sigma = pm.Uniform("sigma", lower=sigma_low, upper=sigma_high, shape=K)

        # --- Likelihood: Student-T for robustness -----------------------
        nu = (pm.Exponential("nu_minus_1", 1/29) + 1) if robust else np.inf
        pm.StudentT("obs", mu=mu[g_idx], sigma=sigma[g_idx], nu=nu, observed=y)

        # 3) Sample from the posterior, including log_likelihood for LOO/WAIC
        idata = pm.sample(
            tune=tune,
            draws=draws,
            chains=chains,
            target_accept=target_accept,
            return_inferencedata=True,
            idata_kwargs={"log_likelihood": True}
        )

    # 4) Store metadata for downstream utilities
    idata.attrs["metric"] = metric
    idata.attrs["lower_is_better"] = lower_is_better
    idata.attrs["groups"] = groups.cat.categories.tolist()
    return idata


def fit_bayesian_null(df, metric, robust=True, tune=1000, draws=2000, chains=4, target_accept=0.95):
    """
    Fit a Bayesian null model where all data share one common mean.
    We use this to compute Bayes Factors against the alternative.
    """
    y = df[metric].values
    mu_m, mu_s = y.mean(), y.std() * 2
    sigma_low, sigma_high = 1e-1, 10

    with pm.Model() as null_model:
        mu = pm.Normal("mu", mu=mu_m, sigma=mu_s)  # single mean
        sigma = pm.Uniform("sigma", lower=sigma_low, upper=sigma_high)
        nu = (pm.Exponential("nu_minus_1", 1/29) + 1) if robust else np.inf
        pm.StudentT("obs", mu=mu, sigma=sigma, nu=nu, observed=y)

        # include log_likelihood for model comparison
        idata_null = pm.sample(
            tune=tune,
            draws=draws,
            chains=chains,
            target_accept=target_accept,
            return_inferencedata=True,
            idata_kwargs={"log_likelihood": True}
        )
    return idata_null


def compute_bayes_factor(idata_alt, idata_null):
    """
    Estimate Bayes Factor BF_10 = p(data | alt) / p(data | null).
    We first try LOO elpd difference. If log_likelihood is missing, we fallback to WAIC.
    BF_10 ≈ exp((elpd_alt - elpd_null) / 2).
    """
    try:
        # Compute LOO and extract expected log predictive density
        loo_alt = az.loo(idata_alt, pointwise=False)
        loo_null = az.loo(idata_null, pointwise=False)
        # New ArviZ returns elpd_loo attribute
        delta = loo_alt.elpd_loo - loo_null.elpd_loo
    except (TypeError, ValueError, AttributeError):
        # Fallback to WAIC if LOO fails or missing attributes
        waic_alt = az.waic(idata_alt, pointwise=False)
        waic_null = az.waic(idata_null, pointwise=False)
        # WAICData has elpd_waic attribute
        delta = waic_alt.elpd_waic - waic_null.elpd_waic
    # Convert elpd difference to approximate Bayes Factor
    bf_10 = np.exp(delta / 2)
    return bf_10


def prob_each_is_best(idata):
    """
    Compute posterior probability that each group is "best".

    We define "best" as having the highest (or lowest) mean across draws.
    """
    metric_is_lower = idata.attrs["lower_is_better"]
    # Stack chains and draws into one dimension
    means = idata.posterior["mu"].stack(sample=("chain", "draw")).values
    # Identify index of best mean in each draw
    best_idx = means.argmin(axis=0) if metric_is_lower else means.argmax(axis=0)
    K = means.shape[0]
    # Compute frequency each group is best
    p = np.bincount(best_idx, minlength=K) / best_idx.size
    return pd.Series(p, index=idata.attrs["groups"], name="p(best)")


def pairwise_contrasts(idata, rope=None):
    """
    Summarize pairwise group differences with HDI and probability.

    For each pair A vs. B, we compute:
      - mean_diff = E[mu_A - mu_B]
      - 95% HDI of that difference
      - p_A_gt_B = posterior probability that mu_A > mu_B
      - optional p_in_rope for Region Of Practical Equivalence
    """
    # Extract posterior means across draws
    means = idata.posterior["mu"].stack(sample=("chain", "draw")).values
    K, _ = means.shape
    names = idata.attrs["groups"]
    rows = []

    # Loop over each unique pair
    for i in range(K - 1):
        for j in range(i + 1, K):
            diff = means[i] - means[j]
            # Compute 95% HDI bounds
            hdi_low, hdi_high = np.quantile(diff, [0.025, 0.975])
            # Probability that A > B in posterior draws
            prob_gt0 = (diff > 0).mean()
            row = dict(
                A=names[i],
                B=names[j],
                mean_diff=diff.mean(),
                hdi_low=hdi_low,
                hdi_high=hdi_high,
                p_A_gt_B=prob_gt0
            )
            if rope is not None:
                # Fraction of draws within the ROPE
                row["p_in_rope"] = ((rope[0] < diff) & (diff < rope[1])).mean()
            rows.append(row)
    return pd.DataFrame(rows)

In [None]:
# We build one tab per metric to explore group performance and run our ANOVA.

# Define your data sources: mapping group name to (data_array, metric_lookup)
group_data = {
    "U-Net | Aerial": (unet_macs, unet_macs_metric_lookup),
    "U-Net | PS": (unet_ps, unet_ps_metric_lookup),
    "Swin U-Net | Aerial": (swin_macs, swin_macs_metric_lookup),
    "Swin U-Net | PS": (swin_ps, swin_ps_metric_lookup),
    # add more groups as needed
}

# Metrics where lower is better (e.g., loss, distance)
minimal_metrics = {"val_loss", "loss", "val_Hausdorff_distance", "Hausdorff_distance"}

# Helper to extract best-per-run for a given metric

def best_per_group(metric_name):
    """Return dict of {group: best values array} for each run."""
    out = {}
    for name, (arr, lookup) in group_data.items():
        # we assume get_best_metric is defined elsewhere
        out[name] = get_best_metric(arr, lookup, metric_name)
    return out

# Build interactive tabs for each metric
metrics = ["loss", "accuracy", "Hausdorff_distance"]  # customize your list
tabs, titles = [], []
for metric in metrics:
    full_metric = f"val_{metric}"
    box = widgets.Output()
    with box:
        clear_output(wait=True)
        print(f"### Processing metric: {full_metric}\n")

        # 1) Box-plot of best epoch values
        best_vals = best_per_group(full_metric)
        plt.boxplot(best_vals.values(), labels=best_vals.keys(), showfliers=True)
        plt.title(full_metric)
        plt.show()

        # 2) Prepare DataFrame for analysis
        df = pd.concat([
            pd.DataFrame({full_metric: v, "group": k})
            for k, v in best_vals.items()
        ], ignore_index=True)
        display(df)

        # 3) Fit models and compare
        idata_alt = fit_bayesian_anova(
            df, full_metric,
            lower_is_better=full_metric in minimal_metrics,
            hierarchical=True, robust=True
        )
        idata_null = fit_bayesian_null(df, full_metric, robust=True)

        # 4) Summaries
        print("Posterior probability each group is the best:")
        display(prob_each_is_best(idata_alt).sort_values(ascending=False))

        print("\nPairwise contrasts (95% HDI and p):")
        display(pairwise_contrasts(idata_alt, rope=[-0.005, 0.005]))

        bf = compute_bayes_factor(idata_alt, idata_null)
        print(f"\nApproximate Bayes Factor BF_10: {bf:.2f}")

        # 5) Optional plots
        az.plot_forest(idata_alt, var_names="mu")
        plt.title(f"{full_metric} – group means with 95% HDI")
        plt.show()

    tabs.append(box)
    titles.append(full_metric)

# Assemble and display the Tab widget
tab_widget = widgets.Tab(children=tabs)
for i, t in enumerate(titles):
    tab_widget.set_title(i, t)
display(tab_widget)