In [2]:
# =============================================================================
# Consolidated Analysis Script for the Winning LBA Model (M_SC)
#
# Description:
# This script performs all necessary post-processing analyses on the winning
# model ('M_SC') from the power set comparison. It combines the functionality
# of several previous scripts into a single, streamlined workflow.
#
# The script will:
# 1. Calculate and print key convergence diagnostics (R-hat and ESS).
# 2. Generate and save diagnostic trace plots for supplementary materials.
# 3. Extract and save the mean and median posterior parameter estimates.
# 4. Run a full posterior predictive simulation to validate the model's fit.
# 5. Save the posterior predictive check results to a final CSV file.
# =============================================================================

# --- 1. Import Necessary Libraries ---
import arviz as az
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# --- 2. Simulation Helper Function ---
def simulate_lba_race(n_trials, v_alc, v_soc, A, k, tau):
    """
    Simulates a set of LBA trials for a single set of parameters.
    """
    b = A + k
    s = 1.0
    start_alc = np.random.uniform(0, A, n_trials)
    start_soc = np.random.uniform(0, A, n_trials)
    drift_alc = np.maximum(1e-8, np.random.normal(v_alc, s, n_trials))
    drift_soc = np.maximum(1e-8, np.random.normal(v_soc, s, n_trials))
    time_alc = (b - start_alc) / drift_alc
    time_soc = (b - start_soc) / drift_soc
    simulated_choices = (time_alc > time_soc).astype(int)
    winner_rt = np.minimum(time_alc, time_soc)
    simulated_rts = winner_rt + tau
    return simulated_choices, simulated_rts

