# Scientific Workflow with HSSM

<center> <img src="./images/HSSM_logo.png"> </center>

Welcome to the scientific workflow tutorial. This tutorial starts with a basic experimental dataset and we will inch from a very simple HSSM model iteratively toward a model that captures many of the main patterns we can identify in our dataset.

Along the way we try to achieve the following balance:

1. Illustrate how HSSM can be used for real scientific workflows. HSSM helps us with model building, running the stats, and reporting results.
2. Allow this tutorial to be used as a **first** look into HSSM, shirking conceptually advanced features that are discused in the many dedicated tutorials you can find on the [documentation](https://lnccbrown.github.io/HSSM/)

## Colab Instructions

If you would like to run this tutorial on Google colab, please click this [link](https://colab.research.google.com/github/lnccbrown/HSSM/blob/drop-more-notebooks-from-execute/carney_workshop/discovery_journey.ipynb). 

Once you are *in the colab*:

1. Follow the **installation instructions below**  (uncomment the respective code)
2.  **restart your runtime**. 

**NOTE**:

You may want to *switch your runtime* to have a GPU or TPU. To do so, go to *Runtime* > *Change runtime type* and select the desired hardware accelerator.
Note that if you switch your runtime you have to follow the installation instructions again.

##### Install hssm

In [1]:
# If running this on Colab, please uncomment the next line
# !pip install hssm

##### Download tutorial data

In [2]:
# # Data Files
# !wget -P  data/carney_workshop_2025_data/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/data/carney_workshop_2025_data/carney_workshop_2025_full.parquet
# !wget -P  data/carney_workshop_2025_data/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/data/carney_workshop_2025_data/carney_workshop_2025_modeling.parquet
# !wget -P  data/carney_workshop_2025_data/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/data/carney_workshop_2025_data/carney_workshop_2025_parameters.pkl

# # Presampled traces
# !wget -P  idata/basic_ddm/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/basic_ddm/traces.nc
# !wget -P  idata/ddm_hier/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/ddm_hier/traces.nc
# !wget -P  idata/angle_hier/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/angle_hier/traces.nc
# !wget -P  idata/angle_hier_v2/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/angle_hier_v2/traces.nc
# !wget -P  idata/angle_hier_v3/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/angle_hier_v3/traces.nc
# !wget -P  idata/angle_hier_v4/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/angle_hier_v4/traces.nc
# !wget -P  idata/angle_v5/ https://raw.githubusercontent.com/lnccbrown/HSSM/drop-more-notebooks-from-execute/carney_workshop/idata/angle_v5/traces.nc

## Start of Tutorial

#### Load modules

In [3]:
import hssm
import pandas as pd
import pickle
import numpy as np
import arviz as az
from matplotlib import pyplot as plt

#### Load workshop data

In [4]:
def load_data(filename_base: str,
              folder: str = "data") -> tuple[pd.DataFrame, pd.DataFrame, dict]:
    """Load saved simulation data and parameters from files.

    Parameters
    ----------
    filename_base : str
        Base filename used when saving files
    folder : str, optional
        Folder containing saved files, by default "data"

    Returns
    -------
    tuple[pd.DataFrame, pd.DataFrame, dict]
        Contains:
        - DataFrame with modeling data
        - DataFrame with full data  
        - Dict containing group and subject parameters
    """
    df_modeling = pd.read_parquet(f"{folder}/{filename_base}_modeling.parquet")
    return df_modeling

In [5]:
workshop_data  = load_data(filename_base = "carney_workshop_2025",
                           folder = "data/carney_workshop_2025_data")

#### Load Plotting Utilities

In [6]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def plot_rt_by_choice(df: pd.DataFrame,
                      categorical_column: str | None = None,
                      colors: dict[str, str] | dict[int, str] | None = None,
                      ax: plt.Axes | None = None):
    if categorical_column is None:
        ax.hist(df['rt'] * df['response'],
                bins = np.linspace(-5,5, 50), 
                # label=f'Condition {cond}', 
                histtype='step',
                density = True,
                color='tab:blue')
    else:
        for cond in df[categorical_column].unique():
            df_cond = df[df[categorical_column] == cond]
            ax.hist(df_cond['rt'] * df_cond['response'],
                    bins = np.linspace(-5,5, 50), 
                    label=f'Condition {cond}', 
                    histtype='step',
                    density = True,
                    color=colors[cond])
        ax.set_xlabel('RT * Choice')
        ax.set_ylabel('Density')
    return ax
        
def inset_bar_plot(df: pd.DataFrame, 
                   categorical_column: str,
                   response_options: list[int],
                   colors: dict[str, str] | dict[int, str] | None = None,
                   ax: plt.Axes | None = None):
    
    axins = inset_axes(ax, 
                       width="35%",
                       height="35%",
                       loc='upper left',
                       borderpad=2.75)
    bar_width = 0.55
    for j, resp in enumerate(response_options):
        for k, cond in enumerate(df[categorical_column].unique()):
            k_displace = -1 if k == 0 else 1
            df_cond = df[df[categorical_column] == cond]
            prop = (df_cond[df_cond.response == resp].shape[0] / len(df_cond))
            axins.bar((resp + ((bar_width / 2) * k_displace)), 
                        prop,
                        width=bar_width,
                        fill = False,
                        edgecolor=colors[cond],
                        label=f'Response {resp}')
    axins.set_xticks(response_options)
    axins.set_ylim(0, 1)
    axins.set_yticks([0.0, 0.5, 1])
    axins.set_title('choice proportion / option', fontsize=8)
    axins.tick_params(axis='both', which='major', labelsize=7)
    axins.set_xlabel('')
    axins.set_ylabel('')
    return ax

def inset_bar_plot_vertical(df: pd.DataFrame,
                            categorical_column: str,
                            response_options: list[int],
                            colors: dict[str, str] | dict[int, str] | None = None,
                            ax: plt.Axes | None = None):
    
    axins = inset_axes(ax,
                       width="35%",
                       height="35%",
                       loc='upper left',
                       borderpad=2.25)
    bar_width = 0.55
    for j, resp in enumerate(response_options):
        # k_displace_dict = {0:}
        for k, cond in enumerate(df[categorical_column].unique()):
            k_displace = -1 if k == 0 else 1
            df_cond = df[df[categorical_column] == cond]
            rt_mean = (df_cond[df_cond.response == resp]).rt.mean()
            axins.barh((resp + ((bar_width / 2) * k_displace)), 
                       rt_mean,
                       height=bar_width,
                       fill = False,
                       edgecolor=colors[cond],
                       label=f'Response {resp}')

    axins.set_yticks(response_options)
    axins.set_xticks([0.0, 1., 2.])
    axins.set_title('rt-mean by choice option', fontsize=8)
    axins.tick_params(axis='both', which='major', labelsize=7)
    axins.set_xlabel('')
    axins.set_ylabel('')
    return ax

def plot_rt_hists(df: pd.DataFrame,
                  by_participant: bool = True,
                  split_by_column: str | None = None,
                  inset_plot: str | None = "choice proportion",
                  cols: int = 5):
    if split_by_column is not None:
        colors = {cond: color for cond, color in zip(df[split_by_column].unique(), 
                            ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple'])}
    else:
        colors = None

    if by_participant:
        # Get unique participant IDs and costly_fail_conditions
        participants = df['participant_id'].unique()

        # Set up subplot grid (adjust cols as needed)
        rows = (len(participants) + cols - 1) // cols
        fig, axes = plt.subplots(rows, cols, 
                                 figsize=(cols*4, rows*3), 
                                 sharey=True, sharex=True)
        axes = axes.flatten()
        for i, pid in enumerate(participants):
            ax = axes[i]
            df_part = df[df['participant_id'] == pid]
            ax = plot_rt_by_choice(df_part,
                                   split_by_column,
                                   colors,
                                   ax)
            
            # Take care of inset plots
            if inset_plot == "choice_proportion":
                ax = inset_bar_plot(df_part, 
                                    split_by_column,
                                    df['response'].unique(),
                                    colors,
                                    ax)
            elif inset_plot == "rt_mean":
                ax = inset_bar_plot_vertical(df_part, 
                                             split_by_column,
                                             df['response'].unique(),
                                             colors,
                                             ax)
            if i == 0:
                ax.legend(title=split_by_column, loc='best', fontsize='small')

        # Hide unused axes
        for j in range(i+1, len(axes)):
            axes[j].set_visible(False)

        plt.tight_layout()
        plt.suptitle('RT, Split by Costly Fail Condition and Participant', y=1.02)
        plt.show()
    else:
        fig, ax = plt.subplots(1, 1, figsize=(4, 3))
        ax = plot_rt_by_choice(df,
                               split_by_column,
                               colors,
                               ax)

        # Take care of inset plots
        if inset_plot == "choice_proportion":
            ax = inset_bar_plot(df,
                                split_by_column,
                                df['response'].unique(),
                                colors,
                                ax)
        elif inset_plot == "rt_mean":
            ax = inset_bar_plot_vertical(df,
                                         split_by_column,
                                         df['response'].unique(),
                                         colors,
                                         ax)
        
        ax.legend(title=split_by_column, loc='best', fontsize='small')
        plt.tight_layout()
        plt.suptitle('RT by Trial, Split by Costly Fail Condition', y=1.02)
        plt.show()

## Exploratory Data Analysis

<center> <img src="images/Experiment.png" height=300 width=700> </center>


Now that we are done preparing the setup, let's get to the meat of it! The picture above gives us a bit of an idea, where the dataset that we are going to work with below comes from (alert: the backstory may or may not be real). 

**20** subjects, performed **250** trials each of a basic *Random dot motion* task. The task seemingly had two important manipulations. 

1. A **costly fail** condition, in which subjects get punished for mistakes.
2. A trial by trail manipulation of **difficulty** (in the **Random dot motion** task, this refers to degree of coherence with which the dots move in a particular direction)

Let's take a look at the actual dataframe.

In [None]:
workshop_data

#### Adding a few columns

As part of prep work for plotting etc. we will add a few columns here. These will be motivated later (close your eyes :)).

In [8]:
# Binary version of difficulty
workshop_data['bin_difficulty'] = workshop_data['continuous_difficulty'].apply(lambda x: 'high' if x > 0 else 'low')

# I want a a ordinal variable that is composed of 5 quantile levels of difficulty
workshop_data['quantile_difficulty'] = pd.qcut(workshop_data['continuous_difficulty'],
                                                             3, labels = ['-1', '0', '1'])

workshop_data['quantile_difficulty_binary'] = pd.qcut(workshop_data['continuous_difficulty'],
                                                             2, labels = ['-1', '1'])

# Slightly
workshop_data['response_l1_plotting'] = workshop_data['response_l1'].apply(lambda x: str(-1) if x == -1 else str(1))


#### Most basic reaction time plot

In [None]:
plot_rt_hists(workshop_data, 
              by_participant = False, 
              split_by_column = None,
              inset_plot = None)

So far so good. Looking at the global reaction time pattern, it does seem commensurate with what we might expect out of basic Sequential Sampling Model (SSM). The basic DDM might be a good start here.

## Basic Model: DDM

<center> <img src="images/DDM_with_params_pic.png" height="500" width="500"> </center>

The picture above illustrates the basic [Drift Diffusion Model](https://pmc.ncbi.nlm.nih.gov/articles/PMC2474742/). Note the parameters,

1. `v` the drift rate (how much evidence do I collect on average per unit of time)
2. `a` the boundary separation (how much evidence do I need to commit to a choice)
3. `z` how biased am I toward a particular choice a priori
4. `ndt` (we will simply call it `t` below), the delay between being exposed to a stimulus and starting the actual evidence accumulation process

In [None]:
BasicDDMModel = hssm.HSSM(data = workshop_data,
                          model = "ddm",
                          loglik_kind = "approx_differentiable",
                          global_formula = "y ~ 1",
                          noncentered = False,
                          )

In [None]:
BasicDDMModel

In [None]:
BasicDDMModel.graph()

In [13]:
try:
    # Load pre-computed traces
    BasicDDMModel.restore_traces(traces = "idata/basic_ddm/traces.nc")
except:
    # Sample posterior
    basic_ddm_idata = BasicDDMModel.sample(chains = 2,
                                            sampler = "nuts_numpyro",
                                            tune = 500,
                                            draws = 500,
                                        )

    # Sample posterior predictive
    BasicDDMModel.sample_posterior_predictive(draws = 200,
                                              safe_mode = True)

    # Save Model
    BasicDDMModel.save_model(model_name = "basic_ddm",
                             allow_absolute_base_path = True,
                             base_path = "idata/",
                             save_idata_only = True)

In [None]:
BasicDDMModel.traces

In [None]:
az.summary(BasicDDMModel.traces)

In [None]:
az.plot_trace(BasicDDMModel.traces)
plt.tight_layout()

In [None]:
az.plot_forest(BasicDDMModel.traces)
plt.tight_layout()

In [None]:
az.plot_pair(BasicDDMModel.traces,
             kind="kde",
             marginals=True)

In [None]:
ax = hssm.plotting.plot_model_cartoon(
    BasicDDMModel,
    n_samples=10,
    bins=20,
    plot_pp_mean=True,
    plot_pp_samples=False,
    n_trajectories=2,  # extra arguments for the underlying plot_model_cartoon() function
);

In [None]:
ax = hssm.plotting.plot_quantile_probability(BasicDDMModel, 
                                             cond="quantile_difficulty",
                                             )
ax.set_ylim(0, 3);
# ax.set_xlim(-0.1, 1.1);

In [None]:
ax = hssm.plotting.plot_quantile_probability(BasicDDMModel, 
                                             cond="costly_fail_condition",
                                             )
ax.set_ylim(0, 3);

In [None]:
ax = hssm.plotting.plot_quantile_probability(BasicDDMModel, 
                                             cond="response_l1_plotting",
                                             )
ax.set_ylim(0, 3);

In [None]:
# Posterior predictive
BasicDDMModel.plot_posterior_predictive(step = True, 
                                        col = 'participant_id',
                                        col_wrap = 5,
                                        bins = np.linspace(-5,5, 50))

#### Taking stock

We can observe a few patterns here. 

- First, cleary the reaction time distributions are not the same for every subject, we need to account for that.
- Second, I does seem like the tail of the reaction time distribution is more graceful in for our predictions than it is in the original subject data. (This was less clear when looking only at the global pattern...)

We will now adjust our model to tackle these patterns one by one. Let's begin by specializing our parameters by subject. 

In **Bayesian Inference** we approach this by introducing a **Hierarchy**, we assume that subject level parameters derive from a common **group distribution**.

Inference then proceeds over the parameters of this group distribution, as well as the subject wise parameters. 

Hierarchies serve as a form of **regularization** of our parameter estimates, the group distribution allows us to share information between the single subject parameters estimates. 

You don't **have** to use a hierarchy, we could introduce a subject wise parameterization e.g. by simply treating `participant_id` as a **categorical** variable / collection of **dummy** variables without using any notion of a group distribution (and you are welcome to try this).

## DDM Hierarchical

Moving on to our first hierarchical model. As a first step, we will use our `global_formula` argument to `(1|participant_id)`, which is equivalent to `1 + (1|participant_id)`,
(use `0 + (1|participant_id)` is you explicitly don't want to create an intercept).

This will make all parameters of our model hierarchical.

In [None]:
DDMHierModel = hssm.HSSM(data = workshop_data,
                         model = "ddm",
                         loglik_kind = "approx_differentiable",
                         global_formula = "y ~ (1|participant_id)", # New
                         noncentered = False,
                        )

In [25]:
try:
    # Load pre-computed traces
    DDMHierModel.restore_traces(traces = "idata/ddm_hier/traces.nc")
except:
    # Sample posterior
    ddm_hier_idata = DDMHierModel.sample(chains = 2,
                                             sampler = "nuts_numpyro",
                                             tune = 500,
                                             draws = 500,
                                            )

    # Sample posterior predictive
    DDMHierModel.sample_posterior_predictive(draws = 200,
                                             safe_mode = True)

    # Save Model
    DDMHierModel.save_model(model_name = "ddm_hier",
                              allow_absolute_base_path = True,
                              base_path = "idata/",
                              save_idata_only = True)

In [None]:
DDMHierModel.graph()

In [None]:
az.plot_trace(DDMHierModel.traces)
plt.tight_layout()

In [None]:
ax = hssm.plotting.plot_model_cartoon(
    DDMHierModel,
    col = "participant_id",
    col_wrap = 5,
    n_samples=100,
    bin_size=0.2,
    plot_pp_mean=True,
    # color_pp_mean = "red",
    # color_pp = "black",
    plot_pp_samples=False,
    n_trajectories=2,  # extra arguments for the underlying plot_model_cartoon() function
);

#### Comparing Parameter Loadings

In [None]:
az.summary(DDMHierModel.traces,
           filter_vars = "like",
           var_names = ["~participant_id"]).sort_index()

In [None]:
az.summary(BasicDDMModel.traces).sort_index()

The mean parameters of our models are de facto quite similar. Allowing subject wise variation however dramatically improved our fit to the data!

#### Quantitative Model Comparison

In [None]:
az.compare(
    {"DDM": BasicDDMModel.traces, 
     "DDM Hierarchical": DDMHierModel.traces}
)

#### Comparing predictions

In [None]:
# Posterior predictive
DDMHierModel.plot_posterior_predictive(step = True, 
                                       col_wrap = 5,
                                       bins = np.linspace(-5,5, 50));

In [None]:
# Posterior predictive
BasicDDMModel.plot_posterior_predictive(step = True, 
                                        col_wrap = 5,
                                        bins = np.linspace(-5,5, 50));

In [None]:
# Posterior predictive
DDMHierModel.plot_posterior_predictive(step = True, 
                                       col = 'participant_id',
                                       col_wrap = 5,
                                       bins = np.linspace(-5,5, 50))

#### Taking Stock

Let's take stock again of any obvious pontential for improving our model here. We are now capturing the data much better subject by subject, however looking closely, 

it seems like the tail behavior of the observed and the predicted data is somewhat different, for a few subjects. 

The particularly suspicious subjest, are:

- `participant_id = 1`
- `participant_id = 14`
- `participant_id = 15`
- `participant_id = 17`


It seems that for these (and there are others) participants, the model predicted data has a wider tail than what we actually observe in our dataset. 

This will motivate a change in the Sequential Sampling Model that we apply. 

## Angle Model Hierarchical

Given what we concluded about the tail behavior of the observed RTs, we will adjust our SSM, to allow for **linear collapsing bounds**. HSSM ships with a such a model,
and we can apply it to our data simple by changing the `model` argument. The corresponding model is called `angle` model in our lingo, and is illustrated below conceptually.



<center> <img src="./images/ANGLE_with_params_pic.png" height=500 width=500> </center>


In [None]:
AngleHierModel = hssm.HSSM(data = workshop_data,
                           model = "angle",
                           loglik_kind = "approx_differentiable",
                           global_formula = "y ~ (1|participant_id)",
                           noncentered = False,
                          )

In [None]:
AngleHierModel.graph()

In [37]:
try:
    # Load pre-computed traces
    AngleHierModel.restore_traces(traces = "idata/angle_hier/traces.nc")
except:
    # Sample posterior
    angle_hier_idata = AngleHierModel.sample(chains = 2,
                                             sampler = "nuts_numpyro",
                                             tune = 500,
                                             draws = 500,
                                            )

    # Sample posterior predictive
    AngleHierModel.sample_posterior_predictive(draws = 200,
                                               safe_mode = True)

    # Save Model
    AngleHierModel.save_model(model_name = "angle_hier",
                              allow_absolute_base_path = True,
                              base_path = "idata/",
                              save_idata_only = True)

In [None]:
ax = hssm.plotting.plot_model_cartoon(
    AngleHierModel,
    col = 'participant_id',
    col_wrap = 5,
    n_samples=10,
    bin_size=0.2,
    plot_pp_mean=True,
    plot_pp_samples=False,
    n_trajectories=2,  # extra arguments for the underlying plot_model_cartoon() function
);

We can up it one notch and include the parameter uncertainty in the `model_cartoon_plot()`. This helps us assess how certain we are about the setting of the boundary collapse here.
Let's see what that looks like!

In [None]:
ax = hssm.plotting.plot_model_cartoon(
    AngleHierModel,
    col = 'participant_id',
    col_wrap = 5,
    n_samples=50,
    bin_size=0.2,
    plot_pp_mean=True,
    plot_pp_samples=True,
    n_trajectories=2,  # extra arguments for the underlying plot_model_cartoon() function
);

#### Angle (theta) parameter Bayesian t-test

In [None]:
az.plot_posterior(AngleHierModel.traces,
                  var_names = ["theta"],
                  ref_val = 0,
                  kind = "hist",
                  ref_val_color = "red",
                  histtype = "step")

In [None]:
# Posterior predictive
AngleHierModel.plot_posterior_predictive(step = True, 
                                         col_wrap = 5,
                                         bins = np.linspace(-5,5, 50))

In [None]:
# Posterior predictive
AngleHierModel.plot_posterior_predictive(step = True, 
                                        col = 'participant_id',
                                        col_wrap = 5,
                                        bins = np.linspace(-5,5, 50))

Visually it seems like we did improve the fit (even though the difference in visual improvement is much less than what we had witnessed introducing the hierarchy in the first place). 

Let us corroborate the visual intuition via formal model comparison.

#### Quantitative Model Comparison

In [None]:
az.compare(
    {"DDM": BasicDDMModel.traces,
     "DDM Hierarchical": DDMHierModel.traces,
     "Angle Hierarchical": AngleHierModel.traces}
)

Good, introducing the `angle` model seemed to have helped us quite a bit, even though, as intuited by the simple visual inspection, the improvement in `elpd_loo` is not as 
the improvement in going from a simple model toward a hierarchical model (even though the actual SSM was misspecified).

So what next? On the surface, it looks like we have a model that fits our data quite well. 


Let's take another look at our data to identify more patterns that we may not capture with out current efforts.



### Further EDA

Maybe it is time to look more directly at the effects of our experiment manipulations.

Below are a few graphs to understand what might be happening.

In [None]:
# Posterior predictive
AngleHierModel.plot_posterior_predictive(step = True, 
                                         col = 'costly_fail_condition',
                                         bins = np.linspace(-5,5, 50))
plt.tight_layout()

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModel, 
                                             cond="costly_fail_condition",
                                             )
ax.set_ylim(0, 3);

In [None]:
plot_rt_hists(workshop_data,
              by_participant = True,
              split_by_column = "costly_fail_condition",
              inset_plot =  "choice_proportion")

In [None]:
plot_rt_hists(workshop_data,
              by_participant = True,
              split_by_column = "costly_fail_condition",
              inset_plot =  "rt_mean")

We can identify two patterns. 

1. On average in the `costly_fail_condition` participants seem to make slightly fewer mistakes
2. On average in the `costly_fail_condition` participants seem to take a little longer for their choices!

This meshes with how we expect the incentives to act. Participants should be slightly more cautious to get it right, if mistakes are costly!

In the contect of SSMs, this is usually mapped on to the `decision threshold` (parameter `a`), so maybe we should try to incorporate the `costly_fail_condition` in the regression
function for that parameter in our model.

## Addressing costly fail condition


To include parameter specific regressions, we can rely on the `include` argument in HSSM. Let's illustrate this.

In [None]:
AngleHierModelV2 = hssm.HSSM(data = workshop_data,
                             model = "angle",
                             loglik_kind = "approx_differentiable",
                             global_formula = "y ~ (1|participant_id)",
                             include = [{"name": "a",
                                         "formula": "a ~ (1 + C(costly_fail_condition)|participant_id)"}],
                             noncentered = False,
                            )

In [None]:
AngleHierModelV2.graph()

In [50]:
try:
    # Load pre-computed traces
    AngleHierModelV2.restore_traces(traces = "idata/angle_hier_v2/traces.nc")
except:
    # Sample posterior
    angle_hier_idata = AngleHierModelV2.sample(chains = 2,
                                             sampler = "nuts_numpyro",
                                             tune = 500,
                                             draws = 500,
                                            )

    # Sample posterior predictive
    AngleHierModelV2.sample_posterior_predictive(draws = 200,
                                                 safe_mode = True)

    # Save Model
    AngleHierModelV2.save_model(model_name = "angle_hier_v2",
                                allow_absolute_base_path = True,
                                base_path = "idata/",
                                save_idata_only = True)

In [None]:
az.plot_trace(AngleHierModelV2.traces)
plt.tight_layout()

In [None]:
az.plot_posterior(AngleHierModelV2.traces,
                  var_names = ["a_C(costly_fail_condition)|participant_id_mu"],
                  ref_val = 0,
                  kind = "hist",
                  ref_val_color = "red",
                  histtype = "step")

In [None]:
# Posterior predictive
AngleHierModelV2.plot_posterior_predictive(step = True, 
                                           # row = 'participant_id',
                                           col = 'costly_fail_condition',
                                           bins = np.linspace(-5, 5, 50),
                                           )
plt.tight_layout()
plt.show()

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModelV2, 
                                             cond="costly_fail_condition",
                                             )
ax.set_ylim(0, 3);

Quite an improvement!
Let's see what our quantitative model comparison metrics say.


### Quantitative Model Comparison

In [None]:
az.compare(
    {"DDM": BasicDDMModel.traces,
     "DDM Hierarchical": DDMHierModel.traces,
     "Angle Hierarchical": AngleHierModel.traces,
     "Angle Hierarchical Cost": AngleHierModelV2.traces}
)

Good, we now incorporated the `costly_fail_condition` in a conceptually coherent manner. 

Let's take a look at `difficulty` next. 

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModelV2, 
                                             cond="quantile_difficulty_binary",
                                             )
ax.set_ylim(0, 3);

In [None]:
plot_rt_hists(workshop_data,
              by_participant = True,
              split_by_column = "quantile_difficulty_binary",
              inset_plot =  "rt_mean")

In [None]:
plot_rt_hists(workshop_data,
              by_participant = True,
              split_by_column = "quantile_difficulty_binary",
              inset_plot =  "choice_proportion")

We see a similar pattern. Difficulty affects choice probability, however the effect on RT is less clear.

What parameter should difficulty map onto? Usually it maps onto the **rate of evidence accumulation**, which is the *drift* (`v`) parameter most SSMs.

We will move ahead and try this. To add specialized regression for `v` we can add another parameter dictionary to the list we pass to `include`.



## Addressing difficulty

In [None]:
AngleHierModelV3 = hssm.HSSM(data = workshop_data,
                             model = "angle",
                             loglik_kind = "approx_differentiable",
                             global_formula = "y ~ (1|participant_id)",
                             include = [{"name": "a",
                                         "formula": "a ~ (1 + C(costly_fail_condition)|participant_id)"},
                                         {"name": "v",
                                         "formula": "v ~ (1 + continuous_difficulty|participant_id)"},
                                        ],
                             noncentered = False,
                            )

In [None]:
AngleHierModelV3.graph()

In [61]:
try:
    # Load pre-computed traces
    AngleHierModelV3.restore_traces(traces = "idata/angle_hier_v3/traces.nc")
except:
    # Sample posterior
    angle_hier_idata = AngleHierModelV3.sample(chains = 2,
                                             sampler = "nuts_numpyro",
                                             tune = 500,
                                             draws = 500,
                                            )

    # Sample posterior predictive
    AngleHierModelV3.sample_posterior_predictive(draws = 200,
                                                 safe_mode = True)

    # Save Model
    AngleHierModelV3.save_model(model_name = "angle_hier_v3",
                                allow_absolute_base_path = True,
                                base_path = "idata/",
                                save_idata_only = True)

In [None]:
az.plot_posterior(AngleHierModelV3.traces,
                  var_names = ["v_continuous_difficulty|participant_id_mu"],
                  ref_val = 0,
                  kind = "hist",
                  ref_val_color = "red",
                  histtype = "step")

Looks like the effect on `v` is small (to the trained eye :)), but it is significant!
Let's check if we can account for the data pattern we missed previously.

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModelV3, 
                                             cond="quantile_difficulty_binary",
                                             )
