# Parameter estimation and model comparison
This notebook will conduct parameter estimation for the following models:
- Hyperbolic discount function
- Modified Rachlin discount function

We then do model comparison, using the WAIC metric. We find evidence that the modified Rachlin model is suprior in terms of the WAIC metric. This metric not only takes 'goodness of fit' into account, but also model complexity. Because of this we have justification for using the modified Rachlin discount function above and beyond the hyperbolic discount function.

Proceeding with the modified Rachlin discount function, we export the (posterior mean) parameter estimates for conducting statistical testing on. We also visualise various aspects of the data.

In [None]:
# Install Black autoformatter with: pip install nb-black
%load_ext lab_black

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# data + modelling
import numpy as np
import pandas as pd
import pymc3 as pm
import os

# plotting
import seaborn as sns

%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from matplotlib import gridspec

plt.rcParams.update({"font.size": 14})

from models import (
    ModifiedRachlin,
    ModifiedRachlinFreeSlope,
    HyperbolicFreeSlope,
)

In [None]:
print(f"PyMC3 version: {pm.__version__}")

import arviz as az
az.__version__

Experiment specific information

NOTE: Set the `expt` variable to either 1 or 2 and run the notebook to do parameter estimation for that experiment.

In [None]:
expt = 1
data_file = f"data/processed/EXPERIMENT{expt}DATA.csv"

In [None]:
if expt is 1:
    group_name = ["Deferred, low", "Online, low", "Deferred, high", "Online, high"]
elif expt is 2:
    group_name = ["Deferred, gain", "Online, gain", "Deferred, loss", "Online, loss"]

Set up our options

In [None]:
# Initialize random number generator
SEED = 123
np.random.seed(SEED)

# Define sampler options
sample_options = {
    "tune": 2000,
    "draws": 5000,
    "chains": 2,
    "cores": 2,
    "nuts_kwargs": {"target_accept": 0.95},
    "random_seed": SEED,
}

# # less ambitious sampling for testing purposes
# sample_options = {'tune': 500, 'draws': 1000,
#                   'chains': 2, 'cores': 2, # 'nuts_kwargs': {'target_accept': 0.95},
#                   'random_seed': SEED}

In [None]:
SHOULD_SAVE = False

In [None]:
out_dir = "output"

# ensure output folder exists
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# ensure subfolders exist
for e in [1, 2]:
    desired = f"{out_dir}/expt{e}/"
    if not os.path.exists(desired):
        os.makedirs(desired)

# Import data

In [None]:
data = pd.read_csv(data_file, index_col=False)

In [None]:
data.head()

In [None]:
expt

# Parameter estimation

## Hyperbolic model

In [None]:
h_free = HyperbolicFreeSlope(data)
h_free.sample_from_posterior(sample_options)

Examine goodness of inferences

In [None]:
pm.energyplot(h_free.posterior_samples)

In [None]:
pm.forestplot(h_free.posterior_samples, var_names=["logk"], r_hat=True)

In [None]:
pm.forestplot(h_free.posterior_samples, var_names=["α"], r_hat=True)

## Modified Rachlin model

In [None]:
mr_free = ModifiedRachlinFreeSlope(data)
mr_free.sample_from_posterior(sample_options)

Examine goodness of inferences

In [None]:
pm.energyplot(mr_free.posterior_samples)

In [None]:
pm.forestplot(mr_free.posterior_samples, var_names=["logk"], r_hat=True)

In [None]:
pm.forestplot(mr_free.posterior_samples, var_names=["logs"], r_hat=True)

In [None]:
pm.forestplot(mr_free.posterior_samples, var_names=["α"], r_hat=True)

# Model comparison
PyMC3 is set up to do model comparison using WAIC. See https://docs.pymc.io/notebooks/model_comparison.html for more info.

In [None]:
hyperbolic_free_waic = pm.waic(h_free.posterior_samples, h_free.model)

In [None]:
free_waic = pm.waic(mr_free.posterior_samples, mr_free.model)

