# Parameter estimation
Running this notebook will to the parameter estimation for an experiment. Set the `expt` parameter to decide which experiment to analyse.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# data + modelling
import numpy as np
import numpy.matlib  # for repmat, used in calc_log_loss()
import pandas as pd
import pymc3 as pm
import math
import os
from sklearn.metrics import log_loss
import random

# plotting
import seaborn as sns
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from matplotlib import gridspec
plt.rcParams.update({'font.size': 14})

from plotting import plot_data

In [None]:
# Initialize random number generator
SEED = 123
np.random.seed(SEED)

# Define sampler options
sample_options = {'tune': 2000, 'draws': 5000,
                  'chains': 4, 'cores': 4,
                  'nuts_kwargs': {'target_accept': 0.95},
                  'random_seed': SEED}

In [None]:
SHOULD_SAVE = False

In [None]:
out_dir = 'output/'

# ensure output folder exists
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    
# ensure subfolders exist
for expt in [1, 2]:
    desired = f'{out_dir}expt{expt}/'
    if not os.path.exists(desired):
        os.makedirs(desired)

Experiment specific information

NOTE: Set the `expt` variable to either 1 or 2 and run the notebook to do parameter estimation for that experiment.

In [None]:
expt = 2
data_file = f'raw_data_expt{expt}/EXPERIMENT{expt}DATA.csv'

## Import data

In [None]:
data = pd.read_csv(data_file, index_col=False)

In [None]:
data.head()

### Create a participant/group lookup table
This is needed so that we can look up the group (ie condition) that a participant belongs to. Want an array with one entry for each participant, the value of which is the condition number.

In [None]:
temp = np.array([data['id'].values, data['condition'].values]).T
temp = np.unique(temp, axis=0)
group = temp[:,1]
group

In [None]:
n_groups = np.max(group)+1
n_groups

## Grab values from dataframe

In [None]:
n_participants = max(data.id)+1
n_participants

In [None]:
RA = data['RA'].values
RB = data['RB'].values
DA = data['DA'].values
DB = data['DB'].values
R = data['R'].values
p = data['id'].values

## Define PyMC3 model

In [None]:
def V(reward, delay, logk, logs):
    '''Calculate the present subjective value of a given prospect'''
    k = pm.math.exp(logk)
    s = pm.math.exp(logs)
    return reward * discount_function(delay, k, s)


def discount_function(delay, k, s):
    ''' This is the MODIFIED Rachlin discount function. This is outlined
    in Vincent & Stewart (2018).
    Vincent, B. T., & Stewart, N. (2018, October 16). The case of muddled
    units in temporal discounting. https://doi.org/10.31234/osf.io/29sgd
    '''
    return 1 / (1.0+(k*delay)**s)


def Φ(VA, VB, ϵ=0.01):
    '''Psychometric function which converts the decision variable (VB-VA)
    into a reponse probability. Output corresponds to probability of choosing
    the delayed reward (option B).'''
    return ϵ + (1.0-2.0*ϵ) * (1/(1+pm.math.exp(-1.7*(VB-VA))))

In [None]:
# Hierarchical model with groups
# Different (k, s) parameters for each participant
# Each participant also comes from a group

g = [0, 1, 2, 3]

with pm.Model() as group_model:
    '''Hierachical model with trials, participants, and groups. 
    Different (k,s) parameters for each participant. 
    Each participant comes from one of 4 groups.
    
    Observed data:
    - RA, DA, RB, DB, R: trial level data
    - group: list of group membership for each participant 
    - g: equals [0, 1, 2, 3] just used for group level inferences about (logk, logs)
    '''
    
    # Hyperpriors 
    mu_logk = pm.Normal('mu_logk', mu=math.log(1/30), sd=2, shape=n_groups)
    sigma_logk = pm.Exponential('sigma_logk', 10, shape=n_groups)
    
    mu_logs = pm.Normal('mu_logs', mu=0, sd=0.5, shape=n_groups)
    sigma_logs = pm.Exponential('sigma_logs', 20, shape=n_groups)
    
    # Priors over parameters for each participant 
    logk = pm.Normal('logk', mu=mu_logk[group], sd=sigma_logk[group], shape=n_participants) 
    logs = pm.Normal('logs', mu=mu_logs[group], sd=sigma_logs[group], shape=n_participants) 
    
    # group level inferences, unattached from the data
    group_logk = pm.Normal('group_logk', mu=mu_logk[g], sd=sigma_logk[g], shape=4) 
    group_logs = pm.Normal('group_logs', mu=mu_logs[g], sd=sigma_logs[g], shape=4)
    
    # Choice function: psychometric
    P = pm.Deterministic('P', Φ(V(RA, DA, logk[p], logs[p]),
                                V(RB, DB, logk[p], logs[p])) )
    
    # Likelihood of observations
    R = pm.Bernoulli('R', p=P, observed=R)

    
