Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

### Overview

HSSM is a Python toolbox that provides a seamless combination of
HSSM is an open-source Python toolbox for computational modeling in cognitive
neuroscience. It supports a broad range of sequential sampling models used to
study decision-making, learning, and other cognitive processes — from basic
research to the analysis of clinical effects. Under the hood, HSSM combines
state-of-the-art likelihood approximation methods with the wider ecosystem of
probabilistic programming languages. It facilitates flexible hierarchical model
building and inference via modern MCMC samplers. HSSM is user-friendly and
provides the ability to rigorously estimate the impact of neural and other
trial-by-trial covariates through parameter-wise mixed-effects models for a
large variety of cognitive process models. HSSM is a
probabilistic programming to enable flexible hierarchical Bayesian inference via
modern MCMC samplers. It is user-friendly and provides the ability to rigorously
estimate the impact of neural and other trial-by-trial covariates through
parameter-wise mixed-effects models. HSSM is a
<a href="https://ccbs.carney.brown.edu/brainstorm">BRAINSTORM</a> project in
collaboration with the Center for Computation and Visualization and the Center
for Computational Brain Science within the Carney Institute at Brown University.
Expand Down
405 changes: 405 additions & 0 deletions addm_andrew_dev /addm_hssm.md

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
![GitHub Repo stars](https://img.shields.io/github/stars/lnccbrown/HSSM)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)

**HSSM** (Hierarchical Sequential Sampling Modeling) is a modern Python toolbox
that provides state-of-the-art likelihood approximation methods within the
Python Bayesian ecosystem. It facilitates hierarchical model building and
inference via fast and robust MCMC samplers. User-friendly, extensible, and
flexible, HSSM can rigorously estimate the impact of neural and other
trial-by-trial covariates through parameter-wise mixed-effects models for a
large variety of cognitive process models.
**HSSM** (Hierarchical Sequential Sampling Modeling) is a modern open-source
Python toolbox for computational modeling in cognitive neuroscience. It supports
a broad range of sequential sampling models used to study decision-making,
learning, and other cognitive processes — from basic research to the analysis of
clinical effects. HSSM provides state-of-the-art likelihood approximation
methods within the Python Bayesian ecosystem and facilitates hierarchical model
building and inference via fast and robust MCMC samplers. User-friendly,
extensible, and flexible, it can rigorously estimate the impact of neural and
other trial-by-trial covariates through parameter-wise mixed-effects models.

HSSM is a [BRAINSTORM](https://ccbs.carney.brown.edu/brainstorm) project in
collaboration with the
Expand Down
351 changes: 351 additions & 0 deletions docs/tutorials/rlssm_quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1b9b429d",
"metadata": {},
"source": [
"# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n",
"\n",
"This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n",
"\n",
"1. **Load** a balanced-panel two-armed bandit dataset\n",
"2. **Define** an annotated learning function and the angle SSM log-likelihood\n",
"3. **Configure** and **instantiate** an `RLSSM` model\n",
"4. **Inspect** the built Bambi / PyMC model\n",
"5. **Run** a minimal 2-draw sampling smoke test\n",
"\n",
"For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n",
"- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n",
"- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "bf38d7f7",
"metadata": {},
"source": [
"## 1. Imports and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d764731",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import hssm\n",
"from hssm.rl import RLSSM, RLSSMConfig\n",
"from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n",
"from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n",
"from hssm.utils import annotate_function\n",
"\n",
"# RLSSM requires float32 throughout (JAX default).\n",
"hssm.set_floatX(\"float32\", update_jax=True)"
]
},
{
"cell_type": "markdown",
"id": "df12303f",
"metadata": {},
"source": [
"## 2. Load the Dataset\n",
"\n",
"We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n",
"It is a **balanced panel**: every participant has the same number of trials. \n",
"Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n",
"\n",
"> **Note:** You can also generate data with\n",
"> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n",
"> See `rlssm_tutorial.ipynb` for an example."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2ef5f6e",
"metadata": {},
"outputs": [],
"source": [
"# Path relative to docs/tutorials/ when running inside the HSSM repo.\n",
"_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n",
"raw = np.load(_fixture_path, allow_pickle=True).item()\n",
"data = pd.DataFrame(raw[\"data\"])\n",
"\n",
"n_participants = data[\"participant_id\"].nunique()\n",
"n_trials = len(data) // n_participants\n",
"\n",
"print(data.head())\n",
"print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")"
]
},
{
"cell_type": "markdown",
"id": "8c310290",
"metadata": {},
"source": [
"## 3. Define the Learning Process\n",
"\n",
"The RL learning process is a JAX function that, given a subject's trial sequence, computes\n",
"the trial-wise drift rate `v` via a Q-learning update rule. \n",
"\n",
"`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n",
"that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n",
"decision process.\n",
"\n",
"- **inputs** — columns that the function reads (free parameters + data columns)\n",
"- **outputs** — what the function produces (here: `v`, the drift rate)\n",
"\n",
"Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n",
"Rescorla-Wagner Q-learning update for a two-armed bandit task."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbcea122",
"metadata": {},
"outputs": [],
"source": [
"compute_v_annotated = annotate_function(\n",
" inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n",
" outputs=[\"v\"],\n",
")(compute_v_subject_wise)\n",
"\n",
"print(\"Learning function inputs :\", compute_v_annotated.inputs)\n",
"print(\"Learning function outputs:\", compute_v_annotated.outputs)"
]
},
{
"cell_type": "markdown",
"id": "7a03305a",
"metadata": {},
"source": [
"## 4. Define the Decision (SSM) Log-Likelihood\n",
"\n",
"The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n",
"`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n",
"2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n",
"per-trial log-probabilities.\n",
"\n",
"We then annotate that callable so the builder knows:\n",
"- which columns the matrix contains (`inputs`)\n",
"- that `v` itself is *computed* by the learning function (not a free parameter)\n",
"\n",
"The ONNX file is loaded from the local test fixture when running inside the HSSM\n",
"repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60bbc036",
"metadata": {},
"outputs": [],
"source": [
"# Use the local fixture when available; fall back to HuggingFace download.\n",
"_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n",
"_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n",
"\n",
"_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n",
"\n",
"angle_logp_func = annotate_function(\n",
" inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n",
" outputs=[\"logp\"],\n",
" computed={\"v\": compute_v_annotated},\n",
")(_angle_logp_jax)\n",
"\n",
"print(\"SSM logp inputs :\", angle_logp_func.inputs)\n",
"print(\"SSM logp outputs:\", angle_logp_func.outputs)\n",
"print(\"Computed deps :\", list(angle_logp_func.computed.keys()))"
]
},
{
"cell_type": "markdown",
"id": "cf8f5b63",
"metadata": {},
"source": [
"## 5. Configure the Model with `RLSSMConfig`\n",
"\n",
"`RLSSMConfig` collects all the information the RLSSM class needs:\n",
"\n",
"| Field | Purpose |\n",
"|-------|---------|\n",
"| `model_name` | Identifier string for the configuration |\n",
"| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n",
"| `list_params` | Ordered list of *free* parameters to sample |\n",
"| `params_default` | Starting / default values for each parameter |\n",
"| `bounds` | Prior bounds for each parameter |\n",
"| `learning_process` | Dict mapping computed param name → annotated learning function |\n",
"| `extra_fields` | Extra data columns required by the learning function |\n",
"| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4beba1bc",
"metadata": {},
"outputs": [],
"source": [
"rlssm_config = RLSSMConfig(\n",
" model_name=\"rlssm_angle_quickstart\",\n",
" loglik_kind=\"approx_differentiable\",\n",
" decision_process=\"angle\",\n",
" decision_process_loglik_kind=\"approx_differentiable\",\n",
" learning_process_kind=\"blackbox\",\n",
" list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n",
" params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n",
" bounds={\n",
" \"rl_alpha\": (0.0, 1.0),\n",
" \"scaler\": (0.0, 10.0),\n",
" \"a\": (0.1, 3.0),\n",
" \"theta\": (-0.1, 0.1),\n",
" \"t\": (0.001, 1.0),\n",
" \"z\": (0.1, 0.9),\n",
" },\n",
" learning_process={\"v\": compute_v_annotated},\n",
" response=[\"rt\", \"response\"],\n",
" choices=[0, 1],\n",
" extra_fields=[\"feedback\"],\n",
" ssm_logp_func=angle_logp_func,\n",
")\n",
"\n",
"print(\"Model name :\", rlssm_config.model_name)\n",
"print(\"Free params :\", rlssm_config.list_params)"
]
},
{
"cell_type": "markdown",
"id": "924ee4c7",
"metadata": {},
"source": [
"## 6. Instantiate the `RLSSM` Model\n",
"\n",
"Passing `data` and `rlssm_config` to `RLSSM`:\n",
"\n",
"- validates the balanced-panel requirement\n",
"- builds a differentiable PyTensor Op that chains the RL learning step and the\n",
" angle log-likelihood\n",
"- constructs the Bambi / PyMC model internally\n",
"\n",
"Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n",
"the Op by the Q-learning update and therefore does not appear in `model.params`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f8da79a",
"metadata": {},
"outputs": [],
"source": [
"model = RLSSM(data=data, model_config=rlssm_config)\n",
"\n",
"assert isinstance(model, RLSSM)\n",
"print(\"Model type :\", type(model).__name__)\n",
"print(\"Participants :\", model.n_participants)\n",
"print(\"Trials/subj :\", model.n_trials)\n",
"print(\"Free parameters :\", list(model.params.keys()))\n",
"assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n",
"assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n",
"model"
]
},
{
"cell_type": "markdown",
"id": "f7f39940",
"metadata": {},
"source": [
"## 7. Inspect the Built Model\n",
"\n",
"After construction, `model.model` exposes the underlying **Bambi model** and\n",
"`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n",
"or customizing priors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0558ad4",
"metadata": {},
"outputs": [],
"source": [
"print(\"=== Bambi model ===\")\n",
"print(model.model)\n",
"\n",
"print(\"\\n=== PyMC model ===\")\n",
"print(model.pymc_model)"
]
},
{
"cell_type": "markdown",
"id": "f4e50110",
"metadata": {},
"source": [
"## 8. Sampling\n",
"\n",
"A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n",
"computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96ce3238",
"metadata": {},
"outputs": [],
"source": [
"trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n",
"\n",
"assert trace is not None\n",
"print(trace)"
]
},
{
"cell_type": "markdown",
"id": "a784a468",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"This notebook showed how to:\n",
"\n",
"1. Load a balanced-panel dataset (`rldm_data.npy`)\n",
"2. Annotate a Q-learning function with `annotate_function`\n",
"3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n",
"4. Define an `RLSSMConfig` and pass it to `RLSSM`\n",
"5. Confirm model structure (free params, Bambi / PyMC objects)\n",
"6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hssm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading