#Fit two step task with an associative algorithm

In [1]:
import os
import numpy as np
import glob
import csv
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats, optimize
from pandas import DataFrame, Series
import seaborn as sns
import random as rd
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
import scipy.stats
import patsy
from scipy.optimize import minimize
from scipy.optimize import basinhopping
from sklearn import linear_model
import multiprocessing
import random
from scipy.stats import norm
from scipy.stats import beta
##Code for analysis of fMRI experiment

In [515]:
ntrials = 200
alpha = .3
m=3.0
p=.4

In [516]:
def initialize():
    #initialize data structures
    objects = ['a1','a2','b1','b2','c1','c2']
    states = ['a','b','c','terminal']
    actions = ['1','2']

    #Initialize transition, rewards, values matrics
    transitions = {}
    rewards = {}
    V = {}
    associations = {}
    for s in states:
        transitions[s] = {}
        rewards[s] = {}
        for a in actions:
            transitions[s][a] = {}
            rewards[s][a] = 0
    for o in objects:
        V[o] = 0
        associations[o] = {}

    for o1 in associations:
        for o2 in objects:
            if o1 != o2: #avoid self associations
                associations[o1][o2] = 0

    #fill in transition probs
    for s1 in states:
        for a in actions:
            for s2 in states:
                transitions[s1][a][s2] = 0         
    transitions['b']['1']['terminal'] = 1
    transitions['b']['2']['terminal'] = 1
    transitions['c']['1']['terminal'] = 1
    transitions['c']['2']['terminal'] = 1
    transitions['a']['1']['b'] = .7
    transitions['a']['1']['c'] = .3
    transitions['a']['2']['b'] = .3
    transitions['a']['2']['c'] = .7

    #set up reward probs
    rewards['b']['1'] = .7
    rewards['b']['2'] = .3
    rewards['c']['1'] = .3
    rewards['c']['2'] = .7
#     rewards['b']['1'] = .8
#     rewards['b']['2'] = .6
#     rewards['c']['1'] = .2
#     rewards['c']['2'] = .5
    
    return transitions, rewards, V, associations, objects, states, actions

In [517]:
#gradually shift reward probabilitites to encourage learning
def update_rewards(rewards):
    for s in ['b','c']: #only update end states
        for a in actions:
            shift = 0#np.random.normal(0,.025)
            if (rewards[s][a] + shift > .75) or (rewards[s][a] + shift < .25): #reflecting boundaries
                rewards[s][a] = rewards[s][a] - shift
            else:
                rewards[s][a] = rewards[s][a] + shift
    return rewards

In [518]:
def get_reward(state,action,rewards):
    return scipy.stats.bernoulli.rvs(rewards[state][action])

In [519]:
def next_state(state,action):
    probs = map(lambda x: transitions[state][action][x], states)
    return np.random.choice(a=states,p=probs)

In [520]:
def get_action(state,V,last_a_action):
    Vs = map(lambda a: V[state+a],actions) #get values of each object in state
    if state == 'a': #model perseveration
        if last_a_action == '1':
            Vs[0] = Vs[0] + p
        else:
            Vs[1] = Vs[1] + p
    normalizing_constant = np.sum(map(lambda v: np.exp(m*v),Vs)) #get total value of state
    probs = map(lambda v: np.exp(v*m), Vs)
    probs = probs / normalizing_constant
    return np.random.choice(a=actions,p=probs)
# get_action('a',V,'1')

In [521]:
def update_associations(state,new_state,action,associations, nsteps):
    if new_state != 'terminal':
        nsteps +=1
        for a in actions:
            associations[state + action][new_state + a] = associations[state + action][new_state + a] + 1
            associations[new_state + a][state + action] = associations[state + action][new_state + a] #make symmetric
    return associations, nsteps

In [522]:
def update_value(rew,state,new_state,action,V,associations,nsteps):
    if new_state != 'terminal':
        delta = rew + max(V[new_state + actions[0]],V[new_state + actions[1]]) - V[state+action]
    else:
        delta = rew - V[state+action]
        
    V[state+action] = V[state+action] + alpha*delta
    
    #percolate value one step back, weighted by the strength of association
    for o in associations[state+action]:
        delta = rew - V[o]
        V[o] = V[o] + associations[state+action][o] * 4*alpha* delta / nsteps
    return V
# transitions, rewards, V, associations, objects, states, actions = initialize()
# nsteps = 0
# print V
# V = update_value(0,'a','b','1',V,associations,nsteps)
# print V