ax.set_ylim(0, 3);

In [None]:
# Posterior predictive
AngleHierModelV3.plot_posterior_predictive(step = True, 
                                           col = "quantile_difficulty_binary",
                                           bins = np.linspace(-5,5, 50))
plt.tight_layout()

Success! This looks much better.

#### Quantitative Model Comparison

In [None]:
az.compare(
    {
     "DDM": BasicDDMModel.traces,
     "DDM Hierarchical": DDMHierModel.traces,
     "Angle Hierarchical": AngleHierModel.traces,
     "Angle Hierarchical Cost": AngleHierModelV2.traces,
     "Angle Hierarchical Cost/Diff": AngleHierModelV3.traces
     }
)

### Anything else?

At this point, we have a model that fits the data quite well. 

We figured that a hierarchy significantly improves our fit, that the `angle` model dominates the basic `ddm` model for our data, and we incorporated effects based on 
our experiment manipulations. 

A natural next step is to check for patterns based on more generic properties of human choice data that we may be able to reason about. 

Anything that comes to mind? Let's take another look at our dataset for some inspiration.

In [None]:
workshop_data

At risk of stating the obvious,we have a column that went unused thus far: `response_l1`, the lagged response.

Maybe this hints at some level of *stickiness* in the choice behavior? How could we incorporate this?

