In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib notebook

In [None]:
import os

data_folder = 'Data/DeterministicTask'
fig_folder = 'figures/'

mb = 'BehaviourDriftingRewardsProperMB.csv'
mf = 'BehaviourDriftingRewardsStriatum23July.csv'
full = 'BehaviourDriftingRewardsFull23July.csv'



In [None]:
# Load data
data_striatum = pd.DataFrame.from_csv(os.path.join(data_folder,mf))
data_hipp = pd.DataFrame.from_csv(os.path.join(data_folder,mb))
data_full = pd.DataFrame.from_csv(os.path.join(data_folder,full))

In [None]:
data_striatum.head()

In [None]:
def add_relevant_columns(dataframe):
    dataframe['PreviousAction'] = dataframe.groupby(['Agent_nr'])['Action1'].shift(1)
    dataframe['PreviousStart'] = dataframe.groupby(['Agent_nr'])['StartState'].shift(1)
    dataframe['PreviousReward'] = dataframe.groupby(['Agent_nr'])['Reward'].shift(1)
    dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)
    dataframe['SameStart'] = (dataframe.StartState == dataframe.PreviousStart)

In [None]:
add_relevant_columns(data_striatum)
add_relevant_columns(data_hipp)
add_relevant_columns(data_full)

In [None]:
data_striatum[data_striatum['Agent_nr']==1].head()

In [None]:
def compute_mean_stay_prob(data):
    means = data[data['Trial']>0].groupby(['PreviousReward', 'SameStart'])['Stay'].mean()
    sems = data.groupby(['PreviousReward', 'SameStart'])['Stay'].sem()
    return means[::-1], sems[::-1]

In [None]:
mean_mf, sem_mf = compute_mean_stay_prob(data_striatum)
mean_mb, sem_mb = compute_mean_stay_prob(data_hipp)
mean_full, sem_full = compute_mean_stay_prob(data_full)

In [None]:
mean_mf = list(mean_mf)
mean_mb = list(mean_mb)
mean_full = list(mean_full)

In [None]:
sem_mf = list(sem_mf)
sem_mb = list(sem_mb)
sem_full = list(sem_full)

## Plotting

In [None]:
def plot_doll_style(ax, data, yerr, title=''):
    lightgray = '#d1d1d1'
    darkgray = '#929292'

    bar_width= 0.2

    bars1 = data[:2]
    bars2 = data[2:]
    errs1 = yerr[:2]
    errs2 = yerr[2:]

    # The x position of bars
    r1 = np.arange(len(bars1))*.8 + 1.5* bar_width
    r2 = [x + bar_width for x in r1]
    
    plt.sca(ax)
    
    plt.bar(r1, bars1, width=bar_width, color=lightgray, yerr=errs1, capsize=4)
    plt.bar(r2, bars2, width=bar_width, color=darkgray, yerr=errs2, capsize=4)
    plt.ylabel('Stay probability', fontsize=15)
    plt.xticks([r+ bar_width/2 for r in r1], ['same', 'different'], fontsize=15)
    plt.yticks(fontsize=15)
    plt.title(title, fontsize=18)
    plt.ylim([0.47, .96])
    plt.xlim([0, 1.6])

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

### Plot the model based learner 
fig, axes = plt.subplots(1,3, figsize= (9,5))


plot_doll_style(axes[0], mean_mb, sem_mb, title='Hippocampus')
plot_doll_style(axes[1], mean_mf, sem_mf, title='Striatum')
plot_doll_style(axes[2], mean_full, sem_full, title='Combined')



leg = axes[1].legend(['Reward', 'No reward'], fontsize=12, frameon=False, handlelength=.7)
leg.set_title('Previous outcome', prop = {'size':12})

plt.tight_layout()
plt.savefig(os.path.join(fig_folder, 'DeterministicTaskResults.svg'))
plt.show()

In [None]:
list(means)

In [None]:
data.groupby(['SameStart','PreviousReward']).mean()

In [None]:
.25+.125