In [6]:
# =========================
# Objective-channel incompatibility in DDM (PyDDM)
# FULL fit (robust LL) vs partial objective fits (ACC-only, RTc-only, RTe-only)
# Self-contained, robust across PyDDM API variants.
# =========================

!pip -q install pyddm pandas numpy

import numpy as np
import pandas as pd
import pyddm as ddm
from pyddm import Model
from pyddm.models import DriftConstant, NoiseConstant, BoundConstant, OverlayNonDecision, ICPointSourceCenter
from pyddm.sample import Sample
from pyddm.functions import fit_adjust_model
from pyddm.models import LossRobustLikelihood

# -------------------------
# Settings
# -------------------------
dt    = 0.005
T_dur = 3.0
eps   = 1e-12
rng   = np.random.default_rng(0)

# TRUE generator
v_true   = 0.9
a_true   = 1.2
ter_true = 0.25
N        = 800

# Grid for partial fits (keep moderate)
v_grid   = np.linspace(-3.0,  3.0, 25)
a_grid   = np.linspace( 0.5,  2.5, 21)
ter_grid = np.linspace( 0.05, 0.50, 10)

# -------------------------
# Model builders
# -------------------------
def model_const(v, a, ter):
    return Model(
        drift=DriftConstant(drift=float(v)),
        noise=NoiseConstant(noise=1),
        bound=BoundConstant(B=float(a)),
        overlay=OverlayNonDecision(nondectime=float(ter)),
        IC=ICPointSourceCenter(),
        dt=dt,
        T_dur=T_dur,
        choice_names=("correct", "error"),
    )

def model_fittable():
    return Model(
        drift=DriftConstant(drift=ddm.Fittable(minval=-5, maxval=5)),
        noise=NoiseConstant(noise=1),
        bound=BoundConstant(B=ddm.Fittable(minval=0.3, maxval=2.5)),
        overlay=OverlayNonDecision(nondectime=ddm.Fittable(minval=0.05, maxval=0.8)),
        IC=ICPointSourceCenter(),
        dt=dt,
        T_dur=T_dur,
        choice_names=("correct", "error"),
    )

# -------------------------
# Solution utilities (API-robust)
# -------------------------
def sol_t_domain(sol):
    if hasattr(sol, "t_domain"):
        td = sol.t_domain
        return np.array(td() if callable(td) else td, dtype=float)
    if hasattr(sol, "model") and hasattr(sol.model, "t_domain"):
        td = sol.model.t_domain
        return np.array(td() if callable(td) else td, dtype=float)
    return np.arange(0, T_dur + dt/2, dt, dtype=float)

def sol_pdf_corr_err(sol):
    # Prefer Solution.pdf('correct'/'error') if present
    if hasattr(sol, "pdf") and callable(sol.pdf):
        pc = np.array(sol.pdf("correct"), dtype=float)
        pe = np.array(sol.pdf("error"), dtype=float)
        return pc, pe
    # Fallback
    pc = np.array(sol.pdf_corr(), dtype=float)
    pe = np.array(sol.pdf_err(), dtype=float)
    return pc, pe

def pred_stats(sol):
    t = sol_t_domain(sol)
    pc, pe = sol_pdf_corr_err(sol)
    L = min(len(t), len(pc), len(pe))
    t, pc, pe = t[:L], pc[:L], pe[:L]

    pC = float(np.sum(pc) * dt)
    pE = float(np.sum(pe) * dt)

    mRTc = float(np.sum(t * pc) * dt / max(pC, eps))
    mRTe = float(np.sum(t * pe) * dt / max(pE, eps))
    return {"acc": pC, "RTc": mRTc, "RTe": mRTe, "pE": pE}

# -------------------------
# Robust choice normalization
# -------------------------
def normalize_choice_to_int(series):
    """
    Return float Series with 1=correct, 0=error, NaN if undecidable.
    Handles strings, bools, numeric, mixed, NaN.
    """
    s = series.copy()

    if s.dtype == object or pd.api.types.is_string_dtype(s):
        is_missing = series.isna()
        s_str = s.astype(str).str.strip().str.lower()
        out = pd.Series(np.nan, index=s.index, dtype="float")

        out.loc[~is_missing & s_str.isin(["correct","c","1","true","t"])] = 1.0
        out.loc[~is_missing & s_str.isin(["error","e","0","false","f"])]  = 0.0

        rem = out.isna() & (~is_missing)
        if rem.any():
            num = pd.to_numeric(series[rem], errors="coerce")
            out.loc[rem] = (num > 0).astype(float)
        return out

    if pd.api.types.is_bool_dtype(s):
        return s.astype(float)

    num = pd.to_numeric(s, errors="coerce")
    out = pd.Series(np.nan, index=s.index, dtype="float")
    ok = np.isfinite(num.to_numpy())
    out.loc[ok] = (num.loc[ok] > 0).astype(float)
    return out