Let us first investigate if there is indeed such a pattern in the data!

In [None]:
# Posterior predictive
AngleHierModelV3.plot_posterior_predictive(step = True, 
                                           col = 'response_l1_plotting',
                                           bins = np.linspace(-5, 5, 50))
plt.tight_layout()

Indeed, it does seem like there is a bit of a pattern here, that we miss so far!

To incoporate choice `stickiness`, a reasonable candidate parameter is `z`, the a priori choice bias. 
Maybe this parameter is affected by the last choice taken?

Let's try to incoporate this. We will 

## Addressing Stickyness

In [None]:
AngleHierModelV4 = hssm.HSSM(data = workshop_data,
                             model = "angle",
                             loglik_kind = "approx_differentiable",
                             global_formula = "y ~ (1|participant_id)",
                             include = [{"name": "a",
                                         "formula": "a ~ (1 + C(costly_fail_condition)|participant_id)"},
                                         {"name": "v",
                                          "formula": "v ~ (1 + continuous_difficulty|participant_id)"},
                                         {"name": "z",
                                          "formula": "z ~ (1 + response_l1|participant_id)"},
                                        ],
                             noncentered = False,
                            )

In [None]:
AngleHierModelV4.graph()

In [70]:
try:
    # Load pre-computed traces
    AngleHierModelV4.restore_traces(traces = "idata/angle_hier_v4/traces.nc")
