## Model fitting demo for the RLWM task
This is a demo for how to fit the RLWM (Reinforcement Learning - Working Memory) model to data from the CNTRACTS RLWM task.

There should be one sample data file ('demo_data.csv') included next to this script.
The script demonstrates how the model can be fit to an individual participant data using maximum likelihood estimation (MLE).

Author: Krishn Bera (krishn_bera@brown.edu)

### Summary of the model

The model used here is the a variant of the original RLWM model (Collins & Frank, 2012). The model captures learning behaviors through an interactive dual-process model of reinforcement learning (RL) and working memory (WM) processes. The RL system learns incrementally accumulating reward values of states and actions. The WM system is characterized as quick, one-shot learning system that is capacity-limited and prone to forgetting. Through this parallel recruitment of dual systems, the model is able to capture the WM-specific effects of load and delay on learning.

--

Collins, A.G.E. and Frank, M.J. (2012), How much of reinforcement learning is working memory, not reinforcement learning? A behavioral, computational, and neurogenetic analysis. European Journal of Neuroscience, 35: 1024-1035. https://doi.org/10.1111/j.1460-9568.2011.07980.x


### Summary of the parameters estimated by the model

- alpha -> learning rate of the RL system
- phi -> WM decay rate
- rho -> weight of the WM system
- gamma -> perseveration parameter
- epsilon -> noise/attention-lapse parameter
- C -> working memory capacity


In [1]:
# import packages
import numpy as np
import random
import pandas as pd
from scipy.optimize import minimize

In [2]:
# specify the RLWM model config as a dictionary
# the dictionary contains the model name as the key, and a dictionary of parameter information as the value
# the parameter information should include the parameter names and the parameter bounds (specified as two lists for the lower and upper bounds)

model_config_rl = {
    "RLWM": {
        "params": ["alpha", "phi", "rho", "gamma", "epsilon"], # Note: C parameter is not included here
        "param_bounds": [[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0]],
        },
}

In [3]:
# specify the fixed parameters and optimization settings

model_rl = 'RLWM' # the model name (must be one of the keys in model_config_rl)
num_actions = 3 # the number of actions in the RLWM task
beta = 100 # the inverse temperature parameter in the softmax function

C_list = [2, 3, 4, 5] # C is the working memory capacity. C_list is the list of C values to be iterated over during the optimization
n_restarts = 20 # the number of random restarts for each C value during the optimization 

In [4]:
# function to calculate softmax of the Q values
# beta is the inverse temperature parameter
def softmax(q_val, beta):
    q_val = np.array(q_val)*beta
    q_val = np.exp(q_val)
    q_val = q_val / np.sum(q_val)
    return q_val

In [5]:
# function to sample random starting points for random restarts during the optimization
# the function returns a list of randomly sampled starting points for the parameters
def sample_uniform_starting_pts(model_rl):

    alpha_low = model_config_rl[model_rl]['param_bounds'][0][0]
    alpha_high = model_config_rl[model_rl]['param_bounds'][1][0]

    phi_low = model_config_rl[model_rl]['param_bounds'][0][1]
    phi_high = model_config_rl[model_rl]['param_bounds'][1][1]

    rho_low = model_config_rl[model_rl]['param_bounds'][0][2]
    rho_high = model_config_rl[model_rl]['param_bounds'][1][2]

    gamma_low = model_config_rl[model_rl]['param_bounds'][0][3]
    gamma_high = model_config_rl[model_rl]['param_bounds'][1][3]

    epsilon_low = model_config_rl[model_rl]['param_bounds'][0][4]
    epsilon_high = model_config_rl[model_rl]['param_bounds'][1][4]

    
    alpha = random.uniform(alpha_low, alpha_high)
    phi = random.uniform(phi_low, phi_high)
    rho = random.uniform(rho_low, rho_high)
    gamma = random.uniform(gamma_low, gamma_high)
    epsilon = random.uniform(epsilon_low, epsilon_high)

    starting_pts = [alpha, phi, rho, gamma, epsilon]
    
    return starting_pts

In [6]:
# function to check for boundary violations of the parameters
# the function returns True if there is a boundary violation, and False otherwise
def check_boundary_violations(alpha, phi, rho, gamma, epsilon):
    
    if alpha < model_config_rl[model_rl]['param_bounds'][0][0] or alpha > model_config_rl[model_rl]['param_bounds'][1][0]:
        return True
    if phi < model_config_rl[model_rl]['param_bounds'][0][1] or phi > model_config_rl[model_rl]['param_bounds'][1][1]:
        return True
    if rho < model_config_rl[model_rl]['param_bounds'][0][2] or rho > model_config_rl[model_rl]['param_bounds'][1][2]:
        return True
    if gamma < model_config_rl[model_rl]['param_bounds'][0][3] or gamma > model_config_rl[model_rl]['param_bounds'][1][3]:
        return True
    if epsilon < model_config_rl[model_rl]['param_bounds'][0][4] or epsilon > model_config_rl[model_rl]['param_bounds'][1][4]:
        return True
    
    return False

