# Simulation Analysis

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import re
import warnings

from scipy import stats
from sklearn.metrics import r2_score, mean_squared_error
import pingouin as pg

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

warnings.filterwarnings("ignore", category=FutureWarning, module="pingouin")
sns.set_theme(style="whitegrid")

## Configuration

In [None]:
# --- Paths ---
SINGLE_TASKS_DIR = Path("../data/model/single_tasks")
TYPING_SIM_DIR = SINGLE_TASKS_DIR / "typing_simulation"
NBACK_SIM_DIR = SINGLE_TASKS_DIR / "nback_simulation"

SIM_DATA_DIR = Path("../data/model/simulated_participants_preprocessed")
HUMAN_DATA_DIR = Path("../data/experiment/data_preprocessed")

# --- Color palette ---
COLORS = {
    "primary_dark": "#000000",
    "teal": "#4A8A94",
    "pink": "#E5A3B8",
    "white": "#FFFFFF",
    "teal_dark": "#2E5A62",
    "pink_dark": "#D47A96",
}

CONDITION_COLORS = {
    "sequential": COLORS["teal"],
    "interrupted": COLORS["pink"],
}

plt.rcParams.update({
    "text.color": COLORS["primary_dark"],
    "axes.labelcolor": COLORS["primary_dark"],
    "xtick.color": COLORS["primary_dark"],
    "ytick.color": COLORS["primary_dark"],
    "axes.edgecolor": COLORS["primary_dark"],
    "figure.facecolor": COLORS["white"],
    "axes.facecolor": COLORS["white"],
})

# --- Analysis constants ---
CONDITION_ORDER = ["Seq-Easy", "Seq-Hard", "Int-Easy", "Int-Hard"]
ACC_THRESH = 0.85
OUTLIER_THRESH = 3.5
WITHIN = ["COND_interruption_condition", "COND_nback_level"]
SUBJECT = "participant_id"
GROUP_COLS = ["participant_id", "COND_interruption_condition", "COND_nback_level"]
ALPHA = 0.05

## Helper Functions

In [None]:
def coefficient_of_variation(x):
    """CV = SD / Mean. Returns NaN if mean is zero."""
    mean_x = np.mean(x)
    return np.std(x, ddof=1) / mean_x if mean_x != 0 else np.nan


def create_condition_label(row):
    """Map interruption x nback factors to a short label like 'Seq-Easy'."""
    interruption = "Int" if "interrupted" in str(row["COND_interruption_condition"]) else "Seq"
    nback = "Easy" if row["COND_nback_level"] == 1 else "Hard"
    return f"{interruption}-{nback}"


def robust_z(series: pd.Series) -> pd.Series:
    """MAD-based robust z-score for outlier detection."""
    med = series.median()
    mad = stats.median_abs_deviation(series, scale="normal")
    if mad == 0:
        return pd.Series(np.zeros_like(series), index=series.index)
    return 0.6745 * (series - med) / mad


def flag_outliers_trial(df, dv, thresh=OUTLIER_THRESH):
    """Flag outlier trials within each participant x condition cell using robust z-score."""
    mask = pd.Series(False, index=df.index)
    for _, grp in df.groupby(GROUP_COLS):
        z = robust_z(grp[dv])
        mask.loc[grp.index] = np.abs(z) > thresh
    return mask


def calculate_fit_metrics(human_data, model_data):
    """R-squared and RMSD between human and model condition means."""
    r_squared = r2_score(human_data, model_data)
    rmsd = np.sqrt(mean_squared_error(human_data, model_data))
    return r_squared, rmsd


def format_mean_sd(mean_val, std_val, decimals=3):
    """Format as 'M (SD)' string."""
    return f"{mean_val:.{decimals}f} ({std_val:.{decimals}f})"


def report_rm_anova(df_long, dv):
    """Run and print a 2-factor repeated-measures ANOVA with post-hoc tests."""
    aov = pg.rm_anova(
        data=df_long, dv=dv, within=WITHIN, subject=SUBJECT, detailed=True
    )
    nice = (
        aov.rename(columns={"Source": "Effect", "ddof1": "df1", "ddof2": "df2", "p_unc": "p"})
        .loc[:, ["Effect", "SS", "df1", "df2", "MS", "F", "p", "ng2"]]
        .round({"SS": 4, "MS": 4, "F": 3, "p": 4, "ng2": 3})
    )
    nice["sig"] = np.where(nice["p"] < ALPHA, "*", "")
    print(f"\n=== Repeated-measures ANOVA on {dv} ===")
    print(nice.to_string(index=False))

    ph = pg.pairwise_tests(
        data=df_long, dv=dv, within=WITHIN, subject=SUBJECT,
        padjust="holm", parametric=True, effsize="hedges",
    )
    sig_ph = ph[ph["p_corr"] < ALPHA]
    print(f"\nPost-hoc paired t-tests (Holm-corrected, \u03b1 = .05)")
    if sig_ph.empty:
        print("  \u2013 none survive correction.")
    else:
        keep = (
            sig_ph[["A", "B", "T", "dof", "p_corr", "hedges"]]
            .rename(columns={"A": "Cell A", "B": "Cell B", "p_corr": "p_Holm", "hedges": "g"})
            .round({"T": 3, "p_Holm": 4, "g": 3})
        )
        print(keep.to_string(index=False))
    return aov

## Single-Task Simulations

### Data Loading