except:
    # Sample posterior
    angle_hier_idata = AngleHierModelV4.sample(chains = 2,
                                             sampler = "nuts_numpyro",
                                             tune = 500,
                                             draws = 500,
                                            )

    # Sample posterior predictive
    AngleHierModelV4.sample_posterior_predictive(draws = 200,
                                                 safe_mode = True)

    # Save Model
    AngleHierModelV4.save_model(model_name = "angle_hier_v4",
                                allow_absolute_base_path = True,
                                base_path = "idata/",
                                save_idata_only = True)

In [None]:
az.plot_trace(AngleHierModelV4.traces,
              divergences = None);
plt.tight_layout()

** Note **:

We can see some rather interesting artifacts in the chains above. Around samples `300-375` it looks like our solid-blue chain got quite stuck. This indicates some problems with the posterior geometry for this model. 
One diagnostic that can be helpful here whether or not we observe a lot of `divergences` during sampling. 

Let's take a look below (notice, we change the `diveregences` argument from `None` to it's default)


In [None]:
az.plot_trace(AngleHierModelV4.traces,
              divergences = 'auto');
plt.tight_layout()

Indeed, we observe a fee divergences here... as rigorous scientists, we should now try to get to the bottom of this phenomenon (it happens often if one tries hierarchical models naively on real experimental data).
In the context of this tutorial, we will let it slide however. It would warrant a longer detour.