# --- 3. Main Execution Block ---
def main():
    """
    Main function to run all post-processing analyses on the winning model.
    """
    # --- Configuration ---
    WINNING_MODEL_NAME = "M_SC"
    DATA_FILE_PATH = r'C:\Users\drfox\LBA_Gemini\aIC_Choice.csv'
    RESULTS_DIR = 'LBA_Model_SC_EL_Only'
    OUTPUT_DIR = 'Winning_Model_Analysis_M_SC_EL_Only'
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    trace_file_path = os.path.join(RESULTS_DIR, f"trace_{WINNING_MODEL_NAME}.nc")

    # --- 1. Load Data and Winning Model Trace ---
    print(f"--- Loading Data and Trace for Winning Model: {WINNING_MODEL_NAME} ---")
    if not os.path.exists(trace_file_path):
        print(f"ERROR: Trace file not found at '{trace_file_path}'")
        return
    if not os.path.exists(DATA_FILE_PATH):
        print(f"ERROR: Data file for labels not found at '{DATA_FILE_PATH}'")
        return

    trace = az.from_netcdf(trace_file_path)
    all_data = pd.read_csv(DATA_FILE_PATH)
    
    # --- 2. Convergence Diagnostics ---
    print(f"\n--- Section 1: Calculating Convergence Diagnostics for {WINNING_MODEL_NAME} ---")
    summary_df = az.summary(trace, kind='all')
    rhat_values = summary_df['r_hat']
    ess_values = summary_df['ess_bulk']
    
    print("\nR-hat (should be close to 1.0):")
    print(f"  Mean: {rhat_values.mean():.4f}, Min: {rhat_values.min():.4f}, Max: {rhat_values.max():.4f}")
    print("\nEffective Sample Size (ESS - higher is better):")
    print(f"  Mean: {ess_values.mean():.0f}, Min: {ess_values.min():.0f}, Max: {ess_values.max():.0f}\n")
    
    # --- 3. Diagnostic Trace Plots ---
    print(f"--- Section 2: Generating Diagnostic Trace Plots for {WINNING_MODEL_NAME} ---")
    plots_output_dir = os.path.join(OUTPUT_DIR, "Trace_Plots")
    os.makedirs(plots_output_dir, exist_ok=True)
    
    # Plot Group-Level Parameters
    try:
        group_vars = [
            'v_alc_group_mu', 'v_soc_group_mu', 'A_group_mu_log',
            'v_alc_group_sigma', 'v_soc_group_sigma', 'A_group_sigma',
            'k_group_mu_log', 'k_group_sigma', 'tau_group_mu_log', 'tau_group_sigma'
        ]
        az.plot_trace(trace, var_names=group_vars)
        plt.suptitle(f"Trace Plots for Group-Level Parameters ({WINNING_MODEL_NAME})", y=1.02, fontsize=16)
        plt.tight_layout()
        group_plot_path = os.path.join(plots_output_dir, "traceplot_group_parameters.png")
        plt.savefig(group_plot_path, dpi=300, bbox_inches='tight')
        print(f"   Saved group-level trace plot to: {group_plot_path}")
        plt.close()
    except Exception as e:
        print(f"   ERROR generating group plots: {e}")
    
    # Plot Subject-Level Parameters for a Representative Subject
    subject_labels = pd.Categorical(all_data['subj_idx']).categories.tolist()
    subject_to_plot_idx = 0
    subject_label = subject_labels[subject_to_plot_idx]
    
    try:
        vars_with_session = ['v_alcohol', 'v_social', 'A']
        vars_without_session = ['k', 'tau']
        az.plot_trace(trace, var_names=vars_with_session, coords={'subject_idx': subject_to_plot_idx})
        plt.suptitle(f"Trace Plots for Subject '{subject_label}' (Session-Varying)", y=1.02, fontsize=14)
        plt.tight_layout()
        subject_plot_path_1 = os.path.join(plots_output_dir, f"traceplot_subject_{subject_label}_session_params.png")
        plt.savefig(subject_plot_path_1, dpi=300, bbox_inches='tight')
        print(f"   Saved subject-level session-varying trace plot to: {subject_plot_path_1}")
        plt.close()

        az.plot_trace(trace, var_names=vars_without_session, coords={'subject_idx': subject_to_plot_idx})
        plt.suptitle(f"Trace Plots for Subject '{subject_label}' (Fixed)", y=1.02, fontsize=14)
        plt.tight_layout()
        subject_plot_path_2 = os.path.join(plots_output_dir, f"traceplot_subject_{subject_label}_fixed_params.png")
        plt.savefig(subject_plot_path_2, dpi=300, bbox_inches='tight')
        print(f"   Saved subject-level fixed trace plot to: {subject_plot_path_2}")
        plt.close()
    except Exception as e:
        print(f"   ERROR generating subject plots: {e}")
    
    # --- 4. NEW: Extract Parameter Estimates ---
    print(f"\n--- Section 3: Extracting Parameter Estimates from {WINNING_MODEL_NAME} ---")
    #session_labels = pd.Categorical(all_data['session_type'], categories=['early', 'late', 'pun'], ordered=True).categories.tolist()
    # REMOVED PUN!!!
    session_labels = pd.Categorical(all_data['session_type'], categories=['early', 'late'], ordered=True).categories.tolist()
    n_subjects = len(subject_labels)
    param_results = []

    for subject_idx in tqdm(range(n_subjects), desc="Extracting Parameters"):
        for session_idx, session_label in enumerate(session_labels):
            subject_label = subject_labels[subject_idx]
            
            # Extract posterior samples for this subject/session
            v_alc_samples = trace.posterior['v_alcohol'].sel(subject_idx=subject_idx, session=session_label).values
            v_soc_samples = trace.posterior['v_social'].sel(subject_idx=subject_idx, session=session_label).values
            A_samples = trace.posterior['A'].sel(subject_idx=subject_idx, session=session_label).values
            k_samples = trace.posterior['k'].sel(subject_idx=subject_idx).values
            tau_samples = trace.posterior['tau'].sel(subject_idx=subject_idx).values
            
            # Calculate derived parameters
            b_samples = A_samples + k_samples
            caution_conventional_samples = b_samples - (A_samples / 2)

            param_results.append({
                'subject': subject_label,
                'session': session_label,
                'mean_v_alcohol': v_alc_samples.mean(),
                #'median_v_alcohol': np.median(v_alc_samples),
                'mean_v_social': v_soc_samples.mean(),
                #'median_v_social': np.median(v_soc_samples),
                'mean_A': A_samples.mean(),
                #'median_A': np.median(A_samples),
                'mean_k': k_samples.mean(),
                #'median_k': np.median(k_samples),
                'mean_tau': tau_samples.mean(),
                #'median_tau': np.median(tau_samples),
                'mean_b_threshold': b_samples.mean(),
                #'median_b_threshold': np.median(b_samples),
                'mean_caution_conventional': caution_conventional_samples.mean(),
                #'median_caution_conventional': np.median(caution_conventional_samples),
            })
            
    params_df = pd.DataFrame(param_results)
    params_csv_path = os.path.join(OUTPUT_DIR, "Winning_Model_Parameters.csv")
    params_df.to_csv(params_csv_path, index=False)
    print(f"   Parameter estimates saved to: {params_csv_path}")


    # --- 5. Posterior Predictive Simulation ---
    print(f"\n--- Section 4: Running Posterior Predictive Simulation for {WINNING_MODEL_NAME} ---")
    n_sessions = len(session_labels)
    pps_results = []

    for subject_idx in tqdm(range(n_subjects), desc="Simulating Subjects"):
        for session_idx in range(n_sessions):
            subject_label = subject_labels[subject_idx]
            session_label = session_labels[session_idx]
            
            v_alc_samples = trace.posterior['v_alcohol'].sel(subject_idx=subject_idx, session=session_label).values.flatten()
            v_soc_samples = trace.posterior['v_social'].sel(subject_idx=subject_idx, session=session_label).values.flatten()
            A_samples = trace.posterior['A'].sel(subject_idx=subject_idx, session=session_label).values.flatten()
            k_samples = trace.posterior['k'].sel(subject_idx=subject_idx).values.flatten()
            tau_samples = trace.posterior['tau'].sel(subject_idx=subject_idx).values.flatten()
            
            n_posterior_samples = len(v_alc_samples)
            all_sim_choices, all_sim_rts = [], []

            for i in range(n_posterior_samples):
                sim_choices, sim_rts = simulate_lba_race(
                    n_trials=500, v_alc=v_alc_samples[i], v_soc=v_soc_samples[i],
                    A=A_samples[i], k=k_samples[i % len(k_samples)], tau=tau_samples[i % len(tau_samples)]
                )
                all_sim_choices.extend(sim_choices)
                all_sim_rts.extend(sim_rts)

            all_sim_choices = np.array(all_sim_choices)
            all_sim_rts = np.minimum(np.array(all_sim_rts), 120.0)
            is_alc_sim = (all_sim_choices == 0)
            
            real_data_slice = all_data[(all_data['subj_idx'] == subject_label) & (all_data['session_type'] == session_label)]
            is_alc_real = (real_data_slice['response'] == 0)

            pps_results.append({
                'subject': subject_label, 'session': session_label,
                'sim_prop_alcohol': np.mean(is_alc_sim),
                'real_prop_alcohol': np.mean(is_alc_real) if not is_alc_real.empty else np.nan,
                'sim_mean_rt_alcohol': np.mean(all_sim_rts[is_alc_sim]) if np.any(is_alc_sim) else np.nan,
                'real_mean_rt_alcohol': real_data_slice.loc[is_alc_real, 'rt'].mean(),
                'sim_mean_rt_social': np.mean(all_sim_rts[~is_alc_sim]) if not np.all(is_alc_sim) else np.nan,
                'real_mean_rt_social': real_data_slice.loc[~is_alc_real, 'rt'].mean(),
                'drift_diff_score': np.mean(v_alc_samples - v_soc_samples)
            })

    pps_results_df = pd.DataFrame(pps_results)
    pps_csv_path = os.path.join(OUTPUT_DIR, "Posterior_Predictive_Check_Results.csv")
    pps_results_df.to_csv(pps_csv_path, index=False)
    print(f"   Posterior predictive check results saved to: {pps_csv_path}")
    
    print("\n=== Consolidated Analysis Complete ===")