pm.model_to_graphviz(group_model)

## Sample from prior

In [None]:
with group_model:
    prior = pm.sample_prior_predictive(10_000)
    
logk = prior['group_logk']
logk = logk.flatten()

logs = prior['group_logs']
logs = logs.flatten()

In [None]:
print(f'logk mean = {np.mean(logk)}, variance = {np.var(logk)}')
print(f'logs mean = {np.mean(logs)}, variance = {np.var(logs)}')

In [None]:
math.log(1/30)

Generate a figure to demonstrate the priors

In [None]:
fig = plt.figure(figsize=(10, 15))
gs = gridspec.GridSpec(3, 2)

# log
ax = fig.add_subplot(gs[0,0])
sns.distplot(logk, ax=ax)
ax.set(xlabel=r'$\ln(k)$', ylabel='prior density', title="(a)")

ax = fig.add_subplot(gs[0,1])
sns.distplot(logs, ax=ax)
ax.set(xlabel=r'$\ln(s)$', ylabel='prior density', title="(b)")

# plot discount functions, sampled from prior
ax = fig.add_subplot(gs[1:2,:])

n_samples_to_plot = 500
delays = np.linspace(0, 101, 500)
for n in range(n_samples_to_plot):
    ax.plot(delays, discount_function(delays, np.exp(logk[n]), np.exp(logs[n])),
            c='k', alpha=0.1)

ax.set(xlabel="delay [seconds]", 
       ylabel='discount fraction\n$1/(1+(k \cdot delay)^s)$', 
       title="(c)",
       xlim=[0, 101],
       ylim=[0, 1])

fig.tight_layout()

if SHOULD_SAVE:
    plt.savefig(f'{out_dir}priors.pdf', bbox_inches='tight')

## Sample from posterior

In [None]:
with group_model:
    trace = pm.sample(**sample_options)

## Diagnostics
Check the posterior is good

In [None]:
pm.energyplot(trace)

## Export parameter estimate table
First we define some functions to calculate measures derived from the model.

In [None]:
def calc_AUC(logk, logs, max_delay=101):
    '''Calculate Area Under Curve measure'''
    delays = np.linspace(0, max_delay, 500)
    df = discount_function(delays, np.exp(logk), np.exp(logs))
    normalised_delays = delays / np.max(delays)
    AUC = np.trapz(df, x=normalised_delays)
    return AUC


def calc_percent_predicted(R_predicted_prob, R_actual):
    nresponses = R_actual.shape[0]
    predicted_responses = np.where(R_predicted_prob>0.5, 1, 0)
    n_correct = sum(np.equal(predicted_responses, R_actual))
    return  n_correct / nresponses


def calc_log_loss(R_predicted_prob, R_actual):
    return log_loss(R_actual, R_predicted_prob)

In [None]:
def make_rowdata(id, logk, logs, pdata, Ractual, Ppredicted):
    logk_point_estimate = np.mean(logk)
    logs_point_estimate = np.mean(logs)
    if expt is 1:
        rowdata = {'id': [id],
                   'PID': pdata['Participant'].reset_index(drop=True)[0],
                   'logk': [logk_point_estimate], 
                   'logs': [logs_point_estimate], 
                   'paradigm': [pdata['paradigm'].values[0]], 
                   'reward_mag': [pdata['reward_mag'].values[0]], 
                   'AUC': calc_AUC(logk_point_estimate, logs_point_estimate), 
                   'percent_predicted': calc_percent_predicted(np.median(Ppredicted, axis=0), Ractual),
                   'log_loss': calc_log_loss(np.median(Ppredicted, axis=0), Ractual)}
    elif expt is 2:
        rowdata = {'id': [id],
                   'PID': pdata['Participant'].reset_index(drop=True)[0],
                   'logk': [logk_point_estimate], 
                   'logs': [np.mean(logs)], 
                   'paradigm': [pdata['paradigm'].values[0]], 
                   'domain': [pdata['domain'].values[0]], 
                   'AUC': calc_AUC(logk_point_estimate, logs_point_estimate), 
                   'percent_predicted': calc_percent_predicted(np.median(Ppredicted, axis=0), Ractual),
                   'log_loss': calc_log_loss(np.median(Ppredicted, axis=0), Ractual)}
    return pd.DataFrame.from_dict(rowdata)
     