Let's move on and focus on whether or not we actually identify a significant **choice stickyness** effect with our analysis:

In [None]:
az.plot_posterior(AngleHierModelV4.traces,
                  var_names = ["z_response_l1|participant_id_mu"],
                  ref_val = 0,
                  kind = "hist",
                  ref_val_color = "red",
                  histtype = "step")

We observe a significant effect on the `z` parameter, in fact a mean effect of `0.073` insinuate a fairly big effect of choice stickyness.
In direct comparison, we might expect this effect to overall have a larger impact on our model fit than the effect of difficulty on `v`, which we investigated in the 
previous section.

In [None]:
# Posterior predictive
AngleHierModelV4.plot_posterior_predictive(step = True, 
                                           col = 'response_l1_plotting',
                                           bins = np.linspace(-5, 5, 50))
plt.tight_layout()

#### Quantitative Model Comparison

In [None]:
az.compare(
    {"DDM": BasicDDMModel.traces,
     "DDM Hierarchical": DDMHierModel.traces,
     "Angle Hierarchical": AngleHierModel.traces,
     "Angle Hierarchical Cost": AngleHierModelV2.traces,
     "Angle Hierarchical Cost/Diff": AngleHierModelV3.traces,
     "Angle Hierarchical Cost/Diff/Sticky": AngleHierModelV4.traces}
)

