diff --git a/README.md b/README.md
index 6a891c20c..3f6baf005 100644
--- a/README.md
+++ b/README.md
@@ -20,13 +20,15 @@
### Overview
-HSSM is a Python toolbox that provides a seamless combination of
+HSSM is an open-source Python toolbox for computational modeling in cognitive
+neuroscience. It supports a broad range of sequential sampling models used to
+study decision-making, learning, and other cognitive processes — from basic
+research to the analysis of clinical effects. Under the hood, HSSM combines
state-of-the-art likelihood approximation methods with the wider ecosystem of
-probabilistic programming languages. It facilitates flexible hierarchical model
-building and inference via modern MCMC samplers. HSSM is user-friendly and
-provides the ability to rigorously estimate the impact of neural and other
-trial-by-trial covariates through parameter-wise mixed-effects models for a
-large variety of cognitive process models. HSSM is a
+probabilistic programming to enable flexible hierarchical Bayesian inference via
+modern MCMC samplers. It is user-friendly and provides the ability to rigorously
+estimate the impact of neural and other trial-by-trial covariates through
+parameter-wise mixed-effects models. HSSM is a
BRAINSTORM project in
collaboration with the Center for Computation and Visualization and the Center
for Computational Brain Science within the Carney Institute at Brown University.
diff --git a/addm_andrew_dev /addm_hssm.md b/addm_andrew_dev /addm_hssm.md
new file mode 100644
index 000000000..f4eee15f7
--- /dev/null
+++ b/addm_andrew_dev /addm_hssm.md
@@ -0,0 +1,405 @@
+# Plan: Integrating aDDM into HSSM
+
+## Context
+
+The attentional drift diffusion model (aDDM; Krajbich et al.) extends the standard DDM by modulating the drift rate based on **which option the subject is currently fixating**. A fast, differentiable JAX likelihood for the aDDM has been prototyped in the sibling repo [efficient-fpt](data/azhang/efficient-fpt) — specifically `get_addm_fptd_jax_fast` in [src/efficient_fpt_jax/multi_stage.py](data/azhang/efficient-fpt/src/efficient_fpt_jax/multi_stage.py).
+
+**Integration approach: vendor, do not depend.** Rather than add `efficient-fpt` as a dependency (it is not on PyPI, it ships compiled Cython for the simulator path that HSSM does not need, and bringing it in pulls a heavy build chain into HSSM's install), we **copy the relevant pure-JAX modules into HSSM** and own them going forward. HSSM already depends on `jax`/`jaxlib`, so the vendored code adds zero new transitive dependencies. The simulator (Cython) and NumPy/CPU paths from efficient-fpt are *not* vendored; only the `efficient_fpt_jax` subpackage's likelihood code.
+
+The goal is to wire the aDDM into HSSM so that users can write:
+
+```python
+model = hssm.HSSM(model="addm", data=addm_trial_df, include=[...])
+model.sample()
+```
+
+where `addm_trial_df` contains standard columns (`rt`, `response`) plus **aDDM-specific per-trial arrays** (item values, fixation onsets, fixation counts, first-fixation flag). The aDDM needs per-trial covariates that are *not* themselves sampled parameters — exactly the pattern RLSSM already solves in HSSM. We therefore follow the RLSSM design so that aDDM lives alongside it rather than carving a new architectural lane.
+
+The intended outcome is a working `aDDM(...)` class (and matching `aDDMConfig`) inside HSSM that (a) validates aDDM-specific trial data, (b) composes the vendored JAX FPT likelihood with sampled parameters `{eta, kappa, sigma, a, b, x0, t}` (non-decision time optional), (c) exposes the standard HSSM hierarchical regression and sampling machinery, and (d) ships with a tutorial notebook and unit tests.
+
+## Design choice: subclass pattern (matching the new RLSSM architecture)
+
+> **Architecture update (post-rebase, commit `b4228d1b` — "HSSM base + RLSSM classes" #893).** The HSSM base was refactored: `HSSMBase` is now an abstract base ([src/hssm/base.py](data/azhang/HSSM/src/hssm/base.py)) and both `HSSM` ([src/hssm/hssm.py](data/azhang/HSSM/src/hssm/hssm.py)) and `RLSSM` ([src/hssm/rl/rlssm.py](data/azhang/HSSM/src/hssm/rl/rlssm.py)) are concrete subclasses. `RLSSMConfig` was moved out of `hssm.config` into [src/hssm/rl/config.py](data/azhang/HSSM/src/hssm/rl/config.py) and **no longer exposes a `to_config()` method** — `HSSMBase` accepts any `BaseModelConfig` directly, and the family-specific subclass is responsible for building its own likelihood `Op` and stamping it onto the config via `dataclasses.replace(...)` before calling `super().__init__()`. The new `RLSSMConfig` also adds an `ssm_logp_func` field (an `@annotate_function`-decorated JAX function) and renamed `learning_process_loglik_kind` → `learning_process_kind`.
+
+Given this new architecture, the plan creates **both**:
+
+1. An `aDDM(HSSMBase)` concrete subclass — a peer of `HSSM` and `RLSSM`, exported from `hssm/__init__.py` as `hssm.aDDM`.
+2. An `aDDMConfig(BaseModelConfig)` dataclass living in `src/hssm/addm/config.py`, peer of `RLSSMConfig` in `src/hssm/rl/config.py`.
+
+The `aDDM` subclass handles (i) validating aDDM-specific data shape, (ii) building the differentiable PyTensor `Op` from the vendored JAX likelihood, (iii) stamping that `Op` onto the config (via `replace(loglik=op, backend="jax")`), and (iv) overriding `_make_model_distribution` to bypass the standard `loglik_kind` dispatching — exactly mirroring `RLSSM.__init__` and `RLSSM._make_model_distribution`.
+
+(There is **no longer** a `to_config()` round-trip; that method existed only in the pre-rebase architecture and has been removed from `RLSSMConfig` as well.)
+
+---
+
+## Step-by-step plan
+
+### Step 1 — Vendor the JAX likelihood code into HSSM
+
+**Source files to copy (from efficient-fpt):**
+
+| Source (efficient-fpt) | Destination (HSSM) | Purpose |
+|---|---|---|
+| `src/efficient_fpt_jax/multi_stage.py` | `src/hssm/addm/likelihoods/jax/multi_stage.py` | `get_addm_fptd_jax_fast`, `pad_sacc_array_safely` |
+| `src/efficient_fpt_jax/single_stage.py` | `src/hssm/addm/likelihoods/jax/single_stage.py` | `fptd_single_jax`, `q_single_jax` (called by multi_stage) |
+| `src/efficient_fpt_jax/utils.py` | `src/hssm/addm/likelihoods/jax/utils.py` | `GAUSS_LEGENDRE_30_X`, `GAUSS_LEGENDRE_30_W` quadrature constants |
+
+**Action:**
+
+1. Copy the three files above verbatim into a new `src/hssm/addm/likelihoods/jax/` package.
+2. Update relative imports inside the copied files so they resolve within `hssm.addm.likelihoods.jax` (e.g. `from .single_stage import ...`, `from .utils import ...` — these are already relative in the source, so no change needed in practice; verify).
+3. Create `src/hssm/addm/likelihoods/jax/__init__.py` exposing only the symbols HSSM needs:
+ ```python
+ from .multi_stage import get_addm_fptd_jax_fast, pad_sacc_array_safely
+ from .single_stage import fptd_single_jax, q_single_jax
+ from .utils import GAUSS_LEGENDRE_30_X, GAUSS_LEGENDRE_30_W
+ ```
+4. Add a header comment to each vendored file recording the upstream commit hash from `efficient-fpt` so future maintainers can diff against upstream when bug fixes land there.
+5. **Do not** vendor `efficient_fpt_jax/batch.py` (not used by the per-trial likelihood we wrap), nor anything from the `efficient_fpt` (Cython/NumPy) subpackage — that path is the simulator and is not part of inference.
+6. **Do not** add `efficient-fpt` to `pyproject.toml`. HSSM already depends on `jax`/`jaxlib`, so the vendored code introduces no new dependencies.
+
+**License/attribution:** efficient-fpt ships under a permissive license (see `efficient-fpt/LICENSE`); copy that license text into `src/hssm/addm/likelihoods/jax/LICENSE` (or `NOTICE`) so attribution travels with the code. If HSSM and efficient-fpt share authors/license, a brief `# Adapted from efficient-fpt (commit )` header is sufficient.
+
+**Rationale:** efficient-fpt is not on PyPI; its Cython compile chain is heavyweight and irrelevant to HSSM (HSSM only needs the inference likelihood, not the simulator); and pinning to a remote git dep would couple HSSM's CI to an unstable upstream. Vendoring lets HSSM ship a frozen, audited copy that evolves on HSSM's release cadence.
+
+**Drift management:** efficient-fpt continues to be the research home for the likelihood. When upstream changes, the vendored copy can be re-synced by re-copying the three files and rebuilding tests. The upstream-commit header in step 4 makes "what version are we on?" trivially answerable. (This is the same pattern HSSM already uses for `bayesflow` from a dev branch, just snapshotted instead of git-tracked.)
+
+### Step 2 — Define the aDDM submodule layout
+
+Create a new package under `src/hssm/addm/`, mirroring the **post-rebase** `src/hssm/rl/` layout (`config.py` and `rlssm.py` at the package root, plus a `likelihoods/` subpackage):
+
+```
+src/hssm/addm/
+ __init__.py # exports aDDM, aDDMConfig
+ config.py # aDDMConfig — peer of hssm.rl.config.RLSSMConfig
+ addm.py # aDDM(HSSMBase) — peer of hssm.rl.rlssm.RLSSM
+ attention_process.py # pluggable fixation/attention models (default: standard_alternating)
+ utils.py # validate_addm_panel(...) — peer of hssm.rl.utils.validate_balanced_panel
+ likelihoods/
+ __init__.py
+ builder.py # make_addm_logp_func / make_addm_logp_op
+ addm_jax.py # thin wrapper that imports from .jax and applies the attention process
+ jax/ # vendored from efficient_fpt_jax (Step 1)
+ __init__.py
+ multi_stage.py
+ single_stage.py
+ utils.py
+```
+
+**Rationale:** aDDM is conceptually a two-stage model (attention process → SSM likelihood) just like RLSSM (learning process → SSM likelihood). After the rebase, RLSSM is a real `HSSMBase` subclass with its own subpackage; the aDDM subpackage now mirrors that exactly — same files (`config.py`, `.py`, `utils.py`, `likelihoods/builder.py`), same conventions. The vendored JAX code lives in its own `jax/` subdirectory so it stays clearly identifiable as upstream-derived, isolated from HSSM-original code in `builder.py` and `addm_jax.py`.
+
+### Step 3 — Add `aDDMConfig` dataclass in `src/hssm/addm/config.py`
+
+**Critical file:** [src/hssm/addm/config.py](data/azhang/HSSM/src/hssm/addm/config.py) (new file) — peer of [src/hssm/rl/config.py](data/azhang/HSSM/src/hssm/rl/config.py).
+
+> **Note:** `aDDMConfig` is **not** added to `src/hssm/config.py`. After the rebase, family-specific configs live in their own subpackages (`hssm.rl.config.RLSSMConfig`, `hssm.addm.config.aDDMConfig`), with only `BaseModelConfig`, `Config`, and `ModelConfig` remaining in `hssm.config`.
+
+```python
+from dataclasses import dataclass, field
+from typing import Any, Callable
+from ..config import BaseModelConfig, DEFAULT_SSM_OBSERVED_DATA, DEFAULT_SSM_CHOICES
+
+@dataclass
+class aDDMConfig(BaseModelConfig):
+ """Config for the attentional DDM."""
+
+ # Required (kw_only) fields — pattern borrowed from RLSSMConfig
+ params_default: list[float] = field(kw_only=True)
+ attention_process: str | Callable | dict[str, Any] = field(
+ kw_only=True, default="standard_alternating"
+ )
+
+ # aDDM-specific extra-field column names (defaultable)
+ # These are *data* columns, not sampled parameters.
+ extra_fields: list[str] | None = field(
+ default_factory=lambda: ["r1", "r2", "sacc_array", "d", "flag"]
+ )
+
+ def __post_init__(self):
+ if self.loglik_kind is None:
+ self.loglik_kind = "approx_differentiable"
+
+ @classmethod
+ def from_defaults(cls, model_name, loglik_kind):
+ raise NotImplementedError(
+ "aDDMConfig does not support from_defaults(). "
+ "Use the aDDM(...) constructor directly, or pass an aDDMConfig "
+ "instance built explicitly."
+ )
+
+ def validate(self) -> None:
+ # Mirror RLSSMConfig.validate: required fields, params_default vs
+ # list_params length parity, every list_params entry has bounds, etc.
+ ...
+
+ def get_defaults(self, param):
+ return None, self.bounds.get(param)
+```
+
+**Key design decisions (post-rebase):**
+
+- **No `to_config()` method.** The new architecture has `HSSMBase` accept any `BaseModelConfig`; family-specific subclasses build the `loglik` `Op` themselves and stamp it onto the config via `dataclasses.replace(...)`. `RLSSMConfig` no longer has `to_config()` and neither will `aDDMConfig`.
+- **No `from_defaults` registration.** Like `RLSSMConfig`, `aDDMConfig` raises `NotImplementedError` from `from_defaults`. Users construct it explicitly (or via a `from_addm_dict` classmethod, optional). Therefore **`aDDM` is *not* registered through the `default_model_config` / `register_model` pipeline** that `HSSM(model="ddm", ...)` uses — instead, users instantiate `hssm.aDDM(...)` directly.
+- `extra_fields` defaults to the five aDDM-specific columns the JAX likelihood needs; these flow through the existing extra-fields machinery (data validator → `Op` `*args`) the same way they do for RLSSM.
+- `attention_process` is a pluggable hook (default `"standard_alternating"`) that maps `(r1, r2, flag, eta, kappa) → mu_array_padded` per trial. This mirrors `RLSSMConfig.learning_process` semantically (declarative documentation; the actual callable is resolved by the builder).
+- `list_params` covers the sampled parameters: `["eta", "kappa", "sigma", "a", "b", "x0"]`. Non-decision time `t` is deliberately omitted initially.
+- `bounds` is required (every `list_params` entry must have an entry), matching `RLSSMConfig.validate`'s post-rebase strictness.
+
+**Reuses:** `BaseModelConfig` ([src/hssm/config.py:48](data/azhang/HSSM/src/hssm/config.py#L48)), defaults `DEFAULT_SSM_OBSERVED_DATA` / `DEFAULT_SSM_CHOICES` ([src/hssm/config.py:24-26](data/azhang/HSSM/src/hssm/config.py#L24-L26)).
+
+### Step 3b — Add `aDDM` subclass in `src/hssm/addm/addm.py`
+
+**Critical file:** [src/hssm/addm/addm.py](data/azhang/HSSM/src/hssm/addm/addm.py) (new file) — peer of [src/hssm/rl/rlssm.py](data/azhang/HSSM/src/hssm/rl/rlssm.py).
+
+This step is **new in the post-rebase plan.** Mirror `RLSSM` (264 lines):
+
+```python
+from dataclasses import replace
+from typing import Any, Callable, Literal, cast
+import bambi as bmb
+import pandas as pd
+import pymc as pm
+
+from hssm.distribution_utils import make_distribution
+from hssm.defaults import INITVAL_JITTER_SETTINGS
+from ..base import HSSMBase
+from .config import aDDMConfig
+from .likelihoods.builder import make_addm_logp_op
+from .utils import validate_addm_panel
+
+
+class aDDM(HSSMBase):
+ def __init__(
+ self,
+ data: pd.DataFrame,
+ model_config: aDDMConfig,
+ include: list | None = None,
+ p_outlier: float | dict | bmb.Prior | None = 0.05,
+ lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0),
+ link_settings: Literal["log_logit"] | None = None,
+ prior_settings: Literal["safe"] | None = "safe",
+ extra_namespace: dict | None = None,
+ missing_data: bool | float = False,
+ deadline: bool | str = False,
+ loglik_missing_data: Callable | None = None,
+ process_initvals: bool = True,
+ initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"],
+ **kwargs,
+ ):
+ self._init_args = self._store_init_args(locals(), kwargs)
+ model_config.validate()
+
+ # Same row-order argument as RLSSM: per-trial sacc_array shape and the
+ # JAX vmap rely on strict 1:1 row→trial correspondence; reordering
+ # missing/deadline rows would break the alignment.
+ if missing_data is not False or deadline is not False:
+ raise NotImplementedError(
+ "aDDM does not support `missing_data` or `deadline` handling..."
+ )
+
+ validate_addm_panel(data, model_config.extra_fields)
+
+ loglik_op = make_addm_logp_op(
+ attention_process=model_config.attention_process,
+ data_cols=list(model_config.response),
+ list_params=list(model_config.list_params),
+ extra_fields=list(model_config.extra_fields or []),
+ )
+
+ # Stamp the Op + backend onto a fresh config copy (do NOT mutate the
+ # caller's config). Identical pattern to RLSSM.__init__.
+ model_config = replace(model_config, loglik=loglik_op, backend="jax")
+
+ super().__init__(
+ data=data, model_config=model_config, include=include,
+ p_outlier=p_outlier, lapse=lapse,
+ link_settings=link_settings, prior_settings=prior_settings,
+ extra_namespace=extra_namespace,
+ missing_data=missing_data, deadline=deadline,
+ loglik_missing_data=loglik_missing_data,
+ process_initvals=process_initvals, initval_jitter=initval_jitter,
+ **kwargs,
+ )
+
+ def _make_model_distribution(self) -> type[pm.Distribution]:
+ # Mirror RLSSM._make_model_distribution: bypass loglik_kind dispatching
+ # and feed the pre-built Op directly into make_distribution.
+ ...
+```
+
+**Key design decisions:**
+
+- **`aDDM` is a peer of `HSSM` and `RLSSM`**, all three inheriting from `HSSMBase`. It is exported from `hssm/__init__.py` as `hssm.aDDM`.
+- **`_make_model_distribution` is overridden** (same as `RLSSM`) because the aDDM `Op` already encapsulates the attention process + per-trial vmap; the standard `make_likelihood_callable` dispatch on `loglik_kind` should be bypassed.
+- **No `participant_col`-style panel reshape** — unlike RLSSM (which reshapes rows into `(n_participants, n_trials, ...)` because the RL learning rule is per-subject), aDDM's likelihood is per-trial. The vmap inside the JAX `Op` handles the trial dimension; participants flow through bambi/HSSM hierarchical regression as usual.
+- **`missing_data` / `deadline` are rejected up front**, same as RLSSM, because rearranging rows would break the strict trial→`sacc_array`-row correspondence.
+
+**Reuses:** `HSSMBase` ([src/hssm/base.py:92](data/azhang/HSSM/src/hssm/base.py#L92)), `make_distribution` ([src/hssm/distribution_utils](data/azhang/HSSM/src/hssm/distribution_utils)), `INITVAL_JITTER_SETTINGS` (defaults.py).
+
+### Step 4 — Build the likelihood op in `addm/likelihoods/builder.py`
+
+Mirror [hssm/rl/likelihoods/builder.py](data/azhang/HSSM/src/hssm/rl/likelihoods/builder.py) (`make_rl_logp_func`, `make_rl_logp_op`).
+
+Two functions:
+
+1. **`make_addm_logp_func(attention_process)`** — returns a callable `logp(data, *args)` where:
+ - `data[:, 0]` = rt, `data[:, 1]` = response
+ - `args` are sampled parameters in `list_params` order: `eta, kappa, sigma, a, b, x0`
+ - extra fields `r1, r2, sacc_array, d, flag` are appended to `args` by the HSSM extra-field machinery (exactly as RLSSM does; see [data_validator.py:156](data/azhang/HSSM/src/hssm/data_validator.py#L156)).
+ - Internally: call the attention process to build `mu_array_padded`, then call `get_addm_fptd_jax_fast(t=rt, d=d, mu_array=mu_array, sacc_array=sacc_array, sigma, a, b, x0)` vmapped over trials.
+
+2. **`make_addm_logp_op(attention_process)`** — wraps the JAX logp as a PyTensor `Op` with VJP, using the same pattern as `make_rl_logp_op`. This gives NUTS gradients for free.
+
+**Reuses:**
+- `get_addm_fptd_jax_fast` and `pad_sacc_array_safely` — imported from the vendored `hssm.addm.likelihoods.jax` package (Step 1).
+- `make_likelihood_callable` ([distribution_utils/dist.py:718](data/azhang/HSSM/src/hssm/distribution_utils/dist.py)) for PyTensor wrapping conventions.
+- `apply_param_bounds_to_loglik` ([distribution_utils/dist.py:40–79](data/azhang/HSSM/src/hssm/distribution_utils/dist.py)) for parameter-bound enforcement.
+
+**Import contract:** `builder.py` imports the vendored likelihood as `from hssm.addm.likelihoods.jax import get_addm_fptd_jax_fast, pad_sacc_array_safely`. No code outside `hssm.addm` should import from the vendored `jax/` subpackage directly — keeping the import surface narrow makes a future re-vendor (or replacement with an upstream PyPI release, should one ever appear) a single-file change.
+
+### Step 5 — Attention process ("learning process" analog)
+
+File: `src/hssm/addm/attention_process.py`.
+
+Default implementation `standard_alternating(r1, r2, flag, eta, kappa, max_d) -> mu_padded`:
+
+```
+mu1 = kappa * (r1 - eta * r2)
+mu2 = kappa * (eta * r1 - r2)
+# alternate mu1/mu2 by stage parity, respecting `flag` for first fixation
+# return shape (n_trials, max_d)
+```
+
+This reproduces the logic in [efficient-fpt addm.py:_build_mu_data_padded](data/azhang/efficient-fpt/src/efficient_fpt/addm.py) but in JAX for autodiff. Expose it via a registry so future variants (e.g., bias, drift offsets) can be registered by name — the same way `RLSSMConfig.learning_process` accepts either a string or dict.
+
+### Step 6 — Export `aDDM` and `aDDMConfig` from the top-level package
+
+**Critical files:**
+- `src/hssm/addm/__init__.py` (new) — re-export `aDDM` and `aDDMConfig`, mirroring [src/hssm/rl/__init__.py](data/azhang/HSSM/src/hssm/rl/__init__.py).
+- [src/hssm/__init__.py](data/azhang/HSSM/src/hssm/__init__.py) — add `from .addm import aDDM` (line 22 already has `from .rl import RLSSM`) and add `"aDDM"` to `__all__`.
+
+> **Architecture update:** `aDDM` follows the **`RLSSM` registration pattern, not the `HSSM(model=...)` pattern**. Like `RLSSM`, it is *not* added to `default_model_config` and *not* registered through `register_model`. Users instantiate it directly:
+> ```python
+> import hssm
+> model = hssm.aDDM(data=addm_trial_df, model_config=cfg, ...)
+> ```
+> No `src/hssm/modelconfig/addm_config.py` is created — that directory holds dicts consumed by `register_model` for the analytical/ONNX-based `HSSM(model="ddm", ...)` flow, which doesn't apply to subclass-based families like RLSSM and aDDM.
+
+**Reuses:** `hssm/__init__.py` re-export pattern (see existing `RLSSM` import at line 22).
+
+### Step 7 — Data validation for aDDM-specific columns
+
+**Critical file:** [src/hssm/data_validator.py](data/azhang/HSSM/src/hssm/data_validator.py).
+
+The DataValidatorMixin currently validates that `extra_fields` columns exist ([line 46](data/azhang/HSSM/src/hssm/data_validator.py#L46)). aDDM also needs **shape validation** because `sacc_array` is a 2D array-of-arrays stored inside a DataFrame column.
+
+Add an optional `_validate_addm_columns()` method invoked when `model_config.model_name == "addm"`:
+- `r1`, `r2`, `d`, `flag` must be 1D numeric with length `n_trials`.
+- `sacc_array` must be a 2D array of shape `(n_trials, max_d)` (or a column of variable-length lists, padded internally via `pad_sacc_array_safely`).
+- `d[i] <= sacc_array.shape[1]`.
+
+Minimally invasive: put the hook in `_post_check_data_sanity` and no-op for non-aDDM models.
+
+### Step 8 — Tests
+
+Create a `tests/addm/` directory mirroring [tests/rl/](data/azhang/HSSM/tests/rl/):
+
+- `tests/addm/test_addm_config.py` — patterned after [tests/test_rlssm_config.py](data/azhang/HSSM/tests/test_rlssm_config.py).
+- `tests/addm/test_addm.py` — patterned after [tests/test_rlssm.py](data/azhang/HSSM/tests/test_rlssm.py).
+- `tests/addm/test_addm_builder_output_shape.py` — patterned after [tests/test_rl_builder_output_shape.py](data/azhang/HSSM/tests/test_rl_builder_output_shape.py).
+- `tests/addm/test_addm_likelihood.py` — patterned after [tests/test_rldm_likelihood.py](data/azhang/HSSM/tests/test_rldm_likelihood.py).
+
+Test classes:
+
+1. `TestaDDMConfigCreation` — build `aDDMConfig`, assert defaults, assert `validate()` raises on missing/inconsistent fields.
+2. `TestaDDMConfigFromDefaults` — assert `aDDMConfig.from_defaults(...)` raises `NotImplementedError` (mirrors RLSSMConfig).
+3. `TestaDDMLikelihood` — tiny synthetic dataset (10 trials), confirm `logp` is finite, gradient w.r.t. each parameter is finite, matches a direct call to the vendored `get_addm_fptd_jax_fast`.
+4. `TestaDDMEndToEnd` — 200-trial synthetic dataset, `hssm.aDDM(data=..., model_config=...)` builds, a single MCMC draw succeeds (smoke test, `draws=5, tune=5`).
+
+> Removed: the obsolete `TestaDDMConfigConversion` (`.to_config()` round-trip) — there is no `to_config()` method in the post-rebase architecture.
+
+**Reuse:** test fixtures from `tests/conftest.py`.
+
+### Step 9 — Tutorial notebook
+
+Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tutorials/rlssm_tutorial.ipynb`:
+- Load/simulate a small aDDM dataset (reuse `simulate_addm` from efficient-fpt example6).
+- Build the model with `hssm.aDDM(data=..., model_config=aDDMConfig(...))` — **not** `hssm.HSSM(model="addm", ...)`, since aDDM follows the RLSSM subclass pattern.
+- Add a hierarchical regression on `eta` (e.g., by participant) to showcase why using HSSM buys more than raw efficient-fpt.
+- Run `model.sample()` and plot posteriors via `arviz`.
+
+### Step 10 — Cleanup
+
+- Delete or rename the stale [addm_andrew_dev](data/azhang/HSSM/addm_andrew_dev) folder (it has a trailing space in its name, which is a foot-gun on many filesystems) once the new module is working.
+- Update `README.md` example list and `mkdocs.yml` nav to include the new tutorial.
+
+---
+
+## Files to be created
+
+**HSSM-original code (mirrors `src/hssm/rl/` layout post-rebase):**
+- `src/hssm/addm/__init__.py` — re-exports `aDDM`, `aDDMConfig`
+- `src/hssm/addm/config.py` — `aDDMConfig` dataclass *(peer of `hssm/rl/config.py`)*
+- `src/hssm/addm/addm.py` — `aDDM(HSSMBase)` class *(peer of `hssm/rl/rlssm.py`)*
+- `src/hssm/addm/utils.py` — `validate_addm_panel` *(peer of `hssm/rl/utils.py`)*
+- `src/hssm/addm/attention_process.py` — `standard_alternating` and registry
+- `src/hssm/addm/likelihoods/__init__.py`
+- `src/hssm/addm/likelihoods/builder.py` — `make_addm_logp_func`, `make_addm_logp_op`
+- `src/hssm/addm/likelihoods/addm_jax.py` — thin wrapper composing attention process + vendored JAX likelihood
+- `tests/addm/test_addm_config.py`
+- `tests/addm/test_addm.py`
+- `tests/addm/test_addm_builder_output_shape.py`
+- `tests/addm/test_addm_likelihood.py`
+- `docs/tutorials/addm_tutorial.ipynb`
+
+**Vendored from efficient-fpt (verbatim copies, kept in their own subpackage):**
+- `src/hssm/addm/likelihoods/jax/__init__.py`
+- `src/hssm/addm/likelihoods/jax/multi_stage.py` ← `efficient_fpt_jax/multi_stage.py`
+- `src/hssm/addm/likelihoods/jax/single_stage.py` ← `efficient_fpt_jax/single_stage.py`
+- `src/hssm/addm/likelihoods/jax/utils.py` ← `efficient_fpt_jax/utils.py`
+- `src/hssm/addm/likelihoods/jax/NOTICE` (or LICENSE) — upstream attribution and commit hash
+
+## Files to be modified
+
+- `src/hssm/__init__.py` — `from .addm import aDDM` and add `"aDDM"` to `__all__` (mirrors the existing `RLSSM` import on line 22).
+- `src/hssm/data_validator.py` — add aDDM column-shape validation hook.
+- `README.md`, `mkdocs.yml` — mention the new model.
+
+> **Removed from "files to be modified":**
+> - ~~`src/hssm/config.py`~~ — `aDDMConfig` lives in its own subpackage at `src/hssm/addm/config.py`, not in the central `config.py`. (RLSSMConfig was moved out of `hssm.config` in the rebase for the same reason.)
+> - ~~`src/hssm/defaults.py`~~ — aDDM is **not** registered in `default_model_config`; users instantiate `hssm.aDDM(...)` directly, same as `hssm.RLSSM(...)`.
+> - ~~`pyproject.toml`~~ — no new dependencies (JAX is already core).
+
+## Key functions/utilities to reuse (no re-implementation)
+
+| Purpose | Location |
+|---|---|
+| JAX FPT likelihood | `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` *(vendored)* |
+| Safe padding of saccade arrays | `hssm.addm.likelihoods.jax.pad_sacc_array_safely` *(vendored)* |
+| Likelihood `Op` builder pattern | [hssm/rl/likelihoods/builder.py](data/azhang/HSSM/src/hssm/rl/likelihoods/builder.py) |
+| Subclass `__init__` / `_make_model_distribution` pattern | [hssm/rl/rlssm.py](data/azhang/HSSM/src/hssm/rl/rlssm.py) |
+| Family-specific config dataclass pattern | [hssm/rl/config.RLSSMConfig](data/azhang/HSSM/src/hssm/rl/config.py) |
+| Abstract base for HSSM model classes | [hssm/base.HSSMBase](data/azhang/HSSM/src/hssm/base.py#L92) |
+| `make_distribution` (consumes pre-built loglik `Op`) | `hssm.distribution_utils.make_distribution` |
+| Extra-fields propagation into logp | [data_validator.DataValidatorMixin._update_extra_fields](data/azhang/HSSM/src/hssm/data_validator.py#L156) |
+| Param bound enforcement | [distribution_utils.dist.apply_param_bounds_to_loglik](data/azhang/HSSM/src/hssm/distribution_utils/dist.py#L40) |
+
+---
+
+## Verification
+
+End-to-end checks, in order:
+
+1. **Unit**: `pytest tests/addm/ -v` — all test files pass, including finite-gradient check.
+2. **Likelihood parity**: in `tests/addm/test_addm_likelihood.py`, assert HSSM's wrapped op returns the same value (to 1e-6) as a direct call to the vendored `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` on a 10-trial fixture. This confirms the HSSM extra-fields/op-wrapping plumbing does not corrupt the underlying JAX computation. (A separate, off-CI sanity script may also compare against an installed `efficient-fpt` checkout to detect drift between the vendored copy and upstream.)
+3. **Smoke sample**: `hssm.aDDM(data=synthetic_trials, model_config=cfg).sample(draws=5, tune=5)` completes without error and returns an `InferenceData`.
+4. **Parameter recovery**: larger off-CI script (e.g., `tests/scripts/addm_recovery.py`) — simulate 1000 trials with known `(eta, kappa, a, b, x0, sigma)`, fit via `hssm.aDDM(...)`, confirm posterior means within ~2σ of ground truth. Reuse the recovery setup from [efficient-fpt example8_empirical/parameter_recovery.ipynb](data/azhang/efficient-fpt/examples/example8_empirical).
+5. **Tutorial runs clean**: `jupyter nbconvert --execute docs/tutorials/addm_tutorial.ipynb` finishes without errors.
+6. **Docs build**: `mkdocs build` succeeds with the new tutorial in nav.
+
+## Open questions for the user
+
+1. ~~**Subclass vs config-only**~~ — *resolved by the rebase.* The new architecture (`HSSMBase` + `RLSSM` subclass) settles this: aDDM follows the subclass pattern, with both `aDDM(HSSMBase)` and `aDDMConfig(BaseModelConfig)` introduced.
+2. **Non-decision time `t`**: include in v1 as an additional sampled parameter (shift RTs), or defer?
+3. **Attention-process extensibility**: is the default `standard_alternating` enough, or should v1 already expose user-pluggable attention processes (e.g., non-alternating fixation patterns)?
+4. **Hierarchical regression target**: should the tutorial demonstrate hierarchical regression on `eta` (attention bias) by participant, or is a different parameter (e.g., `kappa`) more meaningful as a worked example?
diff --git a/docs/index.md b/docs/index.md
index eb7a337e6..ca0a8edd7 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -13,13 +13,15 @@

[](https://github.com/astral-sh/ruff)
-**HSSM** (Hierarchical Sequential Sampling Modeling) is a modern Python toolbox
-that provides state-of-the-art likelihood approximation methods within the
-Python Bayesian ecosystem. It facilitates hierarchical model building and
-inference via fast and robust MCMC samplers. User-friendly, extensible, and
-flexible, HSSM can rigorously estimate the impact of neural and other
-trial-by-trial covariates through parameter-wise mixed-effects models for a
-large variety of cognitive process models.
+**HSSM** (Hierarchical Sequential Sampling Modeling) is a modern open-source
+Python toolbox for computational modeling in cognitive neuroscience. It supports
+a broad range of sequential sampling models used to study decision-making,
+learning, and other cognitive processes — from basic research to the analysis of
+clinical effects. HSSM provides state-of-the-art likelihood approximation
+methods within the Python Bayesian ecosystem and facilitates hierarchical model
+building and inference via fast and robust MCMC samplers. User-friendly,
+extensible, and flexible, it can rigorously estimate the impact of neural and
+other trial-by-trial covariates through parameter-wise mixed-effects models.
HSSM is a [BRAINSTORM](https://ccbs.carney.brown.edu/brainstorm) project in
collaboration with the
diff --git a/docs/tutorials/rlssm_quickstart.ipynb b/docs/tutorials/rlssm_quickstart.ipynb
new file mode 100644
index 000000000..d976b1e8c
--- /dev/null
+++ b/docs/tutorials/rlssm_quickstart.ipynb
@@ -0,0 +1,351 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "1b9b429d",
+ "metadata": {},
+ "source": [
+ "# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n",
+ "\n",
+ "This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n",
+ "\n",
+ "1. **Load** a balanced-panel two-armed bandit dataset\n",
+ "2. **Define** an annotated learning function and the angle SSM log-likelihood\n",
+ "3. **Configure** and **instantiate** an `RLSSM` model\n",
+ "4. **Inspect** the built Bambi / PyMC model\n",
+ "5. **Run** a minimal 2-draw sampling smoke test\n",
+ "\n",
+ "For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n",
+ "- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n",
+ "- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf38d7f7",
+ "metadata": {},
+ "source": [
+ "## 1. Imports and Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6d764731",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path\n",
+ "\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "import hssm\n",
+ "from hssm.rl import RLSSM, RLSSMConfig\n",
+ "from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n",
+ "from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n",
+ "from hssm.utils import annotate_function\n",
+ "\n",
+ "# RLSSM requires float32 throughout (JAX default).\n",
+ "hssm.set_floatX(\"float32\", update_jax=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "df12303f",
+ "metadata": {},
+ "source": [
+ "## 2. Load the Dataset\n",
+ "\n",
+ "We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n",
+ "It is a **balanced panel**: every participant has the same number of trials. \n",
+ "Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n",
+ "\n",
+ "> **Note:** You can also generate data with\n",
+ "> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n",
+ "> See `rlssm_tutorial.ipynb` for an example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c2ef5f6e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Path relative to docs/tutorials/ when running inside the HSSM repo.\n",
+ "_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n",
+ "raw = np.load(_fixture_path, allow_pickle=True).item()\n",
+ "data = pd.DataFrame(raw[\"data\"])\n",
+ "\n",
+ "n_participants = data[\"participant_id\"].nunique()\n",
+ "n_trials = len(data) // n_participants\n",
+ "\n",
+ "print(data.head())\n",
+ "print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8c310290",
+ "metadata": {},
+ "source": [
+ "## 3. Define the Learning Process\n",
+ "\n",
+ "The RL learning process is a JAX function that, given a subject's trial sequence, computes\n",
+ "the trial-wise drift rate `v` via a Q-learning update rule. \n",
+ "\n",
+ "`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n",
+ "that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n",
+ "decision process.\n",
+ "\n",
+ "- **inputs** — columns that the function reads (free parameters + data columns)\n",
+ "- **outputs** — what the function produces (here: `v`, the drift rate)\n",
+ "\n",
+ "Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n",
+ "Rescorla-Wagner Q-learning update for a two-armed bandit task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bbcea122",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "compute_v_annotated = annotate_function(\n",
+ " inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n",
+ " outputs=[\"v\"],\n",
+ ")(compute_v_subject_wise)\n",
+ "\n",
+ "print(\"Learning function inputs :\", compute_v_annotated.inputs)\n",
+ "print(\"Learning function outputs:\", compute_v_annotated.outputs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7a03305a",
+ "metadata": {},
+ "source": [
+ "## 4. Define the Decision (SSM) Log-Likelihood\n",
+ "\n",
+ "The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n",
+ "`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n",
+ "2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n",
+ "per-trial log-probabilities.\n",
+ "\n",
+ "We then annotate that callable so the builder knows:\n",
+ "- which columns the matrix contains (`inputs`)\n",
+ "- that `v` itself is *computed* by the learning function (not a free parameter)\n",
+ "\n",
+ "The ONNX file is loaded from the local test fixture when running inside the HSSM\n",
+ "repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "60bbc036",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use the local fixture when available; fall back to HuggingFace download.\n",
+ "_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n",
+ "_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n",
+ "\n",
+ "_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n",
+ "\n",
+ "angle_logp_func = annotate_function(\n",
+ " inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n",
+ " outputs=[\"logp\"],\n",
+ " computed={\"v\": compute_v_annotated},\n",
+ ")(_angle_logp_jax)\n",
+ "\n",
+ "print(\"SSM logp inputs :\", angle_logp_func.inputs)\n",
+ "print(\"SSM logp outputs:\", angle_logp_func.outputs)\n",
+ "print(\"Computed deps :\", list(angle_logp_func.computed.keys()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cf8f5b63",
+ "metadata": {},
+ "source": [
+ "## 5. Configure the Model with `RLSSMConfig`\n",
+ "\n",
+ "`RLSSMConfig` collects all the information the RLSSM class needs:\n",
+ "\n",
+ "| Field | Purpose |\n",
+ "|-------|---------|\n",
+ "| `model_name` | Identifier string for the configuration |\n",
+ "| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n",
+ "| `list_params` | Ordered list of *free* parameters to sample |\n",
+ "| `params_default` | Starting / default values for each parameter |\n",
+ "| `bounds` | Prior bounds for each parameter |\n",
+ "| `learning_process` | Dict mapping computed param name → annotated learning function |\n",
+ "| `extra_fields` | Extra data columns required by the learning function |\n",
+ "| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4beba1bc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rlssm_config = RLSSMConfig(\n",
+ " model_name=\"rlssm_angle_quickstart\",\n",
+ " loglik_kind=\"approx_differentiable\",\n",
+ " decision_process=\"angle\",\n",
+ " decision_process_loglik_kind=\"approx_differentiable\",\n",
+ " learning_process_kind=\"blackbox\",\n",
+ " list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n",
+ " params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n",
+ " bounds={\n",
+ " \"rl_alpha\": (0.0, 1.0),\n",
+ " \"scaler\": (0.0, 10.0),\n",
+ " \"a\": (0.1, 3.0),\n",
+ " \"theta\": (-0.1, 0.1),\n",
+ " \"t\": (0.001, 1.0),\n",
+ " \"z\": (0.1, 0.9),\n",
+ " },\n",
+ " learning_process={\"v\": compute_v_annotated},\n",
+ " response=[\"rt\", \"response\"],\n",
+ " choices=[0, 1],\n",
+ " extra_fields=[\"feedback\"],\n",
+ " ssm_logp_func=angle_logp_func,\n",
+ ")\n",
+ "\n",
+ "print(\"Model name :\", rlssm_config.model_name)\n",
+ "print(\"Free params :\", rlssm_config.list_params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "924ee4c7",
+ "metadata": {},
+ "source": [
+ "## 6. Instantiate the `RLSSM` Model\n",
+ "\n",
+ "Passing `data` and `rlssm_config` to `RLSSM`:\n",
+ "\n",
+ "- validates the balanced-panel requirement\n",
+ "- builds a differentiable PyTensor Op that chains the RL learning step and the\n",
+ " angle log-likelihood\n",
+ "- constructs the Bambi / PyMC model internally\n",
+ "\n",
+ "Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n",
+ "the Op by the Q-learning update and therefore does not appear in `model.params`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1f8da79a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = RLSSM(data=data, model_config=rlssm_config)\n",
+ "\n",
+ "assert isinstance(model, RLSSM)\n",
+ "print(\"Model type :\", type(model).__name__)\n",
+ "print(\"Participants :\", model.n_participants)\n",
+ "print(\"Trials/subj :\", model.n_trials)\n",
+ "print(\"Free parameters :\", list(model.params.keys()))\n",
+ "assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n",
+ "assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n",
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7f39940",
+ "metadata": {},
+ "source": [
+ "## 7. Inspect the Built Model\n",
+ "\n",
+ "After construction, `model.model` exposes the underlying **Bambi model** and\n",
+ "`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n",
+ "or customizing priors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b0558ad4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"=== Bambi model ===\")\n",
+ "print(model.model)\n",
+ "\n",
+ "print(\"\\n=== PyMC model ===\")\n",
+ "print(model.pymc_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f4e50110",
+ "metadata": {},
+ "source": [
+ "## 8. Sampling\n",
+ "\n",
+ "A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n",
+ "computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "96ce3238",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n",
+ "\n",
+ "assert trace is not None\n",
+ "print(trace)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a784a468",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "This notebook showed how to:\n",
+ "\n",
+ "1. Load a balanced-panel dataset (`rldm_data.npy`)\n",
+ "2. Annotate a Q-learning function with `annotate_function`\n",
+ "3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n",
+ "4. Define an `RLSSMConfig` and pass it to `RLSSM`\n",
+ "5. Confirm model structure (free params, Bambi / PyMC objects)\n",
+ "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "hssm",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/mkdocs.yml b/mkdocs.yml
index 93b2696fd..0ef2ad1b4 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -44,6 +44,7 @@ nav:
- Hierarchical Variational Inference: tutorials/variational_inference_hierarchical.ipynb
- Using HSSM low-level API directly with PyMC: tutorials/pymc.ipynb
- Reinforcement Learning - Sequential Sampling Models (RLSSM): tutorials/rlssm_tutorial.ipynb
+ - RLSSM Quickstart: tutorials/rlssm_quickstart.ipynb
- Add custom RLSSM models: tutorials/add_custom_rlssm_model.ipynb
- Custom models: tutorials/jax_callable_contribution_onnx_example.ipynb
- BayesFlow LRE Integration: tutorials/bayesflow_lre_integration.ipynb
@@ -96,6 +97,7 @@ plugins:
- tutorials/hssm_tutorial_workshop_2.ipynb
- tutorials/add_custom_rlssm_model.ipynb
- tutorials/rlssm_tutorial.ipynb
+ - tutorials/rlssm_quickstart.ipynb
- tutorials/lapse_prob_and_dist.ipynb
- tutorials/plotting.ipynb
- tutorials/scientific_workflow_hssm.ipynb
diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py
index 60dd71020..2f234d086 100644
--- a/src/hssm/__init__.py
+++ b/src/hssm/__init__.py
@@ -19,6 +19,7 @@
from .param import UserParam as Param
from .prior import Prior
from .register import register_model
+from .rl import RLSSM
from .simulator import simulate_data
from .utils import check_data_for_rl, set_floatX
@@ -31,6 +32,7 @@
__all__ = [
"HSSM",
+ "RLSSM",
"Link",
"load_data",
"ModelConfig",
diff --git a/src/hssm/base.py b/src/hssm/base.py
new file mode 100644
index 000000000..806e0eded
--- /dev/null
+++ b/src/hssm/base.py
@@ -0,0 +1,2160 @@
+"""HSSM: Hierarchical Sequential Sampling Models.
+
+A package based on pymc and bambi to perform Bayesian inference for hierarchical
+sequential sampling models.
+
+This file defines the entry class HSSM.
+"""
+
+import datetime
+import logging
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from os import PathLike
+from pathlib import Path
+from typing import Any, Callable, Literal, Optional, Union, cast, get_args
+
+import arviz as az
+import bambi as bmb
+import cloudpickle as cpickle
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import pymc as pm
+import pytensor
+import seaborn as sns
+import xarray as xr
+from bambi.model_components import DistributionalComponent
+from bambi.transformations import transformations_namespace
+from pymc.model.transform.conditioning import do
+
+from hssm._types import SupportedModels
+from hssm.data_validator import DataValidatorMixin
+from hssm.defaults import (
+ INITVAL_JITTER_SETTINGS,
+ INITVAL_SETTINGS,
+)
+from hssm.distribution_utils import (
+ make_family,
+)
+from hssm.missing_data_mixin import MissingDataMixin
+from hssm.utils import (
+ _compute_log_likelihood,
+ _get_alias_dict,
+ _print_prior,
+ _split_array,
+)
+
+from . import plotting
+from .config import BaseModelConfig
+from .param import Params
+from .param import UserParam as Param
+
+_logger = logging.getLogger("hssm")
+
+# NOTE: Temporary mapping from old sampler names to new ones in bambi 0.16.0
+_new_sampler_mapping: dict[str, Literal["pymc", "numpyro", "blackjax"]] = {
+ "mcmc": "pymc",
+ "nuts_numpyro": "numpyro",
+ "nuts_blackjax": "blackjax",
+}
+
+
+class classproperty:
+ """A decorator that combines the behavior of @property and @classmethod.
+
+ This decorator allows you to define a property that can be accessed on the class
+ itself, rather than on instances of the class. It is useful for defining class-level
+ properties that need to perform some computation or access class-level data.
+
+ This implementation is provided for compatibility with Python versions 3.10 through
+ 3.12, as one cannot combine the @property and @classmethod decorators across all
+ these versions.
+
+ Example
+ -------
+ class MyClass:
+ @classproperty
+ def my_class_property(cls):
+ return "This is a class property"
+
+ print(MyClass.my_class_property) # Output: This is a class property
+ """
+
+ def __init__(self, fget):
+ self.fget = fget
+
+ def __get__(self, instance, owner): # noqa: D105
+ return self.fget(owner)
+
+
+class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin):
+ """The basic Hierarchical Sequential Sampling Model (HSSM) class.
+
+ Parameters
+ ----------
+ data
+ A pandas DataFrame with the minimum requirements of containing the data with the
+ columns "rt" and "response".
+ model
+ The name of the model to use. Currently supported models are "ddm", "ddm_sdv",
+ "full_ddm", "angle", "levy", "ornstein", "weibull", "race_no_bias_angle_4",
+ "ddm_seq2_no_bias". If any other string is passed, the model will be considered
+ custom, in which case all `model_config`, `loglik`, and `loglik_kind` have to be
+ provided by the user.
+ choices : optional
+ When an `int`, the number of choices that the participants can make. If `2`, the
+ choices are [-1, 1] by default. If anything greater than `2`, the choices are
+ [0, 1, ..., n_choices - 1] by default. If a `list` is provided, it should be the
+ list of choices that the participants can make. Defaults to `2`. If any value
+ other than the choices provided is found in the "response" column of the data,
+ an error will be raised.
+ include : optional
+ A list of dictionaries specifying parameter specifications to include in the
+ model. If left unspecified, defaults will be used for all parameter
+ specifications. Defaults to None.
+ model_config
+ A fully initialised :class:`~hssm.config.BaseModelConfig` instance
+ (typically :class:`~hssm.config.Config`) produced by the subclass
+ before calling ``super().__init__``. All likelihood, parameter, and
+ data information used by :class:`HSSMBase` is drawn from this object,
+ and it must provide populated ``loglik`` and ``list_params`` fields.
+ p_outlier : optional
+ The fixed lapse probability or the prior distribution of the lapse probability.
+ Defaults to a fixed value of 0.05. When `None`, the lapse probability will not
+ be included in estimation.
+ lapse : optional
+ The lapse distribution. This argument is required only if `p_outlier` is not
+ `None`. Defaults to Uniform(0.0, 10.0).
+ global_formula : optional
+ A string that specifies a regressions formula which will be used for all model
+ parameters. If you specify parameter-wise regressions in addition, these will
+ override the global regression for the respective parameter.
+ link_settings : optional
+ An optional string literal that indicates the link functions to use for each
+ parameter. Helpful for hierarchical models where sampling might get stuck/
+ very slow. Can be one of the following:
+
+ - `"log_logit"`: applies log link functions to positive parameters and
+ generalized logit link functions to parameters that have explicit bounds.
+ - `None`: unless otherwise specified, the `"identity"` link functions will be
+ used.
+ The default value is `None`.
+ prior_settings : optional
+ An optional string literal that indicates the prior distributions to use for
+ each parameter. Helpful for hierarchical models where sampling might get stuck/
+ very slow. Can be one of the following:
+
+ - `"safe"`: HSSM will scan all parameters in the model and apply safe priors to
+ all parameters that do not have explicit bounds.
+ - None: HSSM will use bambi to provide default priors for all parameters. Not
+ recommended when you are using hierarchical models.
+ The default value is `"safe"`.
+ extra_namespace : optional
+ Additional user supplied variables with transformations or data to include in
+ the environment where the formula is evaluated. Defaults to `None`.
+ missing_data : optional
+ Specifies whether the model should handle missing data. Can be a `bool` or a
+ `float`. If `False`, and if the `rt` column contains in the data -999.0,
+ the model will drop these rows and produce a warning. If `True`, the model will
+ treat code -999.0 as missing data. If a `float` is provided, the model will
+ treat this value as the missing data value. Defaults to `False`.
+ deadline : optional
+ Specifies whether the model should handle deadline data. Can be a `bool` or a
+ `str`. If `False`, the model will not do nothing even if a deadline column is
+ provided. If `True`, the model will treat the `deadline` column as deadline
+ data. If a `str` is provided, the model will treat this value as the name of the
+ deadline column. Defaults to `False`.
+ loglik_missing_data : optional
+ A likelihood function for missing data. Please see the `loglik` parameter to see
+ how to specify the likelihood function this parameter. If nothing is provided,
+ a default likelihood function will be used. This parameter is required only if
+ either `missing_data` or `deadline` is not `False`. Defaults to `None`.
+ process_initvals : optional
+ If `True`, the model will process the initial values. Defaults to `True`.
+ initval_jitter : optional
+ The jitter value for the initial values. Defaults to `0.01`.
+ **kwargs
+ Additional arguments passed to the `bmb.Model` object.
+
+ Attributes
+ ----------
+ data
+ A pandas DataFrame with at least two columns of "rt" and "response" indicating
+ the response time and responses.
+ list_params
+ The list of strs of parameter names.
+ model_name
+ The name of the model.
+ loglik:
+ The likelihood function or a path to an onnx file.
+ loglik_kind:
+ The kind of likelihood used.
+ model_config
+ A dictionary representing the model configuration.
+ model_distribution
+ The likelihood function of the model in the form of a pm.Distribution subclass.
+ family
+ A Bambi family object.
+ priors
+ A dictionary containing the prior distribution of parameters.
+ formula
+ A string representing the model formula.
+ link
+ A string or a dictionary representing the link functions for all parameters.
+ params
+ A list of Param objects representing model parameters.
+ initval_jitter
+ The jitter value for the initial values.
+ """
+
+ def __init__(
+ self,
+ data: pd.DataFrame,
+ model_config: BaseModelConfig,
+ include: list[dict[str, Any] | Param] | None = None,
+ p_outlier: float | dict | bmb.Prior | None = 0.05,
+ lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0),
+ global_formula: str | None = None,
+ link_settings: Literal["log_logit"] | None = None,
+ prior_settings: Literal["safe"] | None = "safe",
+ extra_namespace: dict[str, Any] | None = None,
+ missing_data: bool | float = False,
+ deadline: bool | str = False,
+ loglik_missing_data: (
+ str | PathLike | Callable | pytensor.graph.Op | None
+ ) = None,
+ process_initvals: bool = True,
+ initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"],
+ **kwargs,
+ ):
+ # ===== Input Data & Configuration =====
+ self.data = data.copy()
+ self.global_formula = global_formula
+ self.link_settings = link_settings
+ self.prior_settings = prior_settings
+ self.missing_data_value = -999.0
+
+ # Store a safe default for the constructor-arguments snapshot so that
+ # pickling / save-load cannot raise AttributeError if a subclass forgets
+ # to call `_store_init_args(locals(), kwargs)` early. Subclasses are
+ # still expected to overwrite this with the real snapshot. However,
+ # do not overwrite if a subclass already set `_init_args` prior to
+ # calling `super().__init__()` (the subclass may capture its
+ # constructor args before delegating to the base class).
+ if not hasattr(self, "_init_args"):
+ self._init_args: dict[str, Any] = {}
+
+ # Set up additional namespace for formula evaluation
+ additional_namespace = transformations_namespace.copy()
+ if extra_namespace is not None:
+ additional_namespace.update(extra_namespace)
+ self.additional_namespace = additional_namespace
+
+ # region ===== Inference Results (initialized to None/empty) =====
+ self._inference_obj: az.InferenceData | None = None
+ self._inference_obj_vi: pm.Approximation | None = None
+ self._vi_approx = None
+ self._map_dict = None
+ # endregion
+
+ # ===== Initial Values Configuration =====
+ self._initvals: dict[str, Any] = {}
+ self.initval_jitter = initval_jitter
+
+ # region ===== Store the pre-built config =====
+ self.model_config: BaseModelConfig = model_config
+ # endregion
+
+ # region ===== Set up shortcuts so old code will work ======
+ self.response: list[str] = ( # type: ignore[assignment]
+ list(self.model_config.response)
+ if self.model_config.response is not None
+ else []
+ )
+ self.list_params = (
+ list(self.model_config.list_params)
+ if self.model_config.list_params is not None
+ else None
+ )
+ self.choices = self.model_config.choices # type: ignore[assignment]
+ self.model_name = self.model_config.model_name
+ self.loglik = self.model_config.loglik
+ self.loglik_kind = self.model_config.loglik_kind
+ self.extra_fields = self.model_config.extra_fields
+ # endregion
+
+ # TODO: add to HSSMBase
+ self.response = cast("list[str]", self.response)
+ self.is_choice_only: bool = self.model_config.is_choice_only
+
+ if self.choices is None:
+ raise ValueError(
+ "`choices` must be provided either in `model_config` or as an argument."
+ )
+
+ self._validate_choices()
+
+ # region Avoid mypy error later (None.append). Should list_params be Optional?
+ if self.list_params is None:
+ raise ValueError(
+ "`list_params` must be provided in the model configuration."
+ )
+ # endregion
+
+ self.n_choices = len(self.choices) # type: ignore[arg-type]
+
+ self._pre_check_data_sanity()
+
+ self._process_missing_data_and_deadline(
+ missing_data=missing_data,
+ deadline=deadline,
+ loglik_missing_data=loglik_missing_data,
+ )
+
+ # region ===== Process lapse distribution =====
+ self.has_lapse = p_outlier is not None and p_outlier != 0
+ self._check_lapse(lapse)
+ if self.has_lapse and self.list_params[-1] != "p_outlier":
+ self.list_params.append("p_outlier")
+ # endregion
+
+ # Process all parameters
+ self.params = Params.from_user_specs(
+ model=self, # type: ignore[arg-type]
+ include=[] if include is None else include,
+ kwargs=kwargs,
+ p_outlier=p_outlier,
+ )
+ self._parent = self.params.parent
+ self._parent_param = self.params.parent_param
+
+ self._validate_fixed_vectors()
+ self.formula, self.priors, self.link = self.params.parse_bambi(model=self) # type: ignore[arg-type]
+
+ # For parameters that have a regression backend, apply bounds at the likelihood
+ # level to ensure that the samples that are out of bounds
+ # are discarded (replaced with a large negative value).
+ self.bounds = {
+ name: param.bounds
+ for name, param in self.params.items()
+ if param.is_regression and param.bounds is not None
+ }
+
+ # Set p_outlier and lapse
+ self.p_outlier = self.params.get("p_outlier")
+ self.lapse = lapse if self.has_lapse else None
+
+ self._post_check_data_sanity()
+
+ self.model_distribution = self._make_model_distribution()
+
+ self.family = make_family(
+ self.model_distribution,
+ self.list_params,
+ self.link,
+ self._parent,
+ )
+
+ self.model = bmb.Model(
+ self.formula,
+ data=self.data,
+ family=self.family,
+ priors=self.priors, # center_predictors=False
+ extra_namespace=self.additional_namespace,
+ **kwargs,
+ )
+
+ self._aliases = _get_alias_dict(
+ self.model, self._parent_param, self.response_c, self.response_str
+ )
+ self.set_alias(self._aliases)
+ self.model.build()
+
+ # region ===== Fix scalar deterministic dims for bambi >= 0.17 =====
+ # Bambi >= 0.17 declares dims=("__obs__",) for intercept-only
+ # deterministics that actually have shape (1,). This causes an
+ # xarray CoordinateValidationError during pm.sample() when ArviZ
+ # tries to create a DataArray with mismatched dimension sizes.
+ # Fix by removing the dims declaration for these deterministics.
+ self._fix_scalar_deterministic_dims()
+ # endregion
+
+ # region ===== Init vals and jitters =====
+ if process_initvals:
+ self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
+ if self.initval_jitter > 0:
+ self._jitter_initvals(
+ jitter_epsilon=self.initval_jitter,
+ vector_only=True,
+ )
+ # endregion
+
+ # Make sure we reset rvs_to_initial_values --> Only None's
+ # Otherwise PyMC barks at us when asking to compute likelihoods
+ self.pymc_model.rvs_to_initial_values.update(
+ {key_: None for key_ in self.pymc_model.rvs_to_initial_values.keys()}
+ )
+ _logger.info("Model initialized successfully.")
+
+ @abstractmethod
+ def _make_model_distribution(self) -> type[pm.Distribution]:
+ """Make a pm.Distribution for the model.
+
+ This method must be implemented by subclasses to create the appropriate
+ distribution for the specific model type.
+ """
+ ...
+
+ def _fix_scalar_deterministic_dims(self) -> None:
+ """Fix dims metadata for scalar deterministics.
+
+ Bambi >= 0.17 returns shape ``(1,)`` for intercept-only
+ deterministics but still declares ``dims=("__obs__",)``. This causes
+ an xarray ``CoordinateValidationError`` during ``pm.sample()`` because
+ the ``__obs__`` coordinate has ``n_obs`` entries. Removing the dims
+ declaration for these variables lets ArviZ handle them as
+ un-dimensioned arrays, avoiding the conflict.
+ """
+ n_obs = len(self.data)
+ dims_dict = self.pymc_model.named_vars_to_dims
+ for det in self.pymc_model.deterministics:
+ if det.name not in dims_dict:
+ continue
+ dims = dims_dict[det.name]
+ if "__obs__" in dims:
+ # Check static shape: if it doesn't match n_obs, remove dims
+ try:
+ shape_0 = det.type.shape[0]
+ except (IndexError, TypeError):
+ continue
+ if shape_0 is not None and shape_0 != n_obs:
+ del dims_dict[det.name]
+
+ def _validate_fixed_vectors(self) -> None:
+ """Validate that fixed-vector parameters have the correct length.
+
+ Fixed-vector parameters (``prior=np.ndarray``) bypass Bambi's formula
+ system entirely --- they are passed as a scalar ``0.0`` placeholder to
+ Bambi, and the real vector is substituted inside
+ ``HSSMDistribution.logp()`` (see ``dist.py``). Because this
+ substitution is invisible to Bambi, we must validate the vector length
+ against ``len(self.data)`` up front to catch shape mismatches early.
+ """
+ for name, param in self.params.items():
+ if isinstance(param.prior, np.ndarray):
+ if len(param.prior) != len(self.data):
+ raise ValueError(
+ f"Fixed vector for parameter '{name}' has length "
+ f"{len(param.prior)}, but data has {len(self.data)} rows."
+ )
+
+ @classproperty
+ def supported_models(cls) -> tuple[SupportedModels, ...]:
+ """Get a tuple of all supported models.
+
+ Returns
+ -------
+ tuple[SupportedModels, ...]
+ A tuple containing all supported model names.
+ """
+ return get_args(SupportedModels)
+
+ @staticmethod
+ def _store_init_args(
+ local_vars: dict[str, Any], extra_kwargs: dict[str, Any]
+ ) -> dict[str, Any]:
+ """Capture subclass ``__init__`` arguments for save/load serialisation.
+
+ Call this at the very start of a subclass ``__init__`` before any local
+ variables are assigned, passing ``locals()`` and the ``**kwargs`` dict::
+
+ self._init_args = self._store_init_args(locals(), kwargs)
+
+ Parameters
+ ----------
+ local_vars
+ The ``locals()`` snapshot from the subclass ``__init__``.
+ extra_kwargs
+ The ``**kwargs`` dict captured by the subclass ``__init__``.
+
+ Returns
+ -------
+ dict[str, Any]
+ A mapping of parameter names to their values, suitable for
+ reconstructing the instance via ``cls(**init_args)``.
+
+ Notes
+ -----
+ The implementation filters out internal names that commonly appear in
+ ``locals()`` snapshots (for example, ``__class__`` and ``kwargs``) so
+ that the returned mapping is safe to pass back to the class
+ constructor during unpickling.
+ """
+ # Exclude internal names that appear in locals() snapshots and are not
+ # valid constructor parameters when re-instantiating the class.
+ exclude_keys = {"self", "kwargs", "__class__"}
+ result = {k: v for k, v in local_vars.items() if k not in exclude_keys}
+ result.update(extra_kwargs)
+ return result
+
+ def find_MAP(self, **kwargs):
+ """Perform Maximum A Posteriori estimation.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the MAP estimates of the model parameters.
+ """
+ self._map_dict = pm.find_MAP(model=self.pymc_model, **kwargs)
+ return self._map_dict
+
+ def sample(
+ self,
+ sampler: Literal["pymc", "numpyro", "blackjax", "nutpie", "laplace"]
+ | None = None,
+ init: str | None = None,
+ initvals: str | dict | None = None,
+ include_response_params: bool = False,
+ **kwargs,
+ ) -> az.InferenceData | pm.Approximation:
+ """Perform sampling using the `fit` method via bambi.Model.
+
+ Parameters
+ ----------
+ sampler: optional
+ The sampler to use. Can be one of "pymc", "numpyro",
+ "blackjax", "nutpie", or "laplace". If using `blackbox` likelihoods,
+ this cannot be "numpyro", "blackjax", or "nutpie". By default it is None,
+ and sampler will automatically be chosen: when the model uses the
+ `approx_differentiable` likelihood, and `jax` backend, "numpyro" will
+ be used. Otherwise, "pymc" (the default PyMC NUTS sampler) will be used.
+
+ Note that the old sampler names such as "mcmc", "nuts_numpyro",
+ "nuts_blackjax" will be deprecated and removed in future releases. A warning
+ will be raised if any of these old names are used.
+ init: optional
+ Initialization method to use for the sampler. If any of the NUTS samplers
+ is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`.
+ initvals: optional
+ Pass initial values to the sampler. This can be a dictionary of initial
+ values for parameters of the model, or a string "map" to use initialization
+ at the MAP estimate. If "map" is used, the MAP estimate will be computed if
+ not already attached to the base class from prior call to `find_MAP`.
+ include_response_params: optional
+ Include parameters of the response distribution in the output. These usually
+ take more space than other parameters as there's one of them per
+ observation. Defaults to False.
+ kwargs
+ Other arguments passed to bmb.Model.fit(). Please see [here]
+ (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
+ for full documentation.
+
+ Returns
+ -------
+ az.InferenceData | pm.Approximation
+ A reference to the `model.traces` object, which stores the traces of the
+ last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData`
+ instance if `sampler` is `"pymc"` (default), `"numpyro"`,
+ `"blackjax"` or "`laplace".
+ """
+ # If initvals are None (default)
+ # we skip processing initvals here.
+ if sampler in _new_sampler_mapping:
+ _logger.warning(
+ f"Sampler '{sampler}' is deprecated. "
+ "Please use the new sampler names: "
+ "'pymc', 'numpyro', 'blackjax', 'nutpie', or 'laplace'."
+ )
+ sampler = _new_sampler_mapping[sampler] # type: ignore
+
+ if sampler == "vi":
+ raise ValueError(
+ "VI is not supported via the sample() method. "
+ "Please use the vi() method instead."
+ )
+
+ if initvals is not None:
+ if isinstance(initvals, dict):
+ kwargs["initvals"] = initvals
+ else:
+ if isinstance(initvals, str):
+ if initvals == "map":
+ if self._map_dict is None:
+ _logger.info(
+ "initvals='map' but no map"
+ "estimate precomputed. \n"
+ "Running map estimation first..."
+ )
+ self.find_MAP()
+ kwargs["initvals"] = self._map_dict
+ else:
+ kwargs["initvals"] = self._map_dict
+ else:
+ raise ValueError(
+ "initvals argument must be a dictionary or 'map'"
+ " to use the MAP estimate."
+ )
+ else:
+ kwargs["initvals"] = self._initvals
+ _logger.info("Using default initvals. \n")
+
+ if sampler is None:
+ if (
+ self.loglik_kind == "approx_differentiable"
+ and self.model_config.backend == "jax"
+ ):
+ sampler = "numpyro"
+ else:
+ sampler = "pymc"
+
+ if self.loglik_kind == "blackbox":
+ if sampler in ["blackjax", "numpyro", "nutpie"]:
+ raise ValueError(
+ f"{sampler} sampler does not work with blackbox likelihoods."
+ )
+
+ if "step" not in kwargs:
+ kwargs |= {"step": pm.Slice(model=self.pymc_model)}
+
+ if (
+ self.loglik_kind == "approx_differentiable"
+ and self.model_config.backend == "jax"
+ and sampler == "pymc"
+ and kwargs.get("cores", None) != 1
+ ):
+ _logger.warning(
+ "Parallel sampling might not work with `jax` backend and the PyMC NUTS "
+ + "sampler on some platforms. Please consider using `numpyro`, "
+ + "`blackjax`, or `nutpie` sampler if that is a problem."
+ )
+
+ if self._check_extra_fields():
+ self._update_extra_fields()
+
+ if init is None:
+ if sampler in ["pymc", "numpyro", "blackjax", "nutpie"]:
+ init = "adapt_diag"
+ else:
+ init = "auto"
+
+ # If sampler is finally `numpyro` make sure
+ # the jitter argument is set to False
+ if sampler == "numpyro":
+ if "nuts_sampler_kwargs" in kwargs:
+ if kwargs["nuts_sampler_kwargs"].get("jitter"):
+ _logger.warning(
+ "The jitter argument is set to True. "
+ + "This argument is not supported "
+ + "by the numpyro backend. "
+ + "The jitter argument will be set to False."
+ )
+ kwargs["nuts_sampler_kwargs"]["jitter"] = False
+ else:
+ kwargs["nuts_sampler_kwargs"] = {"jitter": False}
+
+ if sampler != "pymc" and "step" in kwargs:
+ raise ValueError(
+ "`step` samplers (enabled by the `step` argument) are only supported "
+ "by the `pymc` sampler."
+ )
+
+ if self._inference_obj is not None:
+ _logger.warning(
+ "The model has already been sampled. Overwriting the previous "
+ + "inference object. Any previous reference to the inference object "
+ + "will still point to the old object."
+ )
+
+ # Define whether likelihood should be computed
+ compute_likelihood = True
+ if "idata_kwargs" in kwargs:
+ if "log_likelihood" in kwargs["idata_kwargs"]:
+ compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True)
+
+ omit_offsets = kwargs.pop("omit_offsets", False)
+ self._inference_obj = self.model.fit(
+ inference_method=(
+ "pymc"
+ if sampler in ["pymc", "numpyro", "blackjax", "nutpie"]
+ else sampler
+ ),
+ init=init,
+ include_response_params=include_response_params,
+ omit_offsets=omit_offsets,
+ **kwargs,
+ )
+
+ # Separate out log likelihood computation
+ if compute_likelihood:
+ self.log_likelihood(self._inference_obj, inplace=True)
+
+ # Subset data vars in posterior
+ self._clean_posterior_group(idata=self._inference_obj)
+ return self.traces
+
+ def vi(
+ self,
+ method: str = "advi",
+ niter: int = 10000,
+ draws: int = 1000,
+ return_idata: bool = True,
+ ignore_mcmc_start_point_defaults=False,
+ **vi_kwargs,
+ ) -> pm.Approximation | az.InferenceData:
+ """Perform Variational Inference.
+
+ Parameters
+ ----------
+ niter : int
+ The number of iterations to run the VI algorithm. Defaults to 3000.
+ method : str
+ The method to use for VI. Can be one of "advi" or "fullrank_advi", "svgd",
+ "asvgd".Defaults to "advi".
+ draws : int
+ The number of samples to draw from the posterior distribution.
+ Defaults to 1000.
+ return_idata : bool
+ If True, returns an InferenceData object. Otherwise, returns the
+ approximation object directly. Defaults to True.
+
+ Returns
+ -------
+ pm.Approximation or az.InferenceData: The mean field approximation object.
+ """
+ if self.loglik_kind == "analytical":
+ _logger.warning(
+ "VI is not recommended for the analytical likelihood,"
+ " since gradients can be brittle."
+ )
+ elif self.loglik_kind == "blackbox":
+ raise ValueError(
+ "VI is not supported for blackbox likelihoods, "
+ " since likelihood gradients are needed!"
+ )
+
+ if ("start" not in vi_kwargs) and not ignore_mcmc_start_point_defaults:
+ _logger.info("Using MCMC starting point defaults.")
+ vi_kwargs["start"] = self._initvals
+
+ # Run variational inference directly from pymc model
+ with self.pymc_model:
+ self._vi_approx = pm.fit(n=niter, method=method, **vi_kwargs)
+
+ # Sample from the approximate posterior
+ if self._vi_approx is not None:
+ self._inference_obj_vi = self._vi_approx.sample(draws)
+
+ # Post-processing
+ self._clean_posterior_group(idata=self._inference_obj_vi)
+
+ # Return the InferenceData object if return_idata is True
+ if return_idata:
+ return self._inference_obj_vi
+ # Otherwise return the appromation object directly
+ return self.vi_approx
+
+ def _clean_posterior_group(self, idata: az.InferenceData | None = None):
+ """Clean up the posterior group of the InferenceData object.
+
+ Parameters
+ ----------
+ idata : az.InferenceData
+ The InferenceData object to clean up. If None, the last InferenceData object
+ will be used.
+ """
+ # # Logic behind which variables to keep:
+ # # We essentially want to get rid of
+ # # all the trial-wise variables.
+
+ # # We drop all distributional components, IF they are deterministics
+ # # (in which case they will be trial wise systematically)
+ # # and we keep distributional components, IF they are
+ # # basic random-variabels (in which case they should never
+ # # appear trial-wise).
+ if idata is None:
+ raise ValueError(
+ "The InferenceData object is None. Cannot clean up the posterior group."
+ )
+ elif not hasattr(idata, "posterior"):
+ raise ValueError(
+ "The InferenceData object does not have a posterior group. "
+ + "Cannot clean up the posterior group."
+ )
+
+ vars_to_keep = set(idata["posterior"].data_vars.keys()).difference(
+ set(
+ key_
+ for key_ in self.model.distributional_components.keys()
+ if key_ in [var_.name for var_ in self.pymc_model.deterministics]
+ )
+ )
+ vars_to_keep_clean = [
+ var_
+ for var_ in vars_to_keep
+ if isinstance(var_, str) and "_mean" not in var_
+ ]
+
+ setattr(
+ idata,
+ "posterior",
+ idata["posterior"][vars_to_keep_clean],
+ )
+
+ def log_likelihood(
+ self,
+ idata: az.InferenceData | None = None,
+ data: pd.DataFrame | None = None,
+ inplace: bool = True,
+ keep_likelihood_params: bool = False,
+ ) -> az.InferenceData | None:
+ """Compute the log likelihood of the model.
+
+ Parameters
+ ----------
+ idata : optional
+ The `InferenceData` object returned by `HSSM.sample()`. If not provided,
+ data : optional
+ A pandas DataFrame with values for the predictors that are used to obtain
+ out-of-sample predictions. If omitted, the original dataset is used.
+ inplace : optional
+ If `True` will modify idata in-place and append a `log_likelihood` group to
+ `idata`. Otherwise, it will return a copy of idata with the predictions
+ added, by default True.
+ keep_likelihood_params : optional
+ If `True`, the trial wise likelihood parameters that are computed
+ on route to getting the log likelihood are kept in the `idata` object.
+ Defaults to False. See also the method `add_likelihood_parameters_to_idata`.
+
+ Returns
+ -------
+ az.InferenceData | None
+ InferenceData or None
+ """
+ if self._inference_obj is None and idata is None:
+ raise ValueError(
+ "Neither has the model been sampled yet nor"
+ + " an idata object has been provided."
+ )
+
+ if idata is None:
+ if self._inference_obj is None:
+ raise ValueError(
+ "The model has not been sampled yet. "
+ + "Please provide an idata object."
+ )
+ else:
+ idata = self._inference_obj
+
+ # Actual likelihood computation
+ idata = _compute_log_likelihood(self.model, idata, data, inplace)
+
+ # clean up posterior:
+ if not keep_likelihood_params:
+ self._clean_posterior_group(idata=idata)
+
+ if inplace:
+ return None
+ else:
+ return idata
+
+ def add_likelihood_parameters_to_idata(
+ self,
+ idata: az.InferenceData | None = None,
+ inplace: bool = False,
+ ) -> az.InferenceData | None:
+ """Add likelihood parameters to the InferenceData object.
+
+ Parameters
+ ----------
+ idata : az.InferenceData
+ The InferenceData object returned by HSSM.sample().
+ inplace : bool
+ If True, the likelihood parameters are added to idata in-place. Otherwise,
+ a copy of idata with the likelihood parameters added is returned.
+ Defaults to False.
+
+ Returns
+ -------
+ az.InferenceData | None
+ InferenceData or None
+ """
+ if idata is None:
+ if self._inference_obj is None:
+ raise ValueError("No idata provided and model not yet sampled!")
+ else:
+ idata = self.model._compute_likelihood_params( # pylint: disable=protected-access
+ deepcopy(self._inference_obj)
+ if not inplace
+ else self._inference_obj
+ )
+ else:
+ idata = self.model._compute_likelihood_params( # pylint: disable=protected-access
+ deepcopy(idata) if not inplace else idata
+ )
+ return idata
+
+ def sample_posterior_predictive(
+ self,
+ idata: az.InferenceData | None = None,
+ data: pd.DataFrame | None = None,
+ inplace: bool = True,
+ include_group_specific: bool = True,
+ kind: Literal["response", "response_params"] = "response",
+ draws: int | float | list[int] | np.ndarray | None = None,
+ safe_mode: bool = True,
+ ) -> az.InferenceData | None:
+ """Perform posterior predictive sampling from the HSSM model.
+
+ Parameters
+ ----------
+ idata : optional
+ The `InferenceData` object returned by `HSSM.sample()`. If not provided,
+ the `InferenceData` from the last time `sample()` is called will be used.
+ data : optional
+ An optional data frame with values for the predictors that are used to
+ obtain out-of-sample predictions. If omitted, the original dataset is used.
+ inplace : optional
+ If `True` will modify idata in-place and append a `posterior_predictive`
+ group to `idata`. Otherwise, it will return a copy of idata with the
+ predictions added, by default True.
+ include_group_specific : optional
+ If `True` will make predictions including the group specific effects.
+ Otherwise, predictions are made with common effects only (i.e. group-
+ specific are set to zero), by default True.
+ kind: optional
+ Indicates the type of prediction required. Can be `"response_params"` or
+ `"response"`. The first returns draws from the posterior distribution of the
+ likelihood parameters, while the latter returns the draws from the posterior
+ predictive distribution (i.e. the posterior probability distribution for a
+ new observation) in addition to the posterior distribution. Defaults to
+ "response_params".
+ draws: optional
+ The number of samples to draw from the posterior predictive distribution
+ from each chain.
+ When it's an integer >= 1, the number of samples to be extracted from the
+ `draw` dimension. If this integer is larger than the number of posterior
+ samples in each chain, all posterior samples will be used
+ in posterior predictive sampling. When a float between 0 and 1, the
+ proportion of samples from the draw dimension from each chain to be used in
+ posterior predictive sampling.. If this proportion is very
+ small, at least one sample will be used. When None, all posterior samples
+ will be used. Defaults to None.
+ safe_mode: bool
+ If True, the function will split the draws into chunks of 10 to avoid memory
+ issues. Defaults to True.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been sampled yet and idata is not provided.
+
+ Returns
+ -------
+ az.InferenceData | None
+ InferenceData or None
+ """
+ if idata is None:
+ if self._inference_obj is None:
+ raise ValueError(
+ "The model has not been sampled yet. "
+ + "Please either provide an idata object or sample the model first."
+ )
+ idata = self._inference_obj
+ _logger.info(
+ "idata=None, we use the traces assigned to the HSSM object as idata."
+ )
+
+ if idata is not None:
+ if "posterior_predictive" in idata.groups():
+ del idata["posterior_predictive"]
+ _logger.warning(
+ "pre-existing posterior_predictive group deleted from idata. \n"
+ )
+
+ if self._check_extra_fields(data):
+ self._update_extra_fields(data)
+
+ if isinstance(draws, np.ndarray):
+ draws = draws.astype(int)
+ elif isinstance(draws, list):
+ draws = np.array(draws).astype(int)
+ elif isinstance(draws, int | float):
+ draws = np.arange(int(draws))
+ elif draws is None:
+ draws = idata["posterior"].draw.values
+ else:
+ raise ValueError(
+ "draws must be an integer, " + "a list of integers, or a numpy array."
+ )
+
+ assert isinstance(draws, np.ndarray)
+
+ # Make a copy of idata, set the `posterior` group to be a random sub-sample
+ # of the original (draw dimension gets sub-sampled)
+
+ idata_copy = idata.copy()
+
+ if (draws.shape != idata["posterior"].draw.values.shape) or (
+ (draws.shape == idata["posterior"].draw.values.shape)
+ and not np.allclose(draws, idata["posterior"].draw.values)
+ ):
+ # Reassign posterior to sub-sampled version
+ setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws))
+
+ if kind == "response":
+ # If we run kind == 'response' we actually run the observation RV
+ if safe_mode:
+ # safe mode splits the draws into chunks of 10 to avoid
+ # memory issues (TODO: Figure out the source of memory issues)
+ split_draws = _split_array(
+ idata_copy["posterior"].draw.values, divisor=10
+ )
+
+ posterior_predictive_list = []
+ for samples_tmp in split_draws:
+ tmp_posterior = idata["posterior"].sel(draw=samples_tmp)
+ setattr(idata_copy, "posterior", tmp_posterior)
+ self.model.predict(
+ idata_copy, kind, data, True, include_group_specific
+ )
+ posterior_predictive_list.append(idata_copy["posterior_predictive"])
+
+ if inplace:
+ idata.add_groups(
+ posterior_predictive=xr.concat(
+ posterior_predictive_list, dim="draw"
+ )
+ )
+ # for inplace, we don't return anything
+ return None
+ else:
+ # Reassign original posterior to idata_copy
+ setattr(idata_copy, "posterior", idata["posterior"])
+ # Add new posterior predictive group to idata_copy
+ del idata_copy["posterior_predictive"]
+ idata_copy.add_groups(
+ posterior_predictive=xr.concat(
+ posterior_predictive_list, dim="draw"
+ )
+ )
+ return idata_copy
+ else:
+ if inplace:
+ # If not safe-mode
+ # We call .predict() directly without any
+ # chunking of data.
+
+ # .predict() is called on the copy of idata
+ # since we still subsampled (or assigned) the draws
+ self.model.predict(
+ idata_copy, kind, data, True, include_group_specific
+ )
+
+ # posterior predictive group added to idata
+ idata.add_groups(
+ posterior_predictive=idata_copy["posterior_predictive"]
+ )
+ # don't return anything if inplace
+ return None
+ else:
+ # Not safe mode and not inplace
+ # Function acts as very thin wrapper around
+ # .predict(). It just operates on the
+ # idata_copy object
+ return self.model.predict(
+ idata_copy, kind, data, False, include_group_specific
+ )
+ elif kind == "response_params":
+ # If kind == 'response_params', we don't need to run the RV directly,
+ # there shouldn't really be any significant memory issues here,
+ # we can simply ignore settings, since the computational overhead
+ # should be very small --> nudges user towards good outputs.
+ _logger.warning(
+ "The kind argument is set to 'mean', but 'draws' argument "
+ + "is not None: The draws argument will be ignored!"
+ )
+ return self.model.predict(
+ idata, kind, data, inplace, include_group_specific
+ )
+ else:
+ raise ValueError("`kind` must be either 'response' or 'response_params'.")
+
+ def plot_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
+ """Produce a posterior predictive plot.
+
+ Equivalent to calling `hssm.plotting.plot_predictive()` with the
+ model. Please see that function for
+ [full documentation][hssm.plotting.plot_predictive].
+
+ Returns
+ -------
+ mpl.axes.Axes | sns.FacetGrid
+ The matplotlib axis or seaborn FacetGrid object containing the plot.
+ """
+ return plotting.plot_predictive(self, **kwargs)
+
+ def plot_quantile_probability(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
+ """Produce a quantile probability plot.
+
+ Equivalent to calling `hssm.plotting.plot_quantile_probability()` with the
+ model. Please see that function for
+ [full documentation][hssm.plotting.plot_quantile_probability].
+
+ Returns
+ -------
+ mpl.axes.Axes | sns.FacetGrid
+ The matplotlib axis or seaborn FacetGrid object containing the plot.
+ """
+ return plotting.plot_quantile_probability(self, **kwargs)
+
+ def predict(self, **kwargs) -> az.InferenceData:
+ """Generate samples from the predictive distribution."""
+ return self.model.predict(**kwargs)
+
+ def sample_do(
+ self, params: dict[str, Any], draws: int = 100, return_model=False, **kwargs
+ ) -> az.InferenceData | tuple[az.InferenceData, pm.Model]:
+ """Generate samples from the predictive distribution using the `do-operator`."""
+ do_model = do(self.pymc_model, params)
+ do_idata = pm.sample_prior_predictive(model=do_model, draws=draws, **kwargs)
+
+ # clean up `rt,response_mean` to `v`
+ do_idata = self._drop_parent_str_from_idata(idata=do_idata)
+
+ # rename otherwise inconsistent dims and coords
+ if "rt,response_extra_dim_0" in do_idata["prior_predictive"].dims:
+ setattr(
+ do_idata,
+ "prior_predictive",
+ do_idata["prior_predictive"].rename_dims(
+ {"rt,response_extra_dim_0": "rt,response_dim"}
+ ),
+ )
+ if "rt,response_extra_dim_0" in do_idata["prior_predictive"].coords:
+ setattr(
+ do_idata,
+ "prior_predictive",
+ do_idata["prior_predictive"].rename_vars(
+ name_dict={"rt,response_extra_dim_0": "rt,response_dim"}
+ ),
+ )
+
+ if return_model:
+ return do_idata, do_model
+ return do_idata
+
+ def sample_prior_predictive(
+ self,
+ draws: int = 500,
+ var_names: str | list[str] | None = None,
+ omit_offsets: bool = True,
+ random_seed: np.random.Generator | None = None,
+ ) -> az.InferenceData:
+ """Generate samples from the prior predictive distribution.
+
+ Parameters
+ ----------
+ draws
+ Number of draws to sample from the prior predictive distribution. Defaults
+ to 500.
+ var_names
+ A list of names of variables for which to compute the prior predictive
+ distribution. Defaults to ``None`` which means both observed and unobserved
+ RVs.
+ omit_offsets
+ Whether to omit offset terms. Defaults to ``True``.
+ random_seed
+ Seed for the random number generator.
+
+ Returns
+ -------
+ az.InferenceData
+ ``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and
+ ``observed_data``.
+ """
+ prior_predictive = self.model.prior_predictive(
+ draws, var_names, omit_offsets, random_seed
+ )
+
+ # AF-COMMENT: Not sure if necessary to include the
+ # mean prior here (which adds deterministics that
+ # could be recomputed elsewhere)
+ prior_predictive.add_groups(posterior=prior_predictive.prior)
+ # Bambi >= 0.17 renamed kind="mean" to kind="response_params".
+ self.model.predict(prior_predictive, kind="response_params", inplace=True)
+
+ # clean
+ setattr(prior_predictive, "prior", prior_predictive["posterior"])
+ del prior_predictive["posterior"]
+
+ if self._inference_obj is None:
+ self._inference_obj = prior_predictive
+ else:
+ self._inference_obj.extend(prior_predictive)
+
+ # clean up `rt,response_mean` to `v`
+ idata = self._drop_parent_str_from_idata(idata=self._inference_obj)
+
+ # rename otherwise inconsistent dims and coords
+ if "rt,response_extra_dim_0" in idata["prior_predictive"].dims:
+ setattr(
+ idata,
+ "prior_predictive",
+ idata["prior_predictive"].rename_dims(
+ {"rt,response_extra_dim_0": "rt,response_dim"}
+ ),
+ )
+ if "rt,response_extra_dim_0" in idata["prior_predictive"].coords:
+ setattr(
+ idata,
+ "prior_predictive",
+ idata["prior_predictive"].rename_vars(
+ name_dict={"rt,response_extra_dim_0": "rt,response_dim"}
+ ),
+ )
+
+ # Update self._inference_obj to match the cleaned idata
+ self._inference_obj = idata
+ return deepcopy(self._inference_obj)
+
+ @property
+ def pymc_model(self) -> pm.Model:
+ """Provide access to the PyMC model.
+
+ Returns
+ -------
+ pm.Model
+ The PyMC model built by bambi
+ """
+ return self.model.backend.model
+
+ def set_alias(self, aliases: dict[str, str | dict]):
+ """Set parameter aliases.
+
+ Sets the aliases according to the dictionary passed to it and rebuild the
+ model.
+
+ Parameters
+ ----------
+ aliases
+ A dict specifying the parameter names being aliased and the aliases.
+ """
+ self.model.set_alias(aliases)
+ self.model.build()
+
+ @property
+ def response_c(self) -> str:
+ """Return the response variable names in c() format."""
+ if self.response is None:
+ return "c()"
+ if len(self.response) == 1:
+ return self.response[0]
+ return f"c({', '.join(self.response)})"
+
+ @property
+ def response_str(self) -> str:
+ """Return the response variable names in string format."""
+ if self.response is None:
+ return ""
+ if len(self.response) == 1:
+ return self.response[0]
+ return ",".join(self.response)
+
+ # NOTE: can't annotate return type because the graphviz dependency is optional
+ def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"):
+ """Produce a graphviz Digraph from a built HSSM model.
+
+ Requires graphviz, which may be installed most easily with `conda install -c
+ conda-forge python-graphviz`. Alternatively, you may install the `graphviz`
+ binaries yourself, and then `pip install graphviz` to get the python bindings.
+ See http://graphviz.readthedocs.io/en/stable/manual.html for more information.
+
+ Parameters
+ ----------
+ formatting
+ One of `"plain"` or `"plain_with_params"`. Defaults to `"plain"`.
+ name
+ Name of the figure to save. Defaults to `None`, no figure is saved.
+ figsize
+ Maximum width and height of figure in inches. Defaults to `None`, the
+ figure size is set automatically. If defined and the drawing is larger than
+ the given size, the drawing is uniformly scaled down so that it fits within
+ the given size. Only works if `name` is not `None`.
+ dpi
+ Point per inch of the figure to save.
+ Defaults to 300. Only works if `name` is not `None`.
+ fmt
+ Format of the figure to save.
+ Defaults to `"png"`. Only works if `name` is not `None`.
+
+ Returns
+ -------
+ graphviz.Graph
+ The graph
+ """
+ graph = self.model.graph(formatting, name, figsize, dpi, fmt)
+
+ parent_param = self._parent_param
+ if parent_param.is_regression:
+ return graph
+
+ # Modify the graph
+ # 1. Remove all nodes and edges related to `{parent}_mean`:
+ graph.body = [
+ item for item in graph.body if f"{parent_param.name}_mean" not in item
+ ]
+ # 2. Add a new edge from parent to response
+ graph.edge(parent_param.name, self.response_str)
+
+ return graph
+
+ def compile_logp(self, keep_transformed: bool = False, **kwargs):
+ """Compile the log probability function for the model.
+
+ Parameters
+ ----------
+ keep_transformed : bool, optional
+ If True, keeps the transformed variables in the compiled function.
+ If False, removes value transforms before compilation.
+ Defaults to False.
+ **kwargs
+ Additional keyword arguments passed to PyMC's compile_logp:
+ - vars: List of variables. Defaults to None (all variables).
+ - jacobian: Whether to include log(|det(dP/dQ)|) term for
+ transformed variables. Defaults to True.
+ - sum: Whether to sum all terms instead of returning a vector.
+ Defaults to True.
+
+ Returns
+ -------
+ callable
+ A compiled function that computes the model log probability.
+ """
+ if keep_transformed:
+ return self.pymc_model.compile_logp(
+ vars=kwargs.get("vars", None),
+ jacobian=kwargs.get("jacobian", True),
+ sum=kwargs.get("sum", True),
+ )
+ else:
+ new_model = pm.model.transform.conditioning.remove_value_transforms(
+ self.pymc_model
+ )
+ return new_model.compile_logp(
+ vars=kwargs.get("vars", None),
+ jacobian=kwargs.get("jacobian", True),
+ sum=kwargs.get("sum", True),
+ )
+
+ def plot_trace(
+ self,
+ data: az.InferenceData | None = None,
+ include_deterministic: bool = False,
+ tight_layout: bool = True,
+ **kwargs,
+ ) -> None:
+ """Generate trace plot with ArviZ but with additional convenience features.
+
+ This is a simple wrapper for the az.plot_trace() function. By default, it
+ filters out the deterministic values from the plot. Please see the
+ [arviz documentation]
+ (https://arviz-devs.github.io/arviz/api/generated/arviz.plot_trace.html)
+ for additional parameters that can be specified.
+
+ Parameters
+ ----------
+ data : optional
+ An ArviZ InferenceData object. If None, the traces stored in the model will
+ be used.
+ include_deterministic : optional
+ Whether to include deterministic variables in the plot. Defaults to False.
+ Note that if include deterministic is set to False and and `var_names` is
+ provided, the `var_names` provided will be modified to also exclude the
+ deterministic values. If this is not desirable, set
+ `include deterministic` to True.
+ tight_layout : optional
+ Whether to call plt.tight_layout() after plotting. Defaults to True.
+ """
+ data = data or self.traces
+ if not isinstance(data, az.InferenceData):
+ raise TypeError("data must be an InferenceData object.")
+
+ if not include_deterministic:
+ var_names = list(
+ set([var.name for var in self.pymc_model.free_RVs]).intersection(
+ set(list(data["posterior"].data_vars.keys()))
+ )
+ )
+ # var_names = self._get_deterministic_var_names(data)
+ if var_names:
+ if "var_names" in kwargs:
+ if isinstance(kwargs["var_names"], str):
+ if kwargs["var_names"] not in var_names:
+ var_names.append(kwargs["var_names"])
+ kwargs["var_names"] = var_names
+ elif isinstance(kwargs["var_names"], list):
+ kwargs["var_names"] = list(
+ set(var_names) | set(kwargs["var_names"])
+ )
+ elif kwargs["var_names"] is None:
+ kwargs["var_names"] = var_names
+ else:
+ raise ValueError(
+ "`var_names` must be a string, a list of strings, or None."
+ )
+ else:
+ kwargs["var_names"] = var_names
+ az.plot_trace(data, **kwargs)
+
+ if tight_layout:
+ plt.tight_layout()
+
+ def summary(
+ self,
+ data: az.InferenceData | None = None,
+ include_deterministic: bool = False,
+ **kwargs,
+ ) -> pd.DataFrame | xr.Dataset:
+ """Produce a summary table with ArviZ but with additional convenience features.
+
+ This is a simple wrapper for the az.summary() function. By default, it
+ filters out the deterministic values from the plot. Please see the
+ [arviz documentation]
+ (https://arviz-devs.github.io/arviz/api/generated/arviz.summary.html)
+ for additional parameters that can be specified.
+
+ Parameters
+ ----------
+ data
+ An ArviZ InferenceData object. If None, the traces stored in the model will
+ be used.
+ include_deterministic : optional
+ Whether to include deterministic variables in the plot. Defaults to False.
+ Note that if include_deterministic is set to False and and `var_names` is
+ provided, the `var_names` provided will be modified to also exclude the
+ deterministic values. If this is not desirable, set
+ `include_deterministic` to True.
+
+ Returns
+ -------
+ pd.DataFrame | xr.Dataset
+ A pandas DataFrame or xarray Dataset containing the summary statistics.
+ """
+ data = data or self.traces
+ if not isinstance(data, az.InferenceData):
+ raise TypeError("data must be an InferenceData object.")
+
+ if not include_deterministic:
+ var_names = list(
+ set([var.name for var in self.pymc_model.free_RVs]).intersection(
+ set(list(data["posterior"].data_vars.keys()))
+ )
+ )
+ # var_names = self._get_deterministic_var_names(data)
+ if var_names:
+ kwargs["var_names"] = list(set(var_names + kwargs.get("var_names", [])))
+ return az.summary(data, **kwargs)
+
+ def initial_point(self, transformed: bool = False) -> dict[str, np.ndarray]:
+ """Compute the initial point of the model.
+
+ This is a slightly altered version of pm.initial_point.initial_point().
+
+ Parameters
+ ----------
+ transformed : bool, optional
+ If True, return the initial point in transformed space.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the initial point of the model parameters.
+ """
+ fn = pm.initial_point.make_initial_point_fn(
+ model=self.pymc_model, return_transformed=transformed
+ )
+ return pm.model.Point(fn(None), model=self.pymc_model)
+
+ def restore_traces(
+ self, traces: az.InferenceData | pm.Approximation | str | PathLike
+ ) -> None:
+ """Restore traces from an InferenceData object or a .netcdf file.
+
+ Parameters
+ ----------
+ traces
+ An InferenceData object or a path to a file containing the traces.
+ """
+ if isinstance(traces, pm.Approximation):
+ self._inference_obj_vi = traces
+ return
+
+ if isinstance(traces, (str, PathLike)):
+ traces = az.from_netcdf(traces)
+ self._inference_obj = cast("az.InferenceData", traces)
+
+ def restore_vi_traces(
+ self, traces: az.InferenceData | pm.Approximation | str | PathLike
+ ) -> None:
+ """Restore VI traces from an InferenceData object or a .netcdf file.
+
+ Parameters
+ ----------
+ traces
+ An InferenceData object or a path to a file containing the VI traces.
+ """
+ if isinstance(traces, pm.Approximation):
+ self._inference_obj_vi = traces
+ return
+
+ if isinstance(traces, (str, PathLike)):
+ traces = az.from_netcdf(traces)
+ self._inference_obj_vi = cast("az.InferenceData", traces)
+
+ def save_model(
+ self,
+ model_name: str | None = None,
+ allow_absolute_base_path: bool = False,
+ base_path: str | Path = "hssm_models",
+ save_idata_only: bool = False,
+ ) -> None:
+ """Save a HSSM model instance and its inference results to disk.
+
+ Parameters
+ ----------
+ model_name : str | None
+ Name to use for the saved model files.
+ If None, will use model.model_name with timestamp
+ allow_absolute_base_path : bool
+ Whether to allow absolute paths for base_path.
+ Defaults to False for safety.
+ base_path : str | Path
+ Base directory to save model files in.
+ Must be relative path if allow_absolute_base_path=False.
+ Defaults to "hssm_models".
+ save_idata_only : bool
+ If True, only saves inference data (traces), not the model pickle.
+ Defaults to False (saves both model and traces).
+
+ Raises
+ ------
+ ValueError
+ If base_path is absolute and allow_absolute_base_path=False
+ """
+ # Convert to Path object for cross-platform compatibility
+ base_path = Path(base_path)
+
+ # Check if base_path is absolute (works on all platforms)
+ if not allow_absolute_base_path and base_path.is_absolute():
+ raise ValueError(
+ "base_path must be a relative path if allow_absolute_base_path is False"
+ )
+
+ if model_name is None:
+ # Get date string format as suffix to model name
+ timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
+ model_name = f"{self.model_name}_{timestamp}"
+
+ # Sanitize model_name and construct full path
+ model_name = model_name.replace(" ", "_")
+ model_path = Path(base_path).joinpath(model_name)
+ model_path.mkdir(parents=True, exist_ok=True)
+
+ # Save model to pickle file
+ if not save_idata_only:
+ with open(model_path.joinpath("model.pkl"), "wb") as f:
+ cpickle.dump(self, f)
+
+ # Save traces to netcdf file
+ if self._inference_obj is not None:
+ az.to_netcdf(self._inference_obj, model_path.joinpath("traces.nc"))
+
+ # Save vi_traces to netcdf file
+ if self._inference_obj_vi is not None:
+ az.to_netcdf(self._inference_obj_vi, model_path.joinpath("vi_traces.nc"))
+
+ @classmethod
+ def load_model(
+ cls, path: Union[str, Path]
+ ) -> Union["HSSMBase", dict[str, Optional[az.InferenceData]]]:
+ """Load a HSSM model instance and its inference results from disk.
+
+ Parameters
+ ----------
+ path : str | Path
+ Path to the model directory or model.pkl file. If a directory is provided,
+ will look for model.pkl, traces.nc and vi_traces.nc files within it.
+
+ Returns
+ -------
+ HSSMBase or dict[str, az.InferenceData | None]
+ The loaded model instance (with inference results attached if available),
+ or a dictionary of traces-only InferenceData objects when no model.pkl is
+ found.
+ """
+ # Convert path to Path object
+ path = Path(path)
+
+ # If path points to a file, assume it's model.pkl
+ if path.is_file():
+ model_dir = path.parent
+ model_path = path
+ else:
+ # Path points to directory
+ model_dir = path
+ model_path = model_dir.joinpath("model.pkl")
+
+ # check if model_dir exists
+ if not model_dir.exists():
+ raise FileNotFoundError(f"Model directory {model_dir} does not exist.")
+
+ # check if model.pkl exists raise logging information if not
+ if not model_path.exists():
+ _logger.info(
+ f"model.pkl file does not exist in {model_dir}. "
+ "Attempting to load traces only."
+ )
+ if (not model_dir.joinpath("traces.nc").exists()) and (
+ not model_dir.joinpath("vi_traces.nc").exists()
+ ):
+ raise FileNotFoundError(f"No traces found in {model_dir}.")
+ else:
+ idata_dict = cls.load_model_idata(model_dir)
+ return idata_dict
+ else:
+ # Load model from pickle file
+ with open(model_path, "rb") as f:
+ model = cpickle.load(f)
+
+ # Load traces if they exist
+ traces_path = model_dir.joinpath("traces.nc")
+ if traces_path.exists():
+ model.restore_traces(traces_path)
+
+ # Load VI traces if they exist
+ vi_traces_path = model_dir.joinpath("vi_traces.nc")
+ if vi_traces_path.exists():
+ model.restore_vi_traces(vi_traces_path)
+ return model
+
+ @classmethod
+ def load_model_idata(cls, path: str | Path) -> dict[str, az.InferenceData | None]:
+ """Load the traces from a model directory.
+
+ Parameters
+ ----------
+ path : str | Path
+ Path to the model directory containing traces.nc and/or vi_traces.nc files.
+
+ Returns
+ -------
+ dict[str, az.InferenceData | None]
+ A dictionary with keys "idata_mcmc" and "idata_vi" containing the traces
+ from the model directory. If the traces do not exist, the corresponding
+ value will be None.
+ """
+ idata_dict: dict[str, az.InferenceData | None] = {}
+ model_dir = Path(path)
+ # check if path exists
+ if not model_dir.exists():
+ raise FileNotFoundError(f"Model directory {model_dir} does not exist.")
+
+ # check if traces.nc exists
+ traces_path = model_dir.joinpath("traces.nc")
+ if not traces_path.exists():
+ _logger.warning(f"traces.nc file does not exist in {model_dir}.")
+ idata_dict["idata_mcmc"] = None
+ else:
+ idata_dict["idata_mcmc"] = az.from_netcdf(traces_path)
+
+ # check if vi_traces.nc exists
+ vi_traces_path = model_dir.joinpath("vi_traces.nc")
+ if not vi_traces_path.exists():
+ _logger.warning(f"vi_traces.nc file does not exist in {model_dir}.")
+ idata_dict["idata_vi"] = None
+ else:
+ idata_dict["idata_vi"] = az.from_netcdf(vi_traces_path)
+
+ return idata_dict
+
+ def __getstate__(self):
+ """Get the state of the model for pickling.
+
+ This method is called when pickling the model.
+ It returns a dictionary containing the constructor
+ arguments needed to recreate the model instance.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the constructor arguments
+ under the key 'constructor_args'.
+ """
+ # Provide a clear error when the initialization snapshot is missing or
+ # empty. This makes the contract explicit and avoids an AttributeError
+ # that is easy to miss for subclasses that forget to capture init args.
+ if not hasattr(self, "_init_args") or not self._init_args:
+ raise RuntimeError(
+ "Model state missing initialization snapshot; ensure subclasses "
+ "call _store_init_args(locals(), kwargs) early in __init__"
+ )
+
+ state = {"constructor_args": self._init_args}
+ return state
+
+ def __setstate__(self, state):
+ """Set the state of the model when unpickling.
+
+ This method is called when unpickling the model. It creates a new instance
+ using the constructor arguments stored in the state dictionary,
+ and copies its attributes to the current instance.
+
+ Parameters
+ ----------
+ state : dict
+ A dictionary containing the constructor arguments under the key
+ 'constructor_args'.
+ """
+ new_instance = self.__class__(**state["constructor_args"])
+ self.__dict__ = new_instance.__dict__
+
+ def __repr__(self) -> str:
+ """Create a representation of the model."""
+ output = [
+ "Hierarchical Sequential Sampling Model",
+ f"Model: {self.model_name}\n",
+ f"Response variable: {self.response_str}",
+ f"Likelihood: {self.loglik_kind}",
+ f"Observations: {len(self.data)}\n",
+ "Parameters:\n",
+ ]
+
+ for param in self.params.values():
+ if param.name == "p_outlier":
+ continue
+ output.append(f"{param.name}:")
+
+ component = self.model.components[param.name]
+
+ # Regression case:
+ if param.is_regression:
+ assert isinstance(component, DistributionalComponent)
+ output.append(f" Formula: {param.formula}")
+ output.append(" Priors:")
+ intercept_term = component.intercept_term
+ if intercept_term is not None:
+ output.append(_print_prior(intercept_term))
+ for _, common_term in component.common_terms.items():
+ output.append(_print_prior(common_term))
+ for _, group_specific_term in component.group_specific_terms.items():
+ output.append(_print_prior(group_specific_term))
+ output.append(f" Link: {param.link}")
+ # None regression case
+ else:
+ if param.prior is None:
+ prior = (
+ component.intercept_term.prior
+ if param.is_parent
+ else component.prior
+ )
+ else:
+ prior = param.prior
+ output.append(f" Prior: {prior}")
+ output.append(f" Explicit bounds: {param.bounds}")
+ output.append(
+ " (ignored due to link function)"
+ if self.link_settings is not None
+ else ""
+ )
+
+ # TODO: Handle p_outlier regression correctly here.
+ if self.p_outlier is not None:
+ output.append("")
+ output.append(f"Lapse probability: {self.p_outlier.prior}")
+ output.append(f"Lapse distribution: {self.lapse}")
+
+ return "\n".join(output)
+
+ def __str__(self) -> str:
+ """Create a string representation of the model."""
+ return self.__repr__()
+
+ @property
+ def traces(self) -> az.InferenceData | pm.Approximation:
+ """Return the trace of the model after sampling.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been sampled yet.
+
+ Returns
+ -------
+ az.InferenceData | pm.Approximation
+ The trace of the model after the last call to `sample()`.
+ """
+ if not self._inference_obj:
+ raise ValueError("Please sample the model first.")
+
+ return self._inference_obj
+
+ @property
+ def vi_idata(self) -> az.InferenceData:
+ """Return the variational inference approximation object.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been sampled yet.
+
+ Returns
+ -------
+ az.InferenceData
+ The variational inference approximation object.
+ """
+ if not self._inference_obj_vi:
+ raise ValueError(
+ "Please run variational inference first, "
+ "no variational posterior attached."
+ )
+
+ return self._inference_obj_vi
+
+ @property
+ def vi_approx(self) -> pm.Approximation:
+ """Return the variational inference approximation object.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been sampled yet.
+
+ Returns
+ -------
+ pm.Approximation
+ The variational inference approximation object.
+ """
+ if not self._vi_approx:
+ raise ValueError(
+ "Please run variational inference first, "
+ "no variational approximation attached."
+ )
+
+ return self._vi_approx
+
+ @property
+ def map(self) -> dict:
+ """Return the MAP estimates of the model parameters.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been sampled yet.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the MAP estimates of the model parameters.
+ """
+ if not self._map_dict:
+ raise ValueError("Please compute map first.")
+
+ return self._map_dict
+
+ @property
+ def initvals(self) -> dict:
+ """Return the initial values of the model parameters for sampling.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the initial values of the model parameters.
+ This dict serves as the default for initial values, and can be passed
+ directly to the `.sample()` function.
+ """
+ if self._initvals == {}:
+ self._initvals = self.initial_point()
+ return self._initvals
+
+ def _check_lapse(self, lapse):
+ """Determine if p_outlier and lapse is specified correctly."""
+ # Basically, avoid situations where only one of them is specified.
+ if self.has_lapse and lapse is None:
+ raise ValueError(
+ "You have specified `p_outlier`. Please also specify `lapse`."
+ )
+ if lapse is not None and not self.has_lapse:
+ _logger.warning(
+ "You have specified the `lapse` argument to include a lapse "
+ + "distribution, but `p_outlier` is set to either 0 or None. "
+ + "Your lapse distribution will be ignored."
+ )
+ if "p_outlier" in self.list_params and self.list_params[-1] != "p_outlier":
+ raise ValueError(
+ "Please do not include 'p_outlier' in `list_params`. "
+ + "We automatically append it to `list_params` when `p_outlier` "
+ + "parameter is not None"
+ )
+
+ def _get_deterministic_var_names(self, idata) -> list[str]:
+ """Filter out the deterministic variables in var_names."""
+ var_names = [
+ f"~{param_name}"
+ for param_name, param in self.params.items()
+ if (param.is_regression)
+ ]
+
+ if f"{self._parent}_mean" in idata["posterior"].data_vars:
+ var_names.append(f"~{self._parent}_mean")
+
+ # Parent parameters (always regression implicitly)
+ # which don't have a formula attached
+ # should be dropped from var_names, since the actual
+ # parent name shows up as a regression.
+ if f"{self._parent}" in idata["posterior"].data_vars:
+ if self.params[self._parent].formula is None:
+ # Drop from var_names
+ var_names = [var for var in var_names if var != f"~{self._parent}"]
+
+ return var_names
+
+ def _drop_parent_str_from_idata(
+ self, idata: az.InferenceData | None
+ ) -> az.InferenceData:
+ """Drop the parent_str variable from an InferenceData object.
+
+ Parameters
+ ----------
+ idata
+ The InferenceData object to be modified.
+
+ Returns
+ -------
+ xr.Dataset
+ The modified InferenceData object.
+ """
+ if idata is None:
+ raise ValueError("Please provide an InferenceData object.")
+ else:
+ for group in idata.groups():
+ if ("rt,response_mean" in idata[group].data_vars) and (
+ self._parent not in idata[group].data_vars
+ ):
+ setattr(
+ idata,
+ group,
+ idata[group].rename({"rt,response_mean": self._parent}),
+ )
+ return idata
+
+ def _postprocess_initvals_deterministic(
+ self, initval_settings: dict = INITVAL_SETTINGS
+ ) -> None:
+ """Set initial values for subset of parameters."""
+ self._initvals = self.initial_point()
+ # Consider case where link functions are set to 'log_logit'
+ # or 'None'
+ if self.link_settings not in ["log_logit", None]:
+ _logger.info(
+ "Not preprocessing initial values, "
+ + "because none of the two standard link settings are chosen!"
+ )
+ return None
+
+ # Set initial values for particular parameters
+ for name_, starting_value in self.pymc_model.initial_point().items():
+ # strip name of `_log__` and `_interval__` suffixes
+ name_tmp = name_.replace("_log__", "").replace("_interval__", "")
+
+ # We need to check if the parameter is actually backed by
+ # a regression.
+
+ # If not, we don't actually apply a link function to it as per default.
+ # Therefore we need to apply the initial value strategy corresponding
+ # to 'None' link function.
+
+ # If the user actively supplies a link function, the user
+ # should also have supplied an initial value insofar it matters.
+
+ if self.params[self._get_prefix(name_tmp)].is_regression:
+ param_link_setting = self.link_settings
+ else:
+ param_link_setting = None
+ if name_tmp in initval_settings[param_link_setting].keys():
+ if self._check_if_initval_user_supplied(name_tmp):
+ _logger.info(
+ "User supplied initial value detected for %s, \n"
+ " skipping overwrite with default value.",
+ name_tmp,
+ )
+ continue
+
+ # Apply specific settings from initval_settings dictionary
+ dtype = self._initvals[name_tmp].dtype
+ self._initvals[name_tmp] = np.array(
+ initval_settings[param_link_setting][name_tmp]
+ ).astype(dtype)
+
+ def _get_prefix(self, name_str: str) -> str:
+ """Resolve parameter prefix, handling underscore-containing RL param names.
+
+ The base-class implementation splits ``name_str`` on the first ``_`` and
+ returns that single token (e.g. ``"rl_alpha_Intercept" → "rl"``), which
+ breaks for RL parameters whose names contain underscores. It also uses a
+ substring check (``"p_outlier" in name_str``) for the lapse parameter,
+ which would misfire for any parameter whose name merely *contains* that
+ substring.
+
+ This override replaces both heuristics with a single longest-prefix-first
+ token search: split on ``_``, then try joining 1…N tokens (longest first)
+ until a candidate is found in ``self.params``. This is both correct for
+ multi-token RL param names and collision-free for ``p_outlier``.
+ """
+ if "_" in name_str:
+ parts = name_str.split("_")
+ for i in range(len(parts), 0, -1):
+ candidate = "_".join(parts[:i])
+ if hasattr(self, "params") and candidate in self.params:
+ return candidate
+ return name_str
+
+ def _check_if_initval_user_supplied(
+ self,
+ name_str: str,
+ return_value: bool = False,
+ ) -> bool | float | int | np.ndarray | dict[str, Any] | None:
+ """Check if initial value is user-supplied."""
+ # The function assumes that the name_str is either raw parameter name
+ # or `paramname_Intercept`, because we only really provide special default
+ # initial values for those types of parameters
+
+ # `p_outlier` is the only basic parameter floating around that has
+ # an underscore in it's name.
+ # We need to handle it separately. (Renaming might be better...)
+ if "_" in name_str:
+ if "p_outlier" not in name_str:
+ name_str_prefix = name_str.split("_")[0]
+ # name_str_suffix = "".join(name_str.split("_")[1:])
+ name_str_suffix = name_str[len(name_str_prefix + "_") :]
+ else:
+ name_str_prefix = "p_outlier"
+ if name_str == "p_outlier":
+ name_str_suffix = ""
+ else:
+ # name_str_suffix = "".join(name_str.split("_")[2:])
+ name_str_suffix = name_str[len("p_outlier_") :]
+ else:
+ name_str_prefix = name_str
+ name_str_suffix = ""
+
+ tmp_param = name_str_prefix
+ if tmp_param == self._parent:
+ # If the parameter was parent it is automatically treated as a
+ # regression.
+ if not name_str_suffix:
+ # No suffix --> Intercept
+ if isinstance(prior_tmp := self.params[tmp_param].prior, dict):
+ args_tmp = getattr(prior_tmp["Intercept"], "args")
+ if return_value:
+ return args_tmp.get("initval", None)
+ else:
+ return "initval" in args_tmp
+ else:
+ if return_value:
+ return None
+ return False
+ else:
+ # If the parameter has a suffix --> use it
+ if isinstance(prior_tmp := self.params[tmp_param].prior, dict):
+ args_tmp = getattr(prior_tmp[name_str_suffix], "args")
+ if return_value:
+ return args_tmp.get("initval", None)
+ else:
+ return "initval" in args_tmp
+ else:
+ if return_value:
+ return None
+ else:
+ return False
+ else:
+ # If the parameter is not a parent, it is treated as a regression
+ # only when actively specified as such.
+ if not name_str_suffix:
+ # If no suffix --> treat as basic parameter.
+ if isinstance(self.params[tmp_param].prior, float) or isinstance(
+ self.params[tmp_param].prior, np.ndarray
+ ):
+ if return_value:
+ return self.params[tmp_param].prior
+ else:
+ return True
+ elif isinstance(self.params[tmp_param].prior, bmb.Prior):
+ args_tmp = getattr(self.params[tmp_param].prior, "args")
+ if "initval" in args_tmp:
+ if return_value:
+ return args_tmp["initval"]
+ else:
+ return True
+ else:
+ if return_value:
+ return None
+ else:
+ return False
+ else:
+ if return_value:
+ return None
+ else:
+ return False
+ else:
+ # If suffix --> treat as regression and use suffix
+ if isinstance(prior_tmp := self.params[tmp_param].prior, dict):
+ args_tmp = getattr(prior_tmp[name_str_suffix], "args")
+ if return_value:
+ return args_tmp.get("initval", None)
+ else:
+ return "initval" in args_tmp
+ else:
+ if return_value:
+ return None
+ else:
+ return False
+
+ def _jitter_initvals(
+ self, jitter_epsilon: float = 0.01, vector_only: bool = False
+ ) -> None:
+ """Apply controlled jitter to initial values."""
+ if vector_only:
+ self.__jitter_initvals_vector_only(jitter_epsilon)
+ else:
+ self.__jitter_initvals_all(jitter_epsilon)
+
+ def __jitter_initvals_vector_only(self, jitter_epsilon: float) -> None:
+ # Note: Calling our initial point function here
+ # --> operate on untransformed variables
+ initial_point_dict = self.initvals
+ for name_, starting_value in initial_point_dict.items():
+ name_tmp = name_.replace("_log__", "").replace("_interval__", "")
+ if starting_value.ndim != 0 and starting_value.shape[0] != 1:
+ starting_value_tmp = starting_value + np.random.uniform(
+ -jitter_epsilon, jitter_epsilon, starting_value.shape
+ ).astype(np.float32)
+
+ # Note: self._initvals shouldn't be None when this is called
+ dtype = self._initvals[name_tmp].dtype
+ self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype)
+
+ def __jitter_initvals_all(self, jitter_epsilon: float) -> None:
+ # Note: Calling our initial point function here
+ # --> operate on untransformed variables
+ initial_point_dict = self.initvals
+ # initial_point_dict = self.pymc_model.initial_point()
+ for name_, starting_value in initial_point_dict.items():
+ name_tmp = name_.replace("_log__", "").replace("_interval__", "")
+ starting_value_tmp = starting_value + np.random.uniform(
+ -jitter_epsilon, jitter_epsilon, starting_value.shape
+ ).astype(np.float32)
+
+ dtype = self.initvals[name_tmp].dtype
+ self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype)
diff --git a/src/hssm/config.py b/src/hssm/config.py
index 9b9e0b12d..f223b859a 100644
--- a/src/hssm/config.py
+++ b/src/hssm/config.py
@@ -20,27 +20,17 @@
if TYPE_CHECKING:
from pytensor.tensor.random.op import RandomVariable
-# ====== Centralized RLSSM defaults =====
+import logging
+
+from ssms.config import model_config as ssms_model_config
+
+_logger = logging.getLogger("hssm")
+
+
+# ====== Centralized SSM defaults =====
DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"]
-DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"]
DEFAULT_SSM_CHOICES = (0, 1)
-RLSSM_REQUIRED_FIELDS = (
- "model_name",
- "description",
- "list_params",
- "bounds",
- "params_default",
- "data",
- "choices",
- "decision_process",
- "learning_process",
- "response",
- "decision_process_loglik_kind",
- "learning_process_loglik_kind",
- "extra_fields",
-)
-
ParamSpec = Union[float, dict[str, Any], Prior, None]
@@ -68,6 +58,9 @@ class BaseModelConfig(ABC):
# Additional data requirements
extra_fields: list[str] | None = None
+ # Random variable (simulator) for posterior predictive sampling
+ rv: Any | None = None
+
@abstractmethod
def validate(self) -> None:
"""Validate configuration. Must be implemented by subclasses."""
@@ -78,6 +71,21 @@ def get_defaults(self, param: str) -> Any:
"""Get default values for a parameter. Must be implemented by subclasses."""
...
+ @property
+ def n_params(self) -> int | None:
+ """Return the number of parameters."""
+ return len(self.list_params) if self.list_params else None
+
+ @property
+ def n_extra_fields(self) -> int | None:
+ """Return the number of extra fields."""
+ return len(self.extra_fields) if self.extra_fields else None
+
+ @property
+ def is_choice_only(self) -> bool:
+ """Return whether the model is choice-only (no RT)."""
+ return self.response is not None and len(self.response) == 1
+
@dataclass
class Config(BaseModelConfig):
@@ -179,7 +187,7 @@ def from_defaults(
return Config(
model_name=model_name,
loglik_kind=loglik_kind,
- response=DEFAULT_RLSSM_OBSERVED_DATA,
+ response=DEFAULT_SSM_OBSERVED_DATA,
)
def update_loglik(self, loglik: Any | None) -> None:
@@ -200,7 +208,7 @@ def update_choices(self, choices: tuple[int, ...] | None) -> None:
Parameters
----------
- choices : tuple[int, ...]
+ choices : tuple[int, ...] | None
A tuple of choices.
"""
if choices is None:
@@ -217,7 +225,7 @@ def update_config(self, user_config: ModelConfig) -> None:
User specified ModelConfig used update self.
"""
if user_config.response is not None:
- self.response = user_config.response
+ self.response = list(user_config.response) # type: ignore[assignment]
if user_config.list_params is not None:
self.list_params = user_config.list_params
if user_config.choices is not None:
@@ -260,208 +268,62 @@ def get_defaults(
"""
return self.default_priors.get(param), self.bounds.get(param)
- @property
- def is_choice_only(self) -> bool:
- """Check if the model is a choice-only model."""
- # Treat both None and an empty list as invalid configurations.
- if not self.response:
- raise ValueError(
- "Please provide at least one `response` column in the configuration."
- )
- return len(self.response) == 1
-
-
-@dataclass
-class RLSSMConfig(BaseModelConfig):
- """Config for reinforcement learning + sequential sampling models.
-
- This configuration class is designed for models that combine reinforcement
- learning processes with sequential sampling decision models (RLSSM).
- """
-
- decision_process_loglik_kind: str = field(kw_only=True)
- learning_process_loglik_kind: str = field(kw_only=True)
- params_default: list[float] = field(kw_only=True)
- decision_process: str | ModelConfig = field(kw_only=True)
- learning_process: dict[str, Any] = field(kw_only=True)
-
- def __post_init__(self):
- """Set default loglik_kind for RLSSM models if not provided."""
- if self.loglik_kind is None:
- self.loglik_kind = "approx_differentiable"
-
- @property
- def n_params(self) -> int | None:
- """Return the number of parameters."""
- return len(self.list_params) if self.list_params else None
-
- @property
- def n_extra_fields(self) -> int | None:
- """Return the number of extra fields."""
- return len(self.extra_fields) if self.extra_fields else None
-
@classmethod
- def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]):
- """
- Create RLSSMConfig from a configuration dictionary.
-
- Parameters
- ----------
- model_name : str
- The name of the RLSSM model.
- config_dict : dict[str, Any]
- Dictionary containing model configuration. Expected keys:
- - description: Model description (optional)
- - list_params: List of parameter names (required)
- - extra_fields: List of extra field names from data (required)
- - params_default: Default parameter values (required)
- - bounds: Parameter bounds (required)
- - response: Response column names (required)
- - choices: Valid choice values (required)
- - decision_process: Decision process specification (required)
- - learning_process: Learning process functions (required)
- - decision_process_loglik_kind: Likelihood kind for decision process
- (required)
- - learning_process_loglik_kind: Likelihood kind for learning process
- (required)
-
- Returns
- -------
- RLSSMConfig
- Configured RLSSM model configuration object.
- """
- # Check for required fields and raise explicit errors if missing
- for field_name in RLSSM_REQUIRED_FIELDS:
- if field_name not in config_dict or config_dict[field_name] is None:
- raise ValueError(f"{field_name} must be provided in config_dict")
-
- return cls(
- model_name=model_name,
- description=config_dict.get("description"),
- list_params=config_dict["list_params"],
- extra_fields=config_dict.get("extra_fields"),
- params_default=config_dict["params_default"],
- decision_process=config_dict["decision_process"],
- learning_process=config_dict["learning_process"],
- bounds=config_dict.get("bounds", {}),
- response=config_dict["response"],
- choices=config_dict["choices"],
- decision_process_loglik_kind=config_dict["decision_process_loglik_kind"],
- learning_process_loglik_kind=config_dict["learning_process_loglik_kind"],
- )
-
- def validate(self) -> None:
- """Validate RLSSM configuration.
-
- Raises
- ------
- ValueError
- If required fields are missing or inconsistent.
+ def _build_model_config(
+ cls,
+ model: SupportedModels | str,
+ loglik_kind: LoglikKind | None,
+ model_config: ModelConfig | dict | None,
+ choices: list[int] | tuple[int, ...] | None,
+ loglik: Any = None,
+ ) -> Config:
+ """Build and return a validated Config for standard HSSM models.
+
+ Resolves defaults, normalizes dict/ModelConfig overrides, applies
+ choices and loglik precedence rules, then validates before returning.
"""
- if self.response is None:
- raise ValueError("Please provide `response` columns in the configuration.")
- if self.list_params is None:
- raise ValueError("Please provide `list_params` in the configuration.")
- if self.choices is None:
- raise ValueError("Please provide `choices` in the configuration.")
- if self.decision_process is None:
- raise ValueError("Please specify a `decision_process`.")
+ config = cls.from_defaults(model, loglik_kind)
- # Validate parameter defaults consistency
- if self.params_default and self.list_params:
- if len(self.params_default) != len(self.list_params):
- raise ValueError(
- f"params_default length ({len(self.params_default)}) doesn't "
- f"match list_params length ({len(self.list_params)})"
- )
-
- def get_defaults(
- self, param: str
- ) -> tuple[float | None, tuple[float, float] | None]:
- """Return default value and bounds for a parameter.
+ if model_config is not None:
+ final_config = _normalize_model_config_with_choices(model_config, choices)
+ config.update_config(final_config)
- Parameters
- ----------
- param
- The name of the parameter.
+ # No model_config provided: apply `choices` when appropriate.
+ # If caller passed a SupportedModels string, ignore explicit `choices`.
+ if (
+ model in get_args(SupportedModels)
+ and choices is not None
+ and model_config is None
+ ):
+ _logger.info(
+ "Model string is in SupportedModels. Ignoring choices arguments."
+ )
- Returns
- -------
- tuple
- A tuple of (default_value, bounds) where:
- - default_value is a float or None if not found
- - bounds is a tuple (lower, upper) or None if not found
- """
- # Try to find the parameter in list_params and get its default value
- default_val = None
- if self.list_params is not None:
- try:
- param_idx = self.list_params.index(param)
- if self.params_default and param_idx < len(self.params_default):
- default_val = self.params_default[param_idx]
- except ValueError:
- # Parameter not in list_params
- pass
-
- return default_val, self.bounds.get(param)
-
- def to_config(self) -> Config:
- """Convert to standard Config for compatibility with HSSM.
-
- This method transforms the RLSSM configuration into a standard Config
- object that can be used with the existing HSSM infrastructure.
-
- Returns
- -------
- Config
- A Config object with RLSSM parameters mapped to standard format.
-
- Notes
- -----
- The transformation converts params_default list to default_priors dict,
- mapping parameter names to their default values.
- """
- # Validate parameter defaults consistency before conversion
- if self.params_default and self.list_params:
- if len(self.params_default) != len(self.list_params):
- raise ValueError(
- f"params_default length ({len(self.params_default)}) doesn't "
- f"match list_params length ({len(self.list_params)}). "
- "This would result in silent data loss during conversion."
+ # If model is not a supported built-in, prefer explicit choices or
+ # fall back to ssms-simulators lookup when available.
+ if model not in get_args(SupportedModels):
+ if choices is not None:
+ config.update_choices(choices)
+ elif model in ssms_model_config:
+ config.update_choices(ssms_model_config[model]["choices"])
+ _logger.info(
+ "choices argument passed as None, "
+ "but found %s in ssms-simulators. "
+ "Using choices, from ssm-simulators configs: %s",
+ model,
+ ssms_model_config[model]["choices"],
)
- # Transform params_default list to default_priors dict
- default_priors = (
- {
- param: default
- for param, default in zip(self.list_params, self.params_default)
- }
- if self.list_params and self.params_default
- else {}
- )
-
- return Config(
- model_name=self.model_name,
- loglik_kind=self.loglik_kind,
- response=self.response,
- choices=self.choices,
- list_params=self.list_params,
- description=self.description,
- bounds=self.bounds,
- default_priors=cast(
- "dict[str, float | dict[str, Any] | Any | None]", default_priors
- ),
- extra_fields=self.extra_fields,
- backend=self.backend or "jax", # RLSSM typically uses JAX
- loglik=self.loglik,
- )
+ config.update_loglik(loglik)
+ config.validate()
+ return config
@dataclass
class ModelConfig:
"""Representation for model_config provided by the user."""
- response: list[str] | None = None
+ response: tuple[str, ...] | None = None
list_params: list[str] | None = None
choices: tuple[int, ...] | None = None
default_priors: dict[str, ParamSpec] = field(default_factory=dict)
@@ -469,3 +331,46 @@ class ModelConfig:
backend: Literal["jax", "pytensor"] | None = None
rv: RandomVariable | None = None
extra_fields: list[str] | None = None
+
+
+def _normalize_model_config_with_choices(
+ model_config: "ModelConfig" | dict[str, Any],
+ choices: list[int] | tuple[int, ...] | None,
+) -> "ModelConfig":
+ """Normalize a user-supplied model_config and apply choices.
+
+ Returns a fresh :class:`ModelConfig` instance and does not mutate the
+ caller's objects. If both ``model_config`` and ``choices`` are provided
+ and ``model_config`` already contains ``choices``, the value from
+ ``model_config`` wins (and a log entry is emitted).
+ """
+ # Normalize input to a mutable dict so we can coerce and avoid mutating
+ # the caller's objects. Build a fresh ModelConfig from that dict.
+ if isinstance(model_config, ModelConfig):
+ mc: dict[str, Any] = {
+ k: getattr(model_config, k) for k in model_config.__dataclass_fields__
+ }
+ else:
+ mc = model_config.copy()
+
+ # Coerce any existing choices on the input to a tuple for immutability
+ if mc.get("choices") is not None:
+ mc["choices"] = tuple(mc["choices"])
+
+ # If caller didn't provide an explicit `choices` argument, return the
+ # normalized ModelConfig built from the input (fresh instance).
+ if choices is None:
+ return ModelConfig(**{k: v for k, v in mc.items() if v is not None})
+
+ # Caller provided choices; prefer the one embedded in model_config if
+ # present, otherwise apply the provided value (coerced to tuple).
+ if mc.get("choices") is not None:
+ _logger.info(
+ "choices list provided in both model_config and "
+ "as an argument directly. Using the one provided in "
+ "model_config. We recommend providing choices in model_config."
+ )
+ return ModelConfig(**{k: v for k, v in mc.items() if v is not None})
+
+ mc["choices"] = tuple(choices)
+ return ModelConfig(**{k: v for k, v in mc.items() if v is not None})
diff --git a/src/hssm/data_validator.py b/src/hssm/data_validator.py
index 01a6373ee..4f6871c2e 100644
--- a/src/hssm/data_validator.py
+++ b/src/hssm/data_validator.py
@@ -6,8 +6,6 @@
import numpy as np
import pandas as pd
-from hssm.defaults import MissingDataNetwork
-
_logger = logging.getLogger("hssm")
@@ -103,55 +101,6 @@ def _post_check_data_sanity(self):
# remaining check on missing data
# which are coming AFTER the data validation
# in the HSSM class, into this function?
- def _handle_missing_data_and_deadline(self):
- """Handle missing data and deadline."""
- if not self.missing_data and not self.deadline:
- # In the case of choice only model, we don't need to do anything with the
- # data.
- if self.is_choice_only:
- return
- # In the case where missing_data is set to False, we need to drop the
- # cases where rt = na_value
- if pd.isna(self.missing_data_value):
- na_dropped = self.data.dropna(subset=["rt"])
- else:
- na_dropped = self.data.loc[
- self.data["rt"] != self.missing_data_value, :
- ]
-
- if len(na_dropped) != len(self.data):
- warnings.warn(
- "`missing_data` is set to False, "
- + "but you have missing data in your dataset. "
- + "Missing data will be dropped.",
- stacklevel=2,
- )
- self.data = na_dropped
-
- elif self.missing_data and not self.deadline:
- # In the case where missing_data is set to True, we need to replace the
- # missing data with a specified na_value
-
- # Create a shallow copy to avoid modifying the original dataframe
- if pd.isna(self.missing_data_value):
- self.data["rt"] = self.data["rt"].fillna(-999.0)
- else:
- self.data["rt"] = self.data["rt"].replace(
- self.missing_data_value, -999.0
- )
-
- else: # deadline = True
- if self.deadline_name not in self.data.columns:
- raise ValueError(
- "You have specified that your data has deadline, but "
- + f"`{self.deadline_name}` is not found in your dataset."
- )
- else:
- self.data.loc[:, "rt"] = np.where(
- self.data["rt"] < self.data[self.deadline_name],
- self.data["rt"],
- -999.0,
- )
def _update_extra_fields(self, new_data: pd.DataFrame | None = None):
"""Update the extra fields data in self.model_distribution.
@@ -174,45 +123,6 @@ def _update_extra_fields(self, new_data: pd.DataFrame | None = None):
new_data[field].values for field in self.extra_fields
]
- @staticmethod
- def _set_missing_data_and_deadline(
- missing_data: bool, deadline: bool, data: pd.DataFrame
- ) -> MissingDataNetwork:
- """Set missing data and deadline."""
- network = MissingDataNetwork.NONE
- if not missing_data:
- return network
- if missing_data and not deadline:
- network = MissingDataNetwork.CPN
- elif missing_data and deadline:
- network = MissingDataNetwork.OPN
- # AF-TODO: GONOGO case not yet correctly implemented
- # else:
- # # TODO: This won't behave as expected yet, GONOGO needs to be split
- # # into a deadline case and a non-deadline case.
- # network = MissingDataNetwork.GONOGO
-
- if np.all(data["rt"] == -999.0):
- if network in [MissingDataNetwork.CPN, MissingDataNetwork.OPN]:
- # AF-TODO: I think we should allow invalid-only datasets.
- raise ValueError(
- "`missing_data` is set to True, but you have no valid data in your "
- "dataset."
- )
- # AF-TODO: This one needs refinement for GONOGO case
- # elif network == MissingDataNetwork.OPN:
- # raise ValueError(
- # "`deadline` is set to True and `missing_data` is set to True, "
- # "but ."
- # )
- # else:
- # raise ValueError(
- # "`missing_data` and `deadline` are both set to True,
- # "but you have "
- # "no missing data and/or no rts exceeding the deadline."
- # )
- return network
-
def _validate_choices(self):
"""
Ensure that `choices` is provided (not None).
diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py
index 2bf01d630..1e43582d4 100644
--- a/src/hssm/hssm.py
+++ b/src/hssm/hssm.py
@@ -6,59 +6,40 @@
This file defines the entry class HSSM.
"""
-import datetime
import logging
-import typing
from copy import deepcopy
-from inspect import isclass, signature
+from inspect import isclass
from os import PathLike
-from pathlib import Path
-from typing import Any, Callable, Literal, Optional, Union, cast, get_args
+from typing import TYPE_CHECKING, Any, Callable, Literal
+from typing import cast as typing_cast
-import arviz as az
import bambi as bmb
-import cloudpickle as cpickle
-import matplotlib as mpl
-import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
-import pytensor
-import seaborn as sns
-import xarray as xr
-from bambi.model_components import DistributionalComponent
-from bambi.transformations import transformations_namespace
-from pymc.model.transform.conditioning import do
-from ssms.config import model_config as ssms_model_config
from hssm._types import LoglikKind, SupportedModels
-from hssm.data_validator import DataValidatorMixin
from hssm.defaults import (
INITVAL_JITTER_SETTINGS,
- INITVAL_SETTINGS,
MissingDataNetwork,
missing_data_networks_suffix,
)
from hssm.distribution_utils import (
assemble_callables,
make_distribution,
- make_family,
make_hssm_rv,
make_likelihood_callable,
make_missing_data_callable,
)
from hssm.utils import (
- _compute_log_likelihood,
- _get_alias_dict,
- _print_prior,
_rearrange_data,
- _split_array,
)
-from . import plotting
+from .base import HSSMBase
from .config import Config, ModelConfig
-from .param import Params
-from .param import UserParam as Param
+
+if TYPE_CHECKING:
+ from pytensor.graph.op import Op
_logger = logging.getLogger("hssm")
@@ -98,7 +79,7 @@ def __get__(self, instance, owner): # noqa: D105
return self.fget(owner)
-class HSSM(DataValidatorMixin):
+class HSSM(HSSMBase):
"""The basic Hierarchical Sequential Sampling Model (HSSM) class.
Parameters
@@ -124,9 +105,12 @@ class HSSM(DataValidatorMixin):
model. If left unspecified, defaults will be used for all parameter
specifications. Defaults to None.
model_config : optional
- A dictionary containing the model configuration information. If None is
- provided, defaults will be used if there are any. Defaults to None.
- Fields for this `dict` are usually:
+ A :class:`~hssm.config.BaseModelConfig` / :class:`~hssm.config.Config`
+ instance or a ``dict`` with model configuration information. The
+ constructor accepts a typed ``ModelConfig`` or a plain ``dict``; when a
+ ``dict`` is provided the library will build a typed :class:`Config`
+ via the factory function. If ``None`` is provided, defaults will be
+ used where available. Fields for this config are usually:
- `"list_params"`: a list of parameters indicating the parameters of the model.
The order in which the parameters are specified in this list is important.
@@ -276,1753 +260,51 @@ def __init__(
data: pd.DataFrame,
model: SupportedModels | str = "ddm",
choices: list[int] | None = None,
- include: list[dict[str, Any] | Param] | None = None,
+ include: list[dict[str, Any] | Any] | None = None,
model_config: ModelConfig | dict | None = None,
loglik: (
- str | PathLike | Callable | pytensor.graph.Op | type[pm.Distribution] | None
+ str | PathLike | Callable | pm.Distribution | type[pm.Distribution] | None
) = None,
loglik_kind: LoglikKind | None = None,
p_outlier: float | dict | bmb.Prior | None = 0.05,
- lapse: float | dict | bmb.Prior | None = None,
+ lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0),
global_formula: str | None = None,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = "safe",
extra_namespace: dict[str, Any] | None = None,
missing_data: bool | float = False,
deadline: bool | str = False,
- loglik_missing_data: (
- str | PathLike | Callable | pytensor.graph.Op | None
- ) = None,
+ loglik_missing_data: (str | PathLike | Callable | None) = None,
process_initvals: bool = True,
initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"],
- **kwargs,
- ):
- # Attach arguments to the instance
- # so that we can easily define some
- # methods that need to access these
- # arguments (context: pickling / save - load).
-
- # Define a dict with all call arguments:
- self._init_args = {
- k: v for k, v in locals().items() if k not in ["self", "kwargs"]
- }
- if kwargs:
- self._init_args.update(kwargs)
-
- self.data = data.copy()
- self._inference_obj: az.InferenceData | None = None
- self._initvals: dict[str, Any] = {}
- self.initval_jitter = initval_jitter
- self._inference_obj_vi: pm.Approximation | None = None
- self._vi_approx = None
- self._map_dict = None
- self.global_formula = global_formula
-
- self.link_settings = link_settings
- self.prior_settings = prior_settings
-
- self.missing_data_value = -999.0
-
- additional_namespace = transformations_namespace.copy()
- if extra_namespace is not None:
- additional_namespace.update(extra_namespace)
- self.additional_namespace = additional_namespace
-
- # Construct a model_config from defaults
- self.model_config = Config.from_defaults(model, loglik_kind)
- # Update defaults with user-provided config, if any
- if model_config is not None:
- if isinstance(model_config, dict):
- if "choices" not in model_config:
- if choices is not None:
- model_config["choices"] = tuple(choices)
- else:
- if choices is not None:
- _logger.info(
- "choices list provided in both model_config and "
- "as an argument directly."
- " Using the one provided in model_config. \n"
- "We recommend providing choices in model_config."
- )
- elif isinstance(model_config, ModelConfig):
- if model_config.choices is None:
- if choices is not None:
- model_config.choices = tuple(choices)
- else:
- if choices is not None:
- _logger.info(
- "choices list provided in both model_config and "
- "as an argument directly."
- " Using the one provided in model_config. \n"
- "We recommend providing choices in model_config."
- )
-
- self.model_config.update_config(
- model_config
- if isinstance(model_config, ModelConfig)
- else ModelConfig(**model_config) # also serves as dict validation
- )
- else:
- # Model config is not provided, but at this point was constructed from
- # defaults.
- if model not in typing.get_args(SupportedModels):
- # TODO: ideally use self.supported_models above but mypy doesn't like it
- if choices is not None:
- self.model_config.update_choices(choices)
- elif model in ssms_model_config:
- self.model_config.update_choices(
- ssms_model_config[model]["choices"]
- )
- _logger.info(
- "choices argument passed as None, "
- "but found %s in ssms-simulators. "
- "Using choices, from ssm-simulators configs: %s",
- model,
- ssms_model_config[model]["choices"],
- )
- else:
- # Model config already constructed from defaults, and model string is
- # in SupportedModels. So we are guaranteed that choices are in
- # self.model_config already.
-
- if choices is not None:
- _logger.info(
- "Model string is in SupportedModels."
- " Ignoring choices arguments."
- )
-
- # Update loglik with user-provided value
- self.model_config.update_loglik(loglik)
- # Ensure that all required fields are valid
- self.model_config.validate()
-
- # Set up shortcuts so old code will work
- # TODO: add to HSSMBase
- self.response = self.model_config.response[:]
- self.list_params = self.model_config.list_params
- self.choices = self.model_config.choices
- self.model_name = self.model_config.model_name
- self.loglik = self.model_config.loglik
- self.loglik_kind = self.model_config.loglik_kind
- self.extra_fields = self.model_config.extra_fields
-
- # TODO: add to HSSMBase
- self.response = cast("list[str]", self.response)
- self.is_choice_only: bool = self.model_config.is_choice_only
-
- if self.choices is None:
- raise ValueError(
- "`choices` must be provided either in `model_config` or as an argument."
- )
-
- self.n_choices = len(self.choices)
-
- self._validate_choices()
- self._pre_check_data_sanity()
-
- # Process missing data setting
- # AF-TODO: Could be a function in data validator?
- # TODO: Move to the MissingDataMixin class when we have it
- if self.is_choice_only and missing_data is not False:
- raise ValueError("Choice-only models cannot have missing data.")
-
- if not self.is_choice_only:
- if isinstance(missing_data, float):
- if not ((self.data.rt == missing_data).any()):
- raise ValueError(
- f"missing_data argument is provided as a float {missing_data}, "
- f"However, you have no RTs of {missing_data} in your dataset!"
- )
- else:
- self.missing_data = True
- self.missing_data_value = missing_data
- elif isinstance(missing_data, bool):
- if missing_data and (not (self.data.rt == -999.0).any()):
- raise ValueError(
- "missing_data argument is provided as True, "
- " so RTs of -999.0 are treated as missing. \n"
- "However, you have no RTs of -999.0 in your dataset!"
- )
- elif (not missing_data) and (self.data.rt == -999.0).any():
- # self.missing_data = True
- raise ValueError(
- "Missing data provided as False. \n"
- "However, you have RTs of -999.0 in your dataset!"
- )
- else:
- self.missing_data = missing_data
- else:
- raise ValueError(
- "missing_data argument must be a bool or a float! \n"
- f"You provided: {type(missing_data)}"
- )
- else:
- self.missing_data = False
-
- if isinstance(deadline, str):
- self.deadline = True
- self.deadline_name = deadline
- else:
- self.deadline = deadline
- self.deadline_name = "deadline"
-
- if (
- not self.missing_data and not self.deadline
- ) and loglik_missing_data is not None:
- raise ValueError(
- "You have specified a loglik_missing_data function, but you have not "
- + "set the missing_data or deadline flag to True."
- )
- self.loglik_missing_data = loglik_missing_data
+ **kwargs: Any,
+ ) -> None:
+ # ===== save/load serialisation =====
+ self._init_args = self._store_init_args(locals(), kwargs)
- # Update data based on missing_data and deadline
- self._handle_missing_data_and_deadline()
- # Set self.missing_data_network based on `missing_data` and `deadline`
- self.missing_data_network = self._set_missing_data_and_deadline(
- self.missing_data, self.deadline, self.data
+ # Build typed Config via factory
+ config = Config._build_model_config(
+ model, loglik_kind, model_config, choices, loglik
)
- if self.deadline:
- # self.response is a tuple (from Config); use concatenation.
- self.response.append(self.deadline_name)
-
- # Process lapse distribution
- self.has_lapse = p_outlier is not None and p_outlier != 0
- self._check_lapse(lapse)
-
- # Process all parameters
- self.params = Params.from_user_specs(
- model=self,
- include=[] if include is None else include,
- kwargs=kwargs,
+ super().__init__(
+ data=data,
+ model_config=config,
+ include=include,
p_outlier=p_outlier,
- )
-
- self._parent = self.params.parent
- self._parent_param = self.params.parent_param
-
- self._validate_fixed_vectors()
- self.formula, self.priors, self.link = self.params.parse_bambi(model=self)
-
- # For parameters that have a regression backend, apply bounds at the likelihood
- # level to ensure that the samples that are out of bounds
- # are discarded (replaced with a large negative value).
- self.bounds = {
- name: param.bounds
- for name, param in self.params.items()
- if param.is_regression and param.bounds is not None
- }
-
- # Set p_outlier and lapse
- self.p_outlier = self.params.get("p_outlier")
-
- self._post_check_data_sanity()
-
- self.model_distribution = self._make_model_distribution()
-
- self.family = make_family(
- self.model_distribution,
- self.list_params,
- self.link,
- self._parent,
- )
-
- self.model = bmb.Model(
- self.formula,
- data=self.data,
- family=self.family,
- priors=self.priors, # center_predictors=False
- extra_namespace=self.additional_namespace,
+ lapse=lapse,
+ global_formula=global_formula,
+ link_settings=link_settings,
+ prior_settings=prior_settings,
+ extra_namespace=extra_namespace,
+ missing_data=missing_data,
+ deadline=deadline,
+ loglik_missing_data=loglik_missing_data,
+ process_initvals=process_initvals,
+ initval_jitter=initval_jitter,
**kwargs,
)
- self._aliases = _get_alias_dict(
- self.model, self._parent_param, self.response_c, self.response_str
- )
- self.set_alias(self._aliases)
- self.model.build()
-
- # Bambi >= 0.17 declares dims=("__obs__",) for intercept-only
- # deterministics that actually have shape (1,). This causes an
- # xarray CoordinateValidationError during pm.sample() when ArviZ
- # tries to create a DataArray with mismatched dimension sizes.
- # Fix by removing the dims declaration for these deterministics.
- self._fix_scalar_deterministic_dims()
-
- if process_initvals:
- self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
- if self.initval_jitter > 0:
- self._jitter_initvals(
- jitter_epsilon=self.initval_jitter,
- vector_only=True,
- )
-
- # Make sure we reset rvs_to_initial_values --> Only None's
- # Otherwise PyMC barks at us when asking to compute likelihoods
- self.pymc_model.rvs_to_initial_values.update(
- {key_: None for key_ in self.pymc_model.rvs_to_initial_values.keys()}
- )
- _logger.info("Model initialized successfully.")
-
- def _fix_scalar_deterministic_dims(self) -> None:
- """Fix dims metadata for scalar deterministics.
-
- Bambi >= 0.17 returns shape ``(1,)`` for intercept-only
- deterministics but still declares ``dims=("__obs__",)``. This causes
- an xarray ``CoordinateValidationError`` during ``pm.sample()`` because
- the ``__obs__`` coordinate has ``n_obs`` entries. Removing the dims
- declaration for these variables lets ArviZ handle them as
- un-dimensioned arrays, avoiding the conflict.
- """
- n_obs = len(self.data)
- dims_dict = self.pymc_model.named_vars_to_dims
- for det in self.pymc_model.deterministics:
- if det.name not in dims_dict:
- continue
- dims = dims_dict[det.name]
- if "__obs__" in dims:
- # Check static shape: if it doesn't match n_obs, remove dims
- try:
- shape_0 = det.type.shape[0]
- except (IndexError, TypeError):
- continue
- if shape_0 is not None and shape_0 != n_obs:
- del dims_dict[det.name]
-
- def _validate_fixed_vectors(self) -> None:
- """Validate that fixed-vector parameters have the correct length.
-
- Fixed-vector parameters (``prior=np.ndarray``) bypass Bambi's formula
- system entirely --- they are passed as a scalar ``0.0`` placeholder to
- Bambi, and the real vector is substituted inside
- ``HSSMDistribution.logp()`` (see ``dist.py``). Because this
- substitution is invisible to Bambi, we must validate the vector length
- against ``len(self.data)`` up front to catch shape mismatches early.
- """
- for name, param in self.params.items():
- if isinstance(param.prior, np.ndarray):
- if len(param.prior) != len(self.data):
- raise ValueError(
- f"Fixed vector for parameter '{name}' has length "
- f"{len(param.prior)}, but data has {len(self.data)} rows."
- )
-
- @classproperty
- def supported_models(cls) -> tuple[SupportedModels, ...]:
- """Get a tuple of all supported models.
-
- Returns
- -------
- tuple[SupportedModels, ...]
- A tuple containing all supported model names.
- """
- return get_args(SupportedModels)
-
- @classmethod
- def _store_init_args(cls, *args, **kwargs):
- """Store initialization arguments using signature binding."""
- sig = signature(cls.__init__)
- bound_args = sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
- return {k: v for k, v in bound_args.arguments.items() if k != "self"}
-
- def find_MAP(self, **kwargs):
- """Perform Maximum A Posteriori estimation.
-
- Returns
- -------
- dict
- A dictionary containing the MAP estimates of the model parameters.
- """
- self._map_dict = pm.find_MAP(model=self.pymc_model, **kwargs)
- return self._map_dict
-
- def sample(
- self,
- sampler: Literal["pymc", "numpyro", "blackjax", "nutpie", "laplace"]
- | None = None,
- init: str | None = None,
- initvals: str | dict | None = None,
- include_response_params: bool = False,
- **kwargs,
- ) -> az.InferenceData | pm.Approximation:
- """Perform sampling using the `fit` method via bambi.Model.
-
- Parameters
- ----------
- sampler: optional
- The sampler to use. Can be one of "pymc", "numpyro",
- "blackjax", "nutpie", or "laplace". If using `blackbox` likelihoods,
- this cannot be "numpyro", "blackjax", or "nutpie". By default it is None,
- and sampler will automatically be chosen: when the model uses the
- `approx_differentiable` likelihood, and `jax` backend, "numpyro" will
- be used. Otherwise, "pymc" (the default PyMC NUTS sampler) will be used.
-
- Note that the old sampler names such as "mcmc", "nuts_numpyro",
- "nuts_blackjax" will be deprecated and removed in future releases. A warning
- will be raised if any of these old names are used.
- init: optional
- Initialization method to use for the sampler. If any of the NUTS samplers
- is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`.
- initvals: optional
- Pass initial values to the sampler. This can be a dictionary of initial
- values for parameters of the model, or a string "map" to use initialization
- at the MAP estimate. If "map" is used, the MAP estimate will be computed if
- not already attached to the base class from prior call to 'find_MAP'.
- include_response_params: optional
- Include parameters of the response distribution in the output. These usually
- take more space than other parameters as there's one of them per
- observation. Defaults to False.
- kwargs
- Other arguments passed to bmb.Model.fit(). Please see [here]
- (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
- for full documentation.
-
- Returns
- -------
- az.InferenceData | pm.Approximation
- A reference to the `model.traces` object, which stores the traces of the
- last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData`
- instance if `sampler` is `"pymc"` (default), `"numpyro"`,
- `"blackjax"` or "`laplace".
- """
- # If initvals are None (default)
- # we skip processing initvals here.
- if sampler in _new_sampler_mapping:
- _logger.warning(
- f"Sampler '{sampler}' is deprecated. "
- "Please use the new sampler names: "
- "'pymc', 'numpyro', 'blackjax', 'nutpie', or 'laplace'."
- )
- sampler = _new_sampler_mapping[sampler] # type: ignore
-
- if sampler == "vi":
- raise ValueError(
- "VI is not supported via the sample() method. "
- "Please use the vi() method instead."
- )
-
- if initvals is not None:
- if isinstance(initvals, dict):
- kwargs["initvals"] = initvals
- else:
- if isinstance(initvals, str):
- if initvals == "map":
- if self._map_dict is None:
- _logger.info(
- "initvals='map' but no map"
- "estimate precomputed. \n"
- "Running map estimation first..."
- )
- self.find_MAP()
- kwargs["initvals"] = self._map_dict
- else:
- kwargs["initvals"] = self._map_dict
- else:
- raise ValueError(
- "initvals argument must be a dictionary or 'map'"
- " to use the MAP estimate."
- )
- else:
- kwargs["initvals"] = self._initvals
- _logger.info("Using default initvals. \n")
-
- if sampler is None:
- if (
- self.loglik_kind == "approx_differentiable"
- and self.model_config.backend == "jax"
- ):
- sampler = "numpyro"
- else:
- sampler = "pymc"
-
- if self.loglik_kind == "blackbox":
- if sampler in ["blackjax", "numpyro", "nutpie"]:
- raise ValueError(
- f"{sampler} sampler does not work with blackbox likelihoods."
- )
-
- if "step" not in kwargs:
- kwargs |= {"step": pm.Slice(model=self.pymc_model)}
-
- if (
- self.loglik_kind == "approx_differentiable"
- and self.model_config.backend == "jax"
- and sampler == "pymc"
- and kwargs.get("cores", None) != 1
- ):
- _logger.warning(
- "Parallel sampling might not work with `jax` backend and the PyMC NUTS "
- + "sampler on some platforms. Please consider using `numpyro`, "
- + "`blackjax`, or `nutpie` sampler if that is a problem."
- )
-
- if self._check_extra_fields():
- self._update_extra_fields()
-
- if init is None:
- if sampler in ["pymc", "numpyro", "blackjax", "nutpie"]:
- init = "adapt_diag"
- else:
- init = "auto"
-
- # If sampler is finally `numpyro` make sure
- # the jitter argument is set to False
- if sampler == "numpyro":
- if "nuts_sampler_kwargs" in kwargs:
- if kwargs["nuts_sampler_kwargs"].get("jitter"):
- _logger.warning(
- "The jitter argument is set to True. "
- + "This argument is not supported "
- + "by the numpyro backend. "
- + "The jitter argument will be set to False."
- )
- kwargs["nuts_sampler_kwargs"]["jitter"] = False
- else:
- kwargs["nuts_sampler_kwargs"] = {"jitter": False}
-
- if sampler != "pymc" and "step" in kwargs:
- raise ValueError(
- "`step` samplers (enabled by the `step` argument) are only supported "
- "by the `pymc` sampler."
- )
-
- if self._inference_obj is not None:
- _logger.warning(
- "The model has already been sampled. Overwriting the previous "
- + "inference object. Any previous reference to the inference object "
- + "will still point to the old object."
- )
-
- # Define whether likelihood should be computed
- compute_likelihood = True
- if "idata_kwargs" in kwargs:
- if "log_likelihood" in kwargs["idata_kwargs"]:
- compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True)
-
- omit_offsets = kwargs.pop("omit_offsets", False)
- self._inference_obj = self.model.fit(
- inference_method=sampler,
- init=init,
- include_response_params=include_response_params,
- omit_offsets=omit_offsets,
- **kwargs,
- )
-
- # Separate out log likelihood computation
- if compute_likelihood:
- self.log_likelihood(self._inference_obj, inplace=True)
-
- # Subset data vars in posterior
- self._clean_posterior_group(idata=self._inference_obj)
- return self.traces
-
- def vi(
- self,
- method: str = "advi",
- niter: int = 10000,
- draws: int = 1000,
- return_idata: bool = True,
- ignore_mcmc_start_point_defaults=False,
- **vi_kwargs,
- ) -> pm.Approximation | az.InferenceData:
- """Perform Variational Inference.
-
- Parameters
- ----------
- niter : int
- The number of iterations to run the VI algorithm. Defaults to 3000.
- method : str
- The method to use for VI. Can be one of "advi" or "fullrank_advi", "svgd",
- "asvgd".Defaults to "advi".
- draws : int
- The number of samples to draw from the posterior distribution.
- Defaults to 1000.
- return_idata : bool
- If True, returns an InferenceData object. Otherwise, returns the
- approximation object directly. Defaults to True.
-
- Returns
- -------
- pm.Approximation or az.InferenceData: The mean field approximation object.
- """
- if self.loglik_kind == "analytical":
- _logger.warning(
- "VI is not recommended for the analytical likelihood,"
- " since gradients can be brittle."
- )
- elif self.loglik_kind == "blackbox":
- raise ValueError(
- "VI is not supported for blackbox likelihoods, "
- " since likelihood gradients are needed!"
- )
-
- if ("start" not in vi_kwargs) and not ignore_mcmc_start_point_defaults:
- _logger.info("Using MCMC starting point defaults.")
- vi_kwargs["start"] = self._initvals
-
- # Run variational inference directly from pymc model
- with self.pymc_model:
- self._vi_approx = pm.fit(n=niter, method=method, **vi_kwargs)
-
- # Sample from the approximate posterior
- if self._vi_approx is not None:
- self._inference_obj_vi = self._vi_approx.sample(draws)
-
- # Post-processing
- self._clean_posterior_group(idata=self._inference_obj_vi)
-
- # Return the InferenceData object if return_idata is True
- if return_idata:
- return self._inference_obj_vi
- # Otherwise return the appromation object directly
- return self.vi_approx
-
- def _clean_posterior_group(self, idata: az.InferenceData | None = None):
- """Clean up the posterior group of the InferenceData object.
-
- Parameters
- ----------
- idata : az.InferenceData
- The InferenceData object to clean up. If None, the last InferenceData object
- will be used.
- """
- # # Logic behind which variables to keep:
- # # We essentially want to get rid of
- # # all the trial-wise variables.
-
- # # We drop all distributional components, IF they are deterministics
- # # (in which case they will be trial wise systematically)
- # # and we keep distributional components, IF they are
- # # basic random-variabels (in which case they should never
- # # appear trial-wise).
- if idata is None:
- raise ValueError(
- "The InferenceData object is None. Cannot clean up the posterior group."
- )
- elif not hasattr(idata, "posterior"):
- raise ValueError(
- "The InferenceData object does not have a posterior group. "
- + "Cannot clean up the posterior group."
- )
-
- vars_to_keep = set(idata["posterior"].data_vars.keys()).difference(
- set(
- key_
- for key_ in self.model.distributional_components.keys()
- if key_ in [var_.name for var_ in self.pymc_model.deterministics]
- )
- )
- vars_to_keep_clean = [
- var_
- for var_ in vars_to_keep
- if isinstance(var_, str) and "_mean" not in var_
- ]
-
- setattr(
- idata,
- "posterior",
- idata["posterior"][vars_to_keep_clean],
- )
-
- def log_likelihood(
- self,
- idata: az.InferenceData | None = None,
- data: pd.DataFrame | None = None,
- inplace: bool = True,
- keep_likelihood_params: bool = False,
- ) -> az.InferenceData | None:
- """Compute the log likelihood of the model.
-
- Parameters
- ----------
- idata : optional
- The `InferenceData` object returned by `HSSM.sample()`. If not provided,
- data : optional
- A pandas DataFrame with values for the predictors that are used to obtain
- out-of-sample predictions. If omitted, the original dataset is used.
- inplace : optional
- If `True` will modify idata in-place and append a `log_likelihood` group to
- `idata`. Otherwise, it will return a copy of idata with the predictions
- added, by default True.
- keep_likelihood_params : optional
- If `True`, the trial wise likelihood parameters that are computed
- on route to getting the log likelihood are kept in the `idata` object.
- Defaults to False. See also the method `add_likelihood_parameters_to_idata`.
-
- Returns
- -------
- az.InferenceData | None
- InferenceData or None
- """
- if self._inference_obj is None and idata is None:
- raise ValueError(
- "Neither has the model been sampled yet nor"
- + " an idata object has been provided."
- )
-
- if idata is None:
- if self._inference_obj is None:
- raise ValueError(
- "The model has not been sampled yet. "
- + "Please provide an idata object."
- )
- else:
- idata = self._inference_obj
-
- # Actual likelihood computation
- idata = _compute_log_likelihood(self.model, idata, data, inplace)
-
- # clean up posterior:
- if not keep_likelihood_params:
- self._clean_posterior_group(idata=idata)
-
- if inplace:
- return None
- else:
- return idata
-
- def add_likelihood_parameters_to_idata(
- self,
- idata: az.InferenceData | None = None,
- inplace: bool = False,
- ) -> az.InferenceData | None:
- """Add likelihood parameters to the InferenceData object.
-
- Parameters
- ----------
- idata : az.InferenceData
- The InferenceData object returned by HSSM.sample().
- inplace : bool
- If True, the likelihood parameters are added to idata in-place. Otherwise,
- a copy of idata with the likelihood parameters added is returned.
- Defaults to False.
-
- Returns
- -------
- az.InferenceData | None
- InferenceData or None
- """
- if idata is None:
- if self._inference_obj is None:
- raise ValueError("No idata provided and model not yet sampled!")
- else:
- idata = self.model._compute_likelihood_params( # pylint: disable=protected-access
- deepcopy(self._inference_obj)
- if not inplace
- else self._inference_obj
- )
- else:
- idata = self.model._compute_likelihood_params( # pylint: disable=protected-access
- deepcopy(idata) if not inplace else idata
- )
- return idata
-
- def sample_posterior_predictive(
- self,
- idata: az.InferenceData | None = None,
- data: pd.DataFrame | None = None,
- inplace: bool = True,
- include_group_specific: bool = True,
- kind: Literal["response", "response_params"] = "response",
- draws: int | float | list[int] | np.ndarray | None = None,
- safe_mode: bool = True,
- ) -> az.InferenceData | None:
- """Perform posterior predictive sampling from the HSSM model.
-
- Parameters
- ----------
- idata : optional
- The `InferenceData` object returned by `HSSM.sample()`. If not provided,
- the `InferenceData` from the last time `sample()` is called will be used.
- data : optional
- An optional data frame with values for the predictors that are used to
- obtain out-of-sample predictions. If omitted, the original dataset is used.
- inplace : optional
- If `True` will modify idata in-place and append a `posterior_predictive`
- group to `idata`. Otherwise, it will return a copy of idata with the
- predictions added, by default True.
- include_group_specific : optional
- If `True` will make predictions including the group specific effects.
- Otherwise, predictions are made with common effects only (i.e. group-
- specific are set to zero), by default True.
- kind: optional
- Indicates the type of prediction required. Can be `"response_params"` or
- `"response"`. The first returns draws from the posterior distribution of the
- likelihood parameters, while the latter returns the draws from the posterior
- predictive distribution (i.e. the posterior probability distribution for a
- new observation) in addition to the posterior distribution. Defaults to
- "response_params".
- draws: optional
- The number of samples to draw from the posterior predictive distribution
- from each chain.
- When it's an integer >= 1, the number of samples to be extracted from the
- `draw` dimension. If this integer is larger than the number of posterior
- samples in each chain, all posterior samples will be used
- in posterior predictive sampling. When a float between 0 and 1, the
- proportion of samples from the draw dimension from each chain to be used in
- posterior predictive sampling.. If this proportion is very
- small, at least one sample will be used. When None, all posterior samples
- will be used. Defaults to None.
- safe_mode: bool
- If True, the function will split the draws into chunks of 10 to avoid memory
- issues. Defaults to True.
-
- Raises
- ------
- ValueError
- If the model has not been sampled yet and idata is not provided.
-
- Returns
- -------
- az.InferenceData | None
- InferenceData or None
- """
- if idata is None:
- if self._inference_obj is None:
- raise ValueError(
- "The model has not been sampled yet. "
- + "Please either provide an idata object or sample the model first."
- )
- idata = self._inference_obj
- _logger.info(
- "idata=None, we use the traces assigned to the HSSM object as idata."
- )
-
- if idata is not None:
- if "posterior_predictive" in idata.groups():
- del idata["posterior_predictive"]
- _logger.warning(
- "pre-existing posterior_predictive group deleted from idata. \n"
- )
-
- if self._check_extra_fields(data):
- self._update_extra_fields(data)
-
- if isinstance(draws, np.ndarray):
- draws = draws.astype(int)
- elif isinstance(draws, list):
- draws = np.array(draws).astype(int)
- elif isinstance(draws, int | float):
- draws = np.arange(int(draws))
- elif draws is None:
- draws = idata["posterior"].draw.values
- else:
- raise ValueError(
- "draws must be an integer, " + "a list of integers, or a numpy array."
- )
-
- assert isinstance(draws, np.ndarray)
-
- # Make a copy of idata, set the `posterior` group to be a random sub-sample
- # of the original (draw dimension gets sub-sampled)
-
- idata_copy = idata.copy()
-
- if (draws.shape != idata["posterior"].draw.values.shape) or (
- (draws.shape == idata["posterior"].draw.values.shape)
- and not np.allclose(draws, idata["posterior"].draw.values)
- ):
- # Reassign posterior to sub-sampled version
- setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws))
-
- if kind == "response":
- # If we run kind == 'response' we actually run the observation RV
- if safe_mode:
- # safe mode splits the draws into chunks of 10 to avoid
- # memory issues (TODO: Figure out the source of memory issues)
- split_draws = _split_array(
- idata_copy["posterior"].draw.values, divisor=10
- )
-
- posterior_predictive_list = []
- for samples_tmp in split_draws:
- tmp_posterior = idata["posterior"].sel(draw=samples_tmp)
- setattr(idata_copy, "posterior", tmp_posterior)
- self.model.predict(
- idata_copy, kind, data, True, include_group_specific
- )
- posterior_predictive_list.append(idata_copy["posterior_predictive"])
-
- if inplace:
- idata.add_groups(
- posterior_predictive=xr.concat(
- posterior_predictive_list, dim="draw"
- )
- )
- # for inplace, we don't return anything
- return None
- else:
- # Reassign original posterior to idata_copy
- setattr(idata_copy, "posterior", idata["posterior"])
- # Add new posterior predictive group to idata_copy
- del idata_copy["posterior_predictive"]
- idata_copy.add_groups(
- posterior_predictive=xr.concat(
- posterior_predictive_list, dim="draw"
- )
- )
- return idata_copy
- else:
- if inplace:
- # If not safe-mode
- # We call .predict() directly without any
- # chunking of data.
-
- # .predict() is called on the copy of idata
- # since we still subsampled (or assigned) the draws
- self.model.predict(
- idata_copy, kind, data, True, include_group_specific
- )
-
- # posterior predictive group added to idata
- idata.add_groups(
- posterior_predictive=idata_copy["posterior_predictive"]
- )
- # don't return anything if inplace
- return None
- else:
- # Not safe mode and not inplace
- # Function acts as very thin wrapper around
- # .predict(). It just operates on the
- # idata_copy object
- return self.model.predict(
- idata_copy, kind, data, False, include_group_specific
- )
- elif kind == "response_params":
- # If kind == 'response_params', we don't need to run the RV directly,
- # there shouldn't really be any significant memory issues here,
- # we can simply ignore settings, since the computational overhead
- # should be very small --> nudges user towards good outputs.
- _logger.warning(
- "The kind argument is set to 'mean', but 'draws' argument "
- + "is not None: The draws argument will be ignored!"
- )
- return self.model.predict(
- idata, kind, data, inplace, include_group_specific
- )
- else:
- raise ValueError("`kind` must be either 'response' or 'response_params'.")
-
- def plot_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
- """Produce a posterior predictive plot.
-
- Equivalent to calling `hssm.plotting.plot_predictive()` with the
- model. Please see that function for
- [full documentation][hssm.plotting.plot_predictive].
-
- Returns
- -------
- mpl.axes.Axes | sns.FacetGrid
- The matplotlib axis or seaborn FacetGrid object containing the plot.
- """
- return plotting.plot_predictive(self, **kwargs)
-
- def plot_quantile_probability(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
- """Produce a quantile probability plot.
-
- Equivalent to calling `hssm.plotting.plot_quantile_probability()` with the
- model. Please see that function for
- [full documentation][hssm.plotting.plot_quantile_probability].
-
- Returns
- -------
- mpl.axes.Axes | sns.FacetGrid
- The matplotlib axis or seaborn FacetGrid object containing the plot.
- """
- return plotting.plot_quantile_probability(self, **kwargs)
-
- def predict(self, **kwargs) -> az.InferenceData:
- """Generate samples from the predictive distribution."""
- return self.model.predict(**kwargs)
-
- def sample_do(
- self, params: dict[str, Any], draws: int = 100, return_model=False, **kwargs
- ) -> az.InferenceData | tuple[az.InferenceData, pm.Model]:
- """Generate samples from the predictive distribution using the `do-operator`."""
- do_model = do(self.pymc_model, params)
- do_idata = pm.sample_prior_predictive(model=do_model, draws=draws, **kwargs)
-
- # clean up `rt,response_mean` to `v`
- do_idata = self._drop_parent_str_from_idata(idata=do_idata)
-
- # rename otherwise inconsistent dims and coords
- if "rt,response_extra_dim_0" in do_idata["prior_predictive"].dims:
- setattr(
- do_idata,
- "prior_predictive",
- do_idata["prior_predictive"].rename_dims(
- {"rt,response_extra_dim_0": "rt,response_dim"}
- ),
- )
- if "rt,response_extra_dim_0" in do_idata["prior_predictive"].coords:
- setattr(
- do_idata,
- "prior_predictive",
- do_idata["prior_predictive"].rename_vars(
- name_dict={"rt,response_extra_dim_0": "rt,response_dim"}
- ),
- )
-
- if return_model:
- return do_idata, do_model
- return do_idata
-
- def sample_prior_predictive(
- self,
- draws: int = 500,
- var_names: str | list[str] | None = None,
- omit_offsets: bool = True,
- random_seed: np.random.Generator | None = None,
- ) -> az.InferenceData:
- """Generate samples from the prior predictive distribution.
-
- Parameters
- ----------
- draws
- Number of draws to sample from the prior predictive distribution. Defaults
- to 500.
- var_names
- A list of names of variables for which to compute the prior predictive
- distribution. Defaults to ``None`` which means both observed and unobserved
- RVs.
- omit_offsets
- Whether to omit offset terms. Defaults to ``True``.
- random_seed
- Seed for the random number generator.
-
- Returns
- -------
- az.InferenceData
- ``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and
- ``observed_data``.
- """
- prior_predictive = self.model.prior_predictive(
- draws, var_names, omit_offsets, random_seed
- )
-
- # AF-COMMENT: Not sure if necessary to include the
- # mean prior here (which adds deterministics that
- # could be recomputed elsewhere)
- prior_predictive.add_groups(posterior=prior_predictive.prior)
- # Bambi >= 0.17 renamed kind="mean" to kind="response_params".
- self.model.predict(prior_predictive, kind="response_params", inplace=True)
-
- # clean
- setattr(prior_predictive, "prior", prior_predictive["posterior"])
- del prior_predictive["posterior"]
-
- if self._inference_obj is None:
- self._inference_obj = prior_predictive
- else:
- self._inference_obj.extend(prior_predictive)
-
- # clean up `rt,response_mean` to `v`
- idata = self._drop_parent_str_from_idata(idata=self._inference_obj)
-
- # rename otherwise inconsistent dims and coords
- if "rt,response_extra_dim_0" in idata["prior_predictive"].dims:
- setattr(
- idata,
- "prior_predictive",
- idata["prior_predictive"].rename_dims(
- {"rt,response_extra_dim_0": "rt,response_dim"}
- ),
- )
- if "rt,response_extra_dim_0" in idata["prior_predictive"].coords:
- setattr(
- idata,
- "prior_predictive",
- idata["prior_predictive"].rename_vars(
- name_dict={"rt,response_extra_dim_0": "rt,response_dim"}
- ),
- )
-
- # Update self._inference_obj to match the cleaned idata
- self._inference_obj = idata
- return deepcopy(self._inference_obj)
-
- @property
- def pymc_model(self) -> pm.Model:
- """Provide access to the PyMC model.
-
- Returns
- -------
- pm.Model
- The PyMC model built by bambi
- """
- return self.model.backend.model
-
- def set_alias(self, aliases: dict[str, str | dict]):
- """Set parameter aliases.
-
- Sets the aliases according to the dictionary passed to it and rebuild the
- model.
-
- Parameters
- ----------
- aliases
- A dict specifying the parameter names being aliased and the aliases.
- """
- self.model.set_alias(aliases)
- self.model.build()
-
- # TODO: update this to HSSMBase
- @property
- def response_c(self) -> str:
- """Return the response variable names in c() format.
-
- New in 0.2.12: when model is choice-only and has deadline, the response
- is not in the form of c(...).
- """
- if self.response is None:
- raise ValueError("Response is not defined.")
- if self.is_choice_only and not self.deadline:
- return self.response[0]
- return f"c({', '.join(self.response)})"
-
- @property
- def response_str(self) -> str:
- """Return the response variable names in string format."""
- if self.response is None:
- return ""
- return ",".join(self.response)
-
- # NOTE: can't annotate return type because the graphviz dependency is optional
- def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"):
- """Produce a graphviz Digraph from a built HSSM model.
-
- Requires graphviz, which may be installed most easily with `conda install -c
- conda-forge python-graphviz`. Alternatively, you may install the `graphviz`
- binaries yourself, and then `pip install graphviz` to get the python bindings.
- See http://graphviz.readthedocs.io/en/stable/manual.html for more information.
-
- Parameters
- ----------
- formatting
- One of `"plain"` or `"plain_with_params"`. Defaults to `"plain"`.
- name
- Name of the figure to save. Defaults to `None`, no figure is saved.
- figsize
- Maximum width and height of figure in inches. Defaults to `None`, the
- figure size is set automatically. If defined and the drawing is larger than
- the given size, the drawing is uniformly scaled down so that it fits within
- the given size. Only works if `name` is not `None`.
- dpi
- Point per inch of the figure to save.
- Defaults to 300. Only works if `name` is not `None`.
- fmt
- Format of the figure to save.
- Defaults to `"png"`. Only works if `name` is not `None`.
-
- Returns
- -------
- graphviz.Graph
- The graph
- """
- graph = self.model.graph(formatting, name, figsize, dpi, fmt)
-
- parent_param = self._parent_param
- if parent_param.is_regression:
- return graph
-
- # Modify the graph
- # 1. Remove all nodes and edges related to `{parent}_mean`:
- graph.body = [
- item for item in graph.body if f"{parent_param.name}_mean" not in item
- ]
- # 2. Add a new edge from parent to response
- graph.edge(parent_param.name, self.response_str)
-
- return graph
-
- def compile_logp(self, keep_transformed: bool = False, **kwargs):
- """Compile the log probability function for the model.
-
- Parameters
- ----------
- keep_transformed : bool, optional
- If True, keeps the transformed variables in the compiled function.
- If False, removes value transforms before compilation.
- Defaults to False.
- **kwargs
- Additional keyword arguments passed to PyMC's compile_logp:
- - vars: List of variables. Defaults to None (all variables).
- - jacobian: Whether to include log(|det(dP/dQ)|) term for
- transformed variables. Defaults to True.
- - sum: Whether to sum all terms instead of returning a vector.
- Defaults to True.
-
- Returns
- -------
- callable
- A compiled function that computes the model log probability.
- """
- if keep_transformed:
- return self.pymc_model.compile_logp(
- vars=kwargs.get("vars", None),
- jacobian=kwargs.get("jacobian", True),
- sum=kwargs.get("sum", True),
- )
- else:
- new_model = pm.model.transform.conditioning.remove_value_transforms(
- self.pymc_model
- )
- return new_model.compile_logp(
- vars=kwargs.get("vars", None),
- jacobian=kwargs.get("jacobian", True),
- sum=kwargs.get("sum", True),
- )
-
- def plot_trace(
- self,
- data: az.InferenceData | None = None,
- include_deterministic: bool = False,
- tight_layout: bool = True,
- **kwargs,
- ) -> None:
- """Generate trace plot with ArviZ but with additional convenience features.
-
- This is a simple wrapper for the az.plot_trace() function. By default, it
- filters out the deterministic values from the plot. Please see the
- [arviz documentation]
- (https://arviz-devs.github.io/arviz/api/generated/arviz.plot_trace.html)
- for additional parameters that can be specified.
-
- Parameters
- ----------
- data : optional
- An ArviZ InferenceData object. If None, the traces stored in the model will
- be used.
- include_deterministic : optional
- Whether to include deterministic variables in the plot. Defaults to False.
- Note that if include deterministic is set to False and and `var_names` is
- provided, the `var_names` provided will be modified to also exclude the
- deterministic values. If this is not desirable, set
- `include deterministic` to True.
- tight_layout : optional
- Whether to call plt.tight_layout() after plotting. Defaults to True.
- """
- data = data or self.traces
- if not isinstance(data, az.InferenceData):
- raise TypeError("data must be an InferenceData object.")
-
- if not include_deterministic:
- var_names = list(
- set([var.name for var in self.pymc_model.free_RVs]).intersection(
- set(list(data["posterior"].data_vars.keys()))
- )
- )
- # var_names = self._get_deterministic_var_names(data)
- if var_names:
- if "var_names" in kwargs:
- if isinstance(kwargs["var_names"], str):
- if kwargs["var_names"] not in var_names:
- var_names.append(kwargs["var_names"])
- kwargs["var_names"] = var_names
- elif isinstance(kwargs["var_names"], list):
- kwargs["var_names"] = list(
- set(var_names) | set(kwargs["var_names"])
- )
- elif kwargs["var_names"] is None:
- kwargs["var_names"] = var_names
- else:
- raise ValueError(
- "`var_names` must be a string, a list of strings, or None."
- )
- else:
- kwargs["var_names"] = var_names
- az.plot_trace(data, **kwargs)
-
- if tight_layout:
- plt.tight_layout()
-
- def summary(
- self,
- data: az.InferenceData | None = None,
- include_deterministic: bool = False,
- **kwargs,
- ) -> pd.DataFrame | xr.Dataset:
- """Produce a summary table with ArviZ but with additional convenience features.
-
- This is a simple wrapper for the az.summary() function. By default, it
- filters out the deterministic values from the plot. Please see the
- [arviz documentation]
- (https://arviz-devs.github.io/arviz/api/generated/arviz.summary.html)
- for additional parameters that can be specified.
-
- Parameters
- ----------
- data
- An ArviZ InferenceData object. If None, the traces stored in the model will
- be used.
- include_deterministic : optional
- Whether to include deterministic variables in the plot. Defaults to False.
- Note that if include_deterministic is set to False and and `var_names` is
- provided, the `var_names` provided will be modified to also exclude the
- deterministic values. If this is not desirable, set
- `include_deterministic` to True.
-
- Returns
- -------
- pd.DataFrame | xr.Dataset
- A pandas DataFrame or xarray Dataset containing the summary statistics.
- """
- data = data or self.traces
- if not isinstance(data, az.InferenceData):
- raise TypeError("data must be an InferenceData object.")
-
- if not include_deterministic:
- var_names = list(
- set([var.name for var in self.pymc_model.free_RVs]).intersection(
- set(list(data["posterior"].data_vars.keys()))
- )
- )
- # var_names = self._get_deterministic_var_names(data)
- if var_names:
- kwargs["var_names"] = list(set(var_names + kwargs.get("var_names", [])))
- return az.summary(data, **kwargs)
-
- def initial_point(self, transformed: bool = False) -> dict[str, np.ndarray]:
- """Compute the initial point of the model.
-
- This is a slightly altered version of pm.initial_point.initial_point().
-
- Parameters
- ----------
- transformed : bool, optional
- If True, return the initial point in transformed space.
-
- Returns
- -------
- dict
- A dictionary containing the initial point of the model parameters.
- """
- fn = pm.initial_point.make_initial_point_fn(
- model=self.pymc_model, return_transformed=transformed
- )
- return pm.model.Point(fn(None), model=self.pymc_model)
-
- def restore_traces(
- self, traces: az.InferenceData | pm.Approximation | str | PathLike
- ) -> None:
- """Restore traces from an InferenceData object or a .netcdf file.
-
- Parameters
- ----------
- traces
- An InferenceData object or a path to a file containing the traces.
- """
- if isinstance(traces, pm.Approximation):
- self._inference_obj_vi = traces
- return
-
- if isinstance(traces, (str, PathLike)):
- traces = az.from_netcdf(traces)
- self._inference_obj = cast("az.InferenceData", traces)
-
- def restore_vi_traces(
- self, traces: az.InferenceData | pm.Approximation | str | PathLike
- ) -> None:
- """Restore VI traces from an InferenceData object or a .netcdf file.
-
- Parameters
- ----------
- traces
- An InferenceData object or a path to a file containing the VI traces.
- """
- if isinstance(traces, pm.Approximation):
- self._inference_obj_vi = traces
- return
-
- if isinstance(traces, (str, PathLike)):
- traces = az.from_netcdf(traces)
- self._inference_obj_vi = cast("az.InferenceData", traces)
-
- def save_model(
- self,
- model_name: str | None = None,
- allow_absolute_base_path: bool = False,
- base_path: str | Path = "hssm_models",
- save_idata_only: bool = False,
- ) -> None:
- """Save a HSSM model instance and its inference results to disk.
-
- Parameters
- ----------
- model_name : str | None
- Name to use for the saved model files.
- If None, will use model.model_name with timestamp
- allow_absolute_base_path : bool
- Whether to allow absolute paths for base_path.
- Defaults to False for safety.
- base_path : str | Path
- Base directory to save model files in.
- Must be relative path if allow_absolute_base_path=False.
- Defaults to "hssm_models".
- save_idata_only : bool
- If True, only saves inference data (traces), not the model pickle.
- Defaults to False (saves both model and traces).
-
- Raises
- ------
- ValueError
- If base_path is absolute and allow_absolute_base_path=False
- """
- # Convert to Path object for cross-platform compatibility
- base_path = Path(base_path)
-
- # Check if base_path is absolute (works on all platforms)
- if not allow_absolute_base_path and base_path.is_absolute():
- raise ValueError(
- "base_path must be a relative path if allow_absolute_base_path is False"
- )
-
- if model_name is None:
- # Get date string format as suffix to model name
- timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
- model_name = f"{self.model_name}_{timestamp}"
-
- # Sanitize model_name and construct full path
- model_name = model_name.replace(" ", "_")
- model_path = base_path / model_name
- model_path.mkdir(parents=True, exist_ok=True)
-
- # Save model to pickle file
- if not save_idata_only:
- with open(model_path.joinpath("model.pkl"), "wb") as f:
- cpickle.dump(self, f)
-
- # Save traces to netcdf file
- if self._inference_obj is not None:
- az.to_netcdf(self._inference_obj, model_path.joinpath("traces.nc"))
-
- # Save vi_traces to netcdf file
- if self._inference_obj_vi is not None:
- az.to_netcdf(self._inference_obj_vi, model_path.joinpath("vi_traces.nc"))
-
- @classmethod
- def load_model(
- cls, path: Union[str, Path]
- ) -> Union["HSSM", dict[str, Optional[az.InferenceData]]]:
- """Load a HSSM model instance and its inference results from disk.
-
- Parameters
- ----------
- path : str | Path
- Path to the model directory or model.pkl file. If a directory is provided,
- will look for model.pkl, traces.nc and vi_traces.nc files within it.
-
- Returns
- -------
- HSSM
- The loaded HSSM model instance with inference results attached if available.
- """
- # Convert path to Path object
- path = Path(path)
-
- # If path points to a file, assume it's model.pkl
- if path.is_file():
- model_dir = path.parent
- model_path = path
- else:
- # Path points to directory
- model_dir = path
- model_path = model_dir.joinpath("model.pkl")
-
- # check if model_dir exists
- if not model_dir.exists():
- raise FileNotFoundError(f"Model directory {model_dir} does not exist.")
-
- # check if model.pkl exists raise logging information if not
- if not model_path.exists():
- _logger.info(
- f"model.pkl file does not exist in {model_dir}. "
- "Attempting to load traces only."
- )
- if (not model_dir.joinpath("traces.nc").exists()) and (
- not model_dir.joinpath("vi_traces.nc").exists()
- ):
- raise FileNotFoundError(f"No traces found in {model_dir}.")
- else:
- idata_dict = cls.load_model_idata(model_dir)
- return idata_dict
- else:
- # Load model from pickle file
- with open(model_path, "rb") as f:
- model = cpickle.load(f)
-
- # Load traces if they exist
- traces_path = model_dir.joinpath("traces.nc")
- if traces_path.exists():
- model.restore_traces(traces_path)
-
- # Load VI traces if they exist
- vi_traces_path = model_dir.joinpath("vi_traces.nc")
- if vi_traces_path.exists():
- model.restore_vi_traces(vi_traces_path)
- return model
-
- @classmethod
- def load_model_idata(cls, path: str | Path) -> dict[str, az.InferenceData | None]:
- """Load the traces from a model directory.
-
- Parameters
- ----------
- path : str | Path
- Path to the model directory containing traces.nc and/or vi_traces.nc files.
-
- Returns
- -------
- dict[str, az.InferenceData | None]
- A dictionary with keys "idata_mcmc" and "idata_vi" containing the traces
- from the model directory. If the traces do not exist, the corresponding
- value will be None.
- """
- idata_dict: dict[str, az.InferenceData | None] = {}
- model_dir = Path(path)
- # check if path exists
- if not model_dir.exists():
- raise FileNotFoundError(f"Model directory {model_dir} does not exist.")
-
- # check if traces.nc exists
- traces_path = model_dir.joinpath("traces.nc")
- if not traces_path.exists():
- _logger.warning(f"traces.nc file does not exist in {model_dir}.")
- idata_dict["idata_mcmc"] = None
- else:
- idata_dict["idata_mcmc"] = az.from_netcdf(traces_path)
-
- # check if vi_traces.nc exists
- vi_traces_path = model_dir.joinpath("vi_traces.nc")
- if not vi_traces_path.exists():
- _logger.warning(f"vi_traces.nc file does not exist in {model_dir}.")
- idata_dict["idata_vi"] = None
- else:
- idata_dict["idata_vi"] = az.from_netcdf(vi_traces_path)
-
- return idata_dict
-
- def __getstate__(self):
- """Get the state of the model for pickling.
-
- This method is called when pickling the model.
- It returns a dictionary containing the constructor
- arguments needed to recreate the model instance.
-
- Returns
- -------
- dict
- A dictionary containing the constructor arguments
- under the key 'constructor_args'.
- """
- state = {"constructor_args": self._init_args}
- return state
-
- def __setstate__(self, state):
- """Set the state of the model when unpickling.
-
- This method is called when unpickling the model. It creates a new instance
- of HSSM using the constructor arguments stored in the state dictionary,
- and copies its attributes to the current instance.
-
- Parameters
- ----------
- state : dict
- A dictionary containing the constructor arguments under the key
- 'constructor_args'.
- """
- new_instance = HSSM(**state["constructor_args"])
- self.__dict__ = new_instance.__dict__
-
- def __repr__(self) -> str:
- """Create a representation of the model."""
- output = [
- "Hierarchical Sequential Sampling Model",
- f"Model: {self.model_name}\n",
- f"Response variable: {self.response_str}",
- f"Likelihood: {self.loglik_kind}",
- f"Observations: {len(self.data)}\n",
- "Parameters:\n",
- ]
-
- for param in self.params.values():
- if param.name == "p_outlier":
- continue
- output.append(f"{param.name}:")
-
- component = self.model.components[param.name]
-
- # Regression case:
- if param.is_regression:
- assert isinstance(component, DistributionalComponent)
- output.append(f" Formula: {param.formula}")
- output.append(" Priors:")
- intercept_term = component.intercept_term
- if intercept_term is not None:
- output.append(_print_prior(intercept_term))
- for _, common_term in component.common_terms.items():
- output.append(_print_prior(common_term))
- for _, group_specific_term in component.group_specific_terms.items():
- output.append(_print_prior(group_specific_term))
- output.append(f" Link: {param.link}")
- # None regression case
- else:
- if param.prior is None:
- prior = (
- component.intercept_term.prior
- if param.is_parent
- else component.prior
- )
- else:
- prior = param.prior
- output.append(f" Prior: {prior}")
- output.append(f" Explicit bounds: {param.bounds}")
- output.append(
- " (ignored due to link function)"
- if self.link_settings is not None
- else ""
- )
-
- # TODO: Handle p_outlier regression correctly here.
- if self.p_outlier is not None:
- output.append("")
- output.append(f"Lapse probability: {self.p_outlier.prior}")
- output.append(f"Lapse distribution: {self.lapse}")
-
- return "\n".join(output)
-
- def __str__(self) -> str:
- """Create a string representation of the model."""
- return self.__repr__()
-
- @property
- def traces(self) -> az.InferenceData | pm.Approximation:
- """Return the trace of the model after sampling.
-
- Raises
- ------
- ValueError
- If the model has not been sampled yet.
-
- Returns
- -------
- az.InferenceData | pm.Approximation
- The trace of the model after the last call to `sample()`.
- """
- if not self._inference_obj:
- raise ValueError("Please sample the model first.")
-
- return self._inference_obj
-
- @property
- def vi_idata(self) -> az.InferenceData:
- """Return the variational inference approximation object.
-
- Raises
- ------
- ValueError
- If the model has not been sampled yet.
-
- Returns
- -------
- az.InferenceData
- The variational inference approximation object.
- """
- if not self._inference_obj_vi:
- raise ValueError(
- "Please run variational inference first, "
- "no variational posterior attached."
- )
-
- return self._inference_obj_vi
-
- @property
- def vi_approx(self) -> pm.Approximation:
- """Return the variational inference approximation object.
-
- Raises
- ------
- ValueError
- If the model has not been sampled yet.
-
- Returns
- -------
- pm.Approximation
- The variational inference approximation object.
- """
- if not self._vi_approx:
- raise ValueError(
- "Please run variational inference first, "
- "no variational approximation attached."
- )
-
- return self._vi_approx
-
- @property
- def map(self) -> dict:
- """Return the MAP estimates of the model parameters.
-
- Raises
- ------
- ValueError
- If the model has not been sampled yet.
-
- Returns
- -------
- dict
- A dictionary containing the MAP estimates of the model parameters.
- """
- if not self._map_dict:
- raise ValueError("Please compute map first.")
-
- return self._map_dict
-
- @property
- def initvals(self) -> dict:
- """Return the initial values of the model parameters for sampling.
-
- Returns
- -------
- dict
- A dictionary containing the initial values of the model parameters.
- This dict serves as the default for initial values, and can be passed
- directly to the `.sample()` function.
- """
- if self._initvals == {}:
- self._initvals = self.initial_point()
- return self._initvals
-
- def _check_lapse(self, lapse):
- """Determine if p_outlier and lapse is specified correctly."""
- # Basically, avoid situations where only one of them is specified.
- if self.list_params[-1] != "p_outlier":
- if "p_outlier" in self.list_params:
- raise ValueError(
- "Please do not include 'p_outlier' in `list_params`. "
- + "We automatically append it to `list_params` when `p_outlier` "
- + "parameter is not None"
- )
- if self.has_lapse:
- if lapse is None:
- if self.is_choice_only:
- self.lapse = 1 / self.n_choices
- else:
- self.lapse = bmb.Prior("Uniform", lower=0.0, upper=20.0)
- else:
- self.lapse = lapse
-
- self.list_params.append("p_outlier")
- return
-
- if lapse is not None:
- _logger.warning(
- "You have specified the `lapse` argument to include a lapse "
- + "distribution, but `p_outlier` is set to either 0 or None. "
- + "Your lapse distribution will be ignored."
- )
- self.lapse = None
-
def _make_model_distribution(self) -> type[pm.Distribution]:
"""Make a pm.Distribution for the model."""
### Logic for different types of likelihoods:
@@ -2039,6 +321,15 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
if isclass(self.loglik) and issubclass(self.loglik, pm.Distribution):
return self.loglik
+ # At this point, loglik should not be a type[Distribution] and should be set
+
+ assert self.loglik is not None, "loglik should be set"
+ assert self.loglik_kind is not None, "loglik_kind should be set"
+ assert not (isclass(self.loglik) and issubclass(self.loglik, pm.Distribution))
+ loglik_callable = typing_cast(
+ "Op | Callable[..., Any] | PathLike | str", self.loglik
+ )
+
# params_is_trialwise_base: one entry per model param (excluding
# p_outlier). Used for graph-level broadcasting in logp() and
# make_distribution, where dist_params does not include extra_fields.
@@ -2059,20 +350,20 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
if self.loglik_kind == "approx_differentiable":
if self.model_config.backend == "jax":
likelihood_callable = make_likelihood_callable(
- loglik=self.loglik,
+ loglik=loglik_callable,
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=params_is_trialwise,
)
else:
likelihood_callable = make_likelihood_callable(
- loglik=self.loglik,
+ loglik=loglik_callable,
loglik_kind="approx_differentiable",
backend=self.model_config.backend,
)
else:
likelihood_callable = make_likelihood_callable(
- loglik=self.loglik,
+ loglik=loglik_callable,
loglik_kind=self.loglik_kind,
backend=self.model_config.backend,
)
@@ -2145,7 +436,7 @@ def dummy_simulator_func(*args, **kwargs):
self.model_config.rv = make_hssm_rv(
dummy_simulator_func,
- list_params=self.list_params,
+ list_params=self.list_params or [],
lapse=self.lapse,
is_choice_only=True,
)
@@ -2159,10 +450,15 @@ def dummy_simulator_func(*args, **kwargs):
if isinstance(param.prior, np.ndarray)
}
+ # Use the typed `model_config` attributes directly
+ _list_params = self.model_config.list_params
+ assert _list_params is not None, "list_params should be set" # for type checker
+ rv_name = getattr(self.model_config, "rv", None) or self.model_config.model_name
+
return make_distribution(
- rv=self.model_config.rv or self.model_name,
+ rv=rv_name,
loglik=self.loglik,
- list_params=self.list_params,
+ list_params=_list_params,
bounds=self.bounds,
lapse=self.lapse,
extra_fields=(
@@ -2175,255 +471,3 @@ def dummy_simulator_func(*args, **kwargs):
# TODO: add to HSSMBase
is_choice_only=self.is_choice_only,
)
-
- def _get_deterministic_var_names(self, idata) -> list[str]:
- """Filter out the deterministic variables in var_names."""
- var_names = [
- f"~{param_name}"
- for param_name, param in self.params.items()
- if (param.is_regression)
- ]
-
- if f"{self._parent}_mean" in idata["posterior"].data_vars:
- var_names.append(f"~{self._parent}_mean")
-
- # Parent parameters (always regression implicitly)
- # which don't have a formula attached
- # should be dropped from var_names, since the actual
- # parent name shows up as a regression.
- if f"{self._parent}" in idata["posterior"].data_vars:
- if self.params[self._parent].formula is None:
- # Drop from var_names
- var_names = [var for var in var_names if var != f"~{self._parent}"]
-
- return var_names
-
- def _drop_parent_str_from_idata(
- self, idata: az.InferenceData | None
- ) -> az.InferenceData:
- """Drop the parent_str variable from an InferenceData object.
-
- Parameters
- ----------
- idata
- The InferenceData object to be modified.
-
- Returns
- -------
- xr.Dataset
- The modified InferenceData object.
- """
- if idata is None:
- raise ValueError("Please provide an InferenceData object.")
- else:
- for group in idata.groups():
- if ("rt,response_mean" in idata[group].data_vars) and (
- self._parent not in idata[group].data_vars
- ):
- setattr(
- idata,
- group,
- idata[group].rename({"rt,response_mean": self._parent}),
- )
- return idata
-
- def _postprocess_initvals_deterministic(
- self, initval_settings: dict = INITVAL_SETTINGS
- ) -> None:
- """Set initial values for subset of parameters."""
- self._initvals = self.initial_point()
- # Consider case where link functions are set to 'log_logit'
- # or 'None'
- if self.link_settings not in ["log_logit", None]:
- _logger.info(
- "Not preprocessing initial values, "
- + "because none of the two standard link settings are chosen!"
- )
- return None
-
- # Set initial values for particular parameters
- for name_, starting_value in self.pymc_model.initial_point().items():
- # strip name of `_log__` and `_interval__` suffixes
- name_tmp = name_.replace("_log__", "").replace("_interval__", "")
-
- # We need to check if the parameter is actually backed by
- # a regression.
-
- # If not, we don't actually apply a link function to it as per default.
- # Therefore we need to apply the initial value strategy corresponding
- # to 'None' link function.
-
- # If the user actively supplies a link function, the user
- # should also have supplied an initial value insofar it matters.
-
- if self.params[self._get_prefix(name_tmp)].is_regression:
- param_link_setting = self.link_settings
- else:
- param_link_setting = None
- if name_tmp in initval_settings[param_link_setting].keys():
- if self._check_if_initval_user_supplied(name_tmp):
- _logger.info(
- "User supplied initial value detected for %s, \n"
- " skipping overwrite with default value.",
- name_tmp,
- )
- continue
-
- # Apply specific settings from initval_settings dictionary
- dtype = self._initvals[name_tmp].dtype
- self._initvals[name_tmp] = np.array(
- initval_settings[param_link_setting][name_tmp]
- ).astype(dtype)
-
- def _get_prefix(self, name_str: str) -> str:
- """Get parameters wise link setting function from parameter prefix."""
- # `p_outlier` is the only basic parameter floating around that has
- # an underscore in it's name.
- # We need to handle it separately. (Renaming might be better...)
- if "_" in name_str:
- if "p_outlier" not in name_str:
- name_str_prefix = name_str.split("_")[0]
- else:
- name_str_prefix = "p_outlier"
- else:
- name_str_prefix = name_str
- return name_str_prefix
-
- def _check_if_initval_user_supplied(
- self,
- name_str: str,
- return_value: bool = False,
- ) -> bool | float | int | np.ndarray | dict[str, Any] | None:
- """Check if initial value is user-supplied."""
- # The function assumes that the name_str is either raw parameter name
- # or `paramname_Intercept`, because we only really provide special default
- # initial values for those types of parameters
-
- # `p_outlier` is the only basic parameter floating around that has
- # an underscore in it's name.
- # We need to handle it separately. (Renaming might be better...)
- if "_" in name_str:
- if "p_outlier" not in name_str:
- name_str_prefix = name_str.split("_")[0]
- # name_str_suffix = "".join(name_str.split("_")[1:])
- name_str_suffix = name_str[len(name_str_prefix + "_") :]
- else:
- name_str_prefix = "p_outlier"
- if name_str == "p_outlier":
- name_str_suffix = ""
- else:
- # name_str_suffix = "".join(name_str.split("_")[2:])
- name_str_suffix = name_str[len("p_outlier_") :]
- else:
- name_str_prefix = name_str
- name_str_suffix = ""
-
- tmp_param = name_str_prefix
- if tmp_param == self._parent:
- # If the parameter was parent it is automatically treated as a
- # regression.
- if not name_str_suffix:
- # No suffix --> Intercept
- if isinstance(prior_tmp := self.params[tmp_param].prior, dict):
- args_tmp = getattr(prior_tmp["Intercept"], "args")
- if return_value:
- return args_tmp.get("initval", None)
- else:
- return "initval" in args_tmp
- else:
- if return_value:
- return None
- return False
- else:
- # If the parameter has a suffix --> use it
- if isinstance(prior_tmp := self.params[tmp_param].prior, dict):
- args_tmp = getattr(prior_tmp[name_str_suffix], "args")
- if return_value:
- return args_tmp.get("initval", None)
- else:
- return "initval" in args_tmp
- else:
- if return_value:
- return None
- else:
- return False
- else:
- # If the parameter is not a parent, it is treated as a regression
- # only when actively specified as such.
- if not name_str_suffix:
- # If no suffix --> treat as basic parameter.
- if isinstance(self.params[tmp_param].prior, float) or isinstance(
- self.params[tmp_param].prior, np.ndarray
- ):
- if return_value:
- return self.params[tmp_param].prior
- else:
- return True
- elif isinstance(self.params[tmp_param].prior, bmb.Prior):
- args_tmp = getattr(self.params[tmp_param].prior, "args")
- if "initval" in args_tmp:
- if return_value:
- return args_tmp["initval"]
- else:
- return True
- else:
- if return_value:
- return None
- else:
- return False
- else:
- if return_value:
- return None
- else:
- return False
- else:
- # If suffix --> treat as regression and use suffix
- if isinstance(prior_tmp := self.params[tmp_param].prior, dict):
- args_tmp = getattr(prior_tmp[name_str_suffix], "args")
- if return_value:
- return args_tmp.get("initval", None)
- else:
- return "initval" in args_tmp
- else:
- if return_value:
- return None
- else:
- return False
-
- def _jitter_initvals(
- self, jitter_epsilon: float = 0.01, vector_only: bool = False
- ) -> None:
- """Apply controlled jitter to initial values."""
- if vector_only:
- self.__jitter_initvals_vector_only(jitter_epsilon)
- else:
- self.__jitter_initvals_all(jitter_epsilon)
-
- def __jitter_initvals_vector_only(self, jitter_epsilon: float) -> None:
- # Note: Calling our initial point function here
- # --> operate on untransformed variables
- initial_point_dict = self.initvals
- for name_, starting_value in initial_point_dict.items():
- name_tmp = name_.replace("_log__", "").replace("_interval__", "")
- if starting_value.ndim != 0 and starting_value.shape[0] != 1:
- starting_value_tmp = starting_value + np.random.uniform(
- -jitter_epsilon, jitter_epsilon, starting_value.shape
- ).astype(np.float32)
-
- # Note: self._initvals shouldn't be None when this is called
- dtype = self._initvals[name_tmp].dtype
- self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype)
-
- def __jitter_initvals_all(self, jitter_epsilon: float) -> None:
- # Note: Calling our initial point function here
- # --> operate on untransformed variables
- initial_point_dict = self.initvals
- # initial_point_dict = self.pymc_model.initial_point()
- for name_, starting_value in initial_point_dict.items():
- name_tmp = name_.replace("_log__", "").replace("_interval__", "")
- starting_value_tmp = starting_value + np.random.uniform(
- -jitter_epsilon, jitter_epsilon, starting_value.shape
- ).astype(np.float32)
-
- dtype = self.initvals[name_tmp].dtype
- self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype)
diff --git a/src/hssm/missing_data_mixin.py b/src/hssm/missing_data_mixin.py
new file mode 100644
index 000000000..5c7012d38
--- /dev/null
+++ b/src/hssm/missing_data_mixin.py
@@ -0,0 +1,200 @@
+"""Mixin module for handling missing data and deadline logic in HSSM models."""
+
+import numpy as np
+import pandas as pd
+
+from hssm.defaults import MissingDataNetwork # noqa: F401
+
+
+class MissingDataMixin:
+ """Mixin for handling missing data and deadline logic in HSSM models.
+
+ Parameters
+ ----------
+ missing_data : optional
+ Specifies whether the model should handle missing data. Can be a `bool`
+ or a `float`. If `False`, and if the `rt` column contains -999.0, the
+ model will drop those rows and produce a warning. If `True`, the model
+ will treat -999.0 as missing data. If a `float` is provided, it will be
+ treated as the missing data value. Defaults to `False`.
+ deadline : optional
+ Specifies whether the model should handle deadline data. Can be a `bool`
+ or a `str`. If `False`, the model will not act even if a deadline column
+ is provided. If `True`, the model will treat the `deadline` column as
+ deadline data. If a `str` is provided, it is treated as the name of the
+ deadline column. Defaults to `False`.
+ loglik_missing_data : optional
+ A likelihood function for missing data. See the `loglik` parameter for
+ details. If not provided, a default likelihood is used. Required only if
+ either `missing_data` or `deadline` is not `False`.
+ """
+
+ def _handle_missing_data_and_deadline(self):
+ """Handle missing data and deadline.
+
+ Originally from DataValidatorMixin. Handles dropping, replacing, or masking
+ missing data and deadline values in self.data based on the current settings.
+ """
+ import warnings
+
+ if not self.missing_data and not self.deadline:
+ # In the case of choice only model, we don't need to do anything with the
+ # data.
+ if self.is_choice_only:
+ return
+ # In the case where missing_data is set to False, we need to drop the
+ # cases where rt = na_value
+ if pd.isna(self.missing_data_value):
+ na_dropped = self.data.dropna(subset=["rt"])
+ else:
+ na_dropped = self.data.loc[
+ self.data["rt"] != self.missing_data_value, :
+ ]
+
+ if len(na_dropped) != len(self.data):
+ warnings.warn(
+ "`missing_data` is set to False, "
+ + "but you have missing data in your dataset. "
+ + "Missing data will be dropped.",
+ stacklevel=2,
+ )
+ self.data = na_dropped
+
+ elif self.missing_data and not self.deadline:
+ # In the case where missing_data is set to True, we need to replace the
+ # missing data with a specified na_value
+
+ # Create a shallow copy to avoid modifying the original dataframe
+ if pd.isna(self.missing_data_value):
+ self.data["rt"] = self.data["rt"].fillna(-999.0)
+ else:
+ self.data["rt"] = self.data["rt"].replace(
+ self.missing_data_value, -999.0
+ )
+
+ else: # deadline = True
+ if self.deadline_name not in self.data.columns:
+ raise ValueError(
+ "You have specified that your data has deadline, but "
+ + f"`{self.deadline_name}` is not found in your dataset."
+ )
+ else:
+ self.data.loc[:, "rt"] = np.where(
+ self.data["rt"] < self.data[self.deadline_name],
+ self.data["rt"],
+ -999.0,
+ )
+
+ @staticmethod
+ def _set_missing_data_and_deadline(
+ missing_data: bool, deadline: bool, data: pd.DataFrame
+ ) -> MissingDataNetwork:
+ """Set missing data and deadline."""
+ if not missing_data:
+ return MissingDataNetwork.NONE
+ network = MissingDataNetwork.OPN if deadline else MissingDataNetwork.CPN
+ # AF-TODO: GONOGO case not yet correctly implemented
+ # else:
+ # # TODO: This won't behave as expected yet, GONOGO needs to be split
+ # # into a deadline case and a non-deadline case.
+ # network = MissingDataNetwork.GONOGO
+
+ if np.all(data["rt"] == -999.0):
+ # AF-TODO: I think we should allow invalid-only datasets.
+ raise ValueError(
+ "`missing_data` is set to True, but you have no valid data in your "
+ "dataset."
+ )
+ # AF-TODO: This one needs refinement for GONOGO case
+ # elif network == MissingDataNetwork.OPN:
+ # raise ValueError(
+ # "`deadline` is set to True and `missing_data` is set to True, "
+ # "but ."
+ # )
+ # else:
+ # raise ValueError(
+ # "`missing_data` and `deadline` are both set to True,
+ # "but you have "
+ # "no missing data and/or no rts exceeding the deadline."
+ # )
+ return network
+
+ def _process_missing_data_and_deadline(
+ self, missing_data: float | bool, deadline: bool | str, loglik_missing_data
+ ):
+ """
+ Process missing data and deadline logic for the model's data.
+
+ This method sets up missing data and deadline handling for the model.
+ It updates self.missing_data, self.missing_data_value, self.deadline,
+ self.deadline_name, and self.loglik_missing_data based on the arguments.
+ It also modifies self.data in-place to drop or replace missing/deadline
+ values as appropriate, and sets self.missing_data_network.
+
+ Parameters
+ ----------
+ missing_data : float or bool
+ If True, treat -999.0 as missing data. If a float, use that value
+ as the missing data marker. If False, drop missing data rows.
+ deadline : bool or str
+ If True, use the 'deadline' column for deadline logic. If a str,
+ use that column name. If False, ignore deadline logic.
+ loglik_missing_data : callable or None
+ Optional custom likelihood function for missing data. If not None,
+ must be used only when missing_data or deadline is True.
+ """
+ if isinstance(missing_data, float):
+ if not ((self.data.rt == missing_data).any()):
+ raise ValueError(
+ f"missing_data argument is provided as a float {missing_data}, "
+ f"However, you have no RTs of {missing_data} in your dataset!"
+ )
+ else:
+ self.missing_data = True
+ self.missing_data_value = missing_data
+ elif isinstance(missing_data, bool):
+ if missing_data and (not (self.data.rt == -999.0).any()):
+ raise ValueError(
+ "missing_data argument is provided as True, "
+ " so RTs of -999.0 are treated as missing. \n"
+ "However, you have no RTs of -999.0 in your dataset!"
+ )
+ elif (not missing_data) and (self.data.rt == -999.0).any():
+ raise ValueError(
+ "Missing data provided as False. \n"
+ "However, you have RTs of -999.0 in your dataset!"
+ )
+ else:
+ self.missing_data = missing_data
+ else:
+ raise ValueError(
+ "missing_data argument must be a bool or a float! \n"
+ f"You provided: {type(missing_data)}"
+ )
+
+ if isinstance(deadline, str):
+ self.deadline = True
+ self.deadline_name = deadline
+ else:
+ self.deadline = deadline
+ self.deadline_name = "deadline"
+
+ if (
+ not self.missing_data and not self.deadline
+ ) and loglik_missing_data is not None:
+ raise ValueError(
+ "You have specified a loglik_missing_data function, but you have not "
+ "set the missing_data or deadline flag to True."
+ )
+ self.loglik_missing_data = loglik_missing_data
+
+ # Update data based on missing_data and deadline
+ self._handle_missing_data_and_deadline()
+ # Set self.missing_data_network based on `missing_data` and `deadline`
+ self.missing_data_network = self._set_missing_data_and_deadline(
+ self.missing_data, self.deadline, self.data
+ )
+
+ if self.deadline and self.response is not None: # type: ignore[attr-defined]
+ if self.deadline_name not in self.response: # type: ignore[attr-defined]
+ self.response.append(self.deadline_name) # type: ignore[attr-defined]
diff --git a/src/hssm/param/param.py b/src/hssm/param/param.py
index 4ea8493f5..3d1897dcf 100644
--- a/src/hssm/param/param.py
+++ b/src/hssm/param/param.py
@@ -157,7 +157,7 @@ def is_trialwise(self) -> bool:
def fill_defaults(
self,
- prior: dict[str, Any] | None = None,
+ prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None,
bounds: tuple[float, float] | None = None,
**kwargs,
) -> None:
diff --git a/src/hssm/param/params.py b/src/hssm/param/params.py
index f542c803d..d90de35ff 100644
--- a/src/hssm/param/params.py
+++ b/src/hssm/param/params.py
@@ -213,6 +213,7 @@ def collect_user_params(
user_param = UserParam.from_dict(param) if isinstance(param, dict) else param
if user_param.name is None:
raise ValueError("Parameter name must be specified.")
+ assert model.list_params is not None, "list_params should be set"
if user_param.name not in model.list_params:
raise ValueError(
f"Parameter {user_param.name} not found in list_params."
@@ -229,6 +230,7 @@ def collect_user_params(
# If any of the keys is found in `list_params` it is a parameter specification.
# We add the parameter specification to `user_params` and remove it from
# `kwargs`
+ assert model.list_params is not None, "list_params should be set"
for param_name in model.list_params:
# Update user_params only if param_name is in kwargs
# and not already in user_params
@@ -272,6 +274,7 @@ def make_params(model: HSSM, user_params: dict[str, UserParam]) -> dict[str, Par
and model.loglik_kind != "approx_differentiable"
)
+ assert model.list_params is not None, "list_params should be set"
for name in model.list_params:
if name in user_params:
param = make_param_from_user_param(model, name, user_params[name])
diff --git a/src/hssm/param/regression_param.py b/src/hssm/param/regression_param.py
index c28d7c8b0..64ac18163 100644
--- a/src/hssm/param/regression_param.py
+++ b/src/hssm/param/regression_param.py
@@ -111,7 +111,7 @@ def from_defaults(
def fill_defaults(
self,
- prior: dict[str, Any] | None = None,
+ prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None,
bounds: tuple[float, float] | None = None,
**kwargs,
) -> None:
diff --git a/src/hssm/param/simple_param.py b/src/hssm/param/simple_param.py
index 70d001047..a71528acf 100644
--- a/src/hssm/param/simple_param.py
+++ b/src/hssm/param/simple_param.py
@@ -111,7 +111,7 @@ def from_user_param(cls, user_param: UserParam) -> "SimpleParam":
def fill_defaults(
self,
- prior: dict[str, Any] | None = None,
+ prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None,
bounds: tuple[float, float] | None = None,
**kwargs,
) -> None:
@@ -201,14 +201,17 @@ class DefaultParam(SimpleParam):
def __init__(
self,
name: str,
- prior: float | np.ndarray | dict[str, Any] | bmb.Prior,
- bounds: tuple[float, float],
+ prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None,
+ bounds: tuple[float, float] | None,
) -> None:
super().__init__(name, prior=prior, bounds=bounds)
@classmethod
def from_defaults(
- cls, name: str, prior: dict[str, Any], bounds: tuple[int, int]
+ cls,
+ name: str,
+ prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None,
+ bounds: tuple[float, float] | None,
) -> "DefaultParam":
"""Create a DefaultParam object from default values.
@@ -248,7 +251,7 @@ def process_prior(self) -> None:
def fill_defaults(
self,
- prior: dict[str, Any] | None = None,
+ prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None,
bounds: tuple[float, float] | None = None,
**kwargs,
) -> None:
diff --git a/src/hssm/param/utils.py b/src/hssm/param/utils.py
index 96965f272..6d630f673 100644
--- a/src/hssm/param/utils.py
+++ b/src/hssm/param/utils.py
@@ -26,7 +26,7 @@ def _make_default_prior(bounds: tuple[float, float] | None) -> bmb.Prior:
A bmb.Prior object representing the default prior for the provided bounds.
"""
if bounds is None:
- raise ValueError("Parameter unspecified.")
+ raise ValueError("Bounds parameter unspecified.")
lower, upper = bounds
if np.isinf(lower) and np.isinf(upper):
prior = bmb.Prior("Normal", mu=0.0, sigma=2.0)
diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py
new file mode 100644
index 000000000..64e17bc41
--- /dev/null
+++ b/src/hssm/rl/__init__.py
@@ -0,0 +1,27 @@
+"""Reinforcement-learning extensions for HSSM.
+
+This subpackage groups components that integrate reinforcement-learning
+learning rules with sequential-sampling decision models (SSMs).
+
+Public API (import from ``hssm.rl``):
+
+- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`.
+- ``RLSSMConfig``: the config class for RL + SSM models, implemented in
+ :mod:`hssm.rl.config`.
+- ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`.
+
+RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include
+helpers such as :func:`~hssm.rl.likelihoods.builder.make_rl_logp_func` and
+:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`.
+
+"""
+
+from .config import RLSSMConfig
+from .rlssm import RLSSM
+from .utils import validate_balanced_panel
+
+__all__ = [
+ "RLSSM",
+ "RLSSMConfig",
+ "validate_balanced_panel",
+]
diff --git a/src/hssm/rl/config.py b/src/hssm/rl/config.py
new file mode 100644
index 000000000..17ec2ea18
--- /dev/null
+++ b/src/hssm/rl/config.py
@@ -0,0 +1,219 @@
+"""RL-specific configuration classes.
+
+This module houses `RLSSMConfig` which was previously defined in
+`hssm.config`. It is intentionally lightweight and re-uses
+`BaseModelConfig` from :mod:`hssm.config` to avoid duplicating core
+behaviour.
+"""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import MISSING, dataclass, field, fields
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from .._types import LoglikKind, SupportedModels
+ from ..config import ModelConfig
+
+from ..config import DEFAULT_SSM_CHOICES, DEFAULT_SSM_OBSERVED_DATA, BaseModelConfig
+
+_logger = logging.getLogger("hssm")
+
+
+@dataclass
+class RLSSMConfig(BaseModelConfig):
+ """Config for reinforcement learning + sequential sampling models.
+
+ Extends :class:`BaseModelConfig` with the fields required by the RLSSM
+ likelihood pipeline. The key extra fields are:
+
+ - ``ssm_logp_func``: the annotated JAX SSM log-likelihood function (see
+ below) whose ``computed`` dict drives per-parameter RL computations.
+ - ``learning_process``: a mapping that declares *how* each computed
+ parameter is specified (see below).
+ - ``decision_process``: the name (string) or :class:`ModelConfig` instance
+ that identifies the SSM decision process (e.g. ``"ddm"``, ``"angle"``).
+ - ``decision_process_loglik_kind`` / ``learning_process_kind``: string
+ tags that record which kind of likelihood and which kind of learning rule
+ are used (e.g. ``"approx_differentiable"`` / ``"blackbox"``).
+
+ ssm_logp_func:
+ A JAX function decorated with ``@annotate_function``. It must carry:
+
+ - ``.inputs`` — ordered list of all parameter names the function
+ expects (e.g. ``["v", "a", "z", "t", "theta", "rt", "response"]``).
+ - ``.outputs`` — list of output names (e.g. ``["logp"]``).
+ - ``.computed`` — dict mapping each *computed* parameter name to the
+ annotated function that produces it. For example::
+
+ {"v": compute_v_annotated}
+
+ where ``compute_v_annotated`` is itself decorated with
+ ``@annotate_function`` and carries ``.inputs`` / ``.outputs``.
+
+ ``make_rl_logp_op`` inspects ``ssm_logp_func.computed`` to resolve
+ which parameters come from data / sampled posteriors and which must
+ be computed by the RL learning rule at each gradient step.
+
+ learning_process:
+ A dict keyed by the name of each computed parameter (matching the keys
+ in ``ssm_logp_func.computed``). Values record how that parameter is
+ specified. The dict is intentionally permissive — current supported
+ value forms are:
+
+ - **callable** — an annotated function (or plain function) used to
+ compute the parameter. The actual computation at runtime is driven
+ by ``ssm_logp_func.computed``; this entry serves as declarative
+ documentation and for config serialisation / round-trip::
+
+ learning_process = {"v": compute_v_annotated}
+
+ - **string** — a symbolic identifier for declarative / YAML-based
+ configs that can be resolved to a callable by the caller::
+
+ learning_process = {"v": "subject_wise_function"}
+
+ An empty dict ``{}`` is valid when the SSM has no computed parameters
+ (i.e. ``ssm_logp_func.computed`` is also empty).
+
+ .. note::
+ The dict is *not* directly consumed by ``make_rl_logp_op``.
+ The actual compute functions used at runtime come from
+ ``ssm_logp_func.computed``. ``learning_process`` therefore acts
+ as a config-level record of intent and is useful for inspection,
+ serialisation, and future higher-level tooling.
+ """
+
+ decision_process_loglik_kind: str = field(kw_only=True)
+ learning_process_kind: str = field(kw_only=True)
+ params_default: list[float] = field(kw_only=True)
+ decision_process: str | "ModelConfig" = field(kw_only=True)
+ learning_process: dict[str, Any] = field(kw_only=True)
+ ssm_logp_func: Any = field(default=None, kw_only=True)
+
+ def __post_init__(self): # noqa: D105
+ if self.loglik_kind is None:
+ self.loglik_kind = "approx_differentiable"
+ _logger.debug(
+ "RLSSMConfig: loglik_kind not specified; "
+ "defaulting to 'approx_differentiable'."
+ )
+
+ @classmethod
+ def from_defaults( # noqa: D102
+ cls, model_name: "SupportedModels" | str, loglik_kind: "LoglikKind" | None
+ ) -> "RLSSMConfig":
+ raise NotImplementedError(
+ "RLSSMConfig does not support from_defaults(). "
+ "Use RLSSMConfig.from_rlssm_dict() or the constructor directly."
+ )
+
+ @classmethod
+ def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": # noqa: D102
+ # Derive required fields from the dataclass itself: a field is required
+ # iff it has no default and no default_factory. This keeps the dataclass
+ # as the single source of truth — no separate required-key list needed.
+ field_exceptions = ("loglik", "loglik_kind", "backend")
+ required_fields = [
+ f.name
+ for f in fields(cls)
+ if f.name not in field_exceptions
+ and f.default is MISSING
+ and f.default_factory is MISSING # type: ignore[misc]
+ ]
+ for field_name in required_fields:
+ if field_name not in config_dict or config_dict[field_name] is None:
+ raise ValueError(f"{field_name} must be provided in config_dict")
+
+ # ssm_logp_func has a dataclass default of None but is required in practice.
+ if config_dict.get("ssm_logp_func") is None:
+ raise ValueError("ssm_logp_func must be provided in config_dict")
+
+ init_kwargs = dict(
+ model_name=config_dict["model_name"],
+ description=config_dict.get("description"),
+ list_params=config_dict.get("list_params"),
+ extra_fields=config_dict.get("extra_fields"),
+ params_default=config_dict["params_default"],
+ decision_process=config_dict["decision_process"],
+ learning_process=config_dict["learning_process"],
+ ssm_logp_func=config_dict.get("ssm_logp_func"),
+ bounds=config_dict.get("bounds", {}),
+ decision_process_loglik_kind=config_dict["decision_process_loglik_kind"],
+ learning_process_kind=config_dict["learning_process_kind"],
+ )
+
+ def _get_or_warn(key: str, default: Any) -> None:
+ if key not in config_dict:
+ _logger.warning(
+ "'%s' not specified in the RLSSM config; using default value: %r.",
+ key,
+ default,
+ )
+ init_kwargs[key] = config_dict.get(key, default)
+
+ _get_or_warn("response", DEFAULT_SSM_OBSERVED_DATA)
+ _get_or_warn("choices", DEFAULT_SSM_CHOICES)
+
+ return cls(**init_kwargs)
+
+ def validate(self) -> None: # noqa: D102
+ if self.response is None:
+ raise ValueError("Please provide `response` columns in the configuration.")
+ if self.list_params is None:
+ raise ValueError("Please provide `list_params` in the configuration.")
+ if self.choices is None:
+ raise ValueError("Please provide `choices` in the configuration.")
+ if self.decision_process is None:
+ raise ValueError("Please specify a `decision_process`.")
+
+ logpfunc = self.ssm_logp_func
+ if logpfunc is None:
+ raise ValueError(
+ "Please provide `ssm_logp_func`: the fully annotated JAX SSM "
+ "log-likelihood function required by `make_rl_logp_op`."
+ )
+ if not callable(logpfunc):
+ raise ValueError(
+ f"`ssm_logp_func` must be a callable, but got {type(logpfunc)!r}."
+ )
+ missing_attrs = [
+ attr
+ for attr in ("inputs", "outputs", "computed")
+ if not hasattr(logpfunc, attr)
+ ]
+ if missing_attrs:
+ raise ValueError(
+ "`ssm_logp_func` must be decorated with `@annotate_function` "
+ "so that it carries the attributes required by `make_rl_logp_op`. "
+ f"Missing attribute(s): {missing_attrs}. "
+ )
+
+ if not isinstance(logpfunc.computed, dict) or not all(
+ callable(v) for v in logpfunc.computed.values()
+ ):
+ raise ValueError(
+ "`ssm_logp_func.computed` must be a dictionary with callable values."
+ )
+
+ if self.params_default and self.list_params:
+ if len(self.params_default) != len(self.list_params):
+ raise ValueError(
+ f"params_default length ({len(self.params_default)}) doesn't "
+ f"match list_params length ({len(self.list_params)})"
+ )
+
+ if self.list_params:
+ missing_bounds = [p for p in self.list_params if p not in self.bounds]
+ if missing_bounds:
+ raise ValueError(
+ f"Missing bounds for parameter(s): {missing_bounds}. "
+ "Every parameter in `list_params` must have a corresponding "
+ "entry in `bounds`."
+ )
+
+ def get_defaults( # noqa: D102
+ self, param: str
+ ) -> tuple[float | None, tuple[float, float] | None]:
+ return None, self.bounds.get(param)
diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py
new file mode 100644
index 000000000..180cd12c8
--- /dev/null
+++ b/src/hssm/rl/rlssm.py
@@ -0,0 +1,264 @@
+"""RLSSM: Reinforcement Learning Sequential Sampling Model.
+
+This module defines the :class:`RLSSM` class, a subclass of :class:`HSSMBase`
+for models that couple a reinforcement learning (RL) learning process with a
+sequential sampling decision model (SSM).
+
+The key difference from :class:`HSSM` is the likelihood:
+ - ``HSSM`` wraps an analytical / ONNX / blackbox callable via
+ :func:`~hssm.distribution_utils.make_likelihood_callable`.
+ - ``RLSSM`` builds a differentiable pytensor ``Op`` directly from an
+ :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via
+ :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`, which internally
+ handles the RL learning rule and per-participant trial structure.
+ This Op is then passed straight to
+ :func:`~hssm.distribution_utils.make_distribution`, bypassing the
+ standard ``loglik`` / ``loglik_kind`` wrapping pipeline.
+"""
+
+from dataclasses import replace
+from typing import TYPE_CHECKING, Any, Callable, Literal, cast
+
+import bambi as bmb
+import pandas as pd
+import pymc as pm
+
+if TYPE_CHECKING:
+ from pytensor.graph import Op
+
+
+from hssm.defaults import (
+ INITVAL_JITTER_SETTINGS,
+)
+from hssm.distribution_utils import make_distribution
+from hssm.rl.likelihoods.builder import make_rl_logp_op
+from hssm.rl.utils import validate_balanced_panel
+
+from ..base import HSSMBase
+from .config import RLSSMConfig
+
+
+class RLSSM(HSSMBase):
+ """Reinforcement Learning Sequential Sampling Model.
+
+ Combines a reinforcement learning (RL) process with a sequential sampling
+ model (SSM) inside a single differentiable likelihood. The RL component
+ computes trial-wise intermediate parameters (e.g., drift rates) from the
+ learning history, which are then fed into the SSM log-likelihood.
+
+ The likelihood is built via
+ :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` from the annotated
+ SSM function stored in *model_config.ssm_logp_func*. This produces a
+ differentiable pytensor ``Op`` that is passed directly to
+ :func:`~hssm.distribution_utils.make_distribution`, superseding the
+ ``loglik`` / ``loglik_kind`` dispatching used by :class:`~hssm.hssm.HSSM`.
+
+ Parameters
+ ----------
+ data : pd.DataFrame
+ Trial-level data. Must contain at least the response columns
+ specified in *model_config* (typically ``"rt"`` and ``"response"``),
+ a participant identifier column (default ``"participant_id"``), and
+ any extra fields listed in *model_config.extra_fields*.
+ The data **must** form a balanced panel: every participant must have
+ the same number of trials.
+ model_config : RLSSMConfig
+ Full configuration for the RLSSM model. Must have ``ssm_logp_func``
+ set to the annotated JAX SSM log-likelihood function.
+ participant_col : str, optional
+ Name of the column that uniquely identifies participants.
+ Used to infer ``n_participants`` and ``n_trials`` from *data*.
+ Defaults to ``"participant_id"``.
+ include : list, optional
+ Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`.
+ p_outlier : float | dict | bmb.Prior | None, optional
+ Lapse probability specification. Defaults to ``0.05``.
+ lapse : dict | bmb.Prior | None, optional
+ Lapse distribution. Defaults to ``Uniform(0, 20)``.
+ link_settings : Literal["log_logit"] | None, optional
+ Link-function preset. Defaults to ``None``.
+ prior_settings : Literal["safe"] | None, optional
+ Prior preset. Defaults to ``"safe"``.
+ extra_namespace : dict | None, optional
+ Extra variables for formula evaluation. Defaults to ``None``.
+ missing_data : bool | float, optional
+ Whether to handle missing RT data coded as ``-999.0``.
+ Defaults to ``False``.
+ deadline : bool | str, optional
+ Whether to handle deadline data. Defaults to ``False``.
+ loglik_missing_data : Callable | None, optional
+ Custom likelihood for missing observations. Defaults to ``None``.
+ process_initvals : bool, optional
+ Whether to post-process initial values. Defaults to ``True``.
+ initval_jitter : float, optional
+ Jitter magnitude for initial values.
+ Defaults to :data:`~hssm.defaults.INITVAL_JITTER_SETTINGS` epsilon.
+ **kwargs
+ Additional keyword arguments forwarded to :class:`bmb.Model`.
+
+ Attributes
+ ----------
+ model_config : RLSSMConfig
+ The RLSSM configuration object (stored as ``self.model_config`` on
+ :class:`~hssm.base.HSSMBase` with the built ``loglik`` Op injected).
+ n_participants : int
+ Number of participants inferred from *data*.
+ n_trials : int
+ Number of trials per participant inferred from *data*.
+ """
+
+ def __init__(
+ self,
+ data: pd.DataFrame,
+ model_config: RLSSMConfig,
+ participant_col: str = "participant_id",
+ include: list[dict[str, Any] | Any] | None = None,
+ p_outlier: float | dict | bmb.Prior | None = 0.05,
+ lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0),
+ link_settings: Literal["log_logit"] | None = None,
+ prior_settings: Literal["safe"] | None = "safe",
+ extra_namespace: dict[str, Any] | None = None,
+ missing_data: bool | float = False,
+ deadline: bool | str = False,
+ loglik_missing_data: Callable | None = None,
+ process_initvals: bool = True,
+ initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"],
+ **kwargs: Any,
+ ) -> None:
+ # ===== save/load serialisation =====
+ self._init_args = self._store_init_args(locals(), kwargs)
+
+ # Validate config (ensures ssm_logp_func is present, etc.)
+ model_config.validate()
+
+ # RLSSM reshapes rows into (n_participants, n_trials, ...) by position,
+ # so _rearrange_data (which moves missing/deadline rows to the front)
+ # would scramble per-participant trial sequences and corrupt RL dynamics.
+ # Raise early so the user gets a clear message before model construction.
+ if missing_data is not False:
+ raise NotImplementedError(
+ "RLSSM currently does not support `missing_data` handling. "
+ "The RL log-likelihood Op relies on strict row order to recover "
+ "per-participant trial sequences; rearranging rows for missing RT "
+ "values would corrupt the RL learning dynamics. "
+ "Please remove missing trials from the data before passing it to RLSSM."
+ )
+ if deadline is not False:
+ raise NotImplementedError(
+ "RLSSM currently does not support `deadline` handling. "
+ "The RL log-likelihood Op relies on strict row order to recover "
+ "per-participant trial sequences; rearranging rows for deadline "
+ "trials would corrupt the RL learning dynamics. Please remove "
+ "deadline trials from the data before passing it to RLSSM."
+ )
+
+ # Infer panel structure and validate balance BEFORE calling super so any
+ # error surfaces before the expensive model-build steps.
+ n_participants, n_trials = validate_balanced_panel(data, participant_col)
+
+ # Store RL-specific state on self BEFORE super().__init__() so that
+ # _make_model_distribution() (called from super) can access them.
+ self.config = model_config
+ self.n_participants = n_participants
+ self.n_trials = n_trials
+
+ # Build the differentiable pytensor Op from the annotated SSM function.
+ # This Op supersedes the loglik/loglik_kind workflow: it is stored on
+ # rlssm_config.loglik so that HSSMBase can access it uniformly via
+ # self.model_config.loglik, without any Config conversion.
+ #
+ # Fresh list() copies are passed to make_rl_logp_op so the closure inside
+ # captures its own isolated list objects. HSSMBase will later append
+ # "p_outlier" to self.list_params, and that mutation must NOT be visible
+ # to the Op's _validate_args_length check at sampling time.
+ loglik_op = make_rl_logp_op(
+ ssm_logp_func=model_config.ssm_logp_func,
+ n_participants=n_participants,
+ n_trials=n_trials,
+ data_cols=list(model_config.response), # type: ignore[arg-type]
+ list_params=list(model_config.list_params), # type: ignore[arg-type]
+ extra_fields=list(model_config.extra_fields or []),
+ )
+
+ # Build a new RLSSMConfig with the Op and backend injected, leaving
+ # the caller's object unmodified (dataclasses.replace creates a shallow
+ # copy with only the specified fields overridden).
+ #
+ # backend is hardcoded to "jax" because the entire RLSSM likelihood
+ # stack is JAX-only. See ssm_logp_func, make_rl_logp_op, and
+ # _make_model_distribution for details.
+ model_config = replace(model_config, loglik=loglik_op, backend="jax")
+
+ super().__init__(
+ data=data,
+ model_config=model_config,
+ include=include,
+ p_outlier=p_outlier,
+ lapse=lapse,
+ link_settings=link_settings,
+ prior_settings=prior_settings,
+ extra_namespace=extra_namespace,
+ missing_data=missing_data,
+ deadline=deadline,
+ loglik_missing_data=loglik_missing_data,
+ process_initvals=process_initvals,
+ initval_jitter=initval_jitter,
+ **kwargs,
+ )
+
+ def _make_model_distribution(self) -> type[pm.Distribution]:
+ """Build a pm.Distribution using the pre-built RL log-likelihood Op.
+
+ Unlike :meth:`HSSM._make_model_distribution`, this method does not go
+ through :func:`~hssm.distribution_utils.make_likelihood_callable`.
+ Instead it uses ``self.loglik`` directly — the differentiable pytensor
+ ``Op`` built in :meth:`__init__` from
+ ``self.model_config.ssm_logp_func``.
+
+ The Op already handles:
+ - The RL learning rule (computing trial-wise intermediate parameters).
+ - The per-participant / per-trial data reshaping.
+ - Gradient computation via its embedded VJP.
+
+ Missing-data network assembly (OPN / CPN) is not yet supported for
+ RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__``
+ before this method is ever reached.
+ """
+ # Use self.list_params (managed by HSSMBase, includes p_outlier when
+ # has_lapse=True) rather than self.model_config.list_params (the original
+ # config list, never mutated by HSSMBase).
+ list_params = self.list_params
+ assert list_params is not None, "list_params must be set"
+ assert isinstance(list_params, list), (
+ "list_params must be a list"
+ ) # for type checker
+
+ # Every RLSSM distribution parameter is trialwise (the Op receives one
+ # value per trial). p_outlier is excluded to match the contract of
+ # make_distribution, which strips p_outlier before indexing this list.
+ params_is_trialwise = [True for name in list_params if name != "p_outlier"]
+
+ extra_fields = self.model_config.extra_fields or []
+ extra_fields_data = (
+ None
+ if not extra_fields
+ else [self.data[field].to_numpy(copy=True) for field in extra_fields]
+ )
+
+ # The differentiable pytensor Op was stored on model_config.loglik during
+ # __init__; ensure it's present and cast for typing.
+ assert self.model_config.loglik is not None, "model_config.loglik must be set"
+ loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik)
+
+ # RLSSMConfig carries no `rv` field; use model_name as the rv identifier.
+ rv_name = self.model_config.model_name
+
+ return make_distribution(
+ rv=rv_name,
+ loglik=loglik_op,
+ list_params=list_params,
+ bounds=self.bounds,
+ lapse=self.lapse,
+ extra_fields=extra_fields_data,
+ params_is_trialwise=params_is_trialwise,
+ )
diff --git a/src/hssm/rl/utils.py b/src/hssm/rl/utils.py
new file mode 100644
index 000000000..d35f08a90
--- /dev/null
+++ b/src/hssm/rl/utils.py
@@ -0,0 +1,70 @@
+"""Utility functions for reinforcement learning + SSM models."""
+
+import pandas as pd
+
+
+def validate_balanced_panel(
+ data: pd.DataFrame,
+ participant_col: str = "participant_id",
+) -> tuple[int, int]:
+ """Validate that data forms a balanced panel and return its shape.
+
+ A balanced panel requires every participant to have exactly the same number
+ of trials (rows in *data*).
+
+ Parameters
+ ----------
+ data : pd.DataFrame
+ The DataFrame to validate.
+ participant_col : str, optional
+ Name of the column that identifies participants.
+ Defaults to ``"participant_id"``.
+
+ Returns
+ -------
+ tuple[int, int]
+ ``(n_participants, n_trials)`` where *n_trials* is the number of rows
+ per participant.
+
+ Raises
+ ------
+ ValueError
+ If *participant_col* is not present in *data*, or if the panel is
+ unbalanced (participants have different trial counts).
+ """
+ if participant_col not in data.columns:
+ raise ValueError(
+ f"Column '{participant_col}' not found in data. "
+ "Please provide the correct participant column name via "
+ "`participant_col`."
+ )
+
+ n_null = data[participant_col].isna().sum()
+ if n_null > 0:
+ raise ValueError(
+ f"Column '{participant_col}' contains {n_null} NaN value(s). "
+ "All rows must have a valid participant identifier."
+ )
+
+ counts = data.groupby(participant_col).size()
+ if counts.nunique() != 1:
+ raise ValueError(
+ "Data must form balanced panels: all participants must have the "
+ f"same number of trials. Observed trial counts: {dict(counts)}"
+ )
+
+ # Check that each participant's rows are contiguous (not interleaved).
+ # make_rl_logp_op reshapes data as (n_participants, n_trials, ...) by row
+ # order, so interleaved rows would silently mix subjects and corrupt the
+ # RL learning dynamics.
+ n_runs = int((data[participant_col] != data[participant_col].shift()).sum())
+ if n_runs != len(counts):
+ raise ValueError(
+ "Data rows must be contiguous per participant. "
+ "The RL likelihood reshapes data by row position; interleaved "
+ "participant rows will corrupt per-participant trial sequences. "
+ "Please sort the data by participant before passing it to RLSSM "
+ f"(e.g. data.sort_values('{participant_col}'))."
+ )
+
+ return int(len(counts)), int(counts.iloc[0])
diff --git a/tests/param/test_default_param.py b/tests/param/test_default_param.py
index 2a90be3f1..3f86ebaa7 100644
--- a/tests/param/test_default_param.py
+++ b/tests/param/test_default_param.py
@@ -4,6 +4,7 @@
from hssm import Prior
from hssm.param.simple_param import DefaultParam
+from hssm.param.utils import _make_default_prior
def test_from_defaults():
@@ -40,3 +41,7 @@ def test_make_default_prior(bounds, prior):
assert param.prior.name == prior.pop("name")
for key, value in prior.items():
assert param.prior.args[key] == value
+
+
+def test_make_default_prior_no_bounds():
+ pytest.raises(ValueError, _make_default_prior, None)
diff --git a/tests/test_config.py b/tests/test_config.py
index 77d8284b6..5685ca2b4 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,8 +1,11 @@
-import numpy as np
+import logging
+
import pytest
+import numpy as np
-import hssm
from hssm.config import Config, ModelConfig
+import hssm
+
hssm.set_floatX("float32")
@@ -87,3 +90,60 @@ def test_update_config():
assert v_prior.name == "Normal"
assert v_bounds == (-np.inf, np.inf)
+
+
+class TestConfigBuildModelConfigExtraLogic:
+ def test_build_model_config_dict_with_choices_conflict(self, caplog):
+ # model 'ddm' has defaults in hssm.defaults; use a minimal dict override
+ model_config = {
+ "response": ("rt", "response"),
+ "list_params": ["v", "a"],
+ "choices": (0, 1),
+ }
+ # provide a different choices argument — should log that model_config wins
+ with caplog.at_level(logging.INFO):
+ cfg = Config._build_model_config("ddm", None, model_config, choices=[1, 0])
+
+ assert isinstance(cfg, Config)
+ assert "choices list provided in both model_config" in caplog.text
+
+ def test_build_model_config_modelconfig_adds_choices(self):
+ # Create a ModelConfig without choices and pass choices argument
+ mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None)
+ cfg = Config._build_model_config("ddm", None, mc, choices=(0, 1))
+ # choices should be applied to resulting Config
+ assert cfg.choices == (0, 1)
+
+ def test_build_model_config_uses_ssms_model_config(self, monkeypatch):
+ # High-level view of the test: ensures that when a model name is not in the built-in
+ # SupportedModels and no choices argument is passed, _build_model_config will consult
+ # the external ssms_model_config registry and use its defaults (here, the choices tuple).
+ # The monkeypatch fixture isolates the change and will be undone after the test.
+
+ # Simulate an external ssms_model_config entry for a model not in SupportedModels
+ fake_model = "external_ssm"
+ fake_choices = (2, 3)
+
+ # Monkeypatch the ssms_model_config mapping in the module
+ import hssm.config as cfgmod
+
+ # Emulate an external package registering defaults for external_ssm.
+ # Ensures `_build_model_config` will consult `ssms_model_config`
+ # when the model name isn't in SupportedModels.
+ monkeypatch.setitem(
+ cfgmod.ssms_model_config, fake_model, {"choices": fake_choices}
+ )
+
+ # Build config with model not in SupportedModels and no choices arg.
+ # Provide a minimal ModelConfig and a dummy `loglik` so
+ # `Config.validate()` runs (loglik is required) while still
+ # exercising the ssms-simulators choices fallback.
+ mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None)
+ result = Config._build_model_config(
+ fake_model,
+ "analytical",
+ mc,
+ choices=None,
+ loglik=(lambda *a, **k: None), # required so Config.validate() passes
+ )
+ assert result.choices == fake_choices
diff --git a/tests/test_data_validator.py b/tests/test_data_validator.py
index af176a8e4..e37151296 100644
--- a/tests/test_data_validator.py
+++ b/tests/test_data_validator.py
@@ -152,29 +152,6 @@ def test_post_check_data_sanity_valid(base_data):
dv_instance_no_missing._post_check_data_sanity()
-def test_handle_missing_data_and_deadline_deadline_column_missing(base_data):
- # Should raise ValueError if deadline is True but deadline_name column is missing
- data = base_data.drop(columns=["deadline"])
- dv = DataValidatorTester(
- data=data,
- deadline=True,
- )
- with pytest.raises(ValueError, match="`deadline` is not found in your dataset"):
- dv._handle_missing_data_and_deadline()
-
-
-def test_handle_missing_data_and_deadline_deadline_applied(base_data):
- # Should set rt to -999.0 where rt >= deadline
- base_data.loc[0, "rt"] = 2.0 # Exceeds deadline
- dv = DataValidatorTester(
- data=base_data,
- deadline=True,
- )
- dv._handle_missing_data_and_deadline()
- assert dv.data.loc[0, "rt"] == -999.0
- assert all(dv.data.loc[1:, "rt"] < dv.data.loc[1:, "deadline"])
-
-
def test_update_extra_fields(monkeypatch):
# Create a DataValidatorTester with extra_fields
data = pd.DataFrame(
@@ -207,61 +184,6 @@ class DummyModelDist:
assert (dv.model_distribution.extra_fields[i] == data[field].values).all()
-def test_set_missing_data_and_deadline():
- # No missing data and no deadline
- data = pd.DataFrame({"rt": [0.5, 0.7]})
- assert (
- DataValidatorMixin._set_missing_data_and_deadline(False, False, data)
- == MissingDataNetwork.NONE
- )
- # Missing data but no deadline
- data = pd.DataFrame({"rt": [0.5, -999.0]})
- assert (
- DataValidatorMixin._set_missing_data_and_deadline(True, False, data)
- == MissingDataNetwork.CPN
- )
- assert (
- DataValidatorMixin._set_missing_data_and_deadline(True, True, data)
- == MissingDataNetwork.OPN
- )
- # AF-TODO: I think GONOGO as a network category can go,
- # but needs a little more thought, out of scope for PR,
- # during which this was commented out.
- # assert (
- # DataValidatorMixin._set_missing_data_and_deadline(True, True, data)
- # == MissingDataNetwork.GONOGO
- # )
-
-
-def test_set_missing_data_and_deadline_all_missing():
- data = pd.DataFrame({"rt": [-999.0, -999.0]})
- # cpn
- with pytest.raises(
- ValueError,
- match="`missing_data` is set to True, but you have no valid data in your "
- "dataset.",
- ):
- DataValidatorMixin._set_missing_data_and_deadline(True, False, data)
-
- # opn
- with pytest.raises(
- ValueError,
- match="`missing_data` is set to True, but you have no valid data in your "
- "dataset.",
- ):
- DataValidatorMixin._set_missing_data_and_deadline(True, True, data)
-
- # AF-TODO: GONOGO case not yet correctly implemented
- # gonogo
- # data = pd.DataFrame({"rt": [-999.0, -999.0]})
- # with pytest.raises(
- # ValueError,
- # match="`missing_data` is set to True, but you have no valid data in your "
- # + "dataset.",
- # ):
- # DataValidatorMixin._set_missing_data_and_deadline(True, True, data)
-
-
def test_validate_choices():
# ====== Valid choices =====
dv = DataValidatorTester(
diff --git a/tests/test_hssm.py b/tests/test_hssm.py
index 2af50779c..c3be947b2 100644
--- a/tests/test_hssm.py
+++ b/tests/test_hssm.py
@@ -86,21 +86,19 @@ def test_custom_model(data_ddm):
with pytest.raises(
ValueError, match="When using a custom model, please provide a `loglik_kind.`"
):
- model = HSSM(data=data_ddm, model="custom")
+ HSSM(data=data_ddm, model="custom")
- with pytest.raises(ValueError, match="Please provide `list_params`"):
- model = HSSM(data=data_ddm, model="custom", loglik_kind="analytical")
+ with pytest.raises(ValueError, match=r"^Please provide `list_params`"):
+ HSSM(data=data_ddm, model="custom", loglik_kind="analytical")
- with pytest.raises(ValueError, match="Please provide `list_params`"):
- model = HSSM(
- data=data_ddm, model="custom", loglik=DDM, loglik_kind="analytical"
- )
+ with pytest.raises(ValueError, match=r"^Please provide `list_params`"):
+ HSSM(data=data_ddm, model="custom", loglik=DDM, loglik_kind="analytical")
with pytest.raises(
ValueError,
- match="Please provide `list_params`",
+ match=r"^Please provide `list_params`",
):
- model = HSSM(
+ HSSM(
data=data_ddm,
model="custom",
loglik=DDM,
@@ -125,9 +123,9 @@ def test_custom_model(data_ddm):
loglik_kind="analytical",
)
- assert model.model_name == "custom"
- assert model.loglik_kind == "analytical"
- assert model.list_params == ["v", "a", "z", "t", "p_outlier"]
+ assert model.model_config.model_name == "custom"
+ assert model.model_config.loglik_kind == "analytical"
+ assert model.model_config.list_params == ["v", "a", "z", "t", "p_outlier"]
@pytest.mark.slow
@@ -165,12 +163,12 @@ def test_sample_prior_predictive(data_ddm_reg):
model_regression = HSSM(
data=data_ddm_reg, include=[dict(name="v", formula="v ~ 1 + x")]
)
- prior_predictive_3 = model_regression.sample_prior_predictive(draws=10)
+ model_regression.sample_prior_predictive(draws=10)
model_regression_a = HSSM(
data=data_ddm_reg, include=[dict(name="a", formula="a ~ 1 + x")]
)
- prior_predictive_4 = model_regression_a.sample_prior_predictive(draws=10)
+ model_regression_a.sample_prior_predictive(draws=10)
model_regression_multi = HSSM(
data=data_ddm_reg,
@@ -179,7 +177,7 @@ def test_sample_prior_predictive(data_ddm_reg):
dict(name="a", formula="a ~ 1 + y"),
],
)
- prior_predictive_5 = model_regression_multi.sample_prior_predictive(draws=10)
+ model_regression_multi.sample_prior_predictive(draws=10)
data_ddm_reg.loc[:, "subject_id"] = np.arange(10)
@@ -190,9 +188,7 @@ def test_sample_prior_predictive(data_ddm_reg):
dict(name="a", formula="a ~ (1|subject_id) + y"),
],
)
- prior_predictive_6 = model_regression_random_effect.sample_prior_predictive(
- draws=10
- )
+ model_regression_random_effect.sample_prior_predictive(draws=10)
@pytest.mark.slow
diff --git a/tests/test_missing_data_mixin.py b/tests/test_missing_data_mixin.py
new file mode 100644
index 000000000..a5476a52b
--- /dev/null
+++ b/tests/test_missing_data_mixin.py
@@ -0,0 +1,221 @@
+import pytest
+import pandas as pd
+
+from hssm.missing_data_mixin import MissingDataMixin
+from hssm.defaults import MissingDataNetwork
+
+
+class DummyModel(MissingDataMixin):
+ """
+ Dummy model for testing MissingDataMixin.
+
+ This class provides stub implementations of methods that the mixin expects
+ to exist on the consuming class. These stubs allow us to verify, via mocks/spies,
+ that the mixin calls them as part of its logic. This is a common pattern for
+ testing mixins: the dummy class provides the required interface, and the test
+ checks the mixin's interaction with it.
+ """
+
+ def __init__(self, data):
+ self.data = data
+ self.response = ["response"]
+ self.missing_data_value = -999.0
+ self.missing_data = False
+ self.deadline = False
+ self.is_choice_only = False
+
+
+# region ===== Fixtures =====
+@pytest.fixture
+def basic_data():
+ return pd.DataFrame({"rt": [1.0, 2.0, -999.0], "response": [1, -1, 1]})
+
+
+@pytest.fixture
+def dummy_model(basic_data):
+ return DummyModel(basic_data)
+
+
+@pytest.fixture
+def dummy_model_with_deadline(basic_data):
+ data = basic_data.assign(deadline=[2.0, 2.0, 2.0])
+ return DummyModel(data)
+
+
+# Indirect fixture dispatcher for parameterized model selection
+@pytest.fixture
+def model(request):
+ return request.getfixturevalue(request.param)
+
+
+# endregion
+
+
+class TestProcessMissingDataAndDeadline:
+ @pytest.mark.parametrize(
+ "model, deadline",
+ [
+ ("dummy_model", False),
+ ("dummy_model_with_deadline", True),
+ ("dummy_model_with_deadline", "deadline"),
+ ],
+ indirect=["model"],
+ )
+ def test_missing_data_false_raises_valueerror(self, model, deadline):
+ """
+ Should raise ValueError if missing_data=False and -999.0 is present in rt column.
+ Covers all cases where deadline is False, True, or a string.
+ """
+ with pytest.raises(ValueError, match="Missing data provided as False"):
+ model._process_missing_data_and_deadline(
+ missing_data=False,
+ deadline=deadline,
+ loglik_missing_data=None,
+ )
+
+
+# --- 2. Additional tests for new features and edge cases in MissingDataMixin ---
+class TestMissingDataMixinNew:
+ def test_set_missing_data_network_set(self, dummy_model):
+ # missing_data True, deadline False
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=True, deadline=False, loglik_missing_data=None
+ )
+ assert dummy_model.missing_data_network == MissingDataNetwork.CPN
+
+ # missing_data True, deadline True
+ dummy_model.data["deadline"] = 2.0
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=True, deadline=True, loglik_missing_data=None
+ )
+ assert dummy_model.missing_data_network == MissingDataNetwork.OPN
+
+ # missing_data False, deadline False (should raise ValueError due to -999.0 present)
+ with pytest.raises(ValueError, match="Missing data provided as False"):
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=False, deadline=False, loglik_missing_data=None
+ )
+
+ def test_response_appended_with_deadline_name(self, dummy_model):
+ # Should append deadline_name to response if not present
+ dummy_model.data["deadline"] = 2.0
+ dummy_model.response = ["response"]
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=True, deadline="deadline", loglik_missing_data=None
+ )
+ assert "deadline" in dummy_model.response
+
+ def test_error_on_missing_data_false_with_missing(self, dummy_model):
+ # Should raise ValueError if missing_data is False and -999.0 is present
+ with pytest.raises(ValueError, match="Missing data provided as False"):
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=False, deadline=False, loglik_missing_data=None
+ )
+
+ def test_missing_data_true_retains_missing_marker(self, dummy_model):
+ # Should retain -999.0 as missing marker if missing_data is True
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=True, deadline=False, loglik_missing_data=None
+ )
+ assert -999.0 in dummy_model.data.rt.values
+
+ def test_deadline_sets_rt_to_missing_marker(self, dummy_model):
+ # Should set rt to -999.0 if above deadline
+ # Set up so that the second RT is above its deadline
+ dummy_model.data["rt"] = [1.0, 3.0, -999.0] # 3.0 > 2.5
+ dummy_model.data["deadline"] = [1.5, 2.5, 2.5]
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=True, deadline="deadline", loglik_missing_data=None
+ )
+ # The first row rt=1.0 < 1.5, so not -999.0; second should be -999.0; third is already -999.0
+ assert dummy_model.data.rt.iloc[0] == 1.0
+ assert dummy_model.data.rt.iloc[1] == -999.0
+ assert dummy_model.data.rt.iloc[2] == -999.0
+
+ def test_loglik_missing_data_error(self, dummy_model):
+ # Should raise if loglik_missing_data is set but both missing_data and deadline are False
+ dummy_model.data.rt = [1.0, 2.0, 3.0] # No -999.0 present
+ with pytest.raises(
+ ValueError,
+ match="loglik_missing_data function, but you have not set the missing_data or deadline flag to True",
+ ):
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=False, deadline=False, loglik_missing_data=lambda x: x
+ )
+
+ def test_process_missing_data_and_deadline_updates_attributes(self, dummy_model):
+ """
+ Test that _process_missing_data_and_deadline updates missing_data, deadline, deadline_name, and loglik_missing_data.
+ """
+
+ # Set up a custom loglik function
+ def custom_loglik(x):
+ return x
+
+ # Add a custom_deadline column to the data to satisfy the mixin's requirements
+ dummy_model.data["custom_deadline"] = 2.0
+ # Call with missing_data True, deadline as string, and custom loglik
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=True,
+ deadline="custom_deadline",
+ loglik_missing_data=custom_loglik,
+ )
+ assert dummy_model.missing_data is True
+ assert dummy_model.deadline is True
+ assert dummy_model.deadline_name == "custom_deadline"
+ assert dummy_model.loglik_missing_data is custom_loglik
+
+ def test_missing_data_value_custom(self, dummy_model):
+ custom_missing = -123.0
+ # Add a row with custom missing value
+ dummy_model.data.loc[len(dummy_model.data)] = [custom_missing, 1]
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=custom_missing,
+ deadline=False,
+ loglik_missing_data=None,
+ )
+ assert dummy_model.missing_data is True
+ assert dummy_model.missing_data_value == custom_missing
+ # After processing, custom missing values are replaced with -999.0
+ assert (dummy_model.data.rt == -999.0).any()
+
+ def test_deadline_column_added_once(self, dummy_model, basic_data):
+ # Add a deadline_col to the data
+ data = basic_data.assign(deadline_col=[2.0, 2.0, 2.0])
+ dummy_model.data = data
+ # Add deadline_col to response already
+ dummy_model.response.append("deadline_col")
+ # Should raise ValueError due to -999.0 in rt when missing_data=False
+ with pytest.raises(ValueError, match="Missing data provided as False"):
+ dummy_model._process_missing_data_and_deadline(
+ missing_data=False,
+ deadline="deadline_col",
+ loglik_missing_data=None,
+ )
+
+ def test_missing_data_and_deadline_together(self, dummy_model_with_deadline):
+ # Should set both flags
+ dummy_model_with_deadline._process_missing_data_and_deadline(
+ missing_data=True,
+ deadline=True,
+ loglik_missing_data=None,
+ )
+ assert dummy_model_with_deadline.missing_data is True
+ assert dummy_model_with_deadline.deadline is True
+ assert dummy_model_with_deadline.deadline_name == "deadline"
+
+ def test_handle_missing_data_and_deadline_direct(self, dummy_model):
+ """
+ Directly test the _handle_missing_data_and_deadline method for coverage.
+ """
+ # Call with no arguments, as expected by the mixin stub
+ dummy_model._handle_missing_data_and_deadline()
+
+ def test_set_missing_data_and_deadline_edge_case(self, dummy_model):
+ all_missing = pd.DataFrame({"rt": [-999.0]})
+ with pytest.raises(ValueError, match="no valid data in your dataset"):
+ dummy_model._set_missing_data_and_deadline(
+ missing_data=True,
+ deadline=False,
+ data=all_missing,
+ )
diff --git a/tests/test_rl_utils.py b/tests/test_rl_utils.py
new file mode 100644
index 000000000..554fabd19
--- /dev/null
+++ b/tests/test_rl_utils.py
@@ -0,0 +1,107 @@
+"""Tests for hssm.rl.utils — validate_balanced_panel."""
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from hssm.rl.utils import validate_balanced_panel
+
+
+def _make_panel(
+ n_participants: int, n_trials: int, participant_col: str = "participant_id"
+) -> pd.DataFrame:
+ """Return a perfectly balanced, contiguous panel DataFrame."""
+ ids = np.repeat(range(n_participants), n_trials)
+ return pd.DataFrame(
+ {participant_col: ids, "rt": np.random.rand(n_participants * n_trials)}
+ )
+
+
+class TestValidateBalancedPanelHappyPath:
+ def test_returns_correct_shape(self) -> None:
+ """Returns (n_participants, n_trials) for a balanced panel."""
+ df = _make_panel(5, 20)
+ n_p, n_t = validate_balanced_panel(df)
+ assert n_p == 5
+ assert n_t == 20
+
+ def test_single_participant(self) -> None:
+ """Single-participant panel is trivially balanced."""
+ df = _make_panel(1, 10)
+ n_p, n_t = validate_balanced_panel(df)
+ assert n_p == 1
+ assert n_t == 10
+
+ def test_custom_participant_col(self) -> None:
+ """Works when participant column has a non-default name."""
+ df = _make_panel(3, 8, participant_col="subj_id")
+ n_p, n_t = validate_balanced_panel(df, participant_col="subj_id")
+ assert n_p == 3
+ assert n_t == 8
+
+
+class TestValidateBalancedPanelMissingColumn:
+ def test_missing_participant_col_raises(self) -> None:
+ """Raises ValueError when participant_col is absent from the DataFrame."""
+ df = pd.DataFrame({"rt": [0.5, 0.6]})
+ with pytest.raises(ValueError, match="not found in data"):
+ validate_balanced_panel(df)
+
+ def test_wrong_participant_col_name_raises(self) -> None:
+ """Raises ValueError when a wrong column name is supplied."""
+ df = _make_panel(2, 5)
+ with pytest.raises(ValueError, match="not found in data"):
+ validate_balanced_panel(df, participant_col="subject")
+
+
+class TestValidateBalancedPanelNaN:
+ def test_nan_participant_id_raises(self) -> None:
+ """Raises ValueError when participant_col contains NaN."""
+ df = _make_panel(3, 4)
+ df.loc[df.index[0], "participant_id"] = float("nan")
+ with pytest.raises(ValueError, match="NaN"):
+ validate_balanced_panel(df)
+
+
+class TestValidateBalancedPanelUnbalanced:
+ def test_unbalanced_panel_raises(self) -> None:
+ """Raises ValueError when participants have different trial counts."""
+ df = _make_panel(3, 5)
+ unbalanced = df.iloc[:-1].copy() # drop one row → participant 2 has 4 trials
+ with pytest.raises(ValueError, match="balanced panels"):
+ validate_balanced_panel(unbalanced)
+
+ def test_one_participant_fewer_trials_raises(self) -> None:
+ """Raises ValueError when one participant has fewer trials than others."""
+ df = _make_panel(3, 5)
+ # Drop last 2 rows of participant 2 so counts differ.
+ mask = ~(
+ (df["participant_id"] == 2)
+ & (df.index >= df[df["participant_id"] == 2].index[-2])
+ )
+ unbalanced = df[mask].copy()
+ with pytest.raises(ValueError, match="balanced panels"):
+ validate_balanced_panel(unbalanced)
+
+
+class TestValidateBalancedPanelContiguity:
+ def test_interleaved_participants_raises(self) -> None:
+ """Raises ValueError when participants' rows are interleaved (not contiguous).
+
+ The RL likelihood reshapes data as (n_participants, n_trials, ...) by row
+ position, so interleaved rows would silently corrupt trial sequences.
+ """
+ # Build interleaved data: [0, 1, 2, 0, 1, 2, ...] (3 participants × 4 trials)
+ ids = np.tile([0, 1, 2], 4)
+ df = pd.DataFrame({"participant_id": ids, "rt": np.random.rand(12)})
+ with pytest.raises(ValueError, match="contiguous"):
+ validate_balanced_panel(df)
+
+ def test_sorted_data_passes(self) -> None:
+ """Sorting an interleaved panel by participant_id makes it valid."""
+ ids = np.tile([0, 1, 2], 4)
+ df = pd.DataFrame({"participant_id": ids, "rt": np.random.rand(12)})
+ df_sorted = df.sort_values("participant_id").reset_index(drop=True)
+ n_p, n_t = validate_balanced_panel(df_sorted)
+ assert n_p == 3
+ assert n_t == 4
diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py
new file mode 100644
index 000000000..973061c73
--- /dev/null
+++ b/tests/test_rlssm.py
@@ -0,0 +1,329 @@
+"""Tests for the RLSSM class.
+
+Mirrors the structure of tests/test_hssm.py, covering initialisation,
+config validation, param keys, balanced-panel enforcement, the no-lapse
+variant, bambi / PyMC model construction, and a sampling smoke test.
+"""
+
+from collections.abc import Generator
+from pathlib import Path
+
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+import pytensor
+import pytest
+
+import hssm
+from hssm.rl import RLSSM, RLSSMConfig
+from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise
+from hssm.utils import annotate_function
+
+# ---------------------------------------------------------------------------
+# Module-level annotated helpers (shared by all tests)
+# ---------------------------------------------------------------------------
+
+# Annotate the RL learning function: maps
+# (rl_alpha, scaler, response, feedback) -> v
+_compute_v_annotated = annotate_function(
+ inputs=["rl_alpha", "scaler", "response", "feedback"],
+ outputs=["v"],
+)(compute_v_subject_wise)
+
+
+# Annotated SSM log-likelihood function (simplified for testing).
+# It receives a 2-D lan_matrix whose columns correspond to
+# [v, a, z, t, theta, rt, response]
+# and returns per-trial log-probabilities of shape (n_total_trials,).
+@annotate_function(
+ inputs=["v", "a", "z", "t", "theta", "rt", "response"],
+ outputs=["logp"],
+ computed={"v": _compute_v_annotated},
+)
+def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray:
+ """Return per-trial log-probabilities (column-sum); structural tests only."""
+ # Return 1D (N,) — PyTensor declares the Op output as pt.vector(), so
+ # gradients arrive as (N,). A (N,1) return causes a VJP shape mismatch.
+ return jnp.sum(lan_matrix, axis=1)
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(scope="module", autouse=True)
+def _set_floatx_float32() -> Generator[None, None, None]:
+ """Ensure float32 is used for this module's tests, then restore previous setting."""
+ prev_floatx = pytensor.config.floatX
+ hssm.set_floatX("float32", update_jax=True)
+ try:
+ yield
+ finally:
+ hssm.set_floatX(prev_floatx, update_jax=True)
+
+
+@pytest.fixture(scope="module")
+def rldm_data() -> pd.DataFrame:
+ """Load the RLDM fixture dataset (balanced panel)."""
+ raw = np.load(
+ Path(__file__).parent / "fixtures" / "rldm_data.npy", allow_pickle=True
+ ).item()
+ return pd.DataFrame(raw["data"])
+
+
+@pytest.fixture(scope="module")
+def rlssm_config() -> RLSSMConfig:
+ """Minimal but valid RLSSMConfig for the RLDM fixture dataset."""
+ return RLSSMConfig(
+ model_name="rldm_test",
+ loglik_kind="approx_differentiable",
+ decision_process="angle",
+ decision_process_loglik_kind="approx_differentiable",
+ learning_process_kind="blackbox",
+ list_params=["rl_alpha", "scaler", "a", "theta", "t", "z"],
+ params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],
+ bounds={
+ "rl_alpha": (0.0, 1.0),
+ "scaler": (0.0, 10.0),
+ "a": (0.1, 3.0),
+ "theta": (-0.1, 0.1),
+ "t": (0.001, 1.0),
+ "z": (0.1, 0.9),
+ },
+ learning_process={"v": _compute_v_annotated},
+ response=["rt", "response"],
+ choices=[0, 1],
+ extra_fields=["feedback"],
+ ssm_logp_func=_dummy_ssm_logp,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Initialisation & config-validation tests
+# ---------------------------------------------------------------------------
+
+
+def test_rlssm_init(rldm_data, rlssm_config) -> None:
+ """Basic RLSSM initialisation should succeed and return an RLSSM instance."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ assert isinstance(model, RLSSM)
+ assert model.model_config.model_name == "rldm_test"
+
+
+def test_rlssm_panel_attrs(rldm_data, rlssm_config) -> None:
+ """n_participants and n_trials should match the fixture data structure."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+
+ n_participants = rldm_data["participant_id"].nunique()
+ n_trials = len(rldm_data) // n_participants
+
+ assert model.n_participants == n_participants
+ assert model.n_trials == n_trials
+
+
+def test_rlssm_params_keys(rldm_data, rlssm_config) -> None:
+ """model.params should contain exactly list_params + p_outlier."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ expected = set(rlssm_config.list_params) | {"p_outlier"}
+ assert set(model.params.keys()) == expected
+
+
+def test_rlssm_unbalanced_raises(rldm_data, rlssm_config) -> None:
+ """Dropping one row should make the panel unbalanced → ValueError."""
+ unbalanced = rldm_data.iloc[:-1].copy()
+ with pytest.raises(ValueError, match="balanced panels"):
+ RLSSM(data=unbalanced, model_config=rlssm_config)
+
+
+def test_rlssm_nan_participant_id_raises(rldm_data, rlssm_config) -> None:
+ """NaN in participant_id column should raise ValueError before groupby silently drops rows."""
+ nan_data = rldm_data.copy()
+ nan_data.loc[nan_data.index[0], "participant_id"] = float("nan")
+ with pytest.raises(ValueError, match="NaN"):
+ RLSSM(data=nan_data, model_config=rlssm_config)
+
+
+def test_rlssm_missing_ssm_logp_func_raises(rldm_data, rlssm_config) -> None:
+ """RLSSMConfig without ssm_logp_func should raise ValueError on init."""
+ bad_config = RLSSMConfig(
+ model_name="rldm_bad",
+ loglik_kind="approx_differentiable",
+ decision_process="angle",
+ decision_process_loglik_kind="approx_differentiable",
+ learning_process_kind="blackbox",
+ list_params=rlssm_config.list_params,
+ params_default=rlssm_config.params_default,
+ bounds=rlssm_config.bounds,
+ learning_process=rlssm_config.learning_process,
+ response=list(rlssm_config.response),
+ choices=list(rlssm_config.choices),
+ extra_fields=list(rlssm_config.extra_fields),
+ # ssm_logp_func intentionally omitted → defaults to None
+ )
+ with pytest.raises(ValueError, match="ssm_logp_func"):
+ RLSSM(data=rldm_data, model_config=bad_config)
+
+
+def test_rlssm_unannotated_ssm_logp_func_raises(rldm_data, rlssm_config) -> None:
+ """A plain callable without @annotate_function attrs should raise ValueError."""
+ bad_config = RLSSMConfig(
+ model_name="rldm_bad",
+ loglik_kind="approx_differentiable",
+ decision_process="angle",
+ decision_process_loglik_kind="approx_differentiable",
+ learning_process_kind="blackbox",
+ list_params=rlssm_config.list_params,
+ params_default=rlssm_config.params_default,
+ bounds=rlssm_config.bounds,
+ learning_process=rlssm_config.learning_process,
+ response=list(rlssm_config.response),
+ choices=list(rlssm_config.choices),
+ extra_fields=list(rlssm_config.extra_fields),
+ ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed
+ )
+ with pytest.raises(ValueError, match="annotate_function"):
+ RLSSM(data=rldm_data, model_config=bad_config)
+
+
+def test_rlssm_missing_data_raises(rldm_data, rlssm_config) -> None:
+ """Passing missing_data!=False should raise NotImplementedError with 'missing_data' in msg."""
+ with pytest.raises(NotImplementedError, match="missing_data"):
+ RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True)
+
+
+def test_rlssm_deadline_raises(rldm_data, rlssm_config) -> None:
+ """Passing deadline!=False should raise NotImplementedError with 'deadline' in msg."""
+ with pytest.raises(NotImplementedError, match="deadline"):
+ RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True)
+
+
+# ---------------------------------------------------------------------------
+# Model-structure tests
+# ---------------------------------------------------------------------------
+
+
+def test_rlssm_params_is_trialwise_aligned(rldm_data, rlssm_config) -> None:
+ """params_is_trialwise must align with list_params (same length, p_outlier=False)."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ assert model.model_config.list_params is not None
+ params_is_trialwise = [
+ name != "p_outlier" for name in model.model_config.list_params
+ ]
+ assert len(params_is_trialwise) == len(model.model_config.list_params)
+ for name, is_tw in zip(model.model_config.list_params, params_is_trialwise):
+ if name == "p_outlier":
+ assert not is_tw, "p_outlier must be non-trialwise"
+ else:
+ assert is_tw, f"{name} must be trialwise"
+
+
+def test_rlssm_get_prefix(rldm_data, rlssm_config) -> None:
+ """_get_prefix must use token-based matching, not substring search.
+
+ - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param)
+ - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring)
+ - 'a_Intercept' → 'a' (single-token standard param)
+ """
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha"
+ assert model._get_prefix("p_outlier_log__") == "p_outlier"
+ assert model._get_prefix("p_outlier") == "p_outlier"
+ assert model._get_prefix("a_Intercept") == "a"
+ # Fallback: not in params
+ assert model._get_prefix("unknown_param") == "unknown_param"
+
+
+def test_rlssm_no_lapse(rldm_data, rlssm_config) -> None:
+ """Setting p_outlier=None should remove p_outlier from params."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None)
+ assert "p_outlier" not in model.params
+
+
+def test_rlssm_model_built(rldm_data, rlssm_config) -> None:
+ """The bambi model should be built and the computed param 'v' absent from params."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ assert model.model is not None
+ # rl_alpha is a free (sampled) parameter
+ assert "rl_alpha" in model.params
+ # v is computed inside the Op; it must NOT appear as a free parameter
+ assert "v" not in model.params
+
+
+def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None:
+ """extra_fields passed to make_distribution must be independent numpy copies.
+
+ to_numpy(copy=True) should return a new buffer; if it returned a view,
+ in-place mutations of the DataFrame would silently corrupt the distribution.
+ """
+ from unittest.mock import patch
+
+ from hssm.distribution_utils import make_distribution as real_make_distribution
+
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ captured: dict = {}
+
+ def capturing_make_distribution(*args, **kwargs):
+ captured["extra_fields"] = kwargs.get("extra_fields")
+ return real_make_distribution(*args, **kwargs)
+
+ with patch(
+ "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution
+ ):
+ model._make_model_distribution()
+
+ assert captured.get("extra_fields") is not None
+ for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]):
+ original = model.data[field_name].to_numpy()
+ assert not np.shares_memory(arr, original), (
+ f"extra_fields['{field_name}'] shares memory with the DataFrame — "
+ "it is a view, not a copy"
+ )
+
+
+def test_rlssm_pymc_model(rldm_data, rlssm_config) -> None:
+ """pymc_model should be accessible after model construction."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ assert model.pymc_model is not None
+
+
+# ---------------------------------------------------------------------------
+# Slow sampling smoke test
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.slow
+def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None:
+ """Minimal sampling run should return an InferenceData object."""
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ trace = model.sample(
+ draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9
+ )
+ assert trace is not None
+
+
+def test_rlssm_pickle_round_trip(
+ rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig
+) -> None:
+ """cloudpickle round-trip must reconstruct an equivalent RLSSM.
+
+ Verifies that __getstate__ / __setstate__ survive serialisation:
+ - The reconstructed object is a fresh RLSSM (not the same instance).
+ - n_participants and n_trials are preserved.
+ - list_params (including p_outlier) are preserved.
+ - model_config.model_name is preserved.
+ - model.model (bambi model) is rebuilt, confirming full re-initialisation.
+ """
+ import cloudpickle
+
+ model = RLSSM(data=rldm_data, model_config=rlssm_config)
+ blob = cloudpickle.dumps(model)
+ restored = cloudpickle.loads(blob)
+
+ assert restored is not model
+ assert isinstance(restored, RLSSM)
+ assert restored.n_participants == model.n_participants
+ assert restored.n_trials == model.n_trials
+ assert restored.list_params == model.list_params
+ assert restored.model_config.model_name == model.model_config.model_name
+ assert restored.model is not None
diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py
index cd80d9d40..879ae3e77 100644
--- a/tests/test_rlssm_config.py
+++ b/tests/test_rlssm_config.py
@@ -1,8 +1,14 @@
import pytest
import hssm
-from hssm.config import Config, RLSSMConfig
-from hssm.config import ModelConfig
+from hssm.config import (
+ DEFAULT_SSM_CHOICES,
+ DEFAULT_SSM_OBSERVED_DATA,
+ Config,
+ ModelConfig,
+)
+from hssm.rl import RLSSMConfig
+from hssm.utils import annotate_function
# Define constants for repeated data structures
DEFAULT_RESPONSE = ("rt", "response")
@@ -16,6 +22,13 @@
}
+# Module-level annotated dummy used wherever from_rlssm_dict needs a valid
+# ssm_logp_func but the test is not about ssm_logp_func itself.
+@annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={})
+def _module_dummy_ssm_logp(x):
+ return x
+
+
# Helper function to create a config dictionary
def create_config_dict(
model_name,
@@ -28,7 +41,8 @@ def create_config_dict(
learning_process={},
decision_process="ddm",
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
+ ssm_logp_func=_module_dummy_ssm_logp,
):
return dict(
model_name=model_name,
@@ -43,7 +57,8 @@ def create_config_dict(
learning_process=learning_process,
decision_process=decision_process,
decision_process_loglik_kind=decision_process_loglik_kind,
- learning_process_loglik_kind=learning_process_loglik_kind,
+ learning_process_kind=learning_process_kind,
+ ssm_logp_func=ssm_logp_func,
data={},
)
@@ -51,17 +66,23 @@ def create_config_dict(
# region fixtures and helpers
@pytest.fixture
def valid_rlssmconfig_kwargs():
+ @annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={})
+ def _dummy_ssm_logp_func(x):
+ return x
+
return dict(
model_name="test_model",
list_params=["alpha", "beta"],
params_default=[0.5, 0.3],
+ bounds={"alpha": (0.0, 1.0), "beta": (0.0, 1.0)},
decision_process="ddm",
response=["rt", "response"],
choices=[0, 1],
extra_fields=["feedback"],
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
learning_process={},
+ ssm_logp_func=_dummy_ssm_logp_func,
)
@@ -149,7 +170,7 @@ def test_from_rlssm_dict_cases(
expected_choices,
expected_learning_process,
):
- config = RLSSMConfig.from_rlssm_dict(model_name, config_dict)
+ config = RLSSMConfig.from_rlssm_dict(config_dict)
assert config.model_name == expected_model_name
assert config.params_default == expected_params_default
assert config.bounds == expected_bounds
@@ -164,8 +185,9 @@ class TestRLSSMConfigValidation:
[
("response", None, "Please provide `response` columns"),
("list_params", None, "Please provide `list_params"),
- ("choices", None, "Please provide `choices"),
- ("decision_process", None, "Please specify a `decision_process"),
+ ("choices", None, "Please provide `choices`"),
+ ("decision_process", None, "Please specify a `decision_process`"),
+ ("ssm_logp_func", None, "Please provide `ssm_logp_func`"),
],
)
def test_validate_missing_fields(
@@ -184,7 +206,7 @@ def test_validate_missing_fields(
"params_default",
"decision_process",
"decision_process_loglik_kind",
- "learning_process_loglik_kind",
+ "learning_process_kind",
"learning_process",
],
)
@@ -201,17 +223,12 @@ def test_validate_success(self, valid_rlssmconfig_kwargs):
config = RLSSMConfig(**valid_rlssmconfig_kwargs)
config.validate()
- def test_validate_params_default_mismatch(self):
+ def test_validate_params_default_mismatch(self, valid_rlssmconfig_kwargs):
config = RLSSMConfig(
- model_name="test_model",
- list_params=["alpha", "beta"],
- params_default=[0.5],
- decision_process="ddm",
- response=["rt", "response"],
- choices=[0, 1],
- decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
- learning_process={},
+ **{
+ **valid_rlssmconfig_kwargs,
+ "params_default": [0.5], # length 1, but list_params has 2 entries
+ }
)
with pytest.raises(
ValueError,
@@ -219,21 +236,84 @@ def test_validate_params_default_mismatch(self):
):
config.validate()
+ def test_validate_ssm_logp_func_not_callable(self, valid_rlssmconfig_kwargs):
+ config = RLSSMConfig(**valid_rlssmconfig_kwargs)
+ config.ssm_logp_func = "not_a_callable"
+ with pytest.raises(ValueError, match="must be a callable"):
+ config.validate()
+
+ def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwargs):
+ config = RLSSMConfig(**valid_rlssmconfig_kwargs)
+ # Replace with a plain callable that lacks @annotate_function attributes
+ config.ssm_logp_func = lambda x: x
+ with pytest.raises(
+ ValueError, match="must be decorated with `@annotate_function`"
+ ):
+ config.validate()
+
+ def test_validate_ssm_logp_func_computed_not_callable(
+ self, valid_rlssmconfig_kwargs
+ ):
+ """`computed` exists but contains non-callable values -> error."""
+ config = RLSSMConfig(**valid_rlssmconfig_kwargs)
+ # Inject a computed mapping with a non-callable value to trigger the
+ # specific validation branch.
+ config.ssm_logp_func.computed = {"x": "not_callable"}
+ with pytest.raises(
+ ValueError,
+ match=r"`ssm_logp_func.computed` must be a dictionary with callable values\.",
+ ):
+ config.validate()
+
+ def test_validate_missing_bounds_for_param(self, valid_rlssmconfig_kwargs):
+ """validate() should raise early when a param has no bounds entry."""
+ kwargs = {**valid_rlssmconfig_kwargs, "bounds": {}} # strip all bounds
+ config = RLSSMConfig(**kwargs)
+ with pytest.raises(ValueError, match="Missing bounds for parameter"):
+ config.validate()
+
+ def test_from_defaults_raises(self):
+ """RLSSMConfig.from_defaults() must raise NotImplementedError."""
+ with pytest.raises(NotImplementedError, match="from_defaults"):
+ RLSSMConfig.from_defaults("ddm", None)
+
class TestRLSSMConfigDefaults:
@pytest.mark.parametrize(
"list_params, params_default, bounds, param, expected_default, expected_bounds",
[
+ # params_default stores initialisation values, NOT priors.
+ # get_defaults always returns None for the prior so that
+ # prior_settings="safe" can assign priors from bounds.
+ #
+ # Case 1: queried param is present in bounds → bound returned.
(
["alpha", "beta", "gamma"],
[0.5, 0.3, 0.2],
- {"beta": (0.0, 1.0)},
+ {"alpha": (0.0, 1.0), "beta": (0.0, 1.0), "gamma": (0.0, 1.0)},
"beta",
- 0.3,
+ None,
+ (0.0, 1.0),
+ ),
+ # Case 2: queried param is NOT in list_params and NOT in bounds
+ # (e.g. a typo or an extra lookup) → both None.
+ (
+ ["alpha", "beta"],
+ [0.5, 0.3],
+ {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)},
+ "gamma",
+ None,
+ None,
+ ),
+ # Case 3: params_default may be empty; param in bounds → bound returned.
+ (
+ ["alpha", "beta"],
+ [],
+ {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)},
+ "alpha",
+ None,
(0.0, 1.0),
),
- (["alpha", "beta"], [0.5, 0.3], {"alpha": (0.0, 1.0)}, "gamma", None, None),
- (["alpha", "beta"], [], {"alpha": (0.0, 1.0)}, "alpha", None, (0.0, 1.0)),
],
)
def test_get_defaults_cases(
@@ -254,7 +334,7 @@ def test_get_defaults_cases(
response=["rt", "response"],
choices=[0, 1],
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
learning_process={},
)
default_val, bounds_val = config.get_defaults(param)
@@ -262,145 +342,6 @@ def test_get_defaults_cases(
assert bounds_val == expected_bounds
-class TestRLSSMConfigConversion:
- @pytest.mark.parametrize(
- "list_params, params_default, backend, expected_backend, expected_default_priors, raises",
- [
- (
- ["alpha", "beta", "v", "a"],
- [0.5, 0.3, 1.0, 1.5],
- "jax",
- "jax",
- {"alpha": 0.5, "beta": 0.3, "v": 1.0, "a": 1.5},
- None,
- ),
- (["alpha"], [0.5], None, "jax", {"alpha": 0.5}, None),
- (["alpha", "beta"], [], None, "jax", {}, None),
- (["alpha", "beta", "gamma"], [0.5, 0.3], None, None, None, ValueError),
- ],
- )
- def test_to_config_cases(
- self,
- list_params,
- params_default,
- backend,
- expected_backend,
- expected_default_priors,
- raises,
- ):
- rlssm_config = RLSSMConfig(
- model_name="test_model",
- list_params=list_params,
- params_default=params_default,
- decision_process="ddm",
- response=["rt", "response"],
- choices=[0, 1],
- backend=backend,
- decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
- learning_process={},
- )
- if raises:
- with pytest.raises(raises):
- rlssm_config.to_config()
- else:
- config = rlssm_config.to_config()
- assert config.backend == expected_backend
- assert config.default_priors == expected_default_priors
-
- def test_to_config(self):
- rlssm_config = RLSSMConfig(
- model_name="rlwm",
- description="RLWM model",
- list_params=["alpha", "beta", "v", "a"],
- params_default=[0.5, 0.3, 1.0, 1.5],
- bounds={
- "alpha": (0.0, 1.0),
- "beta": (0.0, 1.0),
- "v": (-3.0, 3.0),
- "a": (0.3, 2.5),
- },
- decision_process="ddm",
- response=["rt", "response"],
- choices=[0, 1],
- extra_fields=["feedback"],
- backend="jax",
- decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
- learning_process={},
- )
- config = rlssm_config.to_config()
- assert isinstance(config, Config)
- assert config.model_name == "rlwm"
- assert config.description == "RLWM model"
- assert config.list_params == ["alpha", "beta", "v", "a"]
- assert config.response == ["rt", "response"]
- assert config.choices == [0, 1]
- assert config.extra_fields == ["feedback"]
- assert config.backend == "jax"
- assert config.loglik_kind == "approx_differentiable"
- assert config.bounds == {
- "alpha": (0.0, 1.0),
- "beta": (0.0, 1.0),
- "v": (-3.0, 3.0),
- "a": (0.3, 2.5),
- }
- assert config.default_priors == {
- "alpha": 0.5,
- "beta": 0.3,
- "v": 1.0,
- "a": 1.5,
- }
-
- def test_to_config_defaults_backend(self):
- rlssm_config = RLSSMConfig(
- model_name="test_model",
- list_params=["alpha"],
- params_default=[0.5],
- decision_process="ddm",
- response=["rt", "response"],
- choices=[0, 1],
- decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
- learning_process={},
- )
- config = rlssm_config.to_config()
- assert config.backend == "jax"
-
- def test_to_config_no_defaults(self):
- rlssm_config = RLSSMConfig(
- model_name="test_model",
- list_params=["alpha", "beta"],
- params_default=[],
- decision_process="ddm",
- response=["rt", "response"],
- choices=[0, 1],
- decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
- learning_process={},
- )
- config = rlssm_config.to_config()
- assert config.default_priors == {}
-
- def test_to_config_mismatched_defaults_length(self):
- rlssm_config = RLSSMConfig(
- model_name="test_model",
- list_params=["alpha", "beta", "gamma"],
- params_default=[0.5, 0.3],
- decision_process="ddm",
- response=["rt", "response"],
- choices=[0, 1],
- decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
- learning_process={},
- )
- with pytest.raises(
- ValueError,
- match=r"params_default length \(2\) doesn't match list_params length \(3\)",
- ):
- rlssm_config.to_config()
-
-
class TestRLSSMConfigLearningProcess:
def test_learning_process(self):
config = RLSSMConfig(
@@ -412,7 +353,7 @@ def test_learning_process(self):
choices=[0, 1],
learning_process={"v": v_func, "a": a_func},
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
)
assert "v" in config.learning_process
assert "a" in config.learning_process
@@ -429,7 +370,7 @@ def test_immutable_defaults(self):
choices=[0, 1],
learning_process={"v": v_func},
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
)
config2 = RLSSMConfig(
model_name="model2",
@@ -440,7 +381,7 @@ def test_immutable_defaults(self):
choices=[0, 1],
learning_process={"a": a_func},
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
)
config1.learning_process["v"] = "function1"
assert "v" not in config2.learning_process
@@ -457,18 +398,40 @@ def test_from_rlssm_dict_missing_required(self):
"params_default": [0.0],
"decision_process": "ddm",
"learning_process": {},
- "learning_process_loglik_kind": "blackbox",
+ "learning_process_kind": "blackbox",
"response": ["rt", "response"],
"choices": [0, 1],
"description": "desc",
"bounds": {},
"data": {},
"extra_fields": [],
+ "ssm_logp_func": _module_dummy_ssm_logp,
}
with pytest.raises(
ValueError, match="decision_process_loglik_kind must be provided"
):
- RLSSMConfig.from_rlssm_dict("test_model", config_dict)
+ RLSSMConfig.from_rlssm_dict(config_dict)
+
+ def test_from_rlssm_dict_missing_ssm_logp_func(self):
+ # Should raise ValueError at construction time if ssm_logp_func is missing
+ config_dict = {
+ "model_name": "test_model",
+ "name": "test_model",
+ "list_params": ["alpha"],
+ "params_default": [0.0],
+ "decision_process": "ddm",
+ "learning_process": {},
+ "learning_process_kind": "blackbox",
+ "decision_process_loglik_kind": "analytical",
+ "response": ["rt", "response"],
+ "choices": [0, 1],
+ "description": "desc",
+ "bounds": {},
+ "data": {},
+ "extra_fields": [],
+ }
+ with pytest.raises(ValueError, match="ssm_logp_func must be provided"):
+ RLSSMConfig.from_rlssm_dict(config_dict)
def test_missing_decision_process_loglik_kind(self):
with pytest.raises(TypeError):
@@ -488,15 +451,16 @@ def test_missing_decision_process_loglik_kind(self):
"data": {},
"decision_process": "ddm",
"learning_process": {},
- "learning_process_loglik_kind": "blackbox",
+ "learning_process_kind": "blackbox",
"response": ["rt", "response"],
"choices": [0, 1],
"extra_fields": [],
+ "ssm_logp_func": _module_dummy_ssm_logp,
}
with pytest.raises(
ValueError, match="decision_process_loglik_kind must be provided"
):
- RLSSMConfig.from_rlssm_dict("test_model", config_dict)
+ RLSSMConfig.from_rlssm_dict(config_dict)
def test_with_modelconfig_decision_process(self):
decision_config = ModelConfig(
@@ -512,6 +476,57 @@ def test_with_modelconfig_decision_process(self):
response=["rt", "response"],
choices=[0, 1],
decision_process_loglik_kind="analytical",
- learning_process_loglik_kind="blackbox",
+ learning_process_kind="blackbox",
learning_process={},
)
+
+
+class TestRLSSMConfigDefaultWarnings:
+ """Warnings are emitted when 'response' or 'choices' are missing from config_dict."""
+
+ @pytest.fixture
+ def _base_config_dict(self):
+ return {
+ "model_name": "test_model",
+ "list_params": ["alpha"],
+ "params_default": [0.0],
+ "bounds": {"alpha": (0.0, 1.0)},
+ "decision_process": "ddm",
+ "learning_process": {},
+ "learning_process_kind": "blackbox",
+ "decision_process_loglik_kind": "analytical",
+ "extra_fields": [],
+ "ssm_logp_func": _module_dummy_ssm_logp,
+ }
+
+ def test_warns_when_response_missing(self, _base_config_dict, caplog):
+ _base_config_dict["choices"] = (0, 1)
+ # 'response' deliberately omitted
+ with caplog.at_level("WARNING", logger="hssm"):
+ config = RLSSMConfig.from_rlssm_dict(_base_config_dict)
+ assert any("'response' not specified" in m for m in caplog.messages)
+ assert config.response == list(DEFAULT_SSM_OBSERVED_DATA)
+
+ def test_warns_when_choices_missing(self, _base_config_dict, caplog):
+ _base_config_dict["response"] = ["rt", "response"]
+ # 'choices' deliberately omitted
+ with caplog.at_level("WARNING", logger="hssm"):
+ config = RLSSMConfig.from_rlssm_dict(_base_config_dict)
+ assert any("'choices' not specified" in m for m in caplog.messages)
+ assert config.choices == DEFAULT_SSM_CHOICES
+
+ def test_warns_when_both_missing(self, _base_config_dict, caplog):
+ # Both 'response' and 'choices' omitted
+ with caplog.at_level("WARNING", logger="hssm"):
+ config = RLSSMConfig.from_rlssm_dict(_base_config_dict)
+ assert any("'response' not specified" in m for m in caplog.messages)
+ assert any("'choices' not specified" in m for m in caplog.messages)
+ assert config.response == list(DEFAULT_SSM_OBSERVED_DATA)
+ assert config.choices == DEFAULT_SSM_CHOICES
+
+ def test_no_warning_when_both_provided(self, _base_config_dict, caplog):
+ _base_config_dict["response"] = ["rt", "response"]
+ _base_config_dict["choices"] = (0, 1)
+ with caplog.at_level("WARNING", logger="hssm"):
+ RLSSMConfig.from_rlssm_dict(_base_config_dict)
+ assert not any("not specified" in m for m in caplog.messages)
diff --git a/tests/test_save_load.py b/tests/test_save_load.py
index b7b48d57f..60244cf7d 100644
--- a/tests/test_save_load.py
+++ b/tests/test_save_load.py
@@ -10,7 +10,9 @@ def compare_hssm_class_attributes(model_a, model_b):
b = np.array([type(v) for k, v in model_b._init_args.items()])
assert (a == b).all(), "Init arg types not the same"
assert (model_a.data).equals(model_b.data), "Data not the same"
- assert model_a.model_name == model_b.model_name, "Model name not the same"
+ assert model_a.model_config.model_name == model_b.model_config.model_name, (
+ "Model name not the same"
+ )
assert model_a.pymc_model._repr_latex_() == model_b.pymc_model._repr_latex_(), (
"Latex representation of model not the same"
)