In [None]:
mr_free.model.name = "Modified Rachlin, free slope"
h_free.model.name = "Hyperbolic, free slope"

In [None]:
df_comp_WAIC = az.compare(
    {
        mr_free.model: mr_free.posterior_samples,
        h_free.model: h_free.posterior_samples,
    }
)
df_comp_WAIC

In [None]:
h_free.posterior_samples

In [None]:
model_dict = dict(
    zip(
        ["Hyperbolic", "Modified Rachlin"],
        [h_free.posterior_samples, mr_free.posterior_samples],
    )
)
comp = az.compare(model_dict)

In [None]:
#  ax = az.plot_compare(comp)

In [None]:
ax = az.plot_compare(comp)

ax.get_figure().savefig(
    f"{out_dir}/expt{expt}/expt{expt}_model_comparison.pdf", bbox_inches="tight"
)

Based on the model comparison we are going to proceed with Modified Rachlin model.

In [None]:
model = mr_free

del mr_free
del h_free

# Export parameter estimate table
First we define some functions to calculate measures derived from the model.

In [None]:
parameter_estimates = model.calc_results(expt)
parameter_estimates

In [None]:
if SHOULD_SAVE:
    parameter_estimates.to_csv(f'analysis/EXPERIMENT_{expt}_RESULTS.csv')

# Visualisation

## Group level

In [None]:
pm.forestplot(
    model.posterior_samples, var_names=["group_logk", "group_logs"], r_hat=True
)

## Visualise posterior predictions for each group

In [None]:
for group, name in enumerate(group_name):
    model.group_plot(group)
    if SHOULD_SAVE:
        plt.savefig(f'{out_dir}/expt{expt}/expt{expt}_{name}.pdf', bbox_inches='tight')

In [None]:
# trace = model.posterior_samples

# fig, ax = plt.subplots(1, 1, figsize=(8,8))

# for i in range(4):
#     logk = trace['group_logk'][:,i]
#     logs = trace['group_logs'][:,i]
#     ax.scatter(logk, logs, alpha=0.1, label=group_name[i])
    
# leg = ax.legend()

# for lh in leg.legendHandles: 
#     lh.set_alpha(1)
    
# ax.set(xlabel='logk', ylabel='logs', title='parameter space')

# if SHOULD_SAVE:
#     plt.savefig(f'{out_dir}expt{expt}/group_param_space.pdf', bbox_inches='tight')

## Visualise group mean parameter values

In [None]:
# model = mr_free
trace = model.posterior_samples

In [None]:
plt.rcParams.update({"font.size": 14})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))

for i in range(4):
    logk = trace["mu_logk"][:, i]
    logs = trace["mu_logs"][:, i]
    s = np.exp(logs)
    ax.scatter(logk, s, alpha=0.1, label=group_name[i])

leg = ax.legend()

for lh in leg.legendHandles:
    lh.set_alpha(1)

ax.set(xlabel=r"$\log(k)$", ylabel=r"$s$", title=f"Experiment {expt}")

if SHOULD_SAVE:
    plt.savefig(
        f"{out_dir}/expt{expt}/expt{expt}_group_mean_estimates_in_param_space.pdf",
        bbox_inches="tight",
    )

Create joint plot

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))

cols = ["Reds", "Blues", "Greens", "Purples"]

for i in [0, 1, 2, 3]:
    x = trace["mu_logk"][:, i]
    y = np.exp(trace["mu_logs"][:, i])
    sns.kdeplot(x, y, ax=ax, cmap=cols[i], shade=True, shade_lowest=False, cbar=False)

ax.set(xlabel="$\log(k)$", ylabel="$s$")

ax.axhline(y=1, c="k", lw=1)

savename = f"{out_dir}/expt{expt}_group_means_contour.pdf"
plt.savefig(savename, bbox_inches="tight")

Looks like I'll have to do something more manual for what I want

In [None]:
expt

In [None]:
import scipy.stats as stats