# -------------------------
# Simulate data (Solution.sample recommended)
# -------------------------
sol_true = model_const(v_true, a_true, ter_true).solve()
samp = sol_true.sample(N)

df = samp.to_pandas_dataframe().rename(columns={"RT":"rt","choice":"choice"})

# Filter RT first (avoid NaNs propagating)
df = df[np.isfinite(df["rt"]) & (df["rt"] >= 0) & (df["rt"] <= T_dur)].copy()

# Normalize choice
df["choice"] = normalize_choice_to_int(df["choice"])
df = df[np.isfinite(df["choice"])].copy()
df["choice"] = df["choice"].astype(int)
df = df.reset_index(drop=True)

# Observed summary
obs = {
    "N": int(len(df)),
    "N_correct": int((df["choice"]==1).sum()),
    "N_error": int((df["choice"]==0).sum()),
    "acc": float(df["choice"].mean()) if len(df) else float("nan"),
    "RTmin": float(df["rt"].min()) if len(df) else float("nan"),
    "RTmax": float(df["rt"].max()) if len(df) else float("nan"),
    "RTc_mean": float(df.loc[df["choice"]==1,"rt"].mean()) if (df["choice"]==1).any() else float("nan"),
    "RTe_mean": float(df.loc[df["choice"]==0,"rt"].mean()) if (df["choice"]==0).any() else float("nan"),
}
print("=== OBSERVED (simulated) ===")
print(obs)
print("TRUE generator:", {"v":v_true,"a":a_true,"Ter":ter_true})

# -------------------------
# FULL fit (robust LL)
# -------------------------
samp_fit = Sample.from_pandas_dataframe(df, rt_column_name="rt", choice_column_name="choice")
mf = fit_adjust_model(sample=samp_fit, model=model_fittable(), lossfunction=LossRobustLikelihood, verbose=False)

p = mf.parameters()
theta_FULL = {"v": float(p["drift"]["drift"]), "a": float(p["bound"]["B"]), "Ter": float(p["overlay"]["nondectime"])}
print("\n=== ESTIMATED PARAMS ===")
print("theta_FULL (robust LL):", theta_FULL)

# -------------------------
# Trial-level LL evaluator
# -------------------------
def trial_LL(v, a, ter, df_eval):
    sol = model_const(v, a, ter).solve()
    pc, pe = sol_pdf_corr_err(sol)

    rt = df_eval["rt"].to_numpy(float)
    ch = df_eval["choice"].to_numpy(int)

    ok = np.isfinite(rt) & (rt >= 0) & (rt <= T_dur)
    rt = rt[ok]; ch = ch[ok]
    if len(rt) == 0:
        return -np.inf

    idx = np.clip(np.rint(rt/dt).astype(int), 0, int(T_dur/dt))
    maskC = (ch == 1)

    dens = np.empty(len(rt), dtype=float)
    dens[maskC]  = pc[idx[maskC]]
    dens[~maskC] = pe[idx[~maskC]]
    return float(np.sum(np.log(np.maximum(dens, eps))))

LL_FULL = trial_LL(theta_FULL["v"], theta_FULL["a"], theta_FULL["Ter"], df)

# -------------------------
# Partial objective grid-fits
# -------------------------
def fit_ACC_only(obs_acc):
    best = {"loss": np.inf, "v": None, "a": None, "Ter": None, "pred_acc": None}
    for v in v_grid:
        for a in a_grid:
            for ter in ter_grid:
                st = pred_stats(model_const(v,a,ter).solve())
                loss = (st["acc"] - obs_acc)**2
                if loss < best["loss"]:
                    best = {"loss": float(loss), "v": float(v), "a": float(a), "Ter": float(ter), "pred_acc": float(st["acc"])}
    return best

