In [None]:
import numpy as np
import pandas as pd
import pymc3 as pm
from os.path import join, isfile
from os import listdir
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('ticks')
sns.set_context('paper')

import glambox as gb

%load_ext autoreload
%autoreload 2

In [None]:
# Simulate a pretend-to-be-collected dataset.
# In the real world, this dataset is collected from participants, so we do not know
# 1) if GLAM is an adequate model for the data
# 2) data-generating parameters

data_model = gb.GLAM()

parameters = dict(v=dict(mu=0.6, sd=0.25, bounds=(0, 1.5)),
                  gamma=dict(mu=0.1, sd=0.4, bounds=(-1, 1)),
                  s=dict(mu=0.25, sd=0.05, bounds=(0.05, 0.75)),
                  tau=dict(mu=1.0, sd=0.3, bounds=(0.1, 2)))

data_model.simulate_group(kind='hierarchical',
                          n_individuals=50,
                          n_trials=200,
                          n_items=3,
                          parameters=parameters,
                          value_range=(1, 10),
                          seed=1)

data = data_model.data

In [None]:
data.to_csv(join('examples', 'example_3', 'data', 'data.csv'), index=False)

In [None]:
glam = gb.GLAM(data=data)
glam.make_model(kind='individual')
glam.fit(method='MCMC', draws=5000, tune=5000, chains=4)

In [None]:
glam.predict(n_repeats=1)
synthetic = glam.prediction

In [None]:
synthetic.to_csv(join('examples', 'example_3', 'data', 'synthetic.csv'), index=False)

In [None]:
# For this synthetic dataset, we know the generating parameters:
true_parameters = {parameter: glam.estimates[parameter].values
                   for parameter in ['v', 'gamma', 's', 'tau']}

# Save these generating parameters
true_param_df = pd.DataFrame(true_parameters)
true_param_df['subject'] = range(50)
true_param_df.to_csv(join('examples', 'example_3', 'results', 'true_parameters.csv'), index=False)

In [None]:
gb.plots.plot_aggregate(data, line_data=[synthetic], line_labels=['Synthetic Data'])
plt.savefig(join('examples', 'example_3', 'figures', 'data-vs-synthetic.png'), dpi=330)

In [None]:
glam_rec = gb.GLAM(data=synthetic)
glam_rec.make_model(kind='individual')
glam_rec.fit(method='MCMC', draws=5000, tune=5000, chains=4)

In [None]:
glam_rec.estimates.to_csv(join('examples', 'example_3', 'results', 'glam_rec_estimates.csv'), index=False)

In [None]:
def plot_recovery_individual(model, generating_parameters,
                             parameters=['v', 'gamma', 's', 'tau'],
                             xlimits=dict(v=[-0.1, 0.1],
                                          gamma=[-0.25, 0.25],
                                          s=[-0.1, 0.1],
                                          tau=[-1, 1]),
                             figsize=gb.plots.cm2inch(18, 6),
                             fontsize=7):
    """
    Plot parameter recovery results from individually fitted models.
    
    Args:
        model: Fitted GLAM model of type 'individual'
        generating_parameters (dict): Dictionary of data generating parameters
        parameters (list, optional): List of parameters to include
        figsize (tuple, optional): Figure size
    
    
    """
    parameter_names = {'v': 'v',
                       'gamma': r'$\gamma$',
                       's': r'$\sigma$',
                       'tau': r'$\tau$'}

    n_individuals = len(generating_parameters[parameters[0]])
    n_parameters = len(parameters)

    # Construct long dataframe, 
    # every row is one parameter of one subject
    recovery = []
    for parameter in parameters:
        recovery_p = model.estimates[['subject', parameter, parameter + '_hpd_2.5', parameter + '_hpd_97.5']].copy()
        recovery_p.rename({parameter: 'recovered',
                           parameter + '_hpd_2.5': 'recovered_hpd_lower',
                           parameter + '_hpd_97.5': 'recovered_hpd_upper'},
                          axis=1, inplace=True)
        recovery_p['parameter'] = parameter
        recovery_p['generating'] = generating_parameters[parameter]
        recovery.append(recovery_p)
    recovery = pd.concat(recovery)
    recovery['success'] = ((recovery['generating'] > recovery['recovered_hpd_lower']) &
                           (recovery['generating'] < recovery['recovered_hpd_upper'])).values

    # Plot
    fig = plt.figure(figsize=figsize, dpi=330)
    axs = {}
    for p, parameter in enumerate(parameters):
        axs[(0, p)] = plt.subplot2grid((5, 4), (0, p), rowspan=1)
        axs[(1, p)] = plt.subplot2grid((5, 4), (2, p), rowspan=4, sharex=axs[(0, p)])

    for p, parameter in enumerate(parameters):
        parameter_df = recovery.loc[recovery['parameter'] == parameter]
    
        # Histogram of differences
        delta = (parameter_df['recovered'] - parameter_df['generating']).values
        axs[(0, p)].hist(delta,
                         color='black', alpha=0.3,
                         bins=np.linspace(*xlimits[parameter], 21))
        axs[(0, p)].axvline(0,
                            color='black', linewidth=0.5, alpha=0.7)
        axs[(0, p)].set_ylabel('Freq.',
                               fontsize=fontsize)
        for label in axs[(0, p)].get_xticklabels():
            label.set_visible(False)
    
        # Individual HPDs around true value
        ## Success Color coding
        color = np.array(['red', 'green'])[parameter_df['success'].values.astype(int)]
        
        ## Vertical, indicating zero difference
        axs[(1, p)].axvline(0, color='black', zorder=-1, linewidth=0.5, alpha=0.7)
        ## Difference posterior mean - generating
        axs[(1, p)].scatter(x=delta,
                   y=range(n_individuals),
                   color=color,
                   s=4,
                   marker='o', facecolor='white',
                   zorder=2)
        ## HPD        
        axs[(1, p)].hlines(y=range(n_individuals),
                           xmin=parameter_df['recovered_hpd_lower'].values - parameter_df['generating'].values,
                           xmax=parameter_df['recovered_hpd_upper'].values - parameter_df['generating'].values,
                           linewidth=0.5,
                           zorder=1,
                           color=color)

        ## Labels
        axs[(1, p)].set_xlabel(r'$\Delta$' + parameter_names[parameter], 
                      fontsize=fontsize)
        axs[(1, p)].set_ylabel('Participant',
                      fontsize=fontsize)
                
        ## Limits
        axs[(1, p)].set_xlim(*xlimits[parameter])

        ## Panel Labels
        from string import ascii_uppercase
        for label, ax in zip(list(ascii_uppercase),
                             [axs[i, p]
                              for i in [0, 1]
                              for p in range(len(parameters))]):
            ax.tick_params(axis='both', which='major', labelsize=fontsize)
            ax.text(-0.3, 1.05, label, transform=ax.transAxes,
                    fontsize=fontsize, fontweight='bold', va='top')
        sns.despine()
        fig.tight_layout(h_pad=-1)
        
        for ax in [axs[(1, p)]
                      for p in range(len(parameters))]:
            ax.set_yticks([])

    return fig, axs

In [None]:
fig, axs = plot_recovery_individual(glam_rec, true_parameters);
plt.savefig(join('examples', 'example_3', 'figures', 'deltaGenRec.png'), dpi=330)