if __name__ == '__main__':
    main()



--- Loading Data and Trace for Winning Model: M_SC ---

--- Section 1: Calculating Convergence Diagnostics for M_SC ---

R-hat (should be close to 1.0):
  Mean: 1.0000, Min: 1.0000, Max: 1.0000

Effective Sample Size (ESS - higher is better):
  Mean: 14567, Min: 2798, Max: 25076

--- Section 2: Generating Diagnostic Trace Plots for M_SC ---
   Saved group-level trace plot to: Winning_Model_Analysis_M_SC_EL_Only\Trace_Plots\traceplot_group_parameters.png
   Saved subject-level session-varying trace plot to: Winning_Model_Analysis_M_SC_EL_Only\Trace_Plots\traceplot_subject_'R1'_session_params.png
   Saved subject-level fixed trace plot to: Winning_Model_Analysis_M_SC_EL_Only\Trace_Plots\traceplot_subject_'R1'_fixed_params.png

--- Section 3: Extracting Parameter Estimates from M_SC ---


Extracting Parameters: 100%|██████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 282.15it/s]


   Parameter estimates saved to: Winning_Model_Analysis_M_SC_EL_Only\Winning_Model_Parameters.csv

--- Section 4: Running Posterior Predictive Simulation for M_SC ---


Simulating Subjects: 100%|█████████████████████████████████████████████████████████████| 23/23 [01:17<00:00,  3.36s/it]

   Posterior predictive check results saved to: Winning_Model_Analysis_M_SC_EL_Only\Posterior_Predictive_Check_Results.csv

