In [2]:
# =============================================================================
# Hierarchical LBA Model Fitting Script (Full Power Set)
#
# Description:
# This script implements the recommendation of Dr. Heathcote to perform a
# systematic model comparison using the "power set" of potential psychological
# mechanisms. It defines and fits 8 hierarchical Linear Ballistic Accumulator
# (LBA) models to test for the presence of:
#
# 1. Stimulus Bias (S): Differences in drift rates (v) across sessions.
# 2. Caution (C): Differences in start-point variability (A) across sessions.
# 3. Response Bias (R): Differences in the start-point component (k) across
#    sessions. This is equivalent to a start-point bias.
#
# The script fits all 8 possible combinations of these mechanisms (from a
# baseline model with none, to a full model with all three). Each model's
# trace will be saved to a .nc file, ready for a separate model comparison
# analysis.
#
# This script ONLY performs the model fitting.
# =============================================================================

# --- 1. Import Necessary Libraries ---
import pandas as pd
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import os
import traceback

# --- 2. Custom Log-Likelihood Function ---
# This is the mathematical heart of the LBA model. It's a custom function
# that tells PyMC how to calculate the log-probability of observing the
# actual data (the specific reaction times and choices) given a set of
# proposed model parameters (v, A, k, tau).
def logp(rt, choice, v_alc, v_soc, k_alc, k_soc, A, tau):
    """
    Custom LBA log-likelihood function for PyMC.
    Calculates the log-likelihood of observing the reaction times and choices.
    """
    # Define a small constant (epsilon) for numerical stability. This is crucial
    # to prevent calculations like log(0), which would result in -inf and crash
    # the sampler due to floating-point precision limits (underflow).
    epsilon = 1e-8

    # Convert the input data arrays to PyTensor tensors. This is necessary
    # because all math within a PyMC model must be done using PyTensor objects,
    # which allow for automatic differentiation.
    rt = pt.as_tensor_variable(rt)
    choice = pt.as_tensor_variable(choice)

    # --- LBA Parameter Setup ---
    # Drift rate standard deviation is fixed to 1 for model identifiability.
    s = 1.0
    # The decision threshold 'b' is the sum of start-point variability 'A'
    # and the threshold component 'k'.
    b_alc = A + k_alc
    b_soc = A + k_soc
    # Non-decision time 't0' is the parameter 'tau'.
    t0 = tau

    # 't' is the decision time: total RT minus non-decision time.
    t = rt - t0

    # --- Parameter Guards ---
    # Ensure key parameters are always positive by clipping them at epsilon.
    v_alc = pt.maximum(v_alc, epsilon)
    v_soc = pt.maximum(v_soc, epsilon)
    b_alc = pt.maximum(b_alc, epsilon)
    b_soc = pt.maximum(b_soc, epsilon)
    A = pt.maximum(A, epsilon)
    # Use a "safe" version of t for calculations to avoid errors.
    t_safe = pt.maximum(t, epsilon)

    # --- LBA PDF and CDF Calculations ---
    # These are the analytical solutions for the probability density function (pdf)
    # and cumulative distribution function (cdf) of the LBA model, as derived
    # by Brown & Heathcote (2008).
    # PDF: The probability of finishing at a specific time 't'.
    # CDF: The probability of finishing at or before time 't'.

    # For the Alcohol accumulator
    term1_alc = (b_alc - A - t_safe * v_alc) / (t_safe * s)
    term2_alc = (b_alc - t_safe * v_alc) / (t_safe * s)
    pdf_alc = (v_alc * pt.exp(pm.logp(pm.Normal.dist(0, 1), term1_alc)) -
               (b_alc - A) / t_safe * pt.exp(pm.logp(pm.Normal.dist(0, 1), term2_alc)))
    cdf_alc = 1 + ((b_alc - A - t_safe * v_alc) / A * pt.exp(pm.logp(pm.Normal.dist(0, 1), term1_alc)) -
                   (b_alc - t_safe * v_alc) / A * pt.exp(pm.logp(pm.Normal.dist(0, 1), term2_alc)))

    # For the Social accumulator
    term1_soc = (b_soc - A - t_safe * v_soc) / (t_safe * s)
    term2_soc = (b_soc - t_safe * v_soc) / (t_safe * s)
    pdf_soc = (v_soc * pt.exp(pm.logp(pm.Normal.dist(0, 1), term1_soc)) -
               (b_soc - A) / t_safe * pt.exp(pm.logp(pm.Normal.dist(0, 1), term2_soc)))
    cdf_soc = 1 + ((b_soc - A - t_safe * v_soc) / A * pt.exp(pm.logp(pm.Normal.dist(0, 1), term1_soc)) -
                   (b_soc - t_safe * v_soc) / A * pt.exp(pm.logp(pm.Normal.dist(0, 1), term2_soc)))

    # --- Final Log-Likelihood Calculation ---
    # The likelihood of a choice is the probability of the winning accumulator
    # finishing at time 't' (its PDF) multiplied by the probability of the losing
    # accumulator NOT finishing by time 't' (1 - its CDF).
    # We use pt.log and add them because we are working in log-probability space.
    # The pt.maximum(..., epsilon) calls prevent log(0) errors.
    log_likelihood = pt.switch(
        pt.eq(choice, 0), # If the choice was alcohol (0)
        pt.log(pt.maximum(pdf_alc, epsilon)) + pt.log(pt.maximum(1 - cdf_soc, epsilon)),
        # Otherwise (if the choice was social)
        pt.log(pt.maximum(pdf_soc, epsilon)) + pt.log(pt.maximum(1 - cdf_alc, epsilon))
    )

    # --- Final Stability Guard ---
    # This is a critical check. If the sampler proposes a non-decision time `tau`
    # that is longer than the actual reaction time `rt`, the decision time `t`
    # would be negative, which is physically impossible.
    # This line tells the sampler: if any `t` is negative, the log-probability
    # for this set of parameters is -infinity, so reject this step immediately.
    # Otherwise, return the sum of all trial log-likelihoods.
    return pt.switch(pt.any(t < 0), -np.inf, pt.sum(log_likelihood))


