In [1]:
import pandas as pd
import numpy as np
from scipy.optimize import minimize

In [2]:
def softmax(x):
    exp = np.exp(x)
    return exp / np.sum(exp)

In [3]:
def rescorla_wagner(q_val, epsilon_rew, epsilon_pun, epsilon_omi, reward):
    if reward > 0:
        return q_val + epsilon_rew * (reward - q_val)

    if reward < 0:
        return q_val + epsilon_pun * (reward - q_val)

    return q_val + epsilon_omi * (reward - q_val)

In [4]:
def reward(r_t, beta):
    if r_t > 0: return beta
    if r_t < 0: return -beta
    return 0

In [5]:
def log_likelihood(cues, actions, rewards, epsilon_rew, epsilon_pun, epsilon_omi, beta):
    n_stimuli = len(set(cues))
    n_actions = len(set(actions))

    q_vals = np.zeros((n_stimuli, n_actions))

    log_likelihood = 0

    for t, a_t in enumerate(actions):
        s_t = cues[t] - 1
        r_t = reward(rewards[t], beta)

        probs = softmax(q_vals[s_t])
        log_likelihood += np.log(probs[a_t])

        # Update the Q-values using Rescorla-Wagner
        q_vals[s_t, a_t] = rescorla_wagner(
            q_val = q_vals[s_t, a_t],
            epsilon_rew = epsilon_rew,
            epsilon_pun = epsilon_pun,
            epsilon_omi = epsilon_omi,
            reward = r_t
        )

    return log_likelihood

In [6]:
data = pd.read_csv("gen_data.csv")

In [7]:
def loss(params, cues, actions, rewards):
    epsilon_rew, epsilon_pun, epsilon_omi, beta = params
    return -log_likelihood(cues, actions, rewards, epsilon_rew, epsilon_pun, epsilon_omi, beta)

In [8]:
epsilon_bounds = (0.0000001, 0.99999)
beta_bounds = (0.0001, 9.9999)

In [9]:
min_loss = []

for subject_id in data.ID.unique():
    subject = data[ data.ID == subject_id ]

    cues = subject.cue.tolist()
    actions = subject.pressed.tolist()
    rewards = subject.outcome.tolist()

    res = minimize(
        fun = loss,
        x0 = [0.5, 0.5, 0.5, 5],
        bounds = [epsilon_bounds, epsilon_bounds, epsilon_bounds, beta_bounds],
        args = (cues, actions, rewards),
        method = "Nelder-Mead"
    )

    min_loss.append(res.fun)
    print(res.x)

[0.99998952 0.21427158 0.0673047  2.93919209]
[0.09184663 0.50362627 0.20397953 2.91049233]
[0.01854342 0.43550712 0.08017434 4.6164769 ]
[0.99999    0.20228824 0.09650828 2.82105765]
[0.29704141 0.17111278 0.05540489 3.44248434]
[0.36525298 0.10410854 0.03405016 2.89337266]
[0.99999    0.16442843 0.05305727 3.13243937]
[0.04837378 0.05324237 0.44897175 6.23900455]
[0.31431061 0.09867216 0.02668184 4.01211625]
[0.6123934  0.25267447 0.12794605 2.6286623 ]


In [10]:
np.sum(min_loss)

2792.919405767867