=== Consolidated Analysis Complete ===





In [2]:
# =============================================================================
# LBA Model Posterior Predictive Check via Tertile Analysis
#
# Description:
# This script performs a detailed posterior predictive check on the winning LBA
# model (M_SC) by comparing the distributions of simulated and observed
# reaction times (RTs).
#
# Instead of just comparing means, this analysis splits the RT data for each
# subject, session, and choice type (alcohol vs. social) into three equal-sized
# bins (tertiles) and compares the boundaries of these bins. This provides a
# much richer view of the model's ability to capture the entire shape of the
# RT distribution, especially its tails (i.e., the slowest responses).
#
# The script will:
# 1. Load the winning model trace and the real behavioral data.
# 2. For each subject and session:
#    a. Run a large number of simulations using the fitted model parameters.
#    b. Calculate the 33rd and 66th percentile values (the tertile boundaries)
#       for both the real RTs and the simulated RTs.
#    c. Do this separately for alcohol and social choices.
# 3. Save the results to a comprehensive CSV file for plotting and analysis.
# =============================================================================

# --- 1. Import Necessary Libraries ---
import arviz as az
import os
import pandas as pd
import numpy as np
from tqdm import tqdm

# --- 2. Helper Functions ---

def simulate_lba_race(n_trials, v_alc, v_soc, A, k, tau):
    """
    Simulates a set of LBA trials for a single set of parameters.
    This uses the mean of the posterior for a stable simulation.
    """
    b = A + k
    s = 1.0  # Drift rate standard deviation fixed for identifiability
    
    # Generate random start points and drift rates for each trial
    start_alc = np.random.uniform(0, A, n_trials)
    start_soc = np.random.uniform(0, A, n_trials)
    drift_alc = np.maximum(1e-8, np.random.normal(v_alc, s, n_trials))
    drift_soc = np.maximum(1e-8, np.random.normal(v_soc, s, n_trials))
    
    # Calculate finish times for each accumulator
    time_alc = (b - start_alc) / drift_alc
    time_soc = (b - start_soc) / drift_soc
    
    # Determine the winner (choice) and the finishing time (RT)
    simulated_choices = (time_alc > time_soc).astype(int) # 0 for alcohol, 1 for social
    winner_rt = np.minimum(time_alc, time_soc)
    
    # Add non-decision time to the finishing time
    simulated_rts = winner_rt + tau
    return simulated_choices, simulated_rts

def calculate_tertiles(rt_data):
    """
    Calculates the tertile boundaries (33rd and 66th percentiles) for a given
    array of reaction times.
    """
    if len(rt_data) < 3:  # Need at least 3 data points to calculate tertiles
        return np.nan, np.nan
    
    tertile_1_boundary = np.percentile(rt_data, 33.3)
    tertile_2_boundary = np.percentile(rt_data, 66.7)
    
    return tertile_1_boundary, tertile_2_boundary