rows = []
for id in range(n_participants):
    logk = trace['logk'][:,id]
    logs = trace['logs'][:,id]
    P_chooseB = trace['P'][:,id]

    pdata = data.loc[data['id'] == id]

    Ppredicted = trace.P[:, data['id'] == id]
    Ractual = pdata['R'].values

    rowdata = make_rowdata(id, logk, logs, pdata, Ractual, Ppredicted)
    rows.append(rowdata)
    # print(f'{id+1} of {n_participants}')


parameter_estimates = pd.concat(rows, ignore_index=True)

if SHOULD_SAVE:
    parameter_estimates.to_csv(f'analysis/EXPERIMENT_{expt}_RESULTS.csv')

## Visualisation

In [None]:
pm.forestplot(trace, varnames=['logk'])

In [None]:
pm.forestplot(trace, varnames=['logs'])

## Group level

In [None]:
pm.forestplot(trace, varnames=['group_logk', 'group_logs'])

### Visualise posterior predictions for each group

In [None]:
def group_plot(i, trace, data, n_samples_to_plot=100, max_delay=365):
    '''Plot information about a participant. 
    Posterior inferences in parameter space.
    Data and posterior predictive checking in data space.'''
    
    logk = trace['group_logk'][:,i]
    logs = trace['group_logs'][:,i]

    delays = np.linspace(0, max_delay, 1000)

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))

    ax[0].scatter(logk, logs, alpha=0.1)
    ax[0].set(xlabel='logk', ylabel='logs', title='parameter space')

    # plot discount functions, sampled from the posterior
    for n in range(n_samples_to_plot):
        ax[1].plot(delays, discount_function(delays, np.exp(logk[n]), np.exp(logs[n])),
                   c='k', alpha=0.1)

    ax[1].set(xlabel='delay [seconds]', ylabel='RA/RB', title='data space');
    
for group in range(4):
    group_plot(group, trace, data)
    plt.savefig(f'{out_dir}expt{expt}/group{group}.pdf', bbox_inches='tight')

In [None]:
if expt is 1:
    group_name = ['Deferred, low',
                  'Online, low',
                  'Deferred, high',  
                  'Online, high']
elif expt is 2:
    group_name = ['Deferred, gain',
                  'Online, gain',
                  'Deferred, loss', 
                  'Online, loss']

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))

for i in range(4):
    logk = trace['group_logk'][:,i]
    logs = trace['group_logs'][:,i]
    ax.scatter(logk, logs, alpha=0.01, label=group_name[i])
    
leg = ax.legend()

for lh in leg.legendHandles: 
    lh.set_alpha(1)
    
ax.set(xlabel='logk', ylabel='logs', title='parameter space')

if SHOULD_SAVE:
    plt.savefig(f'{out_dir}expt{expt}/group_param_space.pdf', bbox_inches='tight')

## Visualise group mean parameter values

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))

for i in range(4):
    logk = trace['mu_logk'][:,i]
    logs = trace['mu_logs'][:,i]
    ax.scatter(logk, logs, alpha=0.01, label=group_name[i])
    
leg = ax.legend()

for lh in leg.legendHandles: 
    lh.set_alpha(1)

ax.set(xlabel='logk', ylabel='logs', title=f'Experiment {expt}')

if SHOULD_SAVE:
    plt.savefig(f'{out_dir}expt{expt}/group_mean_estimates_in_param_space.pdf', bbox_inches='tight')

In [None]:
id=1
pdata = data.loc[data['id'] == id]
pdata['RB'].values[0]

## Participant level plots