# --- 3. Model Building Function ---
# This is a single, flexible "factory" function that can build any of our 8 models.
# It uses boolean flags (S, C, R) to determine which parameters should be
# allowed to vary across the three experimental sessions.
def build_model(data, n_subjects, n_sessions, tau_upper, S=False, C=False, R=False, T=False):
    """
    A single, generalized function to build any of the 8 models.
    """
    coords = {
        "subject_idx": np.arange(n_subjects),
        "session": ["early", "late", "pun"]
    }

    with pm.Model(coords=coords) as model:
        # --- Priors for Group-Level Parameters (The Population Level) ---
        # These priors define our beliefs about the population of subjects.
        # Each subject's parameter will be drawn from these group distributions.
        # `_group_mu`: The average value of a parameter across all subjects.
        # `_group_sigma`: The standard deviation, representing individual differences.
        # The `dims="session" if S else None` logic is the core of the power set.
        # If a flag (e.g., S) is True, the parameter gets a separate group mean
        # for each session. Otherwise, it has one mean across all sessions.
        v_alc_group_mu = pm.Normal('v_alc_group_mu', mu=1, sigma=1, dims="session" if S else None)
        v_soc_group_mu = pm.Normal('v_soc_group_mu', mu=1, sigma=1, dims="session" if S else None)
        A_group_mu_log = pm.Normal('A_group_mu_log', mu=-1, sigma=1, dims="session" if C else None)
        k_group_mu_log = pm.Normal('k_group_mu_log', mu=-1, sigma=1, dims="session" if R else None)
        
        # Non-decision time (tau) is assumed to NOT be a stable trait of a subject
        tau_group_mu_log = pm.Normal('tau_group_mu_log', mu=-1, sigma=1, dims="session" if T else None)
        
        tau_group_sigma = pm.HalfNormal('tau_group_sigma', sigma=0.1)

        # Sigma parameters (subject variability) do not vary by session.
        v_alc_group_sigma = pm.HalfNormal('v_alc_group_sigma', sigma=1)
        v_soc_group_sigma = pm.HalfNormal('v_soc_group_sigma', sigma=1)
        A_group_sigma = pm.HalfNormal('A_group_sigma', sigma=0.5)
        k_group_sigma = pm.HalfNormal('k_group_sigma', sigma=0.5)
        """
        # Non-decision time (tau) is assumed to be a stable trait of a subject
        # and therefore does not vary by session.
        tau_group_mu_log = pm.Normal('tau_group_mu_log', mu=-1, sigma=1)
        tau_group_sigma = pm.HalfNormal('tau_group_sigma', sigma=0.1)
        """
        # --- Subject-Level Parameters (Non-centered Parameterization) ---
        # This is a standard and highly recommended technique for improving MCMC
        # sampling efficiency in hierarchical models. Instead of directly
        # estimating each subject's parameter, we estimate their `offset` from
        # the group mean, measured in units of standard deviations.
        v_alc_offset = pm.Normal('v_alc_offset', mu=0, sigma=1, dims=("subject_idx", "session") if S else "subject_idx")
        v_soc_offset = pm.Normal('v_soc_offset', mu=0, sigma=1, dims=("subject_idx", "session") if S else "subject_idx")
        A_offset = pm.Normal('A_offset', mu=0, sigma=1, dims=("subject_idx", "session") if C else "subject_idx")
        k_offset = pm.Normal('k_offset', mu=0, sigma=1, dims=("subject_idx", "session") if R else "subject_idx")
        tau_offset = pm.Normal('tau_offset', mu=0, sigma=1, dims=("subject_idx", "session") if T else "subject_idx")

        # --- Transform to Final Subject-Level Parameters ---
        # Here, we reconstruct the final, interpretable parameters for each
        # subject from the group parameters and the estimated offsets.
        # Formula: subject_param = group_mean + offset * group_std
        # We use pm.math.exp for parameters defined in log space (A, k, tau)
        # to ensure they are always positive.
        v_alc = pm.Deterministic('v_alcohol', v_alc_group_mu + v_alc_offset * v_alc_group_sigma, dims=("subject_idx", "session") if S else "subject_idx")
        v_soc = pm.Deterministic('v_social', v_soc_group_mu + v_soc_offset * v_soc_group_sigma, dims=("subject_idx", "session") if S else "subject_idx")
        A = pm.Deterministic('A', pm.math.exp(A_group_mu_log + A_offset * A_group_sigma), dims=("subject_idx", "session") if C else "subject_idx")
        k = pm.Deterministic('k', pm.math.exp(k_group_mu_log + k_offset * k_group_sigma), dims=("subject_idx", "session") if R else "subject_idx")
        tau = pm.Deterministic('tau', pm.math.exp(tau_group_mu_log + tau_offset * tau_group_sigma), dims=("subject_idx", "session") if T else "subject_idx")

        # --- Connect Likelihood to the Model ---
        # This is the final step where we "plug in" our custom logp function.
        # pm.Potential adds the output of our function (the total log-likelihood
        # of the data given the parameters) to the overall model log-probability.
        # The complex indexing (e.g., v_alc[...]) ensures that for each row of
        # data, we pass the correct parameter values for that specific trial's
        # subject and session into the logp function.
        pm.Potential('likelihood', logp(
            rt=data['rt'].values,
            choice=data['response'].values,
            v_alc=v_alc[data['subj_idx_code'].values, data['session_code'].values] if S else v_alc[data['subj_idx_code'].values],
            v_soc=v_soc[data['subj_idx_code'].values, data['session_code'].values] if S else v_soc[data['subj_idx_code'].values],
            k_alc=k[data['subj_idx_code'].values, data['session_code'].values] if R else k[data['subj_idx_code'].values],
            k_soc=k[data['subj_idx_code'].values, data['session_code'].values] if R else k[data['subj_idx_code'].values], # Assuming symmetric start points
            A=A[data['subj_idx_code'].values, data['session_code'].values] if C else A[data['subj_idx_code'].values],
            tau=tau[data['subj_idx_code'].values, data['session_code'].values] if T else tau[data['subj_idx_code'].values],
        ))
    return model