In [None]:
def load_typing_data(directory: Path) -> pd.DataFrame:
    """Load all typing simulation CSVs and standardize column names."""
    all_files = sorted(directory.glob("participant_sim_*_output.csv"))
    print(f"Found {len(all_files)} typing simulation files.")

    rename_map = {
        "actual_duration_s": "OUT_actual_trial_duration_sec",
        "subjective_duration_s": "OUT_time_estimate_seconds",
        "so_ratio": "OUT_time_estimation_ratio",
        "time_per_letter": "OUT_time_per_letter",
        "word_length": "OUT_target_word_length",
    }

    frames = []
    for filepath in all_files:
        try:
            match = re.search(r"participant_(sim_\d+)_output\.csv", filepath.name)
            if not match:
                continue
            df = pd.read_csv(filepath)
            df["participant_id"] = match.group(1)
            df = df.rename(columns=rename_map)
            # Raw so_ratio is stored as a percentage (e.g. 99.1 = 0.991)
            df["OUT_time_estimation_ratio"] = df["OUT_time_estimation_ratio"] / 100.0
            # Model produces no typing errors in single-task mode
            df["OUT_typing_distance"] = 0.0
            frames.append(df)
        except Exception as e:
            print(f"  Could not process {filepath.name}: {e}")

    combined = pd.concat(frames, ignore_index=True)
    print(f"Loaded {len(combined['participant_id'].unique())} participants, {len(combined)} trials.")
    return combined


def load_nback_data(directory: Path) -> pd.DataFrame:
    """Load all N-back simulation CSVs and standardize column names."""
    all_files = sorted(directory.glob("participant_sim_*_*back_output.csv"))
    print(f"Found {len(all_files)} N-back simulation files.")

    filename_pattern = re.compile(r"participant_(sim_\d+)_(1|2)back_output\.csv")

    rename_map = {
        "actual_duration_s": "OUT_actual_trial_duration_sec",
        "subjective_duration_s": "OUT_time_estimate_seconds",
        "so_ratio": "OUT_time_estimation_ratio",
        "accuracy": "OUT_nback_accuracy",
        "hits": "OUT_nback_hits",
        "misses": "OUT_nback_misses",
        "false_alarms": "OUT_nback_false_alarms",
        "correct_rejections": "OUT_nback_correct_rejections",
    }

    frames = []
    for filepath in all_files:
        try:
            match = filename_pattern.search(filepath.name)
            if not match:
                continue
            pid, nback_level_str = match.groups()

            df = pd.read_csv(filepath)
            df["participant_id"] = pid
            # BUG FIX: The raw CSV has nback_level=2 in ALL files (including
            # 1-back). The filename is the ground truth, so overwrite here.
            df["nback_level"] = int(nback_level_str)
            df = df.rename(columns=rename_map)
            # Raw so_ratio is stored as a percentage
            df["OUT_time_estimation_ratio"] = df["OUT_time_estimation_ratio"] / 100.0
            frames.append(df)
        except Exception as e:
            print(f"  Could not process {filepath.name}: {e}")

    combined = pd.concat(frames, ignore_index=True)
    print(f"Loaded {len(combined['participant_id'].unique())} participants, {len(combined)} trials.")
    return combined

In [None]:
typing_sim_df = load_typing_data(TYPING_SIM_DIR)
nback_sim_df = load_nback_data(NBACK_SIM_DIR)

print("\n--- Typing simulation ---")
print(f"Shape: {typing_sim_df.shape}")
display(typing_sim_df.head(3))

print("\n--- N-back simulation ---")
print(f"Shape: {nback_sim_df.shape}")
display(nback_sim_df.head(3))

# Verify repeated-measures structure: each participant should have
# equal counts at 1-back and 2-back
print("\nTrials per participant per N-back level:")
print(
    nback_sim_df.groupby("participant_id")["nback_level"]
    .value_counts()
    .unstack()
    .fillna(0)
    .astype(int)
)

### Summary Statistics

In [None]:
def create_nback_summary_table(nback_df: pd.DataFrame) -> pd.DataFrame:
    """Summary table comparing 1-back vs 2-back on performance and timing."""
    metric_cols = ["OUT_nback_accuracy", "OUT_time_estimate_seconds", "OUT_time_estimation_ratio"]

    participant_means = (
        nback_df.groupby(["participant_id", "nback_level"])[metric_cols].mean()
    )

    participant_cvs = (
        nback_df.groupby(["participant_id", "nback_level"])["OUT_time_estimate_seconds"]
        .apply(coefficient_of_variation)
        .rename("CV")
    )

    participant_summary = pd.merge(participant_means, participant_cvs, on=["participant_id", "nback_level"])
    summary_stats = participant_summary.groupby("nback_level").agg(["mean", "std"])

    labels = {
        "N-back Accuracy": "OUT_nback_accuracy",
        "Time Estimate (s)": "OUT_time_estimate_seconds",
        "S/O Ratio": "OUT_time_estimation_ratio",
        "CV of Time Estimates": "CV",
    }

    table = pd.DataFrame(index=[1, 2])
    table.index.name = "N-back Level"
    for label, col in labels.items():
        means = summary_stats[(col, "mean")]
        stds = summary_stats[(col, "std")]
        table[f"{label} [M (SD)]"] = [format_mean_sd(m, s) for m, s in zip(means, stds)]

    return table


def create_typing_summary_table(typing_df: pd.DataFrame) -> pd.DataFrame:
    """Summary table for typing task performance and timing."""
    metric_cols = ["OUT_time_per_letter", "OUT_time_estimate_seconds", "OUT_time_estimation_ratio"]

    participant_means = typing_df.groupby("participant_id")[metric_cols].mean()

    participant_cvs = (
        typing_df.groupby("participant_id")["OUT_time_estimate_seconds"]
        .apply(coefficient_of_variation)
        .rename("CV")
    )

    participant_summary = pd.merge(participant_means, participant_cvs, on="participant_id")
    summary_stats = participant_summary.agg(["mean", "std"])

    labels = {
        "Time-per-Letter (s)": "OUT_time_per_letter",
        "Time Estimate (s)": "OUT_time_estimate_seconds",
        "S/O Ratio": "OUT_time_estimation_ratio",
        "CV of Time Estimates": "CV",
    }

    rows = []
    for label, col in labels.items():
        rows.append([label, format_mean_sd(summary_stats.loc["mean", col], summary_stats.loc["std", col])])

    return pd.DataFrame(rows, columns=["Metric", "Value [M (SD)]"]).set_index("Metric")