# --- 3. Main Execution Block ---
def main():
    """
    Main function to run the entire tertile analysis.
    """
    # --- Configuration ---
    WINNING_MODEL_NAME = "M_SCT"
    DATA_FILE_PATH = r'C:\\Users\\drfox\\LBA_Gemini\\aIC_Choice.csv'
    RESULTS_DIR = 'LBA_Model_Power_Set'
    OUTPUT_DIR = 'Winning_Model_Analysis_M_SCT'
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    N_SIMULATIONS = 10000 # Use a large number for a stable RT distribution

    # --- Load Data and Trace ---
    print(f"--- Loading data and trace for winning model: {WINNING_MODEL_NAME} ---")
    trace_file_path = os.path.join(RESULTS_DIR, f"trace_{WINNING_MODEL_NAME}.nc")
    if not os.path.exists(trace_file_path):
        print(f"ERROR: Trace file not found at '{trace_file_path}'")
        return
    if not os.path.exists(DATA_FILE_PATH):
        print(f"ERROR: Data file not found at '{DATA_FILE_PATH}'")
        return

    trace = az.from_netcdf(trace_file_path)
    all_data = pd.read_csv(DATA_FILE_PATH)
    
    subject_labels = pd.Categorical(all_data['subj_idx']).categories.tolist()
    session_labels = pd.Categorical(all_data['session_type'], categories=['early', 'late', 'pun'], ordered=True).categories.tolist()
    n_subjects = len(subject_labels)
    n_sessions = len(session_labels)
    
    tertile_results = []

    print("\n--- Running Tertile Analysis ---")
    # --- Main Analysis Loop ---
    for subject_idx in tqdm(range(n_subjects), desc="Processing Subjects"):
        for session_idx in range(n_sessions):
            subject_label = subject_labels[subject_idx]
            session_label = session_labels[session_idx]

            # --- Get Mean Posterior Parameters for Simulation ---
            # We use the mean of the posterior distribution for each parameter
            # to run a single, representative simulation for this subject/session.
            v_alc_mean = trace.posterior['v_alcohol'].sel(subject_idx=subject_idx, session=session_label).values.mean()
            v_soc_mean = trace.posterior['v_social'].sel(subject_idx=subject_idx, session=session_label).values.mean()
            A_mean = trace.posterior['A'].sel(subject_idx=subject_idx, session=session_label).values.mean()
            # k and tau do not vary by session in the M_SC model
            k_mean = trace.posterior['k'].sel(subject_idx=subject_idx).values.mean()
            tau_mean = trace.posterior['tau'].sel(subject_idx=subject_idx).values.mean()

            # --- Run Simulation ---
            sim_choices, sim_rts = simulate_lba_race(
                n_trials=N_SIMULATIONS,
                v_alc=v_alc_mean, v_soc=v_soc_mean,
                A=A_mean, k=k_mean, tau=tau_mean
            )
            # Apply the 120s cap to simulated data
            sim_rts = np.minimum(sim_rts, 120.0)

            # --- Get Real Data ---
            real_data_slice = all_data[(all_data['subj_idx'] == subject_label) & (all_data['session_type'] == session_label)]

            # --- Calculate Tertiles for Alcohol Choices ---
            sim_alc_rts = sim_rts[sim_choices == 0]
            real_alc_rts = real_data_slice[real_data_slice['response'] == 0]['rt'].values
            
            sim_alc_t1, sim_alc_t2 = calculate_tertiles(sim_alc_rts)
            real_alc_t1, real_alc_t2 = calculate_tertiles(real_alc_rts)
            
            # --- Calculate Tertiles for Social Choices ---
            sim_soc_rts = sim_rts[sim_choices == 1]
            real_soc_rts = real_data_slice[real_data_slice['response'] == 1]['rt'].values

            sim_soc_t1, sim_soc_t2 = calculate_tertiles(sim_soc_rts)
            real_soc_t1, real_soc_t2 = calculate_tertiles(real_soc_rts)

            # --- Store Results ---
            tertile_results.append({
                'subject': subject_label,
                'session': session_label,
                'sim_alc_tertile1': sim_alc_t1,
                'sim_alc_tertile2': sim_alc_t2,
                'real_alc_tertile1': real_alc_t1,
                'real_alc_tertile2': real_alc_t2,
                'sim_soc_tertile1': sim_soc_t1,
                'sim_soc_tertile2': sim_soc_t2,
                'real_soc_tertile1': real_soc_t1,
                'real_soc_tertile2': real_soc_t2
            })
            
    # --- Save Final DataFrame ---
    tertile_df = pd.DataFrame(tertile_results)
    output_csv_path = os.path.join(OUTPUT_DIR, "Tertile_Analysis_Results.csv")
    tertile_df.to_csv(output_csv_path, index=False)
    
    print(f"\n--- Analysis Complete ---")
    print(f"Tertile analysis results saved to: {output_csv_path}")

