In [1]:
#!/usr/bin/env python3
# ---------------------------------------------------------
#  fit_human_reversal.py  (Python ≥3.8, NumPy ≥1.19, SciPy ≥1.7)
#  ---------------------------------------------------------
#  Model–behaviour fitting for the FitzGerald et al. (2017)
#  probabilistic reversal‑learning dataset (OpenNeuro ds000222).
#
#  • CoDA  : prospective TD + retrospective splitting
#  • ANCCR : retrospective learning only
#
#  Outputs:
#  ─ summary_human_fits.csv
#  ─ figures/human_fit_BIC.png
# ---------------------------------------------------------
import pathlib, tarfile, json, math, tempfile, shutil, urllib.request, itertools
import numpy as np
import pandas  as pd
from   scipy.optimize import minimize
import matplotlib.pyplot as plt

# ---------------------------------------------------------
# 1. Fetch / unpack one behavioural tar‑ball ----------------
# ---------------------------------------------------------
DATA_URL = (
    "https://openneuro.org/crn/datasets/ds000222/snapshots/"
    "00001/files/ds000222_RL_behaviour.tar?download=1"
)
data_root = pathlib.Path("data_raw")
beh_tar   = data_root / "ds000222_beh.tar"

if not beh_tar.exists():
    print("⏳ Downloading behavioural archive …")
    data_root.mkdir(exist_ok=True, parents=True)
    urllib.request.urlretrieve(DATA_URL, beh_tar)

print("📦 Extracting …")
tmpdir = tempfile.TemporaryDirectory()
with tarfile.open(beh_tar) as tar:
    tar.extractall(path=tmpdir.name)

# ---------------------------------------------------------
# 2. Load each subject’s CSV → list of trials ---------------
# ---------------------------------------------------------
def get_subject_files(root):
    return sorted(pathlib.Path(root).glob("sub-*/*beh.tsv"))

def load_trials_tsv(file_path):
    tsv = pd.read_csv(file_path, sep="\t")
    # Convert to simple vectors
    choice = tsv["choice"].values       # 0 (for left) / 1 (for right)
    reward = tsv["reward"].values.astype(int)  # 0/1
    return choice, reward

beh_files = get_subject_files(tmpdir.name)
print(f"✅ Found {len(beh_files)} behavioural files")

# ---------------------------------------------------------
# 3. Utility – softmax & neg‑log‑lik ------------------------
# ---------------------------------------------------------
def softmax(q, beta):
    expq = np.exp(beta * (q - np.max(q)))
    return expq / expq.sum()

# ---------------------------------------------------------
# 4‑A.  CoDA likelihood (with context splitting) ------------
# ---------------------------------------------------------
def negloglik_coda(params, choices, outcomes):
    """params = (alpha, beta, splitThreshold)"""
    alpha, beta, th = params
    # State 0 initial Q‑values
    Q   = [[0.0, 0.0]]            # list of [Q_left,Q_right] per state
    ctx = 0                       # active state index
    recent_err = []
    nll = 0.0
    for c, r in zip(choices, outcomes):
        p_left = softmax(Q[ctx], beta)[0]
        prob   = p_left if c == 0 else 1 - p_left
        nll   -= np.log(prob + 1e-10)
        # TD update on chosen arm
        delta  = r - Q[ctx][c]
        Q[ctx][c] += alpha * delta
        # store absolute error for split test
        recent_err.append(abs(delta))
        if len(recent_err) > 4:
            recent_err.pop(0)
        # split if four consecutive |δ|>th
        if len(recent_err) == 4 and min(recent_err) > th:
            Q.append(Q[ctx].copy())   # clone values
            ctx = len(Q)-1
            recent_err.clear()
    return nll

# ---------------------------------------------------------
# 4‑B.  ANCCR‑like retrospective likelihood -----------------
# ---------------------------------------------------------
def negloglik_anccr(params, choices, outcomes):
    """params = (alpha, beta)"""
    alpha, beta = params
    Q = [0.0, 0.0]
    nll = 0.0
    for c, r in zip(choices, outcomes):
        p_left = softmax(Q, beta)[0]
        prob   = p_left if c == 0 else 1 - p_left
        nll   -= np.log(prob + 1e-10)
        # retrospective update only at reward
        if r == 1:
            Q[c] += alpha * (1 - Q[c])
        else:
            Q[c] += alpha * (0 - Q[c])
    return nll

# ---------------------------------------------------------
# 5. Fit each model per subject ----------------------------
# ---------------------------------------------------------
from   functools import partial
results = []

for beh_file in beh_files:
    sub_id = beh_file.parts[-2]  # "sub‑XX"
    choice, rew = load_trials_tsv(beh_file)
    # --- CoDA fit (3 params) ---
    bounds_coda = [(0.001,0.8),(1,12),(0.1,1.5)]
    res_coda = minimize(
        partial(negloglik_coda, choices=choice, outcomes=rew),
        x0=[0.2,5.0,0.6], bounds=bounds_coda, method="L-BFGS-B"
    )
    # --- ANCCR fit (2 params) ---
    bounds_an  = [(0.001,0.8),(1,12)]
    res_an     = minimize(
        partial(negloglik_anccr, choices=choice, outcomes=rew),
        x0=[0.2,5.0], bounds=bounds_an, method="L-BFGS-B"
    )
    n          = len(choice)
    bic_coda   = len(res_coda.x)*np.log(n) + 2*res_coda.fun
    bic_an     = len(res_an.x)*np.log(n)   + 2*res_an.fun
    results.append(
        dict(sub=sub_id,
             coda_nll=res_coda.fun, coda_bic=bic_coda,
             anccr_nll=res_an.fun,  anccr_bic=bic_an,
             coda_alpha=res_coda.x[0], coda_beta=res_coda.x[1], coda_th=res_coda.x[2],
             an_alpha=res_an.x[0],    an_beta=res_an.x[1])
    )
    print(f"{sub_id}:  BIC  CoDA={bic_coda:6.1f}  <  ANCCR={bic_an:6.1f}")

summary_df = pd.DataFrame(results).sort_values("sub")
summary_df.to_csv("summary_human_fits.csv", index=False)
print("📑 Saved per‑subject fit stats → summary_human_fits.csv")

# ---------------------------------------------------------
# 6.  Group‑level BIC bar‑plot ------------------------------
# ---------------------------------------------------------
mean_bic = summary_df[["coda_bic","anccr_bic"]].mean()
plt.figure(figsize=(4,4))
plt.bar(["CoDA","ANCCR"], mean_bic, color=["skyblue","salmon"])
plt.ylabel("Mean BIC  (lower is better)")
plt.title("Human reversal learning – model comparison")
plt.tight_layout()
pathlib.Path("figures").mkdir(exist_ok=True)
plt.savefig("figures/human_fit_BIC.png", dpi=300)
print("🖼️ Saved group BIC plot → figures/human_fit_BIC.png")


⏳ Downloading behavioural archive …


HTTPError: HTTP Error 404: Not Found