if expt is 1:
    xmin, xmax = -5, -2.5
    ymin, ymax = 0.5, 2.5
elif expt is 2:
    xmin, xmax = -6, -2.5
    ymin, ymax = 0.5, 3.0


def density_estimation(m1, m2):
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    values = np.vstack([m1, m2])
    kernel = stats.gaussian_kde(values)
    Z = np.reshape(kernel(positions).T, X.shape)
    return X, Y, Z

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))

linestyles = ["solid", "dashed", "solid", "dashed"]
linewidths = [2, 2, 4, 4]

# create proxy lines to get legend working properly. Can't add legends to contour plots
import matplotlib.lines as mlines

proxy_lines = [
    mlines.Line2D(
        [],
        [],
        color="k",
        marker=None,
        lw=linewidths[0],
        linestyle=linestyles[0],
        label=group_name[0],
    ),
    mlines.Line2D(
        [],
        [],
        color="k",
        marker=None,
        lw=linewidths[1],
        linestyle=linestyles[1],
        label=group_name[1],
    ),
    mlines.Line2D(
        [],
        [],
        color="k",
        marker=None,
        lw=linewidths[2],
        linestyle=linestyles[2],
        label=group_name[2],
    ),
    mlines.Line2D(
        [],
        [],
        color="k",
        marker=None,
        lw=linewidths[3],
        linestyle=linestyles[3],
        label=group_name[3],
    ),
]

for i in [0, 1, 2, 3]:
    x = trace["mu_logk"][:, i]
    y = np.exp(trace["mu_logs"][:, i])

    # convert scatter data into x, y, z for contour plotting
    X, Y, Z = density_estimation(x, y)
    Z = Z / np.max(Z)

    ax.contour(
        X, Y, Z, [0.05], colors="k", linewidths=linewidths[i], linestyles=linestyles[i],
    )

ax.legend(handles=proxy_lines, loc="upper left")
ax.set(xlabel="$\ln(k)$", ylabel="$s$")
ax.axhline(y=1, c="k", lw=1)

savename = f"{out_dir}/expt{expt}_group_means_contourBW.pdf"
plt.savefig(savename, bbox_inches="tight")

Additional plots. First get the data into long format.

In [None]:
print(f"Experiment: {expt}\n")
[print(group_name[i]) for i in [0, 1, 2, 3]]

In [None]:
def get_long_format_data(trace, expt):
    # concatenate
    logk = np.concatenate(
        (
            trace["mu_logk"][:, 0],
            trace["mu_logk"][:, 1],
            trace["mu_logk"][:, 2],
            trace["mu_logk"][:, 3],
        )
    )

    s = np.concatenate(
        (
            np.exp(trace["mu_logs"][:, 0]),
            np.exp(trace["mu_logs"][:, 1]),
            np.exp(trace["mu_logs"][:, 2]),
            np.exp(trace["mu_logs"][:, 3]),
        )
    )

    if expt is 1:

        condition = ["Deferred", "Online", "Deferred", "Online"]
        condition = np.repeat(condition, 10000)

        magnitude = ["Low", "Low", "High", "High"]
        magnitude = np.repeat(magnitude, 10000)

        df = pd.DataFrame(
            {"logk": logk, "s": s, "Condition": condition, "Magnitude": magnitude}
        )

    elif expt is 2:

        condition = ["Deferred", "Online", "Deferred", "Online"]
        condition = np.repeat(condition, 10000)

        domain = ["Gain", "Gain", "Loss", "Loss"]
        domain = np.repeat(domain, 10000)

        df = pd.DataFrame(
            {"logk": logk, "s": s, "Condition": condition, "Domain": domain}
        )

    return df

In [None]:
df = get_long_format_data(trace, expt)
df.head()

Plot

In [None]:
# fill colour palete
my_pal = {"Deferred": [1, 1, 1], "Online": [0.75, 0.75, 0.75]}

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