In [None]:
print("--- N-back Simulation Summary ---")
display(create_nback_summary_table(nback_sim_df))

print("\n--- Typing Simulation Summary ---")
display(create_typing_summary_table(typing_sim_df))

### N-back Difficulty Effects

In [None]:
def plot_nback_normalized_impact(nback_df: pd.DataFrame):
    """Z-score normalized line plot comparing accuracy and S/O ratio across difficulty."""
    metrics = ["OUT_nback_accuracy", "OUT_time_estimation_ratio"]

    participant_summary = (
        nback_df.groupby(["participant_id", "nback_level"])[metrics].mean().reset_index()
    )

    for col, label in [("OUT_nback_accuracy", "Accuracy (Z-score)"),
                       ("OUT_time_estimation_ratio", "S/O Ratio (Z-score)")]:
        participant_summary[label] = (
            (participant_summary[col] - participant_summary[col].mean())
            / participant_summary[col].std()
        )

    plot_data = pd.melt(
        participant_summary,
        id_vars=["participant_id", "nback_level"],
        value_vars=["Accuracy (Z-score)", "S/O Ratio (Z-score)"],
        var_name="Metric",
        value_name="Normalized Value (Z-score)",
    )

    fig, ax = plt.subplots(figsize=(10, 7))
    sns.lineplot(
        data=plot_data,
        x="nback_level",
        y="Normalized Value (Z-score)",
        hue="Metric",
        style="Metric",
        markers=True,
        dashes=False,
        palette=["#1f77b4", "#ff7f0e"],
        linewidth=2.5,
        markersize=8,
        ax=ax,
    )

    ax.axhline(0, color="grey", linestyle="--", linewidth=1, alpha=0.8)
    ax.set_title("Similar Impact of N-back Difficulty on Performance and Time Perception", fontsize=18, pad=20)
    ax.set_xlabel("N-back Condition", fontsize=14, labelpad=15)
    ax.set_ylabel("Normalized Value (Z-score)", fontsize=14, labelpad=15)
    ax.set_xticks([1, 2])
    ax.set_xticklabels(["1-Back (Easy)", "2-Back (Hard)"])
    ax.tick_params(axis="both", which="major", labelsize=12)
    ax.legend(title="Metric", fontsize=12, title_fontsize=13)
    sns.despine(trim=True)
    plt.tight_layout()
    plt.show()


plot_nback_normalized_impact(nback_sim_df)

In [None]:
def plot_nback_clustered_bar(nback_df: pd.DataFrame, y_bottom=None, y_top=None):
    """Clustered bar chart comparing accuracy and S/O ratio in absolute values."""
    metrics = ["OUT_nback_accuracy", "OUT_time_estimation_ratio"]

    participant_summary = (
        nback_df.groupby(["participant_id", "nback_level"])[metrics].mean().reset_index()
    )

    condition_means = (
        participant_summary.groupby("nback_level")
        .agg(accuracy_mean=("OUT_nback_accuracy", "mean"),
             so_ratio_mean=("OUT_time_estimation_ratio", "mean"))
        .reset_index()
    )

    plot_data = pd.melt(
        condition_means,
        id_vars="nback_level",
        value_vars=["accuracy_mean", "so_ratio_mean"],
        var_name="Metric",
        value_name="Mean Value",
    )
    plot_data["Metric"] = plot_data["Metric"].replace({
        "accuracy_mean": "N-back Accuracy",
        "so_ratio_mean": "S/O Ratio",
    })

    fig, ax = plt.subplots(figsize=(10, 7))
    sns.barplot(
        data=plot_data,
        x="nback_level",
        y="Mean Value",
        hue="Metric",
        palette=["#1f77b4", "#ff7f0e"],
        edgecolor="black",
        ax=ax,
    )

    ax.set_title("Impact of N-back Difficulty on Performance and Time Perception", fontsize=18, pad=20)
    ax.set_xlabel("N-back Condition", fontsize=14, labelpad=15)
    ax.set_ylabel("Mean Value", fontsize=14, labelpad=15)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["1-Back (Easy)", "2-Back (Hard)"])
    ax.tick_params(axis="both", which="major", labelsize=12)
    ax.legend(title="Metric", fontsize=12, title_fontsize=13, loc="upper right")
    if y_bottom is not None and y_top is not None:
        ax.set_ylim(bottom=y_bottom, top=y_top)
    sns.despine(trim=True)
    plt.tight_layout()
    plt.show()


plot_nback_clustered_bar(nback_sim_df, y_bottom=0.4, y_top=1.0)

## Multitasking Simulations

### Data Loading

