In [180]:
import numpy as np
from scipy.optimize import minimize

In [181]:
def fit_evl1(parm, data):    
    chi = 0
    y = data[:, 0].astype(int)
    x = data[:, 1:3]
    nt = len(y)

    eta = np.exp(parm[0]) / (1 + np.exp(parm[0]))
    w = np.exp(parm[1]) / (1 + np.exp(parm[1]))
    c = np.exp(parm[2])
    if c > 10: c = 10  # clamp

    Q = np.zeros(4)

    for tt in range(1, nt):
        t = tt - 1
        yt = np.array([y[t] == i for i in range(1, 5)], dtype=float)
        ytt = np.array([y[tt] == i for i in range(1, 5)], dtype=float)
        xt = x[t]

        win = abs(xt[0]) / 100.0
        loss = abs(xt[1]) / 100.0

        Q = Q + eta * yt * ((1 - w) * win - w * loss - Q)
        th = (t / 10.0) ** c
        z = np.clip(th * Q, -100, 100)
        s = np.exp(z)
        p = s / np.sum(s)
        pp = 0.0001 + 0.9998 * p
        logpp = np.log(pp)



        chi += np.dot(logpp, ytt)

    return -2 * chi


In [182]:
def bin1(parm, data):
    if len(parm) != 4:
        raise ValueError(f"bin1() expects 4 parameters: got {len(parm)}")

    a, b, c, d = parm
    p = np.array([a, b, c, d])
    
    # Normalize to ensure valid probabilities
    if np.any(p < 0) or not np.isclose(np.sum(p), 1.0):
        return 1e10  # Penalize invalid parameter sets

    y = data[:, 0].astype(int)
    nt = len(y)
    chi = 0

    for t in range(1, nt):
        yt = np.array([y[t] == i for i in range(1, 5)], dtype=float)
        pp = np.clip(0.0001 + 0.9998 * p, 1e-8, 1.0)
        chi += np.dot(np.log(pp), yt)

    return -2 * chi


In [187]:
def run_igt_to_sgt(filename_igt='IGT4001to4002.txt', filename_sgt='SGT4001to4002.txt', sub=1, IGT=True):
    igt = np.loadtxt(filename_igt, usecols=(1, 2, 3))
    sgt = np.loadtxt(filename_sgt, usecols=(1, 2, 3))

    nt = 120
    if sub == 1:
        print("Doing sub1")
        dataIGT = igt[:nt]
        dataSGT = sgt[:nt]
    else:
        print("Doing sub2")
        dataIGT = igt[nt:2*nt]
        dataSGT = sgt[nt:2*nt]

    if IGT:
        dataE = dataIGT
        dataG = dataSGT
    else:
        dataE = dataSGT
        dataG = dataIGT

    print(dataIGT[:5])
    print(dataSGT[:5])

    parm0 = [0.0, 0.0, 0.0]
    res_evl = minimize(lambda p: fit_evl1(p, dataE), parm0, method='Nelder-Mead')
    parmE, chiE1 = res_evl.x, res_evl.fun

    y = dataE[:, 0]
    parm0b = [np.mean(y == i) for i in range(1, 5)]
    res_b = minimize(lambda p: bin1(p, dataE), parm0b, method='Nelder-Mead')
    chiB1 = res_b.fun

    eta = np.exp(parmE[0]) / (1 + np.exp(parmE[0]))
    w = np.exp(parmE[1]) / (1 + np.exp(parmE[1]))
    c = np.exp(parmE[2])

    print('EVL parm:')
    print('  learning   losswgt    choice')
    print([eta, w, c])
    print('Chi improvement Base Chi - EVL Chi fit:', chiB1 - chiE1)

    parmR = [0.25] * 4
    chiB2 = bin1(parmR, dataG)
    chiE2 = fit_evl1(parmE, dataG)
    print('Base Chi - EVL Chi generalization:', chiB2 - chiE2)