# --- 4. Main Execution Block ---
if __name__ == "__main__":
    # --- Configuration ---
    data_file_path = r'C:\Users\drfox\LBA_Gemini\aIC_Choice.csv'
    output_dir = 'LBA_Model_Power_Set'
    os.makedirs(output_dir, exist_ok=True)

    # --- Sampling Settings ---
    # These settings control the MCMC sampler. For the final, publication-quality
    # run, you should use a high number of draws and tuning steps.
    SAMPLING_SETTINGS = {
        "draws": 2000, # Final value
        "tune": 1500,  # Final value
        "chains": 4,   # Number of parallel chains
        "cores": 1,    # Set to number of available CPU cores
        "target_accept": 0.95, # Helps with complex models
        "init": 'advi+adapt_diag' # Robust initialization method
    }

    # --- Data Preparation ---
    print("1. Loading and preparing data for model fitting...")
    all_data = pd.read_csv(data_file_path)
    all_data = all_data.dropna(subset=['rt', 'response'])
    all_data['response'] = all_data['response'].astype(int)

    # Convert string identifiers for subject and session into integer codes (0, 1, 2...).
    # This is essential for indexing the parameter arrays inside the PyMC model.
    all_data['subj_idx_code'] = pd.Categorical(all_data['subj_idx']).codes
    all_data['session_code'] = pd.Categorical(all_data['session_type'], categories=['early', 'late', 'pun'], ordered=True).codes

    n_subjects = all_data['subj_idx'].nunique()
    n_sessions = all_data['session_type'].nunique()
    tau_upper_limit = max(0.05, all_data['rt'].min() - 0.01)

    # --- Define the 8 Models of the Power Set ---
    # This dictionary is the control center for the script. It defines each of
    # the 8 models we want to fit by specifying which mechanisms (S, C, R)
    # are allowed to vary across sessions for that model.
    models_to_fit = {
        # Baseline model: Nothing varies
        "M_base": {"S": False, "C": False, "R": False},
        # Single-mechanism models
        "M_S":    {"S": True,  "C": False, "R": False},
        "M_C":    {"S": False, "C": True,  "R": False},
        "M_R":    {"S": False, "C": False, "R": True},
        # Double-mechanism models
        "M_SC":   {"S": True,  "C": True,  "R": False},
        "M_SR":   {"S": True,  "C": False, "R": True},
        "M_CR":   {"S": False, "C": True,  "R": True},
        # Triple-mechanism model
        "M_SCR":  {"S": True,  "C": True,  "R": True},
        # Adding tau
        "M_SCT":  {"S": True,  "C": True,  "R": False, "T": True},
    }

    # --- Model Fitting Loop ---
    # This loop iterates through our dictionary, building, fitting, and saving
    # each model one by one. This is the part that will take a very long time.
    for name, params in models_to_fit.items():
        print(f"\n--- Fitting {name} ---")

        # --- ADDED: Check if the model has already been fit ---
        trace_path = os.path.join(output_dir, f'trace_{name}.nc')
        if os.path.exists(trace_path):
            print(f"   Trace for {name} already exists. Skipping.")
            continue # Skip to the next model in the loop

        print(f"   Mechanisms varying: S={params['S']}, C={params['C']}, R={params['R']}")
        try:
            # Build the model using our flexible factory function
            model = build_model(all_data, n_subjects, n_sessions, tau_upper_limit, **params)
            # Run the NUTS sampler
            with model:
                trace = pm.sample(**SAMPLING_SETTINGS)

                # Save the results immediately after a model finishes. This is crucial
                # so that if the script fails on a later model, you don't lose
                # the progress from the ones that completed successfully.
                trace.to_netcdf(trace_path)
                print(f"   Trace for {name} saved to {trace_path}")

        except Exception as e:
            print(f"❌ ERROR fitting {name}: {e}")
            traceback.print_exc()

    print("\n\n=== Full Power Set Model Fitting Script Complete ===")