# --- Run the main function ---
if __name__ == '__main__':
    main()


--- Loading data and trace for winning model: M_SCT ---

--- Running Tertile Analysis ---


Processing Subjects: 100%|█████████████████████████████████████████████████████████████| 23/23 [00:03<00:00,  6.18it/s]


--- Analysis Complete ---
Tertile analysis results saved to: Winning_Model_Analysis_M_SCT\Tertile_Analysis_Results.csv





In [3]:
# =============================================================================
# LBA Model - RT Distribution Histogram Plotter
#
# Description:
# This script provides a detailed visual posterior predictive check by generating
# frequency histograms that directly compare the distribution of observed (real)
# reaction times (RTs) against the distribution of simulated RTs from the
# winning LBA model (M_SC).
#
# The script will:
# 1. Load the winning model trace and the real behavioral data.
# 2. For each subject and each session (early, late, pun):
#    a. Run a large simulation using the subject's fitted model parameters.
#    b. Create a plot with two panels:
#       - Left Panel: Overlaid histograms of real vs. simulated RTs for
#         ALCOHOL choices.
#       - Right Panel: Overlaid histograms of real vs. simulated RTs for
#         SOCIAL choices.
# 3. Save each plot as a separate PNG file in a dedicated output directory.
#
# This visualization is crucial for diagnosing model misfit. A good model
# will generate simulated RT distributions that closely match the shape,
# center, and spread of the real data.
# =============================================================================

# --- 1. Import Necessary Libraries ---
import arviz as az
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# --- 2. Helper Function ---

def simulate_lba_race(n_trials, v_alc, v_soc, A, k, tau):
    """
    Simulates a set of LBA trials for a single set of parameters.
    """
    b = A + k
    s = 1.0
    start_alc = np.random.uniform(0, A, n_trials)
    start_soc = np.random.uniform(0, A, n_trials)
    drift_alc = np.maximum(1e-8, np.random.normal(v_alc, s, n_trials))
    drift_soc = np.maximum(1e-8, np.random.normal(v_soc, s, n_trials))
    time_alc = (b - start_alc) / drift_alc
    time_soc = (b - start_soc) / drift_soc
    simulated_choices = (time_alc > time_soc).astype(int)
    winner_rt = np.minimum(time_alc, time_soc)
    simulated_rts = winner_rt + tau
    return simulated_choices, simulated_rts