run_igt_to_sgt(sub=1)
run_igt_to_sgt(sub=2)
run_igt_to_sgt(sub=1, IGT=False)
run_igt_to_sgt(sub=2, IGT=False)

Doing sub1
[[  4.  50.   0.]
 [  2. 100.   0.]
 [  1. 100. 300.]
 [  4.  50.   0.]
 [  1. 100.   0.]]
[[  4.   0.  50.]
 [  3.   0. 100.]
 [  2.  50.   0.]
 [  1.   0. 525.]
 [  1. 100.   0.]]
EVL parm:
  learning   losswgt    choice
[np.float64(0.002768836307331574), np.float64(0.6006429473549582), np.float64(1.8462997529844363)]
Chi improvement Base Chi - EVL Chi fit: 37.55632083321137
Base Chi - EVL Chi generalization: 27.14039982164195
Doing sub2
[[2.00e+00 1.00e+02 1.25e+03]
 [1.00e+00 1.00e+02 2.00e+02]
 [4.00e+00 5.00e+01 0.00e+00]
 [4.00e+00 5.00e+01 0.00e+00]
 [4.00e+00 5.00e+01 0.00e+00]]
[[  1. 100.   0.]
 [  2.   0. 325.]
 [  4.   0.  50.]
 [  3.   0. 100.]
 [  1. 100.   0.]]
EVL parm:
  learning   losswgt    choice