In [None]:
def load_simulation_data(directory, acc_thresh=ACC_THRESH):
    """Load simulation CSVs, return (all_data_with_labels, clean_data_for_analysis)."""
    all_files = sorted(directory.glob("participant_*_CLEAN.csv"))
    if not all_files:
        raise FileNotFoundError(f"No preprocessed files in {directory}")

    frames = []
    for filepath in all_files:
        pid_match = re.search(
            r"participant_(sim_\d+)_output_CLEAN\.csv", filepath.name
        )
        if not pid_match:
            continue
        df = pd.read_csv(filepath)
        df["participant_id"] = pid_match.group(1)
        frames.append(df)

    all_data = pd.concat(frames, ignore_index=True)

    numeric_cols = [
        "OUT_time_estimation_ratio", "OUT_normalized_absolute_error",
        "OUT_time_estimate_seconds", "OUT_nback_accuracy",
        "OUT_time_per_letter", "OUT_actual_trial_duration_sec",
    ]
    for col in numeric_cols:
        if col in all_data.columns:
            all_data[col] = pd.to_numeric(all_data[col], errors="coerce")

    all_data["condition_label"] = all_data.apply(create_condition_label, axis=1)
    all_data["condition_label"] = pd.Categorical(
        all_data["condition_label"], categories=CONDITION_ORDER, ordered=True,
    )

    print(f"Loaded {all_data['participant_id'].nunique()} simulated participants, "
          f"{len(all_data)} total trials.")

    df_main = all_data.query("OUT_experiment_phase == 'main'").copy()
    df_main = df_main[df_main["OUT_nback_accuracy"] >= acc_thresh].copy()

    dvs_to_screen = [
        "OUT_time_per_letter",
        "OUT_time_estimation_ratio",
        "OUT_normalized_absolute_error",
    ]
    for dv in dvs_to_screen:
        if dv in df_main.columns:
            df_main[f"outlier_{dv}"] = flag_outliers_trial(df_main, dv)
    mask_any = df_main.filter(regex=r"^outlier_").any(axis=1)
    clean_data = df_main[~mask_any].copy()

    n_removed = mask_any.sum()
    print(f"Clean subset: {len(clean_data)} trials "
          f"({n_removed} outliers removed from {len(df_main)} main-phase trials).")
    return all_data, clean_data


def load_human_data(directory, acc_thresh=ACC_THRESH):
    """Load and clean human experimental data for model comparison."""
    frames = []
    for filepath in sorted(directory.glob("participant_*_output_CLEAN.csv")):
        pid_match = re.search(
            r"participant_(\d+)_output_CLEAN\.csv", filepath.name
        )
        if not pid_match:
            continue
        df = pd.read_csv(filepath)
        df["participant_id"] = pid_match.group(1)
        frames.append(df)

    df_all = pd.concat(frames, ignore_index=True)

    numeric_cols = [
        "OUT_time_estimation_ratio", "OUT_normalized_absolute_error",
        "OUT_time_estimate_seconds", "OUT_nback_accuracy",
        "OUT_time_per_letter", "OUT_actual_trial_duration_sec",
    ]
    for col in numeric_cols:
        if col in df_all.columns:
            df_all[col] = pd.to_numeric(df_all[col], errors="coerce")

    df_main = df_all.query("OUT_experiment_phase == 'main'").copy()
    df_main = df_main[df_main["OUT_nback_accuracy"] >= acc_thresh].copy()

    dvs_to_screen = [
        "OUT_time_estimation_ratio",
        "OUT_normalized_absolute_error",
    ]
    for dv in dvs_to_screen:
        if dv in df_main.columns:
            df_main[f"outlier_{dv}"] = flag_outliers_trial(df_main, dv)
    mask_any = df_main.filter(regex=r"^outlier_").any(axis=1)
    df_clean = df_main[~mask_any].copy()

    print(f"Loaded {df_all['participant_id'].nunique()} human participants, "
          f"{len(df_clean)} clean trials.")
    return df_clean

In [None]:
all_sim_data, df_clean_sim = load_simulation_data(SIM_DATA_DIR)
df_clean_human = load_human_data(HUMAN_DATA_DIR)

In [None]:
print(f"Simulation: {all_sim_data['participant_id'].nunique()} participants, "
      f"{len(all_sim_data)} total trials, {len(df_clean_sim)} clean trials")
print(f"Human: {df_clean_human['participant_id'].nunique()} participants, "
      f"{len(df_clean_human)} clean trials")

sim_ids = sorted(
    all_sim_data["participant_id"].unique(),
    key=lambda x: int(x.split("_")[1]),
)
print(f"\nSimulation participants: {sim_ids}")
display(all_sim_data.head(3))

### Descriptive Visualizations

In [None]:
def plot_beeswarm_timing(data_df):
    """Beeswarm plot of S/O ratio across the four experimental conditions."""
    fig, ax = plt.subplots(figsize=(12, 6))

    sns.swarmplot(
        x="condition_label",
        y="OUT_time_estimation_ratio",
        hue="condition_label",
        data=data_df,
        order=CONDITION_ORDER,
        hue_order=CONDITION_ORDER,
        palette=[COLORS["teal"], COLORS["teal_dark"],
                 COLORS["pink"], COLORS["pink_dark"]],
        size=3.8,
        alpha=0.7,
        legend=False,
        ax=ax,
    )

    ax.axhline(1.0, color="orange", linestyle="--", linewidth=1.5, alpha=0.7)
    ax.axvline(x=1.5, color=COLORS["primary_dark"], linestyle=":",
               alpha=0.2, linewidth=1)

    ax.set_title("Simulated Participants: Timing Accuracy Across Conditions",
                 fontsize=14, fontweight="bold", pad=15)
    ax.set_xlabel("")
    ax.set_ylabel("Subjective / Objective Ratio", fontsize=12)
    ax.set_xticks(range(len(CONDITION_ORDER)))
    ax.set_xticklabels(["Sequential\nEasy", "Sequential\nHard",
                        "Interrupted\nEasy", "Interrupted\nHard"])
    ax.set_ylim(0.6, 1.1)
    sns.despine()
    plt.tight_layout()
    plt.show()


plot_beeswarm_timing(all_sim_data)