In [None]:
def participant_plot(id, trace, data, n_samples_to_plot=100, legend=True):
    '''Plot information about a participant. 
    Posterior inferences in parameter space.
    Data and posterior predictive checking in data space.'''
    
    logk = trace['logk'][:,id]
    logs = trace['logs'][:,id]

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))

    # PARAMETER SPACE ==================================
    ax[0].scatter(logk, logs, alpha=0.1)
    ax[0].set(xlabel='logk', ylabel='logs', title='parameter space')

    # DATA SPACE =======================================
    plot_data_space(id, ax[1], data, logk, logs)

In [None]:
def plot_data_space(id, ax, data, logk, logs, n_samples_to_plot=50):
    
    # plot the data
    pdata = data.loc[data['id'] == id]
    plot_data(pdata, ax, legend=False)
    
    # plot discount functions
    max_delay = np.max(pdata['DB'].values) * 1.1
    delays = np.linspace(0, max_delay, 1000)
    
    # plot discount functions, sampled from the posterior
    for n in range(n_samples_to_plot):
        RB = pdata['RB'].values[0]
        ax.plot(delays, 
                RB*discount_function(delays, np.exp(logk[n]), np.exp(logs[n])),
                c='k', alpha=0.1)
    # plot median discount rate
    ax.plot(delays, 
                RB*discount_function(delays, np.exp(np.median(logk[n])), np.exp(np.median(logs[n]))),
                c='k', linewidth=3)
    
    # plot participant id info text
    if pdata['RB'].values[0] > 0:
        text_y = 1.
    elif pdata['RB'].values[0] < 0:
        text_y = -1.
        
    ax.text(2, text_y, f'participant id: {id}',
         horizontalalignment='left',
         verticalalignment='center', #transform = ax.transAxes,
         fontsize=10)
        
    ax.set(xlabel='delay [seconds]', 
           ylabel='immediate reward [cents]')
    ax.set_xlim(left=0)

In [None]:
participant_plot(0, trace, data)

🔥 Export all participant level plots. This takes a while to do. 🔥 

In [None]:
if SHOULD_SAVE:
    for id in range(n_participants):
        print(f'{id} of {n_participants}')
        participant_plot(id, trace, data, legend=False)


        savename = f'{out_dir}expt{expt}/id{id}.pdf'
        plt.savefig(savename, bbox_inches='tight')

        # Close the figure to avoid very heavy plotting inside the notebook
        plt.close(plt.gcf())

## Demo figure
We are going to plot example data + parameter estimates for each condition (row) and a number of randomly chosen participants in each column.

In [None]:
def ids_in_condition(data, condition):
    '''Return a list of id's in this condition'''
    return data[data['condition'] == condition].id.unique()

In [None]:
plt.rcParams.update({'font.size': 14})

N_EXAMPLES = 3  # number of columns

fig, ax = plt.subplots(4, N_EXAMPLES, figsize=(15, 13))

# Ording of these is crucial... see the data import notebook for the key
if expt is 1:
    row_headings = ['Deferred, low',
                    'Online, low',
                    'Deferred, high',  
                    'Online, high']
elif expt is 2:
    row_headings = ['Deferred, gain',
                    'Online, gain',
                    'Deferred, loss', 
                    'Online, loss']
                
pad = 13 # in points
for axis, row_title in zip(ax[:,0], row_headings):
    axis.annotate(row_title, xy=(0, 0.5), xytext=(-axis.yaxis.labelpad - pad, 0),
                  xycoords=axis.yaxis.label, textcoords='offset points',
                  size='large', ha='center', va='center', rotation=90)
    
fig.tight_layout()

# plot stuff
for condition in [0, 1, 2, 3]:
    
    # get 3 participants who took part in this condition
    valid_ids = ids_in_condition(data, condition)
    ids = np.random.choice(valid_ids, N_EXAMPLES)
    
    
    for col, exemplar_id in enumerate(ids):        
        plot_data_space(exemplar_id, ax[condition, col], data,
                        trace['logk'][:,exemplar_id], trace['logs'][:,exemplar_id])
        
fig.tight_layout()

# selectively remove x labels
for condition in [0, 1, 2]:
    for exemplar in [0, 1, 2]:
        ax[condition, exemplar].set(xlabel=None)
        
# selectively remove y labels
for condition in [0, 1, 2, 3]:
    for exemplar in [1, 2]:
        ax[condition, exemplar].set(ylabel=None)
        
if SHOULD_SAVE:
    plt.savefig(f'{out_dir}example_fits_experiment{expt}.pdf', bbox_inches='tight')