if expt is 1:

    sns.violinplot(
        x="Magnitude",
        y="logk",
        hue="Condition",
        data=df,
        palette=my_pal,
        ax=ax[0],
        split=False,
        inner=None,
    )

    sns.violinplot(
        x="Magnitude",
        y="s",
        hue="Condition",
        data=df,
        palette=my_pal,
        ax=ax[1],
        split=False,
        inner=None,
    )

    ax[1].axhline(y=1, c="k", lw=1)

elif expt is 2:

    sns.violinplot(
        x="Domain",
        y="logk",
        hue="Condition",
        data=df,
        palette=my_pal,
        ax=ax[0],
        split=False,
        inner=None,
    )

    sns.violinplot(
        x="Domain",
        y="s",
        hue="Condition",
        data=df,
        palette=my_pal,
        ax=ax[1],
        split=False,
        inner=None,
    )

    ax[1].axhline(y=1, c="k", lw=1)


savename = f"{out_dir}/expt{expt}_group_means.pdf"
plt.savefig(savename, bbox_inches="tight")

## Participant level plots

Do one example

In [None]:
model.participant_plot(0)

In [None]:
n_participants = len(data.id.unique())
n_participants

🔥 Export all participant level plots. This takes a while to do. 🔥 

In [None]:
if SHOULD_SAVE:
    for id in range(n_participants):
        print(f'{id} of {n_participants}')
        model.participant_plot(id)

        savename = f'{out_dir}/expt{expt}/id{id}_expt{expt}.pdf'
        plt.savefig(savename, bbox_inches='tight')

        # Close the figure to avoid very heavy plotting inside the notebook
        plt.close(plt.gcf())

## Demo figure
We are going to plot example data + parameter estimates for each condition (row) and a number of randomly chosen participants in each column.

In [None]:
def ids_in_condition(data, condition):
    '''Return a list of id's in this condition'''
    return data[data['condition'] == condition].id.unique()

In [None]:
plt.rcParams.update({'font.size': 14})

N_CONDITIONS = 4
N_EXAMPLES = 3  # number of columns

fig, ax = plt.subplots(N_CONDITIONS, N_EXAMPLES, figsize=(15, 13))

# Ording of these is crucial... see the data import notebook for the key
if expt is 1:
    row_headings = ['Deferred, low',
                    'Online, low',
                    'Deferred, high',  
                    'Online, high']
elif expt is 2:
    row_headings = ['Deferred, gain',
                    'Online, gain',
                    'Deferred, loss', 
                    'Online, loss']
                
pad = 13 # in points
for axis, row_title in zip(ax[:,0], row_headings):
    axis.annotate(row_title, xy=(0, 0.5), xytext=(-axis.yaxis.labelpad - pad, 0),
                  xycoords=axis.yaxis.label, textcoords='offset points',
                  size='large', ha='center', va='center', rotation=90)
    
fig.tight_layout()

# plot stuff
for condition in [0, 1, 2, 3]:
    
    # get 3 participants who took part in this condition
    valid_ids = ids_in_condition(data, condition)
    ids = np.random.choice(valid_ids, N_EXAMPLES, replace=False)
    
    
    for col, exemplar_id in enumerate(ids):
        model.plot_participant_data_space(ax[condition, col],
                                       (trace['logk'][:,exemplar_id], 
                                        trace['logs'][:,exemplar_id]),
                                       exemplar_id)
        # remove title
        ax[condition, col].set_title("")
        
#         plot_data_space(exemplar_id, ax[condition, col], data,
#                         trace['logk'][:,exemplar_id], trace['logs'][:,exemplar_id])
        
fig.tight_layout()

# selectively remove x labels
for condition in [0, 1, 2]:
    for exemplar in [0, 1, 2]:
        ax[condition, exemplar].set(xlabel=None)
        
# selectively remove y labels
for condition in [0, 1, 2, 3]:
    for exemplar in [1, 2]:
        ax[condition, exemplar].set(ylabel=None)
        
if SHOULD_SAVE:
    plt.savefig(f'{out_dir}/example_fits_experiment{expt}.pdf', bbox_inches='tight')