In [None]:
def plot_performance_bars(data_df):
    """Bar charts of mean trial duration and time-per-letter by condition."""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    bar_palette = ["#f3a712", "#e4572e", "#c0326e", "#e86af0"]

    for ax, col, ylabel, title, fmt in [
        (axes[0], "OUT_actual_trial_duration_sec",
         "Mean Duration (seconds)",
         "Mean Trial Duration by Condition", ".2f"),
        (axes[1], "OUT_time_per_letter",
         "Mean Time Per Letter (seconds)",
         "Mean Time Per Letter by Condition", ".3f"),
    ]:
        summary = (
            data_df.groupby("condition_label", observed=True)[col]
            .mean()
            .reindex(CONDITION_ORDER)
            .reset_index()
        )
        bp = sns.barplot(
            x="condition_label", y=col, hue="condition_label", data=summary,
            palette=bar_palette, edgecolor="black", linewidth=1.5, legend=False, ax=ax,
        )
        for p in bp.patches:
            ax.annotate(
                format(p.get_height(), fmt),
                (p.get_x() + p.get_width() / 2.0, p.get_height()),
                ha="center", va="center", xytext=(0, 9),
                textcoords="offset points", fontsize=11, fontweight="bold",
            )
        ax.set_title(title, fontsize=14, pad=15)
        ax.set_xlabel("Condition", fontsize=12)
        ax.set_ylabel(ylabel, fontsize=12)
        min_val = summary[col].min()
        ax.set_ylim(bottom=min_val * 0.95)
        sns.despine(left=True)
        ax.grid(axis="y", linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.show()


plot_performance_bars(all_sim_data)

In [None]:
def plot_nback_accuracy_by_level(data_df):
    """Bar chart of mean N-back accuracy: 1-back vs 2-back."""
    data_df = data_df.copy()
    data_df["nback_accuracy_percent"] = data_df["OUT_nback_accuracy"] * 100

    summary = (
        data_df.groupby("COND_nback_level")["nback_accuracy_percent"]
        .mean()
        .reset_index()
    )
    summary["label"] = summary["COND_nback_level"].apply(lambda x: f"{x}-Back")

    fig, ax = plt.subplots(figsize=(8, 6))
    bp = sns.barplot(
        x="label", y="nback_accuracy_percent", hue="label", data=summary,
        palette=["#f3a712", "#e4572e"], edgecolor="black", linewidth=1.5, legend=False, ax=ax,
    )
    for p in bp.patches:
        ax.annotate(
            f"{p.get_height():.2f}%",
            (p.get_x() + p.get_width() / 2.0, p.get_height()),
            ha="center", va="center", xytext=(0, 9),
            textcoords="offset points", fontsize=12, fontweight="bold",
        )

    ax.set_title("Mean N-Back Accuracy by Task Difficulty", fontsize=14, pad=15)
    ax.set_xlabel("N-Back Level", fontsize=12)
    ax.set_ylabel("Mean Accuracy (%)", fontsize=12)
    min_val = summary["nback_accuracy_percent"].min()
    max_val = summary["nback_accuracy_percent"].max()
    ax.set_ylim(bottom=min_val * 0.3, top=max_val * 1.05)
    sns.despine(left=True)
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.show()


plot_nback_accuracy_by_level(all_sim_data)

In [None]:
def plot_time_perception_by_nback(data_df):
    """Absolute error and S/O ratio by N-back level (1x2 subplot)."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    palette = ["#f3a712", "#e4572e"]

    for ax, col, ylabel, title, fmt in [
        (axes[0], "OUT_normalized_absolute_error",
         "Mean Normalized Absolute Error",
         "Absolute Error by Task Difficulty", ".3f"),
        (axes[1], "OUT_time_estimation_ratio",
         "Mean S/O Ratio",
         "S/O Ratio by Task Difficulty", ".3f"),
    ]:
        by_condition = (
            data_df.groupby("condition_label", observed=True)[col]
            .mean().reset_index()
        )
        easy_mean = by_condition[
            by_condition["condition_label"].isin(["Seq-Easy", "Int-Easy"])
        ][col].mean()
        hard_mean = by_condition[
            by_condition["condition_label"].isin(["Seq-Hard", "Int-Hard"])
        ][col].mean()

        summary = pd.DataFrame({
            "N-Back Level": ["1-Back (Easy)", "2-Back (Hard)"],
            "value": [easy_mean, hard_mean],
        })

        bp = sns.barplot(
            x="N-Back Level", y="value", hue="N-Back Level", data=summary,
            palette=palette, edgecolor="black", linewidth=1.5, legend=False, ax=ax,
        )
        for p in bp.patches:
            ax.annotate(
                format(p.get_height(), fmt),
                (p.get_x() + p.get_width() / 2.0, p.get_height()),
                ha="center", va="center", xytext=(0, 9),
                textcoords="offset points", fontsize=12, fontweight="bold",
            )
        ax.set_title(title, fontsize=14, pad=15)
        ax.set_xlabel("N-Back Level", fontsize=12)
        ax.set_ylabel(ylabel, fontsize=12)
        if col == "OUT_normalized_absolute_error":
            ax.set_ylim(bottom=0, top=summary["value"].max() * 1.15)
        else:
            ax.set_ylim(bottom=summary["value"].min() * 0.95, top=1.05)
            ax.axhline(1.0, color="black", linestyle="--", linewidth=1, alpha=0.7)
        sns.despine(left=True)
        ax.grid(axis="y", linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.show()


plot_time_perception_by_nback(all_sim_data)

In [None]:
def plot_metrics_by_interruption(data_df):
    """S/O ratio, absolute error, CV, and trial duration by interruption condition (2x2)."""
    int_palette = ["#c0326e", "#e86af0"]
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # S/O Ratio
    ax = axes[0, 0]
    s = data_df.groupby("COND_interruption_condition")["OUT_time_estimation_ratio"].mean().reset_index()
    bp = sns.barplot(x="COND_interruption_condition", y="OUT_time_estimation_ratio",
                     hue="COND_interruption_condition", data=s,
                     palette=int_palette, edgecolor="black", linewidth=1.5, legend=False, ax=ax)
    for p in bp.patches:
        ax.annotate(f"{p.get_height():.3f}", (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha="center", va="center", xytext=(0, 9), textcoords="offset points",
                    fontsize=11, fontweight="bold")
    ax.axhline(1.0, color="black", linestyle="--", linewidth=1, alpha=0.7)
    ax.set_title("Mean S/O Ratio by Interruption", fontsize=13, pad=10)
    ax.set_xlabel(""); ax.set_ylabel("S/O Ratio", fontsize=11)
    ax.set_ylim(bottom=s["OUT_time_estimation_ratio"].min() * 0.95, top=1.05)

    # Absolute Error
    ax = axes[0, 1]
    s = data_df.groupby("COND_interruption_condition")["OUT_normalized_absolute_error"].mean().reset_index()
    bp = sns.barplot(x="COND_interruption_condition", y="OUT_normalized_absolute_error",
                     hue="COND_interruption_condition", data=s,
                     palette=int_palette, edgecolor="black", linewidth=1.5, legend=False, ax=ax)
    for p in bp.patches:
        ax.annotate(f"{p.get_height():.3f}", (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha="center", va="center", xytext=(0, 9), textcoords="offset points",
                    fontsize=11, fontweight="bold")
    ax.set_title("Mean Absolute Error by Interruption", fontsize=13, pad=10)
    ax.set_xlabel(""); ax.set_ylabel("Absolute Error", fontsize=11)
    ax.set_ylim(bottom=0, top=s["OUT_normalized_absolute_error"].max() * 1.15)

    # CV
    ax = axes[1, 0]
    cv_per_part = (
        data_df.groupby(["participant_id", "condition_label"], observed=True)["OUT_time_estimate_seconds"]
        .agg(sd="std", mean="mean").reset_index()
    )
    cv_per_part["cv"] = (cv_per_part["sd"] / cv_per_part["mean"]).fillna(0)
    cv_per_part["interruption"] = cv_per_part["condition_label"].apply(
        lambda x: "interrupted" if "Int" in x else "sequential"
    )
    s = cv_per_part.groupby("interruption")["cv"].mean().reset_index()
    bp = sns.barplot(x="interruption", y="cv", hue="interruption", data=s,
                     palette=int_palette, edgecolor="black", linewidth=1.5, legend=False, ax=ax)
    for p in bp.patches:
        ax.annotate(f"{p.get_height():.3f}", (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha="center", va="center", xytext=(0, 9), textcoords="offset points",
                    fontsize=11, fontweight="bold")
    ax.set_title("Mean CV by Interruption", fontsize=13, pad=10)
    ax.set_xlabel(""); ax.set_ylabel("Coefficient of Variation", fontsize=11)
    ax.set_ylim(bottom=0.16, top=s["cv"].max() * 1.05)

    # Trial Duration
    ax = axes[1, 1]
    s = data_df.groupby("COND_interruption_condition")["OUT_actual_trial_duration_sec"].mean().reset_index()
    bp = sns.barplot(x="COND_interruption_condition", y="OUT_actual_trial_duration_sec",
                     hue="COND_interruption_condition", data=s,
                     palette=int_palette, edgecolor="black", linewidth=1.5, legend=False, ax=ax)
    for p in bp.patches:
        ax.annotate(f"{p.get_height():.2f}", (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha="center", va="center", xytext=(0, 9), textcoords="offset points",
                    fontsize=11, fontweight="bold")
    ax.set_title("Mean Trial Duration by Interruption", fontsize=13, pad=10)
    ax.set_xlabel(""); ax.set_ylabel("Duration (seconds)", fontsize=11)
    min_val = s["OUT_actual_trial_duration_sec"].min()
    max_val = s["OUT_actual_trial_duration_sec"].max()
    ax.set_ylim(bottom=min_val * 0.95, top=max_val * 1.05)

    for ax in axes.flat:
        sns.despine(left=True, ax=ax)
        ax.grid(axis="y", linestyle="--", alpha=0.7)

    plt.suptitle("Main Effects of Interruption Condition",
                 fontsize=15, fontweight="bold", y=1.01)
    plt.tight_layout()
    plt.show()


plot_metrics_by_interruption(all_sim_data)

### Model Fitting

In [None]:
TP_METRIC_COLS = ["mean_SO_ratio", "mean_abs_error", "cv_estimate_seconds"]


def aggregate_time_perception(df_clean):
    """Per-participant means, then grand means per condition."""
    part_means = df_clean.groupby(GROUP_COLS, as_index=False).agg(
        mean_SO_ratio=("OUT_time_estimation_ratio", "mean"),
        mean_abs_error=("OUT_normalized_absolute_error", "mean"),
        cv_estimate_seconds=("OUT_time_estimate_seconds", coefficient_of_variation),
    )
    summary = (
        part_means
        .groupby(["COND_interruption_condition", "COND_nback_level"])
        [TP_METRIC_COLS].mean()
        .reset_index()
    )
    return part_means, summary


human_tp_part, human_tp_summary = aggregate_time_perception(df_clean_human)
model_tp_part, model_tp_summary = aggregate_time_perception(df_clean_sim)

print("--- Human Time Perception Summary ---")
display(human_tp_summary)
print("\n--- Model Time Perception Summary ---")
display(model_tp_summary)

In [None]:
def aggregate_by_conditions(df, col):
    """2-step aggregation: per-participant means -> grand means per condition."""
    part = df.groupby(GROUP_COLS)[col].mean().reset_index()
    return (
        part.groupby(["COND_interruption_condition", "COND_nback_level"])
        [col].mean().reset_index()
    )

def aggregate_by_nback(df, col):
    """2-step aggregation by nback level only."""
    part = (
        df.groupby(["participant_id", "COND_nback_level"])[col]
        .mean().reset_index()
    )
    return part.groupby("COND_nback_level")[col].mean().reset_index()


# N-back accuracy
human_nback_summary = aggregate_by_nback(df_clean_human, "OUT_nback_accuracy")
model_nback_summary = aggregate_by_nback(df_clean_sim, "OUT_nback_accuracy")

# Trial duration
human_duration_summary = aggregate_by_conditions(df_clean_human, "OUT_actual_trial_duration_sec")
model_duration_summary = aggregate_by_conditions(df_clean_sim, "OUT_actual_trial_duration_sec")

# Time per letter
human_tpl_summary = aggregate_by_conditions(df_clean_human, "OUT_time_per_letter")
model_tpl_summary = aggregate_by_conditions(df_clean_sim, "OUT_time_per_letter")

print("--- N-Back Accuracy ---")
print(pd.merge(human_nback_summary, model_nback_summary,
               on="COND_nback_level", suffixes=("_human", "_model")))

print("\n--- Trial Duration ---")
print(pd.merge(human_duration_summary, model_duration_summary,
               on=["COND_interruption_condition", "COND_nback_level"],
               suffixes=("_human", "_model")))

print("\n--- Time Per Letter ---")
print(pd.merge(human_tpl_summary, model_tpl_summary,
               on=["COND_interruption_condition", "COND_nback_level"],
               suffixes=("_human", "_model")))

In [None]:
# --- Compute R\u00b2 and RMSD for all metrics ---
fit_results = []

# Time perception metrics
tp_fit_df = pd.merge(
    human_tp_summary, model_tp_summary,
    on=["COND_interruption_condition", "COND_nback_level"],
    suffixes=("_human", "_model"),
)

for metric, label in [
    ("mean_SO_ratio", "SO_ratio"),
    ("mean_abs_error", "absolute_error"),
    ("cv_estimate_seconds", "CV"),
]:
    r2, rmsd = calculate_fit_metrics(
        tp_fit_df[f"{metric}_human"], tp_fit_df[f"{metric}_model"]
    )
    fit_results.append({"Metric": label, "R_squared": r2, "RMSD": rmsd})

# N-back accuracy
nback_fit = pd.merge(
    human_nback_summary, model_nback_summary,
    on="COND_nback_level", suffixes=("_human", "_model"),
)
r2, rmsd = calculate_fit_metrics(
    nback_fit["OUT_nback_accuracy_human"], nback_fit["OUT_nback_accuracy_model"]
)
fit_results.append({"Metric": "nback_accuracy", "R_squared": r2, "RMSD": rmsd})

# Trial duration
dur_fit = pd.merge(
    human_duration_summary, model_duration_summary,
    on=["COND_interruption_condition", "COND_nback_level"],
    suffixes=("_human", "_model"),
)
r2, rmsd = calculate_fit_metrics(
    dur_fit["OUT_actual_trial_duration_sec_human"],
    dur_fit["OUT_actual_trial_duration_sec_model"],
)
fit_results.append({"Metric": "trial_duration", "R_squared": r2, "RMSD": rmsd})

# Time per letter
tpl_fit = pd.merge(
    human_tpl_summary, model_tpl_summary,
    on=["COND_interruption_condition", "COND_nback_level"],
    suffixes=("_human", "_model"),
)
r2, rmsd = calculate_fit_metrics(
    tpl_fit["OUT_time_per_letter_human"],
    tpl_fit["OUT_time_per_letter_model"],
)
fit_results.append({"Metric": "time_per_letter", "R_squared": r2, "RMSD": rmsd})

# Display
fit_results_df = pd.DataFrame(fit_results).set_index("Metric")
fit_results_df["R_squared"] = fit_results_df["R_squared"].round(2)
fit_results_df["RMSD"] = fit_results_df["RMSD"].round(3)
print("Model Fit Results (R\u00b2 and RMSD)")
display(fit_results_df)

In [None]:
def plot_model_fit_grid():
    """2\u00d73 grid comparing human vs model across all six metrics."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.flatten()

    plot_configs = [
        {"data": (human_tp_summary, model_tp_summary), "metric": "mean_abs_error",
         "title": "Absolute Error", "ylabel": "Absolute Error", "ylim": (0.03, 0.20)},
        {"data": (human_tp_summary, model_tp_summary), "metric": "mean_SO_ratio",
         "title": "S/O Ratio", "ylabel": "Subjective/Objective Ratio", "ylim": (0.7, 1.05)},
        {"data": (human_tp_summary, model_tp_summary), "metric": "cv_estimate_seconds",
         "title": "Coefficient of Variation", "ylabel": "CV", "ylim": (0.08, 0.20)},
        {"data": (human_nback_summary, model_nback_summary), "metric": "OUT_nback_accuracy",
         "title": "N-back Accuracy", "ylabel": "Accuracy", "ylim": (0.98, 1.0),
         "is_nback": True},
        {"data": (human_duration_summary, model_duration_summary),
         "metric": "OUT_actual_trial_duration_sec",
         "title": "Trial Duration", "ylabel": "Duration (s)", "ylim": (20, 25)},
        {"data": (human_tpl_summary, model_tpl_summary), "metric": "OUT_time_per_letter",
         "title": "Time per Letter", "ylabel": "Time per Letter (s)", "ylim": (0.8, 2)},
    ]

    for ax, config in zip(axes, plot_configs):
        human_data, model_data = config["data"]
        metric = config["metric"]

        if config.get("is_nback", False):
            x_pos = [1, 2]
            x_labels = ["1-back", "2-back"]
            human_vals = human_data.sort_values("COND_nback_level")[metric].values
            model_vals = model_data.sort_values("COND_nback_level")[metric].values
        else:
            x_pos = [1, 2, 3, 4]
            x_labels = ["Seq\nEasy", "Seq\nHard", "Int\nEasy", "Int\nHard"]
            h_sorted = human_data.sort_values(["COND_interruption_condition", "COND_nback_level"])
            m_sorted = model_data.sort_values(["COND_interruption_condition", "COND_nback_level"])
            human_vals = np.concatenate([
                h_sorted[h_sorted["COND_interruption_condition"] == "sequential"][metric].values,
                h_sorted[h_sorted["COND_interruption_condition"] == "interrupted"][metric].values,
            ])
            model_vals = np.concatenate([
                m_sorted[m_sorted["COND_interruption_condition"] == "sequential"][metric].values,
                m_sorted[m_sorted["COND_interruption_condition"] == "interrupted"][metric].values,
            ])

        ax.plot(x_pos, human_vals, "o-", color=COLORS["teal_dark"], linewidth=2,
                markersize=8, markeredgecolor="white", markeredgewidth=1)
        ax.plot(x_pos, model_vals, "s--", color=COLORS["pink_dark"], linewidth=2,
                markersize=7, markeredgecolor="white", markeredgewidth=1)

        if metric == "mean_SO_ratio":
            ax.axhline(y=1.0, color=COLORS["primary_dark"], linestyle=":", alpha=0.3)

        ax.set_title(config["title"], fontsize=11, fontweight="bold", pad=5)
        ax.set_ylabel(config["ylabel"], fontsize=9)
        ax.set_ylim(config["ylim"])
        ax.set_xticks(x_pos)
        ax.set_xticklabels(x_labels, fontsize=8)
        ax.tick_params(axis="y", labelsize=8)
        ax.yaxis.grid(True, alpha=0.15, linestyle="--")
        ax.set_axisbelow(True)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        if not config.get("is_nback", False):
            ax.axvline(x=2.5, color=COLORS["primary_dark"], linestyle=":", alpha=0.2)

    fig.suptitle("Model Fit: Comparison Across Metrics",
                 fontsize=15, fontweight="bold", y=1.02)

    legend_elements = [
        Line2D([0], [0], color=COLORS["teal_dark"], linestyle="-", marker="o",
               markersize=8, label="Human", markeredgecolor="white", markeredgewidth=1),
        Line2D([0], [0], color=COLORS["pink_dark"], linestyle="--", marker="s",
               markersize=7, label="Model", markeredgecolor="white", markeredgewidth=1),
    ]
    fig.legend(handles=legend_elements, loc="upper center",
               bbox_to_anchor=(0.5, 0.98), ncol=2, frameon=True,
               fontsize=10, columnspacing=3)

    plt.tight_layout()
    plt.subplots_adjust(top=0.88, bottom=0.05, left=0.08, right=0.98)
    plt.show()


plot_model_fit_grid()

### 2\u00d72 Repeated-Measures ANOVAs

In [None]:
agg_perf_sim = df_clean_sim.groupby(GROUP_COLS, as_index=False).agg(
    mean_time_per_letter=("OUT_time_per_letter", "mean"),
)

agg_time_sim = df_clean_sim.groupby(GROUP_COLS, as_index=False).agg(
    mean_SO_ratio=("OUT_time_estimation_ratio", "mean"),
    mean_abs_error=("OUT_normalized_absolute_error", "mean"),
    cv_estimate_seconds=("OUT_time_estimate_seconds", coefficient_of_variation),
)

print(f"Performance aggregation: {len(agg_perf_sim)} rows "
      f"({agg_perf_sim['participant_id'].nunique()} participants)")
print(f"Time perception aggregation: {len(agg_time_sim)} rows")

In [None]:
report_rm_anova(agg_perf_sim, "mean_time_per_letter")

In [None]:
for dv in ["mean_SO_ratio", "mean_abs_error", "cv_estimate_seconds"]:
    report_rm_anova(agg_time_sim, dv)

In [None]:
def plot_interaction_panel(agg_df, metrics, ylabels, titles, suptitle):
    """Grouped bar chart interaction plot panel (matches HIGH RES original)."""
    n = len(metrics)
    fig, axes = plt.subplots(1, n, figsize=(5 * n, 5))
    if n == 1:
        axes = [axes]

    for idx, (ax, metric, ylabel, title) in enumerate(zip(axes, metrics, ylabels, titles)):
        means = agg_df.groupby(['COND_nback_level', 'COND_interruption_condition'])[metric].mean().unstack()
        sems = agg_df.groupby(['COND_nback_level', 'COND_interruption_condition'])[metric].sem().unstack()

        x = np.arange(2)
        width = 0.35

        ax.bar(x - width/2, means['sequential'].values, width,
               yerr=sems['sequential'].values, capsize=5,
               color=COLORS['teal'], edgecolor=COLORS['teal_dark'], linewidth=1.5,
               label='Sequential', alpha=0.8)

        ax.bar(x + width/2, means['interrupted'].values, width,
               yerr=sems['interrupted'].values, capsize=5,
               color=COLORS['pink'], edgecolor=COLORS['pink_dark'], linewidth=1.5,
               label='Interrupted', alpha=0.8)

        ax.set_xlabel('Task Complexity', fontsize=11)
        ax.set_ylabel(ylabel, fontsize=11)
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(['1-back\n(Easy)', '2-back\n(Hard)'])

        if metric == 'mean_SO_ratio':
            ax.axhline(y=1.0, color=COLORS['primary_dark'], linestyle=':', alpha=0.3)
            ax.set_ylim(0.76, 0.97)

        ax.yaxis.grid(True, alpha=0.15, linestyle='--')
        ax.set_axisbelow(True)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if idx == 0:
            ax.legend(loc='upper left', frameon=True)

    fig.suptitle(suptitle, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()


plot_interaction_panel(
    agg_time_sim,
    metrics=["mean_abs_error", "mean_SO_ratio"],
    ylabels=["Absolute Error", "Subjective/Objective Ratio"],
    titles=["Absolute Error", "S/O Ratio"],
    suptitle="Model Predictions: Interaction Effect on Time Perception",
)

plot_interaction_panel(
    agg_perf_sim,
    metrics=["mean_time_per_letter"],
    ylabels=["Mean Time per Letter (s)"],
    titles=["Time per Letter"],
    suptitle="Model Predictions: Interaction Effect on Task Performance",
)