And indeed, the drop in`elpd_loo` is even more substantial, than the improvement generated by incorporating the `difficulty` effect.

#### Taking Stock

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModelV4, 
                                             cond="response_l1_plotting",
                                             )
ax.set_ylim(0, 3);

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModelV4, 
                                             cond="costly_fail_condition",
                                             )
ax.set_ylim(0, 3);

In [None]:
ax = hssm.plotting.plot_quantile_probability(AngleHierModelV4, 
                                             cond="quantile_difficulty_binary",
                                             )
ax.set_ylim(0, 3);

## Sanity Check, was the hierarchy really necessary

In [None]:
AngleModelV5 = hssm.HSSM(data = workshop_data,
                         model = "angle",
                         loglik_kind = "approx_differentiable",
                         global_formula = "y ~ 1",
                         include = [{"name": "a",
                                        "formula": "a ~ 1 + C(costly_fail_condition)"},
                                        {"name": "v",
                                        "formula": "v ~ 1 + continuous_difficulty"},
                                        {"name": "z",
                                        "formula": "z ~ 1 + response_l1"},
                                    ],
                         noncentered = False,
                         )

In [80]:
try:
    # Load pre-computed traces
    AngleModelV5.restore_traces(traces = "idata/angle_v5/traces.nc")