1. Loading and preparing data for model fitting...

--- Fitting M_base ---
   Trace for M_base already exists. Skipping.

--- Fitting M_S ---
   Trace for M_S already exists. Skipping.

--- Fitting M_C ---
   Trace for M_C already exists. Skipping.

--- Fitting M_R ---
   Trace for M_R already exists. Skipping.

--- Fitting M_SC ---
   Trace for M_SC already exists. Skipping.

--- Fitting M_SR ---
   Trace for M_SR already exists. Skipping.

--- Fitting M_CR ---
   Trace for M_CR already exists. Skipping.

--- Fitting M_SCR ---
   Trace for M_SCR already exists. Skipping.

--- Fitting M_SCT ---
   Mechanisms varying: S=True, C=True, R=False


Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...


Convergence achieved at 7600
Interrupted at 7,599 [3%]: Average Loss = 35,165


❌ ERROR fitting M_SCT: Initial evaluation of model at starting point failed!
Starting values:
{'v_alc_group_mu': array([0.58078595, 1.65252066, 0.63474878]), 'v_soc_group_mu': array([1.60305649, 0.82584417, 0.75681287]), 'A_group_mu_log': array([-1.06411811, -1.2729501 , -0.6104757 ]), 'k_group_mu_log': array(-3.19796383), 'tau_group_mu_log': array([ 1.1849156 , -0.17553242, -1.50502301]), 'tau_group_sigma_log__': array(-2.57088267), 'v_alc_group_sigma_log__': array(-0.55836067), 'v_soc_group_sigma_log__': array(-0.31392122), 'A_group_sigma_log__': array(-1.33268931), 'k_group_sigma_log__': array(-0.30629248), 'v_alc_offset': array([[-0.40094456,  0.71203411, -0.80920281],
       [-1.17650992,  0.98310257, -0.36228311],
       [ 0.46279949, -1.15308493,  0.32646702],
       [-0.01080394,  0.76816937, -0.84467209],
       [-0.59360142, -1.28626487, -0.32531806],
       [-0.10144885,  0.34600491, -0.42366648],
       [ 0.09984128, -0.68965985, -0.74357773],
       [-1.60357831,  0.551865

Traceback (most recent call last):
  File "C:\Users\drfox\AppData\Local\Temp\ipykernel_5084\281306268.py", line 281, in <module>
    trace = pm.sample(**SAMPLING_SETTINGS)
  File "C:\Users\drfox\anaconda3\envs\lba_env\lib\site-packages\pymc\sampling\mcmc.py", line 792, in sample
    model.check_start_vals(ip)
  File "C:\Users\drfox\anaconda3\envs\lba_env\lib\site-packages\pymc\model\core.py", line 1745, in check_start_vals
    raise SamplingError(
pymc.exceptions.SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'v_alc_group_mu': array([0.58078595, 1.65252066, 0.63474878]), 'v_soc_group_mu': array([1.60305649, 0.82584417, 0.75681287]), 'A_group_mu_log': array([-1.06411811, -1.2729501 , -0.6104757 ]), 'k_group_mu_log': array(-3.19796383), 'tau_group_mu_log': array([ 1.1849156 , -0.17553242, -1.50502301]), 'tau_group_sigma_log__': array(-2.57088267), 'v_alc_group_sigma_log__': array(-0.55836067), 'v_soc_group_sigma_log__': array(-0.31392122), 'A_group

## 