In [7]:
# function to calculate the log likelihood of the RLWM model
# this function is called by the scipy.optimize.minimize function during the optimization
def RLWM_LL(params, subj_data, num_actions, C, beta):

    # get the parameter values
    alpha = params[0] 
    phi = params[1]  
    rho = params[2] 
    gamma = params[3] 
    epsilon = params[4] 

    # check for boundary violations, return inf if there is a boundary violation
    if check_boundary_violations(alpha, phi, rho, gamma, epsilon) == True:
        return np.inf

    # get the list of blocks in the data
    block_list = np.unique(subj_data['block_id'])

    # initialize the log likelihood
    subj_ll = 0
    subj_trl_idx = 0
    
    # iterate over blocks
    for bl in block_list:

        # extract the data for the current block
        block_data = subj_data.loc[subj_data['block_id'] == bl]

        # get the set size for the current block
        set_size = len(np.unique(block_data['stim_id']))

        # get the trials, rewards, and actions for the current block
        trials = block_data['stim_id'].values
        reward_list = block_data['feedback'].values
        action_list = block_data['resp'].values

        # initialize the Q values for the RL and WM systems
        q_RL = np.ones((set_size, num_actions)) * 1/num_actions
        q_WM = np.ones((set_size, num_actions)) * 1/num_actions

        # initialize the weight term for the WM system
        weight = rho * min(1, C/set_size)
        
        # initialize the policy
        pol = np.zeros(num_actions)

        # iterate over trials
        for tr in np.arange(len(trials)):

            # get the current state
            state = int(trials[tr])

            # compute the RL and WM policies using the Q values
            pol_RL = softmax(q_RL[state, :], beta)
            pol_WM = softmax(q_WM[state, :], beta)

            # compute the mixed policy as the weighted sum of the RL and WM policies
            pol = weight * pol_WM + (1-weight) * pol_RL

            # computre the final policy as a mixture of the mixed policy and the noisy (uniform) policy
            pol_final = (1 - epsilon) * pol + epsilon * np.tile([1/num_actions], num_actions)

            # get the action
            action = int(action_list[tr])

            # get the reward
            # Note: reward is always 1 for the RLWM model if the participant responds correctly and 0 otherwise
            if reward_list[tr] == 1 or reward_list[tr] == 2:
                reward = 1
            elif reward_list[tr] == 0:
                reward = 0

            # compute the log likelihood for the trial and add it to subject's log likelihood
            subj_ll += np.log(pol_final[action])
            
            # update the Q values for the RL and WM systems
            # if the RPE is negative, the Q values are updated using the perseveration parameter gamma
            if (reward - q_RL[state, action]) >= 0:
                q_RL[state, action] = q_RL[state, action] + alpha * (reward - q_RL[state, action])
                q_WM[state, action] = q_WM[state, action] + 1 * (reward - q_WM[state, action])
            else:
                q_RL[state, action] = q_RL[state, action] + gamma * alpha * (reward - q_RL[state, action])
                q_WM[state, action] = q_WM[state, action] + gamma * 1 * (reward - q_WM[state, action])

            # WM decay on each trial, update the Q values for the WM system
            q_WM = q_WM + phi * ((1/num_actions)-q_WM)

            subj_trl_idx += 1
    
    # return the negative log likelihood
    return -subj_ll

In [8]:
# load csv file as a dataframe
data = pd.read_csv('demo_data.csv')

In [9]:
# inspect the data
data

Unnamed: 0,subj_idx,ns,block_id,stim_id,rt,resp,corr_resp,acc,feedback,FBProb,stim_ctr,delay,early_late,pcor
0,268,4,0,0,0.523952,0,0,1,1,0.2,1.0,0,0,0
1,268,4,0,0,0.800853,0,0,1,1,0.2,2.0,1,-1,1
2,268,4,0,1,0.539283,1,2,0,0,0.8,1.0,0,0,0
3,268,4,0,1,0.816042,2,2,1,2,0.8,2.0,0,0,0
4,268,4,0,2,0.953407,0,0,1,1,0.5,1.0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
354,268,5,9,1,0.587742,0,0,1,2,0.8,10.0,5,1,9
355,268,5,9,2,0.574108,2,2,1,1,0.5,10.0,5,1,7
356,268,5,9,0,0.561117,0,0,1,1,0.2,10.0,8,1,9
357,268,5,9,4,0.564499,2,2,1,1,0.2,10.0,5,1,7


In [10]:
# initialize the best negative log likelihood to infinity
best_negLL = np.inf

# iterate over the C values to find the best fit C value
for C in C_list:

    # perform multiple restarts of the optimization to avoid local minima
    for n_r in range(n_restarts):

        # specify arguments for the RLWM_LL function; these are fixed arguments passed to the RLWM_LL function
        args = (data, num_actions, C, beta)
        
        # sample starting points for the parameters
        x0 = sample_uniform_starting_pts(model_rl)

        # optimize the parameters using the Nelder-Mead MLE method
        res = minimize(RLWM_LL, x0, args=args, method='Nelder-Mead')

        # get the negative log likelihood with the best fit parameters
        m_negLL = RLWM_LL(res.x, data, num_actions, C, beta)

        # if the negative log likelihood is lower than the current best, update the best parameters
        if m_negLL < best_negLL:
            best_x = res.x
            best_x = np.append(best_x, C)
            best_negLL = m_negLL

In [11]:
# print the best fit parameters
print('Best fit parameters: \n')
for i in range(len(best_x)-1):
    print(model_config_rl[model_rl]['params'][i], " - ", np.round(best_x[i], 5))
print("C - ", best_x[-1])

Best fit parameters: 

alpha  -  0.02315
phi  -  0.23377
rho  -  1.0
gamma  -  0.42864
epsilon  -  0.00617
C -  3.0