except:
    # Sample posterior
    angle_hier_idata = AngleModelV5.sample(chains = 2,
                                           sampler = "nuts_numpyro",
                                           tune = 500,
                                           draws = 500,
                                           )

    # Sample posterior predictive
    AngleModelV5.sample_posterior_predictive(draws = 200,
                                                 safe_mode = True)

    # Save Model
    AngleModelV5.save_model(model_name = "angle_v5",
                            allow_absolute_base_path = True,
                            base_path = "idata/",
                            save_idata_only = True)

In [None]:
az.compare(
    {
     "DDM": BasicDDMModel.traces,
     "DDM Hierarchical": DDMHierModel.traces,
     "Angle Hierarchical": AngleHierModel.traces,
     "Angle Hierarchical Cost": AngleHierModelV2.traces,
     "Angle Hierarchical Cost/Diff": AngleHierModelV3.traces,
     "Angle Hierarchical Cost/Diff/Sticky": AngleHierModelV4.traces,
     "Angle Cost/Diff/Sticky": AngleModelV5.traces
     }
)

## The End:

So far so good, we completed a rather comprehensive model exploration and we generated quite a few insights!
We could obviously go on and try more and more complex models and maybe there is more to find out here... we leave this up to you and hope that HSSM will continue to help you along the way :).