In [523]:
#run trial
def take_step(state,rewards,associations,V,nsteps,output,last_a_action):
    if state == 'terminal': #end state
        rewards = update_rewards(rewards)
        return rewards, associations, V, nsteps 

    #do standard MDP stuff
    action = get_action(state,V,last_a_action)
    new_state = next_state(state,action)
    rew = get_reward(state,action,rewards)
    if state == 'a':
        last_a_action = action
        
    #log what's happening
    output['rew'].append(rew)
    output['action'].append(action)
    output['newstate'].append(new_state)
    output['state'].append(state)
    
    #update values and associations
    associations,nsteps = update_associations(state,new_state,action,associations,nsteps) #update associations
    value = update_value(rew,state,new_state,action,V,associations,nsteps)
    
    return take_step(new_state,rewards,associations,V,nsteps,output,last_a_action)

In [524]:
ntrials = 5000
transitions, rewards, V, associations, objects, states, actions = initialize()
nsteps = 0.0
last_a_action = '1'
output = {'state':[],'action':[],'newstate':[],'rew':[]}
for i in range(ntrials):
    rewards, associations, V, nsteps = take_step('a',rewards,associations,V,nsteps,output,last_a_action)
# output = pd.DataFrame(output)

In [525]:
#analyze common and rare transitions and add to DF
output['transition_type'] = []
for n,s in enumerate(output['newstate']):
    if s  == 'terminal':
        output['transition_type'].append('end')
    elif (s == 'b' and output['action'][n] == '1') or (s == 'c' and output['action'][n] == '2'):
        output['transition_type'].append('common')
    elif (s == 'b' and output['action'][n] == '2') or (s == 'c' and output['action'][n] == '1'):
        output['transition_type'].append('rare')
output = pd.DataFrame(output)

In [526]:
#calculate stay and switch
output['stay'] = np.nan
a_indices =  output[output['state'] == 'a'].index
a_indices = a_indices.values
stay_or_switch = ['np.nan']
for n,idx in enumerate(a_indices):
    if n>0:
        last_action = output.iloc[a_indices[n-1]].action
        current_action = output.iloc[a_indices[n]].action
        if last_action == current_action:
            stay_or_switch.append('stay')
        else:
            stay_or_switch.append('switch')
output.ix[output['state'] == 'a','stay']  = stay_or_switch

In [527]:
results = {'rewarded':{'common':[],'rare':[]},'nonrewarded':{'common':[],'rare':[]}}
for r in ['rewarded','nonrewarded']:
    if r == 'rewarded':
        rew = 1
    else:
        rew = 0
        
    indices = output[(output['newstate']=='terminal') & (output['rew'] == rew)].index[:-1]
    transition_type = output.iloc[indices-1]['transition_type'].values
    action = output.iloc[indices+1]['stay'].values
    
    
    for c in ['common','rare']:
        choices = list(action[transition_type == c])
        results[r][c] = choices.count('stay')/float(len(choices))

results = pd.DataFrame(results)
print results

        nonrewarded  rewarded
common     0.555556  0.726651
rare       0.640646  0.636927


In [449]:
for o1 in associations:
    for o2 in associations:
        if o1 != o2:
            print o1,o2,associations[o1][o2]/nsteps
print V
print output

a1 a2 0.0
a1 b1 0.6168
a1 b2 0.6168
a1 c2 0.2622
a1 c1 0.2622
a2 a1 0.0
a2 b1 0.0358
a2 b2 0.0358
a2 c2 0.0852
a2 c1 0.0852
b1 a1 0.6168
b1 a2 0.0358
b1 b2 0.0
b1 c2 0.0
b1 c1 0.0
b2 a1 0.6168
b2 a2 0.0358
b2 b1 0.0
b2 c2 0.0
b2 c1 0.0
c2 a1 0.2622
c2 a2 0.0852
c2 b1 0.0
c2 b2 0.0
c2 c1 0.0
c1 a1 0.2622
c1 a2 0.0852
c1 b1 0.0
c1 b2 0.0
c1 c2 0.0
{'a1': 0.6806570538546526, 'a2': 0.7285485558948297, 'b1': 0.5480558591056993, 'b2': 0.2978784262829568, 'c2': 0.5933646005688671, 'c1': 0.1767515953230546}
     action  newstate  rew state transition_type    stay
0         1         c    0     a            rare  np.nan
1         1  terminal    1     c             end     NaN
2         2         c    0     a          common  switch
3         2  terminal    1     c             end     NaN
4         2         b    0     a            rare    stay
5         2  terminal    0     b             end     NaN
6         1         c    0     a            rare  switch
7         1  terminal    0     c       