In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
%matplotlib notebook

In [None]:
path = 'Data/StochasticTask/'
full = 'BehaviourDriftingRewardsFull.csv'
striat = 'BehaviourDriftingRewardsStriatum.csv'
hipp = 'BehaviourDriftingRewardsHippocampus.csv'
data_full = pd.DataFrame.from_csv(os.path.join(path, full))
data_hipp = pd.DataFrame.from_csv(os.path.join(path, hipp))
data_striat = pd.DataFrame.from_csv(os.path.join(path, striat))

In [None]:
data_hipp.head()

In [None]:
terminal_states = [3,4,5,6]
def is_common_or_rare(action, out):
    left_outcomes = (3, 4)
    right_outcomes = (5, 6)
    if action == 'left' and out in left_outcomes:
        return 'common'
    elif action == 'left' and out in right_outcomes:
        return 'rare'
    elif action == 'right' and out in left_outcomes:
        return 'rare'
    elif action == 'right' and out in right_outcomes:
        return 'common'
    else:
        raise ValueError('The combination of action and outcome does not make sense')

In [None]:
def add_relevant_columns(dataframe):
    dataframe['PreviousAction'] = dataframe.groupby(['Agent_nr'])['Action1'].shift(1)
    dataframe['PreviousStart'] = dataframe.groupby(['Agent_nr'])['StartState'].shift(1)
    dataframe['PreviousReward'] = dataframe.groupby(['Agent_nr'])['Reward'].shift(1)
    dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)
    dataframe['Transition'] = np.vectorize(is_common_or_rare)(dataframe['Action1'], dataframe['Terminus'])
    dataframe['PreviousTransition'] = dataframe.groupby(['Agent_nr'])['Transition'].shift(1)

In [None]:
add_relevant_columns(data_full)
add_relevant_columns(data_striat)
add_relevant_columns(data_hipp)

In [None]:
data_striat[data_striat['Agent_nr']==0].head()

In [None]:
def compute_mean_stay_prob(data):
    means = data[data['Trial']>0].groupby(['PreviousTransition', 'PreviousReward'])['Stay'].mean()
    sems = data.groupby(['PreviousTransition', 'PreviousReward'])['Stay'].sem()
    return means, sems

In [None]:
mean_striat, sem_striat = compute_mean_stay_prob(data_striat)
mean_hipp, sem_hipp = compute_mean_stay_prob(data_hipp)
mean_full, sem_full = compute_mean_stay_prob(data_full)

In [None]:
mean_hipp

In [None]:
mean_mf = list(mean_striat)
sem_mf = list(sem_striat)
mean_mb = list(mean_hipp)
sem_mb = list(sem_hipp)

mean_combined = list(mean_full)
sem_combined = list(sem_full)



In [None]:
def plot_daw_style(ax, data, yerr, title=''):
    lightgray = '#d1d1d1'
    darkgray = '#929292'

    bar_width= 0.2

    bars1 = data[:2][::-1]
    bars2 = data[2:][::-1]
    errs1 = yerr[:2][::-1]
    errs2 = yerr[2:][::-1]

    # The x position of bars
    r1 = np.array([0.125, 0.625]) 
    r2 = [x + bar_width + .05 for x in r1]
    
    plt.sca(ax)
    
    plt.bar(r1, bars1, width=bar_width, color='blue', yerr=errs1, capsize=4)
    plt.bar(r2, bars2, width=bar_width, color='red', yerr=errs2, capsize=4)
    plt.xticks([r+ bar_width/2 +.025 for r in r1], ['Rewarded', 'Unrewarded'], fontsize=12)
    plt.yticks(fontsize=12)
    plt.title(title, fontsize=16)
    plt.ylim([.5, 1])
    plt.xlim([0, 1])


In [None]:
fig, axes = plt.subplots(1,3, figsize= (10,2.5), sharey=True)

plot_daw_style(axes[0], mean_mf, sem_mf, title='Striatum')
plot_daw_style(axes[1], mean_mb, sem_mb, title='Hippocampus')
plot_daw_style(axes[2], mean_combined, sem_combined, title='Full model')



leg = axes[0].legend(['Common', 'Rare'], fontsize=12, frameon=False, handlelength=0.7)
plt.sca(axes[0])
plt.ylabel('Stay probability', fontsize=12)


plt.show()

In [None]:
data_hipp[data_hipp.Transition=='rare']['Transition'].count() /data_hipp.Transition.count()

In [None]:
b = data_hipp[data_hipp.PreviousReward==1]
b = b[b.Transition=='rare']

In [None]:
b

In [None]:
help(manifold.Isomap.fit)