[np.float64(0.08211490067807217), np.float64(0.9999999999999812), np.float64(1.0279594656879296e-95)]
Chi improvement Base Chi - EVL Chi fit: -3.1133664445624163
Base Chi - EVL Chi generalization: -42.72047468362416
Doing sub1
[[  4.  50.   0.]
 [  2. 100.   0.]

In [56]:
def fit_wsls(parm, data):
    chi = 0
    y = data[:, 0].astype(int)
    x = data[:, 1:3]
    nt = len(y)

    thresh = np.exp(parm[0])
    stay_rew = np.exp(parm[1])
    loss_pun = np.exp(parm[2])

    c = np.exp(parm[3])
    c = np.clip(c, 0.01, 5)

    Q = np.zeros(4)

    for tt in range(1, nt):
        t = tt - 1
        yt = np.array([y[t] == i for i in range(1, 5)], dtype=float)
        ytt = np.array([y[tt] == i for i in range(1, 5)], dtype=float)
        xt = x[t]

        win = abs(xt[0]) / 100.0
        loss = abs(xt[1]) / 100.0

        if win - loss > thresh:
            Q = yt * stay_rew
        elif loss - win > thresh:
            Q = yt * loss_pun

        th = np.clip((t / 10) ** c, 1e-5, 1e5)
        z = np.clip(th * Q, -100, 100)
        s = np.exp(z) + 1e-10
        p = s / np.sum(s)
        pp = np.clip(0.0001 + 0.9998 * p, 1e-8, 1.0)

        chi += np.dot(np.log(pp), ytt)

    return -2 * chi


In [76]:
def run_wsls(filename_igt='IGT4001to4002.txt', filename_sgt='SGT4001to4002.txt', sub=1, IGT=True):

    print("Subject", sub, "Model", IGT)
    # Load and prepare data
    igt = np.loadtxt(filename_igt, usecols=(1, 2, 3))
    sgt = np.loadtxt(filename_sgt, usecols=(1, 2, 3))
    nt = 120

    if sub == 1:
        dataIGT = igt[:nt]
        dataSGT = sgt[:nt]
    else:
        dataIGT = igt[nt:2*nt]
        dataSGT = sgt[nt:2*nt]

    # For fitting and generalization
    if IGT:
        dataE = dataIGT  # Estimation data
        dataG = dataSGT  # Generalization test
    else:
        dataE = dataSGT  # Estimation data
        dataG = dataIGT  # Generalization test

    # Print data sources
    print("Estimation data:", "IGT" if np.all(dataE == dataIGT) else "SGT")
    print("Generalization data:", "IGT" if np.all(dataG == dataIGT) else "SGT")

    # Initial WSLS parameter values
    parm0 = [0.0, 0.0, 0.0, 0.0]
    res_wsls = minimize(lambda p: fit_wsls(p, dataE), parm0)
    parmWSLS = res_wsls.x
    chiE1 = res_wsls.fun

    # Baseline fit
    y = dataE[:, 0]
    parm0b = [np.mean(y == i) for i in range(1, 5)]
    res_b = minimize(lambda p: bin1(p, dataE), parm0b)
    chiB1 = res_b.fun

    # Convert parameters
    thresh = np.exp(parmWSLS[0])
    stayRew = np.exp(parmWSLS[1])
    lossPun = np.exp(parmWSLS[2])
    c = np.exp(parmWSLS[3])

    print("WSLS parameters:")
    print("  thresh   stayReward   lossPunishment   choice")
    print([thresh, stayRew, lossPun, c])

    print("Chi improvement Base Chi - WSLS Chi fit (G²):")
    paramDiff = len(parmWSLS) - 3
    print((chiB1 - chiE1) - (paramDiff * np.log(nt)))

    # Generalization
    parmR = [0.25, 0.25, 0.25, 0.25]
    chiB2 = bin1(parmR, dataG)
    chiE2 = fit_wsls(parmWSLS, dataG)
    print("Base Chi - WSLS Chi generalization:")
    print((chiB2 - chiE2) - (paramDiff * np.log(nt)))
    
run_wsls()
run_wsls(sub=2)
run_wsls(sub=1, IGT=False)
run_wsls(sub=2, IGT=False)

Subject 1 Model True
Estimation data: IGT
Generalization data: SGT
WSLS parameters:
  thresh   stayReward   lossPunishment   choice
[np.float64(1.0), np.float64(1.0), np.float64(4.102176688481295e-12), np.float64(136603.62357182716)]
Chi improvement Base Chi - WSLS Chi fit (G²):
-10.687541843199433
Base Chi - WSLS Chi generalization:
-113.79193691039309
Subject 2 Model True
Estimation data: IGT
Generalization data: SGT
WSLS parameters:
  thresh   stayReward   lossPunishment   choice
[np.float64(1.0), np.float64(1.0), np.float64(9.649065187417758e-12), np.float64(80017.56087924719)]
Chi improvement Base Chi - WSLS Chi fit (G²):
-24.73372506503035
Base Chi - WSLS Chi generalization:
-125.78309749628825
Subject 1 Model False
Estimation data: SGT
Generalization data: IGT
WSLS parameters:
  thresh   stayReward   lossPunishment   choice
[np.float64(1.0), np.float64(2.2122408225720944), np.float64(0.8850998043506944), np.float64(0.009334703414039382)]
Chi improvement Base Chi - WSLS Chi fit (

In [164]:
parm_matlab = [20.0, -0.2438, 0.0]  # Raw parameters
chi_train = fit_evl1(parm_matlab, dataE)
chi_gen = fit_evl1(parm_matlab, dataG)

print("Python result at MATLAB parameters:")
print("Train χ² =", chi_train)
print("Gen   χ² =", chi_gen)
print("Gen improvement =", bin1([0.25, 0.25, 0.25, 0.25], dataG) - chi_gen)

Python result at MATLAB parameters:
Train χ² = 656.1779333786076
Gen   χ² = 904.2341245170041
Gen improvement = -574.343661811105


In [189]:
igt = np.loadtxt("IGT4001to4002.txt", usecols=(1, 2, 3))
print(igt[120:125, :])

[[2.00e+00 1.00e+02 1.25e+03]
 [1.00e+00 1.00e+02 2.00e+02]
 [4.00e+00 5.00e+01 0.00e+00]
 [4.00e+00 5.00e+01 0.00e+00]
 [4.00e+00 5.00e+01 0.00e+00]]