# --- 3. Main Execution Block ---
def main():
    """
    Main function to run the simulation and generate all plots.
    """
    # --- Configuration ---
    WINNING_MODEL_NAME = "M_SCT"
    DATA_FILE_PATH = r'C:\\Users\\drfox\\LBA_Gemini\\aIC_Choice.csv'
    RESULTS_DIR = 'LBA_Model_Power_Set'
    OUTPUT_DIR = 'Winning_Model_Analysis_M_SCT'
    PLOT_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "RT_Histogram_Plots")
    os.makedirs(PLOT_OUTPUT_DIR, exist_ok=True)
    
    N_SIMULATIONS = 10000

    # --- Load Data and Trace ---
    print(f"--- Loading data and trace for model: {WINNING_MODEL_NAME} ---")
    trace_file_path = os.path.join(RESULTS_DIR, f"trace_{WINNING_MODEL_NAME}.nc")
    if not os.path.exists(trace_file_path):
        print(f"ERROR: Trace file not found at '{trace_file_path}'")
        return

    trace = az.from_netcdf(trace_file_path)
    all_data = pd.read_csv(DATA_FILE_PATH)
    
    subject_labels = pd.Categorical(all_data['subj_idx']).categories.tolist()
    session_labels = ["early", "late", "pun"]

    print(f"\n--- Generating and saving {len(subject_labels) * len(session_labels)} histogram plots ---")
    
    # --- Main Plotting Loop ---
    for subject_label in tqdm(subject_labels, desc="Processing Subjects"):
        subject_idx = subject_labels.index(subject_label)
        
        for session_label in session_labels:
            
            # --- Get Mean Posterior Parameters ---
            v_alc_mean = trace.posterior['v_alcohol'].sel(subject_idx=subject_idx, session=session_label).values.mean()
            v_soc_mean = trace.posterior['v_social'].sel(subject_idx=subject_idx, session=session_label).values.mean()
            A_mean = trace.posterior['A'].sel(subject_idx=subject_idx, session=session_label).values.mean()
            k_mean = trace.posterior['k'].sel(subject_idx=subject_idx).values.mean()
            tau_mean = trace.posterior['tau'].sel(subject_idx=subject_idx).values.mean()

            # --- Run Simulation ---
            sim_choices, sim_rts = simulate_lba_race(
                n_trials=N_SIMULATIONS, v_alc=v_alc_mean, v_soc=v_soc_mean,
                A=A_mean, k=k_mean, tau=tau_mean
            )
            sim_rts = np.minimum(sim_rts, 120.0) # Apply 120s cap

            # --- Get Real Data Slice ---
            real_data_slice = all_data[(all_data['subj_idx'] == subject_label) & (all_data['session_type'] == session_label)]

            # --- Create Plot ---
            fig, axes = plt.subplots(1, 2, figsize=(18, 7), sharey=True)
            fig.suptitle(f'RT Distributions for Subject: {subject_label} | Session: {session_label}', fontsize=18)

            # --- Alcohol Choice Panel ---
            real_alc_rts = real_data_slice[real_data_slice['response'] == 0]['rt']
            sim_alc_rts = sim_rts[sim_choices == 0]
            
            # Define bins dynamically to cover the range of the real data
            max_rt = max(real_alc_rts.max() if not real_alc_rts.empty else 20, 20)
            bins = np.linspace(0, max_rt, 50)

            sns.histplot(real_alc_rts, bins=bins, ax=axes[0], color='royalblue', label='Observed RTs', stat='density')
            sns.histplot(sim_alc_rts, bins=bins, ax=axes[0], color='orange', alpha=0.6, label='Simulated RTs', stat='density')
            axes[0].set_title('Alcohol Choices', fontsize=14)
            axes[0].set_xlabel('Reaction Time (s)', fontsize=12)
            axes[0].set_ylabel('Density', fontsize=12)
            axes[0].legend()
            axes[0].set_xlim(0, max_rt)

            # --- Social Choice Panel ---
            real_soc_rts = real_data_slice[real_data_slice['response'] == 1]['rt']
            sim_soc_rts = sim_rts[sim_choices == 1]
            
            max_rt = max(real_soc_rts.max() if not real_soc_rts.empty else 20, 20)
            bins = np.linspace(0, max_rt, 50)
            
            sns.histplot(real_soc_rts, bins=bins, ax=axes[1], color='royalblue', label='Observed RTs', stat='density')
            sns.histplot(sim_soc_rts, bins=bins, ax=axes[1], color='orange', alpha=0.6, label='Simulated RTs', stat='density')
            axes[1].set_title('Social Choices', fontsize=14)
            axes[1].set_xlabel('Reaction Time (s)', fontsize=12)
            axes[1].set_ylabel('') # Hide y-axis label for clarity
            axes[1].legend()
            axes[1].set_xlim(0, max_rt)

            # --- Save and Close Plot ---
            plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make room for suptitle
            plot_filename = f"RT_Histogram_{subject_label}_{session_label}.png"
            plot_filepath = os.path.join(PLOT_OUTPUT_DIR, plot_filename)
            plt.savefig(plot_filepath, dpi=150)
            plt.close(fig)

    print(f"\n--- All plots saved successfully to: {PLOT_OUTPUT_DIR} ---")

# --- Run the main function ---
if __name__ == '__main__':
    main()


--- Loading data and trace for model: M_SCT ---

--- Generating and saving 69 histogram plots ---


Processing Subjects: 100%|█████████████████████████████████████████████████████████████| 23/23 [01:20<00:00,  3.49s/it]


--- All plots saved successfully to: Winning_Model_Analysis_M_SCT\RT_Histogram_Plots ---