## Pointers to advanced Topics

We are scratching only the surface of what cann be done with [HSSM](https://github.com/lnccbrown/HSSM/), let alone the broader eco-system supporting [simulation based inference (SBI)](https://simulation-based-inference.org/).

Check out our simulator package, [ssm-simulators](https://github.com/lnccbrown/ssm-simulators) as well as our our little neural network library for training [LANs](https://elifesciences.org/articles/65074), [lanfactory](https://github.com/lnccbrown/LANfactory). 

Exciting work is being done (more on this in the next tutorial) on connecting to other packages in the wider eco-system, such as [BayesFlow](https://bayesflow.org/main/index.html) as well as the [sbi](https://sbi-dev.github.io/sbi/v0.24.0/) package.

Here is a taste of advanced topics with links to corresponding tutorials:

- [Variational Inference with HSSM](https://lnccbrown.github.io/HSSM/tutorials/variational_inference/)
- [Build PyMC models with HSSM random variables](https://lnccbrown.github.io/HSSM/tutorials/pymc/)
- [Connect compiled models to third party MCMC libraries](https://lnccbrown.github.io/HSSM/tutorials/compile_logp/)
- [Construct custom models from simulators and contributed likelihoods](https://lnccbrown.github.io/HSSM/tutorials/jax_callable_contribution_onnx_example/)
- [Using link functions to transform parameters](https://lnccbrown.github.io/HSSM/api/link/#hssm.Link)

you will find this and a lot more information in the [official documentation](https://lnccbrown.github.io/HSSM/)
