# Parameter estimation
Running this notebook will to the parameter estimation for an experiment. Set the `expt` parameter to decide which experiment to analyse.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# data + modelling
import numpy as np
import numpy.matlib  # for repmat, used in calc_log_loss()
import pandas as pd
import pymc3 as pm
import math
import os
# from sklearn.metrics import log_loss
import random

# plotting
import seaborn as sns
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from matplotlib import gridspec
plt.rcParams.update({'font.size': 14})

#from plotting import plot_data
from models import ModifiedRachlinFreeSlope

## Set up our options

In [None]:
# Initialize random number generator
SEED = 123
np.random.seed(SEED)

# Define sampler options
sample_options = {'tune': 2000, 'draws': 5000,
                  'chains': 4, 'cores': 4,
                  'nuts_kwargs': {'target_accept': 0.95},
                  'random_seed': SEED}

# less ambitious sampling for testing purposes
test_sample_options = {'tune': 50, 'draws': 100,
                  'chains': 2, 'cores': 2, # 'nuts_kwargs': {'target_accept': 0.95},
                  'random_seed': SEED}

In [None]:
SHOULD_SAVE = False

In [None]:
out_dir = 'output/'

# ensure output folder exists
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    
# ensure subfolders exist
for expt in [1, 2]:
    desired = f'{out_dir}expt{expt}/'
    if not os.path.exists(desired):
        os.makedirs(desired)

Experiment specific information

NOTE: Set the `expt` variable to either 1 or 2 and run the notebook to do parameter estimation for that experiment.

In [None]:
expt = 2
data_file = f'raw_data_expt{expt}/EXPERIMENT{expt}DATA.csv'

## Import data

In [None]:
data = pd.read_csv(out_dir + data_file, index_col=False)

In [None]:
data.head()

## Modified Rachlin model (with free slope parameter)

In [None]:
mr = ModifiedRachlinFreeSlope(data)

In [None]:
mr.sample_from_posterior(test_sample_options)

### Confirm goodness

In [None]:
pm.energyplot(mr.posterior_samples)

In [None]:
pm.forestplot(mr.posterior_samples, var_names=['logk'], r_hat=True)

In [None]:
pm.forestplot(mr.posterior_samples, var_names=['logs'], r_hat=True)

In [None]:
pm.forestplot(mr.posterior_samples, var_names=['α'], r_hat=True)

## Export parameter estimate table
First we define some functions to calculate measures derived from the model.

In [None]:
parameter_estimates = mr.calc_results(expt)
parameter_estimates

In [None]:
if SHOULD_SAVE:
    parameter_estimates.to_csv(f'analysis/EXPERIMENT_{expt}_RESULTS.csv')

## Visualisation

## Group level

In [None]:
pm.forestplot(mr.posterior_samples, var_names=['group_logk', 'group_logs'], r_hat=True)

### Visualise posterior predictions for each group

In [None]:
for group in range(4):
    mr.group_plot(group)
    if SHOULD_SAVE:
        plt.savefig(f'{out_dir}expt{expt}/group{group}.pdf', bbox_inches='tight')

In [None]:
if expt is 1:
    group_name = ['Deferred, low',
                  'Online, low',
                  'Deferred, high',  
                  'Online, high']
elif expt is 2:
    group_name = ['Deferred, gain',
                  'Online, gain',
                  'Deferred, loss', 
                  'Online, loss']

In [None]:
trace = mr.posterior_samples

fig, ax = plt.subplots(1, 1, figsize=(8,8))

for i in range(4):
    logk = trace['group_logk'][:,i]
    logs = trace['group_logs'][:,i]
    ax.scatter(logk, logs, alpha=0.1, label=group_name[i])
    
leg = ax.legend()

for lh in leg.legendHandles: 
    lh.set_alpha(1)
    
ax.set(xlabel='logk', ylabel='logs', title='parameter space')

if SHOULD_SAVE:
    plt.savefig(f'{out_dir}expt{expt}/group_param_space.pdf', bbox_inches='tight')

## Visualise group mean parameter values

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))

for i in range(4):
    logk = trace['mu_logk'][:,i]
    logs = trace['mu_logs'][:,i]
    ax.scatter(logk, logs, alpha=0.1, label=group_name[i])
    
leg = ax.legend()

for lh in leg.legendHandles: 
    lh.set_alpha(1)

ax.set(xlabel='logk', ylabel='logs', title=f'Experiment {expt}')

if SHOULD_SAVE:
    plt.savefig(f'{out_dir}expt{expt}/group_mean_estimates_in_param_space.pdf', bbox_inches='tight')

## Participant level plots

In [None]:
# id=1
# pdata = data.loc[data['id'] == id]
# pdata['RB'].values[0]

In [None]:
mr.participant_plot(0)

🔥 Export all participant level plots. This takes a while to do. 🔥 

In [None]:
if SHOULD_SAVE:
    for id in range(n_participants):
        print(f'{id} of {n_participants}')
        mr.participant_plot(id)

        savename = f'{out_dir}expt{expt}/id{id}.pdf'
        plt.savefig(savename, bbox_inches='tight')

        # Close the figure to avoid very heavy plotting inside the notebook
        plt.close(plt.gcf())

## Demo figure
We are going to plot example data + parameter estimates for each condition (row) and a number of randomly chosen participants in each column.

In [None]:
def ids_in_condition(data, condition):
    '''Return a list of id's in this condition'''
    return data[data['condition'] == condition].id.unique()

In [None]:
plt.rcParams.update({'font.size': 14})

N_EXAMPLES = 3  # number of columns

fig, ax = plt.subplots(4, N_EXAMPLES, figsize=(15, 13))

# Ording of these is crucial... see the data import notebook for the key
if expt is 1:
    row_headings = ['Deferred, low',
                    'Online, low',
                    'Deferred, high',  
                    'Online, high']
elif expt is 2:
    row_headings = ['Deferred, gain',
                    'Online, gain',
                    'Deferred, loss', 
                    'Online, loss']
                
pad = 13 # in points
for axis, row_title in zip(ax[:,0], row_headings):
    axis.annotate(row_title, xy=(0, 0.5), xytext=(-axis.yaxis.labelpad - pad, 0),
                  xycoords=axis.yaxis.label, textcoords='offset points',
                  size='large', ha='center', va='center', rotation=90)
    
fig.tight_layout()

# plot stuff
for condition in [0, 1, 2, 3]:
    
    # get 3 participants who took part in this condition
    valid_ids = ids_in_condition(data, condition)
    ids = np.random.choice(valid_ids, N_EXAMPLES, replace=False)
    
    
    for col, exemplar_id in enumerate(ids):
        mr.plot_participant_data_space(ax[condition, col],
                                       (trace['logk'][:,exemplar_id], 
                                        trace['logs'][:,exemplar_id]),
                                       exemplar_id)
#         plot_data_space(exemplar_id, ax[condition, col], data,
#                         trace['logk'][:,exemplar_id], trace['logs'][:,exemplar_id])
        
fig.tight_layout()

# selectively remove x labels
for condition in [0, 1, 2]:
    for exemplar in [0, 1, 2]:
        ax[condition, exemplar].set(xlabel=None)
        
# selectively remove y labels
for condition in [0, 1, 2, 3]:
    for exemplar in [1, 2]:
        ax[condition, exemplar].set(ylabel=None)
        
if SHOULD_SAVE:
    plt.savefig(f'{out_dir}example_fits_experiment{expt}.pdf', bbox_inches='tight')