def fit_RT_only(which, obs_mean):
    best = {"loss": np.inf, "v": None, "a": None, "Ter": None, "pred_rt": None}
    for v in v_grid:
        for a in a_grid:
            for ter in ter_grid:
                st = pred_stats(model_const(v,a,ter).solve())
                pred = st["RTc"] if which=="correct" else st["RTe"]
                loss = (pred - obs_mean)**2
                if loss < best["loss"]:
                    best = {"loss": float(loss), "v": float(v), "a": float(a), "Ter": float(ter), "pred_rt": float(pred)}
    return best

print("\n=== Partial objective grid fits ===")
theta_ACC = fit_ACC_only(obs["acc"])
theta_RTc = fit_RT_only("correct", obs["RTc_mean"])
theta_RTe = fit_RT_only("error",   obs["RTe_mean"])

print("theta_ACC:", {k: theta_ACC[k] for k in ["v","a","Ter","loss","pred_acc"]})
print("theta_RTc:", {k: theta_RTc[k] for k in ["v","a","Ter","loss","pred_rt"]})
print("theta_RTe:", {k: theta_RTe[k] for k in ["v","a","Ter","loss","pred_rt"]})

# -------------------------
# Cross-evaluate all fits on FULL trial-level LL
# -------------------------
rows = []
def add(name, th):
    LL = trial_LL(th["v"], th["a"], th["Ter"], df)
    rows.append({"name":name, "v":th["v"], "a":th["a"], "Ter":th["Ter"], "LL_total":LL, "ΔLL_vs_FULL": LL-LL_FULL})

add("FULL (robust LL)", theta_FULL)
add("ACC-only (grid)",  theta_ACC)
add("RTc-only (grid)",  theta_RTc)
add("RTe-only (grid)",  theta_RTe)

T = pd.DataFrame(rows).sort_values("LL_total", ascending=False).reset_index(drop=True)

print("\n=== FULL trial-level LL ===")
print(f"LL(theta_FULL) = {LL_FULL:.6f}")
for r in rows[1:]:
    print(f"LL({r['name']}) = {r['LL_total']:.6f}   ΔLL vs FULL = {r['LL_total']-LL_FULL:.6f}")

print("\nTABLE:")
display(T)

=== OBSERVED (simulated) ===
{'N': 765, 'N_correct': 692, 'N_error': 73, 'acc': 0.9045751633986928, 'RTmin': 0.33541649750138824, 'RTmax': 2.9733155682716976, 'RTc_mean': 1.2237400060121515, 'RTe_mean': 1.1156190544684663}
TRUE generator: {'v': 0.9, 'a': 1.2, 'Ter': 0.25}


Info: Params [0.98179677 1.1567299  0.26172472] gave 834.2686545236294
Info:pyddm:Params [0.98179677 1.1567299  0.26172472] gave 834.2686545236294



=== ESTIMATED PARAMS ===
theta_FULL (robust LL): {'v': 0.9817967720225952, 'a': 1.156729904345945, 'Ter': 0.2617247191093848}

=== Partial objective grid fits ===
theta_ACC: {'v': 2.25, 'a': 0.5, 'Ter': 0.5, 'loss': 5.679691762354787e-09, 'pred_acc': 0.9046505271284463}
theta_RTc: {'v': -0.5, 'a': 1.3, 'Ter': 0.05, 'loss': 1.3518126796917895e-07, 'pred_rt': 1.2233723359590207}
theta_RTe: {'v': 0.0, 'a': 0.8, 'Ter': 0.5, 'loss': 5.37980770550862e-08, 'pred_rt': 1.1153871103436106}

=== FULL trial-level LL ===
LL(theta_FULL) = -834.268655
LL(ACC-only (grid)) = -4025.403297   ΔLL vs FULL = -3191.134643
LL(RTc-only (grid)) = -1790.185140   ΔLL vs FULL = -955.916486
LL(RTe-only (grid)) = -2658.362122   ΔLL vs FULL = -1824.093467

TABLE:


Unnamed: 0,name,v,a,Ter,LL_total,ΔLL_vs_FULL
0,FULL (robust LL),0.981797,1.15673,0.261725,-834.268655,0.0
1,RTc-only (grid),-0.5,1.3,0.05,-1790.18514,-955.916486
2,RTe-only (grid),0.0,0.8,0.5,-2658.362122,-1824.093467
3,ACC-only (grid),2.25,0.5,0.5,-4025.403297,-3191.134643
