From b8a8adba13c58b42f3615308915630420701ae2b Mon Sep 17 00:00:00 2001 From: AndrewZhang599 Date: Sun, 3 May 2026 18:44:47 -0400 Subject: [PATCH 1/4] Plan for HSSM integration --- addm_andrew_dev /addm_hssm.md | 258 ++++++++++++++++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 addm_andrew_dev /addm_hssm.md diff --git a/addm_andrew_dev /addm_hssm.md b/addm_andrew_dev /addm_hssm.md new file mode 100644 index 00000000..405a2a84 --- /dev/null +++ b/addm_andrew_dev /addm_hssm.md @@ -0,0 +1,258 @@ +# Plan: Integrating aDDM into HSSM + +## Context + +The attentional drift diffusion model (aDDM; Krajbich et al.) extends the standard DDM by modulating the drift rate based on **which option the subject is currently fixating**. A fast, differentiable JAX likelihood for the aDDM has been prototyped in the sibling repo [efficient-fpt](data/azhang/efficient-fpt) — specifically `get_addm_fptd_jax_fast` in [src/efficient_fpt_jax/multi_stage.py](data/azhang/efficient-fpt/src/efficient_fpt_jax/multi_stage.py). + +**Integration approach: vendor, do not depend.** Rather than add `efficient-fpt` as a dependency (it is not on PyPI, it ships compiled Cython for the simulator path that HSSM does not need, and bringing it in pulls a heavy build chain into HSSM's install), we **copy the relevant pure-JAX modules into HSSM** and own them going forward. HSSM already depends on `jax`/`jaxlib`, so the vendored code adds zero new transitive dependencies. The simulator (Cython) and NumPy/CPU paths from efficient-fpt are *not* vendored; only the `efficient_fpt_jax` subpackage's likelihood code. + +The goal is to wire the aDDM into HSSM so that users can write: + +```python +model = hssm.HSSM(model="addm", data=addm_trial_df, include=[...]) +model.sample() +``` + +where `addm_trial_df` contains standard columns (`rt`, `response`) plus **aDDM-specific per-trial arrays** (item values, fixation onsets, fixation counts, first-fixation flag). The aDDM needs per-trial covariates that are *not* themselves sampled parameters — exactly the pattern RLSSM already solves in HSSM. We therefore follow the RLSSM design so that aDDM lives alongside it rather than carving a new architectural lane. + +The intended outcome is a working `model="addm"` path inside HSSM that (a) validates aDDM-specific trial data, (b) composes the vendored JAX FPT likelihood with sampled parameters `{eta, kappa, sigma, a, b, x0, t}` (non-decision time optional), (c) exposes the standard HSSM hierarchical regression and sampling machinery, and (d) ships with a tutorial notebook and unit tests. + +## Design choice: config pattern, not subclass + +The plan creates an **`aDDMConfig`** dataclass plus a small submodule — no new `aDDM(HSSM)` subclass is introduced. (If the user prefers an explicit subclass for API discoverability, a thin `class aDDM(HSSM)` wrapper can be added on top.) + +--- + +## Step-by-step plan + +### Step 1 — Vendor the JAX likelihood code into HSSM + +**Source files to copy (from efficient-fpt):** + +| Source (efficient-fpt) | Destination (HSSM) | Purpose | +|---|---|---| +| `src/efficient_fpt_jax/multi_stage.py` | `src/hssm/addm/likelihoods/jax/multi_stage.py` | `get_addm_fptd_jax_fast`, `pad_sacc_array_safely` | +| `src/efficient_fpt_jax/single_stage.py` | `src/hssm/addm/likelihoods/jax/single_stage.py` | `fptd_single_jax`, `q_single_jax` (called by multi_stage) | +| `src/efficient_fpt_jax/utils.py` | `src/hssm/addm/likelihoods/jax/utils.py` | `GAUSS_LEGENDRE_30_X`, `GAUSS_LEGENDRE_30_W` quadrature constants | + +**Action:** + +1. Copy the three files above verbatim into a new `src/hssm/addm/likelihoods/jax/` package. +2. Update relative imports inside the copied files so they resolve within `hssm.addm.likelihoods.jax` (e.g. `from .single_stage import ...`, `from .utils import ...` — these are already relative in the source, so no change needed in practice; verify). +3. Create `src/hssm/addm/likelihoods/jax/__init__.py` exposing only the symbols HSSM needs: + ```python + from .multi_stage import get_addm_fptd_jax_fast, pad_sacc_array_safely + from .single_stage import fptd_single_jax, q_single_jax + from .utils import GAUSS_LEGENDRE_30_X, GAUSS_LEGENDRE_30_W + ``` +4. Add a header comment to each vendored file recording the upstream commit hash from `efficient-fpt` so future maintainers can diff against upstream when bug fixes land there. +5. **Do not** vendor `efficient_fpt_jax/batch.py` (not used by the per-trial likelihood we wrap), nor anything from the `efficient_fpt` (Cython/NumPy) subpackage — that path is the simulator and is not part of inference. +6. **Do not** add `efficient-fpt` to `pyproject.toml`. HSSM already depends on `jax`/`jaxlib`, so the vendored code introduces no new dependencies. + +**License/attribution:** efficient-fpt ships under a permissive license (see `efficient-fpt/LICENSE`); copy that license text into `src/hssm/addm/likelihoods/jax/LICENSE` (or `NOTICE`) so attribution travels with the code. If HSSM and efficient-fpt share authors/license, a brief `# Adapted from efficient-fpt (commit )` header is sufficient. + +**Rationale:** efficient-fpt is not on PyPI; its Cython compile chain is heavyweight and irrelevant to HSSM (HSSM only needs the inference likelihood, not the simulator); and pinning to a remote git dep would couple HSSM's CI to an unstable upstream. Vendoring lets HSSM ship a frozen, audited copy that evolves on HSSM's release cadence. + +**Drift management:** efficient-fpt continues to be the research home for the likelihood. When upstream changes, the vendored copy can be re-synced by re-copying the three files and rebuilding tests. The upstream-commit header in step 4 makes "what version are we on?" trivially answerable. (This is the same pattern HSSM already uses for `bayesflow` from a dev branch, just snapshotted instead of git-tracked.) + +### Step 2 — Define the aDDM submodule layout + +Create a new package under `src/hssm/addm/`, mirroring `src/hssm/rl/`: + +``` +src/hssm/addm/ + __init__.py + likelihoods/ + __init__.py + builder.py # make_addm_logp_func / make_addm_logp_op + addm_jax.py # thin wrapper that imports from .jax and applies the attention process + jax/ # vendored from efficient_fpt_jax (Step 1) + __init__.py + multi_stage.py + single_stage.py + utils.py + attention_process.py # pluggable fixation/attention models +``` + +**Rationale:** aDDM is conceptually a two-stage model (attention process → SSM likelihood) just like RLSSM (learning process → SSM likelihood). Reusing the folder layout makes the parallel obvious to future maintainers. The vendored JAX code lives in its own `jax/` subdirectory so it stays clearly identifiable as upstream-derived, isolated from HSSM-original code in `builder.py` and `addm_jax.py`. + +### Step 3 — Add `aDDMConfig` dataclass in `config.py` + +**Critical file:** [src/hssm/config.py](data/azhang/HSSM/src/hssm/config.py) — add a new dataclass beneath `RLSSMConfig` (around line 457). + +```python +@dataclass +class aDDMConfig(BaseModelConfig): + """Config for the attentional DDM.""" + model_name: str = "addm" + list_params: list[str] = field( + default_factory=lambda: ["eta", "kappa", "sigma", "a", "b", "x0"] + ) + params_default: list[float] = field( + default_factory=lambda: [0.3, 1.0, 1.0, 2.0, 0.0, 0.0] + ) + response: list[str] = field(default_factory=lambda: ["rt", "response"]) + choices: tuple[int, ...] = (-1, 1) + # trial-level covariates consumed by the attention process: + extra_fields: list[str] | None = field( + default_factory=lambda: ["r1", "r2", "sacc_array", "d", "flag"] + ) + bounds: dict[str, tuple[float, float]] = field(default_factory=dict) + loglik_kind: str = "approx_differentiable" + attention_process: str | Callable = "standard_alternating" + description: str | None = "Attentional Drift Diffusion Model" + + def to_config(self) -> Config: ... +``` + +**Key design decisions:** +- `extra_fields` defaults to the five aDDM-specific columns that the JAX likelihood needs. These are **not** sampled parameters — they come from the data. +- `attention_process` is a pluggable hook (default `"standard_alternating"`) that maps `(r1, r2, flag, eta, kappa) → mu_array_padded` per trial. This mirrors `RLSSMConfig.learning_process`. +- `list_params` covers the sampled parameters. Non-decision time `t` is deliberately omitted initially; it can be added later via shifted RTs. +- `.to_config()` builds a standard HSSM `Config` pointing at the new likelihood op from Step 4, so downstream `HSSM.__init__` behavior is unchanged. + +**Reuses:** `BaseModelConfig` (config.py:48), `Config.from_defaults` registration flow (config.py:96–145), `register_model` (register.py:16–60). + +### Step 4 — Build the likelihood op in `addm/likelihoods/builder.py` + +Mirror [hssm/rl/likelihoods/builder.py](data/azhang/HSSM/src/hssm/rl/likelihoods/builder.py) (`make_rl_logp_func`, `make_rl_logp_op`). + +Two functions: + +1. **`make_addm_logp_func(attention_process)`** — returns a callable `logp(data, *args)` where: + - `data[:, 0]` = rt, `data[:, 1]` = response + - `args` are sampled parameters in `list_params` order: `eta, kappa, sigma, a, b, x0` + - extra fields `r1, r2, sacc_array, d, flag` are appended to `args` by the HSSM extra-field machinery (exactly as RLSSM does; see [data_validator.py:156](data/azhang/HSSM/src/hssm/data_validator.py#L156)). + - Internally: call the attention process to build `mu_array_padded`, then call `get_addm_fptd_jax_fast(t=rt, d=d, mu_array=mu_array, sacc_array=sacc_array, sigma, a, b, x0)` vmapped over trials. + +2. **`make_addm_logp_op(attention_process)`** — wraps the JAX logp as a PyTensor `Op` with VJP, using the same pattern as `make_rl_logp_op`. This gives NUTS gradients for free. + +**Reuses:** +- `get_addm_fptd_jax_fast` and `pad_sacc_array_safely` — imported from the vendored `hssm.addm.likelihoods.jax` package (Step 1). +- `make_likelihood_callable` ([distribution_utils/dist.py:718](data/azhang/HSSM/src/hssm/distribution_utils/dist.py)) for PyTensor wrapping conventions. +- `apply_param_bounds_to_loglik` ([distribution_utils/dist.py:40–79](data/azhang/HSSM/src/hssm/distribution_utils/dist.py)) for parameter-bound enforcement. + +**Import contract:** `builder.py` imports the vendored likelihood as `from hssm.addm.likelihoods.jax import get_addm_fptd_jax_fast, pad_sacc_array_safely`. No code outside `hssm.addm` should import from the vendored `jax/` subpackage directly — keeping the import surface narrow makes a future re-vendor (or replacement with an upstream PyPI release, should one ever appear) a single-file change. + +### Step 5 — Attention process ("learning process" analog) + +File: `src/hssm/addm/attention_process.py`. + +Default implementation `standard_alternating(r1, r2, flag, eta, kappa, max_d) -> mu_padded`: + +``` +mu1 = kappa * (r1 - eta * r2) +mu2 = kappa * (eta * r1 - r2) +# alternate mu1/mu2 by stage parity, respecting `flag` for first fixation +# return shape (n_trials, max_d) +``` + +This reproduces the logic in [efficient-fpt addm.py:_build_mu_data_padded](data/azhang/efficient-fpt/src/efficient_fpt/addm.py) but in JAX for autodiff. Expose it via a registry so future variants (e.g., bias, drift offsets) can be registered by name — the same way `RLSSMConfig.learning_process` accepts either a string or dict. + +### Step 6 — Register `"addm"` as a built-in model + +**Critical files:** +- [src/hssm/modelconfig/](data/azhang/HSSM/src/hssm/modelconfig/) — add `addm_config.py` in the same style as the existing per-model configs (e.g., `ddm_config.py`). +- [src/hssm/defaults.py](data/azhang/HSSM/src/hssm/defaults.py) — register `"addm"` in the default model list so `hssm.HSSM(model="addm", ...)` works out of the box. + +`addm_config.py` returns a dict with `response`, `list_params`, `choices`, `description`, and a `likelihoods` sub-dict keyed `"approx_differentiable"` whose `loglik` points to `make_addm_logp_op(...)` from Step 4 and `extra_fields=["r1","r2","sacc_array","d","flag"]`. + +**Reuses:** `register_model` (register.py:16–60) — already handles the registration flow; we just need to pass the right dict. + +### Step 7 — Data validation for aDDM-specific columns + +**Critical file:** [src/hssm/data_validator.py](data/azhang/HSSM/src/hssm/data_validator.py). + +The DataValidatorMixin currently validates that `extra_fields` columns exist ([line 46](data/azhang/HSSM/src/hssm/data_validator.py#L46)). aDDM also needs **shape validation** because `sacc_array` is a 2D array-of-arrays stored inside a DataFrame column. + +Add an optional `_validate_addm_columns()` method invoked when `model_config.model_name == "addm"`: +- `r1`, `r2`, `d`, `flag` must be 1D numeric with length `n_trials`. +- `sacc_array` must be a 2D array of shape `(n_trials, max_d)` (or a column of variable-length lists, padded internally via `pad_sacc_array_safely`). +- `d[i] <= sacc_array.shape[1]`. + +Minimally invasive: put the hook in `_post_check_data_sanity` and no-op for non-aDDM models. + +### Step 8 — Tests + +New file: `tests/test_addm_config.py`, patterned after [tests/test_rlssm_config.py](data/azhang/HSSM/tests/test_rlssm_config.py): + +1. `TestaDDMConfigCreation` — build `aDDMConfig`, assert defaults. +2. `TestaDDMConfigConversion` — `.to_config()` round-trip. +3. `TestaDDMLikelihood` — tiny synthetic dataset (10 trials), confirm `logp` is finite, gradient w.r.t. each parameter is finite, matches a direct call to `get_addm_fptd_jax_fast`. +4. `TestaDDMEndToEnd` — 200-trial synthetic dataset, `hssm.HSSM(model="addm", ...)` builds, a single MCMC draw succeeds (smoke test, `draws=5, tune=5`). + +**Reuse:** test fixtures from `tests/conftest.py`. + +### Step 9 — Tutorial notebook + +Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tutorials/rlssm_tutorial.ipynb`: +- Load/simulate a small aDDM dataset (reuse `simulate_addm` from efficient-fpt example6). +- Build the HSSM model with `model="addm"`. +- Add a hierarchical regression on `eta` (e.g., by participant) to showcase why using HSSM buys more than raw efficient-fpt. +- Run `model.sample()` and plot posteriors via `arviz`. + +### Step 10 — Cleanup + +- Delete or rename the stale [addm_andrew_dev](data/azhang/HSSM/addm_andrew_dev) folder (it has a trailing space in its name, which is a foot-gun on many filesystems) once the new module is working. +- Update `README.md` example list and `mkdocs.yml` nav to include the new tutorial. + +--- + +## Files to be created + +**HSSM-original code:** +- `src/hssm/addm/__init__.py` +- `src/hssm/addm/attention_process.py` +- `src/hssm/addm/likelihoods/__init__.py` +- `src/hssm/addm/likelihoods/builder.py` +- `src/hssm/addm/likelihoods/addm_jax.py` +- `src/hssm/modelconfig/addm_config.py` +- `tests/test_addm_config.py` +- `docs/tutorials/addm_tutorial.ipynb` + +**Vendored from efficient-fpt (verbatim copies, kept in their own subpackage):** +- `src/hssm/addm/likelihoods/jax/__init__.py` +- `src/hssm/addm/likelihoods/jax/multi_stage.py` ← `efficient_fpt_jax/multi_stage.py` +- `src/hssm/addm/likelihoods/jax/single_stage.py` ← `efficient_fpt_jax/single_stage.py` +- `src/hssm/addm/likelihoods/jax/utils.py` ← `efficient_fpt_jax/utils.py` +- `src/hssm/addm/likelihoods/jax/NOTICE` (or LICENSE) — upstream attribution and commit hash + +## Files to be modified + +- `src/hssm/config.py` — add `aDDMConfig` dataclass. +- `src/hssm/defaults.py` — register `"addm"` in the default model list. +- `src/hssm/data_validator.py` — add aDDM column-shape validation hook. +- `README.md`, `mkdocs.yml` — mention the new model. + +(`pyproject.toml` is **not** modified — no new dependencies are introduced. JAX is already a core dependency.) + +## Key functions/utilities to reuse (no re-implementation) + +| Purpose | Location | +|---|---| +| JAX FPT likelihood | `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` *(vendored)* | +| Safe padding of saccade arrays | `hssm.addm.likelihoods.jax.pad_sacc_array_safely` *(vendored)* | +| Likelihood op wrapping pattern | [hssm/rl/likelihoods/builder.py](data/azhang/HSSM/src/hssm/rl/likelihoods/builder.py) | +| Config → standard Config conversion | [config.RLSSMConfig.to_config](data/azhang/HSSM/src/hssm/config.py#L408) | +| Model registration | [register.register_model](data/azhang/HSSM/src/hssm/register.py#L16) | +| Extra-fields propagation into logp | [data_validator.DataValidatorMixin._update_extra_fields](data/azhang/HSSM/src/hssm/data_validator.py#L156) | +| Param bound enforcement | [distribution_utils.dist.apply_param_bounds_to_loglik](data/azhang/HSSM/src/hssm/distribution_utils/dist.py#L40) | + +--- + +## Verification + +End-to-end checks, in order: + +1. **Unit**: `pytest tests/test_addm_config.py -v` — all four test classes pass, including finite-gradient check. +2. **Likelihood parity**: in `test_addm_config.py::TestaDDMLikelihood`, assert HSSM's wrapped op returns the same value (to 1e-6) as a direct call to the vendored `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` on a 10-trial fixture. This confirms the HSSM extra-fields/op-wrapping plumbing does not corrupt the underlying JAX computation. (A separate, off-CI sanity script may also compare against an installed `efficient-fpt` checkout to detect drift between the vendored copy and upstream.) +3. **Smoke sample**: `hssm.HSSM(model="addm", data=synthetic_trials).sample(draws=5, tune=5)` completes without error and returns an `InferenceData`. +4. **Parameter recovery**: larger off-CI script (e.g., `tests/scripts/addm_recovery.py`) — simulate 1000 trials with known `(eta, kappa, a, b, x0, sigma)`, fit in HSSM, confirm posterior means within ~2σ of ground truth. Reuse the recovery setup from [efficient-fpt example8_empirical/parameter_recovery.ipynb](data/azhang/efficient-fpt/examples/example8_empirical). +5. **Tutorial runs clean**: `jupyter nbconvert --execute docs/tutorials/addm_tutorial.ipynb` finishes without errors. +6. **Docs build**: `mkdocs build` succeeds with the new tutorial in nav. + +## Open questions for the user + +1. **Subclass vs config-only**: confirm the config-pattern approach (no `class aDDM(HSSM)`) is acceptable, or whether a thin `hssm.aDDM` convenience class is desired on top. +2. **Non-decision time `t`**: include in v1 as an additional sampled parameter (shift RTs), or defer? +3. **Attention-process extensibility**: is the default `standard_alternating` enough, or should v1 already expose user-pluggable attention processes (e.g., non-alternating fixation patterns)? From 82812f5309d8619f81380dc763fbf699fcac4026 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Tue, 28 Apr 2026 13:56:41 -0400 Subject: [PATCH 2/4] update to include topic specific content (#951) --- README.md | 14 ++++++++------ docs/index.md | 16 +++++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 6a891c20..3f6baf00 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,15 @@ ### Overview -HSSM is a Python toolbox that provides a seamless combination of +HSSM is an open-source Python toolbox for computational modeling in cognitive +neuroscience. It supports a broad range of sequential sampling models used to +study decision-making, learning, and other cognitive processes — from basic +research to the analysis of clinical effects. Under the hood, HSSM combines state-of-the-art likelihood approximation methods with the wider ecosystem of -probabilistic programming languages. It facilitates flexible hierarchical model -building and inference via modern MCMC samplers. HSSM is user-friendly and -provides the ability to rigorously estimate the impact of neural and other -trial-by-trial covariates through parameter-wise mixed-effects models for a -large variety of cognitive process models. HSSM is a +probabilistic programming to enable flexible hierarchical Bayesian inference via +modern MCMC samplers. It is user-friendly and provides the ability to rigorously +estimate the impact of neural and other trial-by-trial covariates through +parameter-wise mixed-effects models. HSSM is a BRAINSTORM project in collaboration with the Center for Computation and Visualization and the Center for Computational Brain Science within the Carney Institute at Brown University. diff --git a/docs/index.md b/docs/index.md index eb7a337e..ca0a8edd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -13,13 +13,15 @@ ![GitHub Repo stars](https://img.shields.io/github/stars/lnccbrown/HSSM) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) -**HSSM** (Hierarchical Sequential Sampling Modeling) is a modern Python toolbox -that provides state-of-the-art likelihood approximation methods within the -Python Bayesian ecosystem. It facilitates hierarchical model building and -inference via fast and robust MCMC samplers. User-friendly, extensible, and -flexible, HSSM can rigorously estimate the impact of neural and other -trial-by-trial covariates through parameter-wise mixed-effects models for a -large variety of cognitive process models. +**HSSM** (Hierarchical Sequential Sampling Modeling) is a modern open-source +Python toolbox for computational modeling in cognitive neuroscience. It supports +a broad range of sequential sampling models used to study decision-making, +learning, and other cognitive processes — from basic research to the analysis of +clinical effects. HSSM provides state-of-the-art likelihood approximation +methods within the Python Bayesian ecosystem and facilitates hierarchical model +building and inference via fast and robust MCMC samplers. User-friendly, +extensible, and flexible, it can rigorously estimate the impact of neural and +other trial-by-trial covariates through parameter-wise mixed-effects models. HSSM is a [BRAINSTORM](https://ccbs.carney.brown.edu/brainstorm) project in collaboration with the From d58c466b3e3da4dc5d64ae36448e7c7fbe7e40b6 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Tue, 5 May 2026 13:13:12 -0400 Subject: [PATCH 3/4] HSSM base + RLSSM classes (#893) * Copy hssm.py to prepare for base class extraction * Extract HSSM base class to hssmbase.py with refactorings and add tests * refactor: extract init args logic to _get_init_args static method * refactor: reorganize initialization of input data and configuration in HSSM class * refactor: enhance comment clarity for model_config construction in HSSM class * refactor: improve handling of user-provided model_config and choices in HSSM class * refactor: implement model_config construction in a dedicated method * refactor: remove slow marker from multiple test functions in test_hssmbase * refactor: streamline model_config validation and enhance shortcut setup in HSSM class * refactor: enhance type annotation for model_config and add validation for list_params in HSSM class * refactor: replace DataValidator with DataValidatorMixin in HSSM and related tests * refactor: remove unused import of bambi in test_hssmbase * refactor: remove unused import of typing and simplify SupportedModels check in HSSM class * refactor: simplify sample_prior_predictive calls in test_sample_prior_predictive * refactor: correct typo in comments regarding inconsistent dimensions and coordinates * refactor: remove unused variable assignments in test_sample_prior_predictive * refactor: remove unused variable assignment in test_sample_prior_predictive * refactor: remove redundant assignment in sample_prior_predictive test * refactor: simplify HSSM instantiation in custom model tests * refactor: implement parameter initialization in DataValidatorMixin * refactor: assign HSSM instance to variable in test_custom_model * refactor: enhance parameter initialization in DataValidatorMixin and add response handling in HSSM * refactor: update parameter types in DataValidatorMixin constructor * refactor: assign HSSM instance to variable in test_custom_model * refactor: handle None response in response_c and response_str properties * refactor: simplify docstring in DataValidatorMixin class * refactor: remove unused variables in test_sample_prior_predictive * fix: correct typo in classproperty docstring * refactor: update condition to check for None in _update_extra_fields method * refactor: remove unused initialization arguments and related method from HSSM class * rename hssmbase.py to base.py * refactor: rename HSSM class to HSSMBase for clarity and consistency * refactor: replace HSSM with HSSMBase in test cases for consistency * fix: update load_model and state restoration methods to reference HSSMBase instead of HSSM * Make config a class variable * refactor: migrate missing data tests from test_data_validator.py to test_missing_data_mixin.py * test: add parameterized test for handling missing data as bool and float * test: add warning handling for dropping rows when missing_data is False * test: add error handling for invalid missing_data types in MissingDataMixin * test: add tests for deadline handling in MissingDataMixin * test: add additional tests for custom missing data handling and deadline logic in MissingDataMixin * test: refactor tests in MissingDataMixin to use dummy_model fixture for consistency * test: enhance DummyModel and fixtures for improved missing data and deadline handling * feat: integrate MissingDataMixin into HSSM class for enhanced data handling * refactor: move _handle_missing_data_and_deadline method missing data mixin * feat: implement MissingDataMixin for comprehensive handling of missing data and deadlines * feat: extend HSSMBase class with MissingDataMixin for improved data handling * fix: resolve mypy type checking issues in MissingDataMixin for deadline handling * test: mark test_sample_prior_predictive as expected to fail in CI * fix: add missing newline for improved readability in test_hssmbase.py * refactor: replace explicit choices validation with method call * refactor: improve missing data handling and update tests for edge cases * refactor: update tests for MissingDataMixin to handle missing data scenarios * fix: add type ignore for choices length calculation in HSSMBase * test: add comprehensive tests for MissingDataMixin's missing data handling * refactor: streamline missing data and deadline handling using MissingDataMixin * fix: remove uncessary check * refactor: simplify network assignment logic in MissingDataMixin * fix: remove unnecessary initialization of network in MissingDataMixin * refactor: update test structure and improve parameterization in MissingDataMixin tests * refactor: organize code sections with region markers in HSSMBase class * refactor: add region markers for clarity in HSSMBase class methods * feat: make HSSMBase an abstract class and define abstract method for model distribution * feat: refactor HSSM class to inherit from HSSMBase and remove mixins * fix: move data sanity check to the correct position in HSSMBase class * Implement feature X to enhance user experience and fix bug Y in module Z * test: remove obsolete test_hssmbase.py file * refactor: clean up imports in hssm.py for better readability * fix: update prior type hint in fill_defaults and from_defaults methods to include bmb.Prior * fix: update fill_defaults method to include bmb.Prior type hint for prior parameter * fix: add type ignore comments for model.list_params and DefaultParam.from_defaults parameters * fix: update fill_defaults method to include bmb.Prior type hint for prior parameter * fix: replace assertions with ValueError for loglik and list_params validation in HSSM class * refactor: remove unused imports from base.py * fix: update error message for missing list_params in HSSM initialization * fix: add validation for loglik_kind in HSSM class initialization * refactor: update comment style for clarity in _make_model_distribution method * fix: handle None values for response and choices in HSSMBase initialization * fix: streamline exception handling for missing list_params in HSSM initialization * Restore init args so tests pass * fix: update instance creation in HSSMBase to use class reference * refactor: remove extra _set_missing_data_and_deadline method from DataValidatorMixin * refactor: rename test class for clarity in missing data handling * fix: update exception message regex for list_params validation in HSSM * fix: improve error message for unspecified bounds in _make_default_prior function * fix: ensure model_name is retrieved correctly in RLSSMConfig initialization * fix: remove 'data' field from RLSSM_REQUIRED_FIELDS * Use base in HSSM class * Cast choices to list * Fix response assertion in test_from_defaults to use list instead of tuple * Refactor HSSM class to improve parameter handling in likelihood and distribution functions * Update response assertions in test_from_defaults to use lists instead of tuples * Restore hssm.py as in main * Restore param * Restore params * Restore regression_params * Restore simple param * Restore test_hsmm * Fix base for dimensionality problems * Fix mypy bugs * Remove duplicate comment regarding Bambi's kind parameter renaming * Fix RLSSMConfig to require model_name in config_dict * Update docstrings in HSSMBase for clarity on initial values and return types * Fix line too long * Add ssm_logp_func to RLSSMConfig and update validation tests * Add RLSSM model and utilities for reinforcement learning integration * Refactor RLSSM parameter handling and add custom prefix resolution for RL parameters * Add tests for RLSSM class covering initialization, validation, and model structure * Refactor loglik handling in RLSSM to improve type safety with casting * Add NaN value check for participant column in validate_balanced_panel function * Add validation for ssm_logp_func in RLSSMConfig to ensure it is callable and has required attributes * Add exclude rules for ruff and mypy hooks to skip tests directory * Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is callable and properly annotated * Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM * Reject missing data and deadline handling in RLSSM initialization to preserve trial sequence integrity * Add tests to validate error handling for missing data and deadline in RLSSM initialization * Refactor path handling for loading RLDM fixture dataset in tests * Add fixture to set floatX to float32 for module tests * Ensure params_is_trialwise aligns with list_params in RLSSM initialization * Clarify comments on default_priors in ModelConfig and remove unnecessary assertion for list_params * Update RLSSM to use to_numpy(copy=True) for extra_fields and add test for independent copies * Refactor parameter name resolution in RLSSM to handle underscores correctly and improve substring checks * Add test for _get_prefix method in RLSSM to ensure token-based matching * Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter and update tests accordingly * Fix comment in test_rlssm.py to clarify output shape of log-likelihood function * Update RLSSMConfig documentation to mark description as required * Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig initialization * Add dummy ssm_logp_func to tests and validate its presence in RLSSMConfig * Remove unused logging import from rlssm.py * Remove redundant exclude rule for ruff-format in pre-commit configuration * Add to_model_config method to RLSSMConfig for ModelConfig conversion * Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig and simplify Op parameter handling * Integrate Config and RLSSMConfig into HSSM and RLSSM classes for improved configuration handling * Update choices type from list to tuple for consistency in BaseModelConfig and DataValidatorMixin * Update choices type from list to tuple in test_constructor for consistency * Add deprecation warnings for model_config attributes in HSSMBase * Refactor HSSMBase to support BaseModelConfig and improve model_config handling * Add model configuration building methods to BaseModelConfig and Config classes * Refactor model configuration handling in HSSMBase and HSSM classes to delegate config building and improve attribute access * Add properties to BaseModelConfig for parameter and extra field counts * Refactor RLSSM attributes to use public naming convention for configuration and participant/trial counts * Refactor test_rlssm_panel_attrs to use public attributes for participant and trial counts * Refactor HSSMBase to streamline model configuration handling and update initialization parameters * Refactor BaseModelConfig and RLSSMConfig by removing unused abstract methods and adding a new method for building validated Config instances * Refactor HSSM class to remove Config inheritance and add initialization parameters for model configuration * Refactor RLSSM class to remove RLSSMConfig inheritance and streamline model configuration handling * Refactor Config and RLSSMConfig classes to use concrete types in method signatures * Update Config class parameter types for choices to improve type safety * Update choices method to accept a tuple for model_config.choices * Add tests for model configuration handling and choices logic in Config * Enhance HSSMBase initialization with safe default for constructor arguments and explicit error handling for missing snapshot * Update model_config validation to check for non-null choices * Refactor HSSM distribution method to use typed model_config attributes and avoid deprecated proxy properties * Update test cases to use tuples for choices in model configuration * Refactor RLSSM to utilize model_config for list_params and loglik, enhancing type safety and validation * Fix typo in comment regarding model_config choices validation * Refactor RLSSM tests to access model configuration attributes directly, ensuring consistency with updated model_config structure * Update attribute comparison in compare_hssm_class_attributes to use model_config for model_name * Update test assertions to access model configuration attributes directly * Refactor model configuration normalization to streamline choices handling and improve logging * Refactor choices handling in Config class to improve clarity and logging * Refactor _normalize_model_config_with_choices to improve input handling and choices normalization * Refactor likelihood callable construction to simplify logic and enhance clarity * Refactor _make_model_distribution to utilize model_config for loglik and loglik_kind * Fix formatting in HSSM class for consistency in likelihood callable parameters * Fix formatting in HSSM class for consistency in likelihood callable parameters * Refactor HSSM class to use typed model_config attributes directly and resolve loglik * Restore make_model_dist in HSSM * Remove deprecated properties and methods from HSSMBase class * Enhance HSSMBase class to prevent overwriting _init_args if already set in subclasses and exclude additional internal names from locals() snapshots during re-instantiation. * Clarify model_config parameter documentation in HSSMBase class to specify required fields and improve readability. * Enhance HSSMBase class documentation to clarify filtering of internal names in parameter mapping for safe unpickling. * Update model_config parameter documentation in HSSM class to support BaseModelConfig instance and clarify usage of dict for configuration. * Add test to validate external model config fallback in _build_model_config * Update sampling parameters in test_rlssm_sample_smoke for speed * Add RLSSM quickstart notebook for model instantiation and sampling demonstration * Add RLSSM Quickstart tutorial to navigation and plugins * Remove redundant next steps and streamline summary in RLSSM quickstart notebook * Refactor RLSSMConfig methods to simplify parameter handling and remove unused conversion tests * Fix handling of list_params in HSSMBase to ensure proper conversion from None * Refactor RLSSM to inject model configuration directly, removing unnecessary Config conversion * Update TestRLSSMConfigDefaults to reflect None for default parameters instead of fixed values * Refactor RLSSM to inject loglik and backend directly into a new RLSSMConfig instance, preserving the original configuration. * Add validation for missing bounds in RLSSMConfig parameters * Fix RLSSM to use model_config for ssm_logp_func and update test cases for default parameter bounds * Enhance RLSSM tests to align params_is_trialwise with list_params and add pickle round-trip verification * Add test to ensure RLSSMConfig.from_defaults raises NotImplementedError * Clarify RLSSMConfig.from_defaults behavior and raise NotImplementedError for unsupported usage * Inject JAX backend into RLSSMConfig during initialization * Refactor RLSSM class to use model_config instead of rlssm_config for consistency * Fix merge conflicts with base branch * Remove commented out lines * Remove RLSSMConfig import from __init__.py * Reorganize import statements by moving RLSSMConfig import to the correct position * Move RLSSMConfig import to the correct module in test files * Update docstring in __init__.py and exports * Remove RLSSMConfig class and its associated methods from config.py * Move RLSSMConfig class hssm.rl module * Refactor config.py to remove RLSSM-specific defaults and unify observed data constants * Fix formatting of error messages in TestRLSSMConfigValidation for consistency * Enhance validation in RLSSMConfig for ssm_logp_func attributes * Add validation test for non-callable values in ssm_logp_func.computed * Rename 'learning_process_loglik_kind' to 'learning_process_kind' in RLSSMConfig and related tests * Simplify response and list_params assignment in HSSMBase by removing conditional checks * Revert "Simplify response and list_params assignment in HSSMBase by removing conditional checks" This reverts commit 7cf8bca5db9e5f66859a0adafb4d439496a0e624. * Refactor RLSSMConfig to dynamically retrieve required fields for validation * Update RLSSMConfig to handle field exceptions in from_rlssm_dict method * Fix import path for RLSSM and RLSSMConfig; correct learning_process_loglik_kind key in RLSSMConfig; update model instantiation parameter name * Fix merge conflicts * Fix instantiation of HSSMBase in __setstate__ method to use class reference * Fix type hints * Merge base branch cp-main-sb and add empty data check to validate_balanced_panel Agent-Logs-Url: https://github.com/lnccbrown/HSSM/sessions/a86b0208-9ac2-480e-a585-320d6a0e1bbe Co-authored-by: cpaniaguam <68481491+cpaniaguam@users.noreply.github.com> * Fix ordering of empty-data check and use explicit None check in is_choice_only Agent-Logs-Url: https://github.com/lnccbrown/HSSM/sessions/a86b0208-9ac2-480e-a585-320d6a0e1bbe Co-authored-by: cpaniaguam <68481491+cpaniaguam@users.noreply.github.com> * Update RLSSM to raise NotImplementedError for unsupported missing_data and deadline handling * Revert "Fix ordering of empty-data check and use explicit None check in is_choice_only" This reverts commit f30186b7f149769b30d9c75fa51312392d200fdd. * Revert "Merge base branch cp-main-sb and add empty data check to validate_balanced_panel" This reverts commit 38ffc8d95631b46e329f20a86eb6c6778d857cd9. * Revert "Merge remote-tracking branch 'origin/cp-main-sb' into rlssm-class-make-model-dist" This reverts commit c696125b0129e94080605ba2042bed202e3fef13, reversing changes made to 1233cd7f531ab56f3df08bf85b200093372b09e4. * Update tests to raise NotImplementedError for unsupported missing_data and deadline handling Co-authored-by: Copilot * Update precommit * Enhance test_rlssm_get_prefix to validate fallback for unknown parameters * Move custom _get_prefix method to base * Add validation for contiguous participant rows in validate_balanced_panel Co-authored-by: Copilot * Update Config class to ignore choices when model_config is None Co-authored-by: Copilot * Clarify trialwise parameter handling in RLSSM by updating p_outlier exclusion logic Co-authored-by: Copilot * Refactor RLSSMConfig from_rlssm_dict method to derive required fields directly from the dataclass and improve validation for ssm_logp_func Co-authored-by: Copilot * Add handling for choice-only models in MissingDataMixin and update test fixture * Refactor is_choice_only assignment in HSSMBase to directly use model_config Co-authored-by: Copilot * Enhance RLSSMConfig to log warnings for missing 'response' and 'choices' in config_dict Co-authored-by: Copilot * Enhance RLSSMConfig docstring to detail fields for RLSSM likelihood pipeline Co-authored-by: Copilot --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot --- docs/tutorials/rlssm_quickstart.ipynb | 351 ++++ mkdocs.yml | 2 + src/hssm/__init__.py | 2 + src/hssm/base.py | 2160 +++++++++++++++++++++++++ src/hssm/config.py | 325 ++-- src/hssm/data_validator.py | 90 -- src/hssm/hssm.py | 2074 +----------------------- src/hssm/missing_data_mixin.py | 200 +++ src/hssm/param/param.py | 2 +- src/hssm/param/params.py | 3 + src/hssm/param/regression_param.py | 2 +- src/hssm/param/simple_param.py | 13 +- src/hssm/param/utils.py | 2 +- src/hssm/rl/__init__.py | 27 + src/hssm/rl/config.py | 219 +++ src/hssm/rl/rlssm.py | 264 +++ src/hssm/rl/utils.py | 70 + tests/param/test_default_param.py | 5 + tests/test_config.py | 64 +- tests/test_data_validator.py | 78 - tests/test_hssm.py | 32 +- tests/test_missing_data_mixin.py | 221 +++ tests/test_rl_utils.py | 107 ++ tests/test_rlssm.py | 329 ++++ tests/test_rlssm_config.py | 357 ++-- tests/test_save_load.py | 4 +- 26 files changed, 4410 insertions(+), 2593 deletions(-) create mode 100644 docs/tutorials/rlssm_quickstart.ipynb create mode 100644 src/hssm/base.py create mode 100644 src/hssm/missing_data_mixin.py create mode 100644 src/hssm/rl/__init__.py create mode 100644 src/hssm/rl/config.py create mode 100644 src/hssm/rl/rlssm.py create mode 100644 src/hssm/rl/utils.py create mode 100644 tests/test_missing_data_mixin.py create mode 100644 tests/test_rl_utils.py create mode 100644 tests/test_rlssm.py diff --git a/docs/tutorials/rlssm_quickstart.ipynb b/docs/tutorials/rlssm_quickstart.ipynb new file mode 100644 index 00000000..d976b1e8 --- /dev/null +++ b/docs/tutorials/rlssm_quickstart.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1b9b429d", + "metadata": {}, + "source": [ + "# RLSSM Quickstart: Instantiation, Model Building, and Sampling\n", + "\n", + "This notebook provides a minimal end-to-end demonstration of the `RLSSM` class:\n", + "\n", + "1. **Load** a balanced-panel two-armed bandit dataset\n", + "2. **Define** an annotated learning function and the angle SSM log-likelihood\n", + "3. **Configure** and **instantiate** an `RLSSM` model\n", + "4. **Inspect** the built Bambi / PyMC model\n", + "5. **Run** a minimal 2-draw sampling smoke test\n", + "\n", + "For a full treatment — simulating data, hierarchical formulas, meaningful sampling, and posterior visualization — see:\n", + "- [rlssm_tutorial.ipynb](rlssm_tutorial.ipynb)\n", + "- [add_custom_rlssm_model.ipynb](add_custom_rlssm_model.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "bf38d7f7", + "metadata": {}, + "source": [ + "## 1. Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d764731", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import hssm\n", + "from hssm.rl import RLSSM, RLSSMConfig\n", + "from hssm.distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx\n", + "from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise\n", + "from hssm.utils import annotate_function\n", + "\n", + "# RLSSM requires float32 throughout (JAX default).\n", + "hssm.set_floatX(\"float32\", update_jax=True)" + ] + }, + { + "cell_type": "markdown", + "id": "df12303f", + "metadata": {}, + "source": [ + "## 2. Load the Dataset\n", + "\n", + "We use a small synthetic two-armed bandit dataset from the HSSM test fixtures. \n", + "It is a **balanced panel**: every participant has the same number of trials. \n", + "Columns: `participant_id`, `trial_id`, `rt`, `response`, `feedback`.\n", + "\n", + "> **Note:** You can also generate data with\n", + "> [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators).\n", + "> See `rlssm_tutorial.ipynb` for an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2ef5f6e", + "metadata": {}, + "outputs": [], + "source": [ + "# Path relative to docs/tutorials/ when running inside the HSSM repo.\n", + "_fixture_path = Path(\"../../tests/fixtures/rldm_data.npy\")\n", + "raw = np.load(_fixture_path, allow_pickle=True).item()\n", + "data = pd.DataFrame(raw[\"data\"])\n", + "\n", + "n_participants = data[\"participant_id\"].nunique()\n", + "n_trials = len(data) // n_participants\n", + "\n", + "print(data.head())\n", + "print(f\"\\nParticipants: {n_participants} | Trials per participant: {n_trials}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8c310290", + "metadata": {}, + "source": [ + "## 3. Define the Learning Process\n", + "\n", + "The RL learning process is a JAX function that, given a subject's trial sequence, computes\n", + "the trial-wise drift rate `v` via a Q-learning update rule. \n", + "\n", + "`annotate_function` attaches `.inputs`, `.outputs`, and (optionally) `.computed` metadata\n", + "that the RLSSM likelihood builder uses to automatically construct the input matrix for the\n", + "decision process.\n", + "\n", + "- **inputs** — columns that the function reads (free parameters + data columns)\n", + "- **outputs** — what the function produces (here: `v`, the drift rate)\n", + "\n", + "Here we annotate the built-in `compute_v_subject_wise` function, which implements a simple\n", + "Rescorla-Wagner Q-learning update for a two-armed bandit task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbcea122", + "metadata": {}, + "outputs": [], + "source": [ + "compute_v_annotated = annotate_function(\n", + " inputs=[\"rl_alpha\", \"scaler\", \"response\", \"feedback\"],\n", + " outputs=[\"v\"],\n", + ")(compute_v_subject_wise)\n", + "\n", + "print(\"Learning function inputs :\", compute_v_annotated.inputs)\n", + "print(\"Learning function outputs:\", compute_v_annotated.outputs)" + ] + }, + { + "cell_type": "markdown", + "id": "7a03305a", + "metadata": {}, + "source": [ + "## 4. Define the Decision (SSM) Log-Likelihood\n", + "\n", + "The decision process uses the **angle model** likelihood, loaded from an ONNX file.\n", + "`make_jax_matrix_logp_funcs_from_onnx` returns a JAX callable that accepts a\n", + "2-D matrix whose columns are `[v, a, z, t, theta, rt, response]` and returns\n", + "per-trial log-probabilities.\n", + "\n", + "We then annotate that callable so the builder knows:\n", + "- which columns the matrix contains (`inputs`)\n", + "- that `v` itself is *computed* by the learning function (not a free parameter)\n", + "\n", + "The ONNX file is loaded from the local test fixture when running inside the HSSM\n", + "repository; otherwise it is downloaded from the HuggingFace Hub (`franklab/HSSM`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60bbc036", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the local fixture when available; fall back to HuggingFace download.\n", + "_local_onnx = Path(\"../../tests/fixtures/angle.onnx\").resolve()\n", + "_onnx_model = str(_local_onnx) if _local_onnx.exists() else \"angle.onnx\"\n", + "\n", + "_angle_logp_jax = make_jax_matrix_logp_funcs_from_onnx(model=_onnx_model)\n", + "\n", + "angle_logp_func = annotate_function(\n", + " inputs=[\"v\", \"a\", \"z\", \"t\", \"theta\", \"rt\", \"response\"],\n", + " outputs=[\"logp\"],\n", + " computed={\"v\": compute_v_annotated},\n", + ")(_angle_logp_jax)\n", + "\n", + "print(\"SSM logp inputs :\", angle_logp_func.inputs)\n", + "print(\"SSM logp outputs:\", angle_logp_func.outputs)\n", + "print(\"Computed deps :\", list(angle_logp_func.computed.keys()))" + ] + }, + { + "cell_type": "markdown", + "id": "cf8f5b63", + "metadata": {}, + "source": [ + "## 5. Configure the Model with `RLSSMConfig`\n", + "\n", + "`RLSSMConfig` collects all the information the RLSSM class needs:\n", + "\n", + "| Field | Purpose |\n", + "|-------|---------|\n", + "| `model_name` | Identifier string for the configuration |\n", + "| `decision_process` | Name of the SSM (e.g. `\"angle\"`) |\n", + "| `list_params` | Ordered list of *free* parameters to sample |\n", + "| `params_default` | Starting / default values for each parameter |\n", + "| `bounds` | Prior bounds for each parameter |\n", + "| `learning_process` | Dict mapping computed param name → annotated learning function |\n", + "| `extra_fields` | Extra data columns required by the learning function |\n", + "| `ssm_logp_func` | Annotated JAX callable for the decision-process likelihood |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4beba1bc", + "metadata": {}, + "outputs": [], + "source": [ + "rlssm_config = RLSSMConfig(\n", + " model_name=\"rlssm_angle_quickstart\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " decision_process=\"angle\",\n", + " decision_process_loglik_kind=\"approx_differentiable\",\n", + " learning_process_kind=\"blackbox\",\n", + " list_params=[\"rl_alpha\", \"scaler\", \"a\", \"theta\", \"t\", \"z\"],\n", + " params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5],\n", + " bounds={\n", + " \"rl_alpha\": (0.0, 1.0),\n", + " \"scaler\": (0.0, 10.0),\n", + " \"a\": (0.1, 3.0),\n", + " \"theta\": (-0.1, 0.1),\n", + " \"t\": (0.001, 1.0),\n", + " \"z\": (0.1, 0.9),\n", + " },\n", + " learning_process={\"v\": compute_v_annotated},\n", + " response=[\"rt\", \"response\"],\n", + " choices=[0, 1],\n", + " extra_fields=[\"feedback\"],\n", + " ssm_logp_func=angle_logp_func,\n", + ")\n", + "\n", + "print(\"Model name :\", rlssm_config.model_name)\n", + "print(\"Free params :\", rlssm_config.list_params)" + ] + }, + { + "cell_type": "markdown", + "id": "924ee4c7", + "metadata": {}, + "source": [ + "## 6. Instantiate the `RLSSM` Model\n", + "\n", + "Passing `data` and `rlssm_config` to `RLSSM`:\n", + "\n", + "- validates the balanced-panel requirement\n", + "- builds a differentiable PyTensor Op that chains the RL learning step and the\n", + " angle log-likelihood\n", + "- constructs the Bambi / PyMC model internally\n", + "\n", + "Note that `v` (the drift rate) is *not* a free parameter — it is computed inside\n", + "the Op by the Q-learning update and therefore does not appear in `model.params`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f8da79a", + "metadata": {}, + "outputs": [], + "source": [ + "model = RLSSM(data=data, model_config=rlssm_config)\n", + "\n", + "assert isinstance(model, RLSSM)\n", + "print(\"Model type :\", type(model).__name__)\n", + "print(\"Participants :\", model.n_participants)\n", + "print(\"Trials/subj :\", model.n_trials)\n", + "print(\"Free parameters :\", list(model.params.keys()))\n", + "assert \"rl_alpha\" in model.params, \"rl_alpha must be a free parameter\"\n", + "assert \"v\" not in model.params, \"v is computed, not a free parameter\"\n", + "model" + ] + }, + { + "cell_type": "markdown", + "id": "f7f39940", + "metadata": {}, + "source": [ + "## 7. Inspect the Built Model\n", + "\n", + "After construction, `model.model` exposes the underlying **Bambi model** and\n", + "`model.pymc_model` exposes the **PyMC model** context — useful for debugging\n", + "or customizing priors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0558ad4", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=== Bambi model ===\")\n", + "print(model.model)\n", + "\n", + "print(\"\\n=== PyMC model ===\")\n", + "print(model.pymc_model)" + ] + }, + { + "cell_type": "markdown", + "id": "f4e50110", + "metadata": {}, + "source": [ + "## 8. Sampling\n", + "\n", + "A minimal sampling run — 2 draws, 2 tuning steps, 1 chain — confirms that the full\n", + "computational graph (Q-learning scan → angle logp → NUTS gradient) is wired correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96ce3238", + "metadata": {}, + "outputs": [], + "source": [ + "trace = model.sample(draws=2, tune=2, chains=1, cores=1, sampler=\"numpyro\", target_accept=0.9)\n", + "\n", + "assert trace is not None\n", + "print(trace)" + ] + }, + { + "cell_type": "markdown", + "id": "a784a468", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook showed how to:\n", + "\n", + "1. Load a balanced-panel dataset (`rldm_data.npy`)\n", + "2. Annotate a Q-learning function with `annotate_function`\n", + "3. Load the angle ONNX likelihood and annotate it so the builder can assemble the input matrix\n", + "4. Define an `RLSSMConfig` and pass it to `RLSSM`\n", + "5. Confirm model structure (free params, Bambi / PyMC objects)\n", + "6. Run a 2-draw sampling smoke test that returns an `arviz.InferenceData` object" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hssm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 93b2696f..0ef2ad1b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,6 +44,7 @@ nav: - Hierarchical Variational Inference: tutorials/variational_inference_hierarchical.ipynb - Using HSSM low-level API directly with PyMC: tutorials/pymc.ipynb - Reinforcement Learning - Sequential Sampling Models (RLSSM): tutorials/rlssm_tutorial.ipynb + - RLSSM Quickstart: tutorials/rlssm_quickstart.ipynb - Add custom RLSSM models: tutorials/add_custom_rlssm_model.ipynb - Custom models: tutorials/jax_callable_contribution_onnx_example.ipynb - BayesFlow LRE Integration: tutorials/bayesflow_lre_integration.ipynb @@ -96,6 +97,7 @@ plugins: - tutorials/hssm_tutorial_workshop_2.ipynb - tutorials/add_custom_rlssm_model.ipynb - tutorials/rlssm_tutorial.ipynb + - tutorials/rlssm_quickstart.ipynb - tutorials/lapse_prob_and_dist.ipynb - tutorials/plotting.ipynb - tutorials/scientific_workflow_hssm.ipynb diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index 60dd7102..2f234d08 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -19,6 +19,7 @@ from .param import UserParam as Param from .prior import Prior from .register import register_model +from .rl import RLSSM from .simulator import simulate_data from .utils import check_data_for_rl, set_floatX @@ -31,6 +32,7 @@ __all__ = [ "HSSM", + "RLSSM", "Link", "load_data", "ModelConfig", diff --git a/src/hssm/base.py b/src/hssm/base.py new file mode 100644 index 00000000..806e0ede --- /dev/null +++ b/src/hssm/base.py @@ -0,0 +1,2160 @@ +"""HSSM: Hierarchical Sequential Sampling Models. + +A package based on pymc and bambi to perform Bayesian inference for hierarchical +sequential sampling models. + +This file defines the entry class HSSM. +""" + +import datetime +import logging +from abc import ABC, abstractmethod +from copy import deepcopy +from os import PathLike +from pathlib import Path +from typing import Any, Callable, Literal, Optional, Union, cast, get_args + +import arviz as az +import bambi as bmb +import cloudpickle as cpickle +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pymc as pm +import pytensor +import seaborn as sns +import xarray as xr +from bambi.model_components import DistributionalComponent +from bambi.transformations import transformations_namespace +from pymc.model.transform.conditioning import do + +from hssm._types import SupportedModels +from hssm.data_validator import DataValidatorMixin +from hssm.defaults import ( + INITVAL_JITTER_SETTINGS, + INITVAL_SETTINGS, +) +from hssm.distribution_utils import ( + make_family, +) +from hssm.missing_data_mixin import MissingDataMixin +from hssm.utils import ( + _compute_log_likelihood, + _get_alias_dict, + _print_prior, + _split_array, +) + +from . import plotting +from .config import BaseModelConfig +from .param import Params +from .param import UserParam as Param + +_logger = logging.getLogger("hssm") + +# NOTE: Temporary mapping from old sampler names to new ones in bambi 0.16.0 +_new_sampler_mapping: dict[str, Literal["pymc", "numpyro", "blackjax"]] = { + "mcmc": "pymc", + "nuts_numpyro": "numpyro", + "nuts_blackjax": "blackjax", +} + + +class classproperty: + """A decorator that combines the behavior of @property and @classmethod. + + This decorator allows you to define a property that can be accessed on the class + itself, rather than on instances of the class. It is useful for defining class-level + properties that need to perform some computation or access class-level data. + + This implementation is provided for compatibility with Python versions 3.10 through + 3.12, as one cannot combine the @property and @classmethod decorators across all + these versions. + + Example + ------- + class MyClass: + @classproperty + def my_class_property(cls): + return "This is a class property" + + print(MyClass.my_class_property) # Output: This is a class property + """ + + def __init__(self, fget): + self.fget = fget + + def __get__(self, instance, owner): # noqa: D105 + return self.fget(owner) + + +class HSSMBase(ABC, DataValidatorMixin, MissingDataMixin): + """The basic Hierarchical Sequential Sampling Model (HSSM) class. + + Parameters + ---------- + data + A pandas DataFrame with the minimum requirements of containing the data with the + columns "rt" and "response". + model + The name of the model to use. Currently supported models are "ddm", "ddm_sdv", + "full_ddm", "angle", "levy", "ornstein", "weibull", "race_no_bias_angle_4", + "ddm_seq2_no_bias". If any other string is passed, the model will be considered + custom, in which case all `model_config`, `loglik`, and `loglik_kind` have to be + provided by the user. + choices : optional + When an `int`, the number of choices that the participants can make. If `2`, the + choices are [-1, 1] by default. If anything greater than `2`, the choices are + [0, 1, ..., n_choices - 1] by default. If a `list` is provided, it should be the + list of choices that the participants can make. Defaults to `2`. If any value + other than the choices provided is found in the "response" column of the data, + an error will be raised. + include : optional + A list of dictionaries specifying parameter specifications to include in the + model. If left unspecified, defaults will be used for all parameter + specifications. Defaults to None. + model_config + A fully initialised :class:`~hssm.config.BaseModelConfig` instance + (typically :class:`~hssm.config.Config`) produced by the subclass + before calling ``super().__init__``. All likelihood, parameter, and + data information used by :class:`HSSMBase` is drawn from this object, + and it must provide populated ``loglik`` and ``list_params`` fields. + p_outlier : optional + The fixed lapse probability or the prior distribution of the lapse probability. + Defaults to a fixed value of 0.05. When `None`, the lapse probability will not + be included in estimation. + lapse : optional + The lapse distribution. This argument is required only if `p_outlier` is not + `None`. Defaults to Uniform(0.0, 10.0). + global_formula : optional + A string that specifies a regressions formula which will be used for all model + parameters. If you specify parameter-wise regressions in addition, these will + override the global regression for the respective parameter. + link_settings : optional + An optional string literal that indicates the link functions to use for each + parameter. Helpful for hierarchical models where sampling might get stuck/ + very slow. Can be one of the following: + + - `"log_logit"`: applies log link functions to positive parameters and + generalized logit link functions to parameters that have explicit bounds. + - `None`: unless otherwise specified, the `"identity"` link functions will be + used. + The default value is `None`. + prior_settings : optional + An optional string literal that indicates the prior distributions to use for + each parameter. Helpful for hierarchical models where sampling might get stuck/ + very slow. Can be one of the following: + + - `"safe"`: HSSM will scan all parameters in the model and apply safe priors to + all parameters that do not have explicit bounds. + - None: HSSM will use bambi to provide default priors for all parameters. Not + recommended when you are using hierarchical models. + The default value is `"safe"`. + extra_namespace : optional + Additional user supplied variables with transformations or data to include in + the environment where the formula is evaluated. Defaults to `None`. + missing_data : optional + Specifies whether the model should handle missing data. Can be a `bool` or a + `float`. If `False`, and if the `rt` column contains in the data -999.0, + the model will drop these rows and produce a warning. If `True`, the model will + treat code -999.0 as missing data. If a `float` is provided, the model will + treat this value as the missing data value. Defaults to `False`. + deadline : optional + Specifies whether the model should handle deadline data. Can be a `bool` or a + `str`. If `False`, the model will not do nothing even if a deadline column is + provided. If `True`, the model will treat the `deadline` column as deadline + data. If a `str` is provided, the model will treat this value as the name of the + deadline column. Defaults to `False`. + loglik_missing_data : optional + A likelihood function for missing data. Please see the `loglik` parameter to see + how to specify the likelihood function this parameter. If nothing is provided, + a default likelihood function will be used. This parameter is required only if + either `missing_data` or `deadline` is not `False`. Defaults to `None`. + process_initvals : optional + If `True`, the model will process the initial values. Defaults to `True`. + initval_jitter : optional + The jitter value for the initial values. Defaults to `0.01`. + **kwargs + Additional arguments passed to the `bmb.Model` object. + + Attributes + ---------- + data + A pandas DataFrame with at least two columns of "rt" and "response" indicating + the response time and responses. + list_params + The list of strs of parameter names. + model_name + The name of the model. + loglik: + The likelihood function or a path to an onnx file. + loglik_kind: + The kind of likelihood used. + model_config + A dictionary representing the model configuration. + model_distribution + The likelihood function of the model in the form of a pm.Distribution subclass. + family + A Bambi family object. + priors + A dictionary containing the prior distribution of parameters. + formula + A string representing the model formula. + link + A string or a dictionary representing the link functions for all parameters. + params + A list of Param objects representing model parameters. + initval_jitter + The jitter value for the initial values. + """ + + def __init__( + self, + data: pd.DataFrame, + model_config: BaseModelConfig, + include: list[dict[str, Any] | Param] | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + global_formula: str | None = None, + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: ( + str | PathLike | Callable | pytensor.graph.Op | None + ) = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs, + ): + # ===== Input Data & Configuration ===== + self.data = data.copy() + self.global_formula = global_formula + self.link_settings = link_settings + self.prior_settings = prior_settings + self.missing_data_value = -999.0 + + # Store a safe default for the constructor-arguments snapshot so that + # pickling / save-load cannot raise AttributeError if a subclass forgets + # to call `_store_init_args(locals(), kwargs)` early. Subclasses are + # still expected to overwrite this with the real snapshot. However, + # do not overwrite if a subclass already set `_init_args` prior to + # calling `super().__init__()` (the subclass may capture its + # constructor args before delegating to the base class). + if not hasattr(self, "_init_args"): + self._init_args: dict[str, Any] = {} + + # Set up additional namespace for formula evaluation + additional_namespace = transformations_namespace.copy() + if extra_namespace is not None: + additional_namespace.update(extra_namespace) + self.additional_namespace = additional_namespace + + # region ===== Inference Results (initialized to None/empty) ===== + self._inference_obj: az.InferenceData | None = None + self._inference_obj_vi: pm.Approximation | None = None + self._vi_approx = None + self._map_dict = None + # endregion + + # ===== Initial Values Configuration ===== + self._initvals: dict[str, Any] = {} + self.initval_jitter = initval_jitter + + # region ===== Store the pre-built config ===== + self.model_config: BaseModelConfig = model_config + # endregion + + # region ===== Set up shortcuts so old code will work ====== + self.response: list[str] = ( # type: ignore[assignment] + list(self.model_config.response) + if self.model_config.response is not None + else [] + ) + self.list_params = ( + list(self.model_config.list_params) + if self.model_config.list_params is not None + else None + ) + self.choices = self.model_config.choices # type: ignore[assignment] + self.model_name = self.model_config.model_name + self.loglik = self.model_config.loglik + self.loglik_kind = self.model_config.loglik_kind + self.extra_fields = self.model_config.extra_fields + # endregion + + # TODO: add to HSSMBase + self.response = cast("list[str]", self.response) + self.is_choice_only: bool = self.model_config.is_choice_only + + if self.choices is None: + raise ValueError( + "`choices` must be provided either in `model_config` or as an argument." + ) + + self._validate_choices() + + # region Avoid mypy error later (None.append). Should list_params be Optional? + if self.list_params is None: + raise ValueError( + "`list_params` must be provided in the model configuration." + ) + # endregion + + self.n_choices = len(self.choices) # type: ignore[arg-type] + + self._pre_check_data_sanity() + + self._process_missing_data_and_deadline( + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + ) + + # region ===== Process lapse distribution ===== + self.has_lapse = p_outlier is not None and p_outlier != 0 + self._check_lapse(lapse) + if self.has_lapse and self.list_params[-1] != "p_outlier": + self.list_params.append("p_outlier") + # endregion + + # Process all parameters + self.params = Params.from_user_specs( + model=self, # type: ignore[arg-type] + include=[] if include is None else include, + kwargs=kwargs, + p_outlier=p_outlier, + ) + self._parent = self.params.parent + self._parent_param = self.params.parent_param + + self._validate_fixed_vectors() + self.formula, self.priors, self.link = self.params.parse_bambi(model=self) # type: ignore[arg-type] + + # For parameters that have a regression backend, apply bounds at the likelihood + # level to ensure that the samples that are out of bounds + # are discarded (replaced with a large negative value). + self.bounds = { + name: param.bounds + for name, param in self.params.items() + if param.is_regression and param.bounds is not None + } + + # Set p_outlier and lapse + self.p_outlier = self.params.get("p_outlier") + self.lapse = lapse if self.has_lapse else None + + self._post_check_data_sanity() + + self.model_distribution = self._make_model_distribution() + + self.family = make_family( + self.model_distribution, + self.list_params, + self.link, + self._parent, + ) + + self.model = bmb.Model( + self.formula, + data=self.data, + family=self.family, + priors=self.priors, # center_predictors=False + extra_namespace=self.additional_namespace, + **kwargs, + ) + + self._aliases = _get_alias_dict( + self.model, self._parent_param, self.response_c, self.response_str + ) + self.set_alias(self._aliases) + self.model.build() + + # region ===== Fix scalar deterministic dims for bambi >= 0.17 ===== + # Bambi >= 0.17 declares dims=("__obs__",) for intercept-only + # deterministics that actually have shape (1,). This causes an + # xarray CoordinateValidationError during pm.sample() when ArviZ + # tries to create a DataArray with mismatched dimension sizes. + # Fix by removing the dims declaration for these deterministics. + self._fix_scalar_deterministic_dims() + # endregion + + # region ===== Init vals and jitters ===== + if process_initvals: + self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS) + if self.initval_jitter > 0: + self._jitter_initvals( + jitter_epsilon=self.initval_jitter, + vector_only=True, + ) + # endregion + + # Make sure we reset rvs_to_initial_values --> Only None's + # Otherwise PyMC barks at us when asking to compute likelihoods + self.pymc_model.rvs_to_initial_values.update( + {key_: None for key_ in self.pymc_model.rvs_to_initial_values.keys()} + ) + _logger.info("Model initialized successfully.") + + @abstractmethod + def _make_model_distribution(self) -> type[pm.Distribution]: + """Make a pm.Distribution for the model. + + This method must be implemented by subclasses to create the appropriate + distribution for the specific model type. + """ + ... + + def _fix_scalar_deterministic_dims(self) -> None: + """Fix dims metadata for scalar deterministics. + + Bambi >= 0.17 returns shape ``(1,)`` for intercept-only + deterministics but still declares ``dims=("__obs__",)``. This causes + an xarray ``CoordinateValidationError`` during ``pm.sample()`` because + the ``__obs__`` coordinate has ``n_obs`` entries. Removing the dims + declaration for these variables lets ArviZ handle them as + un-dimensioned arrays, avoiding the conflict. + """ + n_obs = len(self.data) + dims_dict = self.pymc_model.named_vars_to_dims + for det in self.pymc_model.deterministics: + if det.name not in dims_dict: + continue + dims = dims_dict[det.name] + if "__obs__" in dims: + # Check static shape: if it doesn't match n_obs, remove dims + try: + shape_0 = det.type.shape[0] + except (IndexError, TypeError): + continue + if shape_0 is not None and shape_0 != n_obs: + del dims_dict[det.name] + + def _validate_fixed_vectors(self) -> None: + """Validate that fixed-vector parameters have the correct length. + + Fixed-vector parameters (``prior=np.ndarray``) bypass Bambi's formula + system entirely --- they are passed as a scalar ``0.0`` placeholder to + Bambi, and the real vector is substituted inside + ``HSSMDistribution.logp()`` (see ``dist.py``). Because this + substitution is invisible to Bambi, we must validate the vector length + against ``len(self.data)`` up front to catch shape mismatches early. + """ + for name, param in self.params.items(): + if isinstance(param.prior, np.ndarray): + if len(param.prior) != len(self.data): + raise ValueError( + f"Fixed vector for parameter '{name}' has length " + f"{len(param.prior)}, but data has {len(self.data)} rows." + ) + + @classproperty + def supported_models(cls) -> tuple[SupportedModels, ...]: + """Get a tuple of all supported models. + + Returns + ------- + tuple[SupportedModels, ...] + A tuple containing all supported model names. + """ + return get_args(SupportedModels) + + @staticmethod + def _store_init_args( + local_vars: dict[str, Any], extra_kwargs: dict[str, Any] + ) -> dict[str, Any]: + """Capture subclass ``__init__`` arguments for save/load serialisation. + + Call this at the very start of a subclass ``__init__`` before any local + variables are assigned, passing ``locals()`` and the ``**kwargs`` dict:: + + self._init_args = self._store_init_args(locals(), kwargs) + + Parameters + ---------- + local_vars + The ``locals()`` snapshot from the subclass ``__init__``. + extra_kwargs + The ``**kwargs`` dict captured by the subclass ``__init__``. + + Returns + ------- + dict[str, Any] + A mapping of parameter names to their values, suitable for + reconstructing the instance via ``cls(**init_args)``. + + Notes + ----- + The implementation filters out internal names that commonly appear in + ``locals()`` snapshots (for example, ``__class__`` and ``kwargs``) so + that the returned mapping is safe to pass back to the class + constructor during unpickling. + """ + # Exclude internal names that appear in locals() snapshots and are not + # valid constructor parameters when re-instantiating the class. + exclude_keys = {"self", "kwargs", "__class__"} + result = {k: v for k, v in local_vars.items() if k not in exclude_keys} + result.update(extra_kwargs) + return result + + def find_MAP(self, **kwargs): + """Perform Maximum A Posteriori estimation. + + Returns + ------- + dict + A dictionary containing the MAP estimates of the model parameters. + """ + self._map_dict = pm.find_MAP(model=self.pymc_model, **kwargs) + return self._map_dict + + def sample( + self, + sampler: Literal["pymc", "numpyro", "blackjax", "nutpie", "laplace"] + | None = None, + init: str | None = None, + initvals: str | dict | None = None, + include_response_params: bool = False, + **kwargs, + ) -> az.InferenceData | pm.Approximation: + """Perform sampling using the `fit` method via bambi.Model. + + Parameters + ---------- + sampler: optional + The sampler to use. Can be one of "pymc", "numpyro", + "blackjax", "nutpie", or "laplace". If using `blackbox` likelihoods, + this cannot be "numpyro", "blackjax", or "nutpie". By default it is None, + and sampler will automatically be chosen: when the model uses the + `approx_differentiable` likelihood, and `jax` backend, "numpyro" will + be used. Otherwise, "pymc" (the default PyMC NUTS sampler) will be used. + + Note that the old sampler names such as "mcmc", "nuts_numpyro", + "nuts_blackjax" will be deprecated and removed in future releases. A warning + will be raised if any of these old names are used. + init: optional + Initialization method to use for the sampler. If any of the NUTS samplers + is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`. + initvals: optional + Pass initial values to the sampler. This can be a dictionary of initial + values for parameters of the model, or a string "map" to use initialization + at the MAP estimate. If "map" is used, the MAP estimate will be computed if + not already attached to the base class from prior call to `find_MAP`. + include_response_params: optional + Include parameters of the response distribution in the output. These usually + take more space than other parameters as there's one of them per + observation. Defaults to False. + kwargs + Other arguments passed to bmb.Model.fit(). Please see [here] + (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit) + for full documentation. + + Returns + ------- + az.InferenceData | pm.Approximation + A reference to the `model.traces` object, which stores the traces of the + last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData` + instance if `sampler` is `"pymc"` (default), `"numpyro"`, + `"blackjax"` or "`laplace". + """ + # If initvals are None (default) + # we skip processing initvals here. + if sampler in _new_sampler_mapping: + _logger.warning( + f"Sampler '{sampler}' is deprecated. " + "Please use the new sampler names: " + "'pymc', 'numpyro', 'blackjax', 'nutpie', or 'laplace'." + ) + sampler = _new_sampler_mapping[sampler] # type: ignore + + if sampler == "vi": + raise ValueError( + "VI is not supported via the sample() method. " + "Please use the vi() method instead." + ) + + if initvals is not None: + if isinstance(initvals, dict): + kwargs["initvals"] = initvals + else: + if isinstance(initvals, str): + if initvals == "map": + if self._map_dict is None: + _logger.info( + "initvals='map' but no map" + "estimate precomputed. \n" + "Running map estimation first..." + ) + self.find_MAP() + kwargs["initvals"] = self._map_dict + else: + kwargs["initvals"] = self._map_dict + else: + raise ValueError( + "initvals argument must be a dictionary or 'map'" + " to use the MAP estimate." + ) + else: + kwargs["initvals"] = self._initvals + _logger.info("Using default initvals. \n") + + if sampler is None: + if ( + self.loglik_kind == "approx_differentiable" + and self.model_config.backend == "jax" + ): + sampler = "numpyro" + else: + sampler = "pymc" + + if self.loglik_kind == "blackbox": + if sampler in ["blackjax", "numpyro", "nutpie"]: + raise ValueError( + f"{sampler} sampler does not work with blackbox likelihoods." + ) + + if "step" not in kwargs: + kwargs |= {"step": pm.Slice(model=self.pymc_model)} + + if ( + self.loglik_kind == "approx_differentiable" + and self.model_config.backend == "jax" + and sampler == "pymc" + and kwargs.get("cores", None) != 1 + ): + _logger.warning( + "Parallel sampling might not work with `jax` backend and the PyMC NUTS " + + "sampler on some platforms. Please consider using `numpyro`, " + + "`blackjax`, or `nutpie` sampler if that is a problem." + ) + + if self._check_extra_fields(): + self._update_extra_fields() + + if init is None: + if sampler in ["pymc", "numpyro", "blackjax", "nutpie"]: + init = "adapt_diag" + else: + init = "auto" + + # If sampler is finally `numpyro` make sure + # the jitter argument is set to False + if sampler == "numpyro": + if "nuts_sampler_kwargs" in kwargs: + if kwargs["nuts_sampler_kwargs"].get("jitter"): + _logger.warning( + "The jitter argument is set to True. " + + "This argument is not supported " + + "by the numpyro backend. " + + "The jitter argument will be set to False." + ) + kwargs["nuts_sampler_kwargs"]["jitter"] = False + else: + kwargs["nuts_sampler_kwargs"] = {"jitter": False} + + if sampler != "pymc" and "step" in kwargs: + raise ValueError( + "`step` samplers (enabled by the `step` argument) are only supported " + "by the `pymc` sampler." + ) + + if self._inference_obj is not None: + _logger.warning( + "The model has already been sampled. Overwriting the previous " + + "inference object. Any previous reference to the inference object " + + "will still point to the old object." + ) + + # Define whether likelihood should be computed + compute_likelihood = True + if "idata_kwargs" in kwargs: + if "log_likelihood" in kwargs["idata_kwargs"]: + compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True) + + omit_offsets = kwargs.pop("omit_offsets", False) + self._inference_obj = self.model.fit( + inference_method=( + "pymc" + if sampler in ["pymc", "numpyro", "blackjax", "nutpie"] + else sampler + ), + init=init, + include_response_params=include_response_params, + omit_offsets=omit_offsets, + **kwargs, + ) + + # Separate out log likelihood computation + if compute_likelihood: + self.log_likelihood(self._inference_obj, inplace=True) + + # Subset data vars in posterior + self._clean_posterior_group(idata=self._inference_obj) + return self.traces + + def vi( + self, + method: str = "advi", + niter: int = 10000, + draws: int = 1000, + return_idata: bool = True, + ignore_mcmc_start_point_defaults=False, + **vi_kwargs, + ) -> pm.Approximation | az.InferenceData: + """Perform Variational Inference. + + Parameters + ---------- + niter : int + The number of iterations to run the VI algorithm. Defaults to 3000. + method : str + The method to use for VI. Can be one of "advi" or "fullrank_advi", "svgd", + "asvgd".Defaults to "advi". + draws : int + The number of samples to draw from the posterior distribution. + Defaults to 1000. + return_idata : bool + If True, returns an InferenceData object. Otherwise, returns the + approximation object directly. Defaults to True. + + Returns + ------- + pm.Approximation or az.InferenceData: The mean field approximation object. + """ + if self.loglik_kind == "analytical": + _logger.warning( + "VI is not recommended for the analytical likelihood," + " since gradients can be brittle." + ) + elif self.loglik_kind == "blackbox": + raise ValueError( + "VI is not supported for blackbox likelihoods, " + " since likelihood gradients are needed!" + ) + + if ("start" not in vi_kwargs) and not ignore_mcmc_start_point_defaults: + _logger.info("Using MCMC starting point defaults.") + vi_kwargs["start"] = self._initvals + + # Run variational inference directly from pymc model + with self.pymc_model: + self._vi_approx = pm.fit(n=niter, method=method, **vi_kwargs) + + # Sample from the approximate posterior + if self._vi_approx is not None: + self._inference_obj_vi = self._vi_approx.sample(draws) + + # Post-processing + self._clean_posterior_group(idata=self._inference_obj_vi) + + # Return the InferenceData object if return_idata is True + if return_idata: + return self._inference_obj_vi + # Otherwise return the appromation object directly + return self.vi_approx + + def _clean_posterior_group(self, idata: az.InferenceData | None = None): + """Clean up the posterior group of the InferenceData object. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object to clean up. If None, the last InferenceData object + will be used. + """ + # # Logic behind which variables to keep: + # # We essentially want to get rid of + # # all the trial-wise variables. + + # # We drop all distributional components, IF they are deterministics + # # (in which case they will be trial wise systematically) + # # and we keep distributional components, IF they are + # # basic random-variabels (in which case they should never + # # appear trial-wise). + if idata is None: + raise ValueError( + "The InferenceData object is None. Cannot clean up the posterior group." + ) + elif not hasattr(idata, "posterior"): + raise ValueError( + "The InferenceData object does not have a posterior group. " + + "Cannot clean up the posterior group." + ) + + vars_to_keep = set(idata["posterior"].data_vars.keys()).difference( + set( + key_ + for key_ in self.model.distributional_components.keys() + if key_ in [var_.name for var_ in self.pymc_model.deterministics] + ) + ) + vars_to_keep_clean = [ + var_ + for var_ in vars_to_keep + if isinstance(var_, str) and "_mean" not in var_ + ] + + setattr( + idata, + "posterior", + idata["posterior"][vars_to_keep_clean], + ) + + def log_likelihood( + self, + idata: az.InferenceData | None = None, + data: pd.DataFrame | None = None, + inplace: bool = True, + keep_likelihood_params: bool = False, + ) -> az.InferenceData | None: + """Compute the log likelihood of the model. + + Parameters + ---------- + idata : optional + The `InferenceData` object returned by `HSSM.sample()`. If not provided, + data : optional + A pandas DataFrame with values for the predictors that are used to obtain + out-of-sample predictions. If omitted, the original dataset is used. + inplace : optional + If `True` will modify idata in-place and append a `log_likelihood` group to + `idata`. Otherwise, it will return a copy of idata with the predictions + added, by default True. + keep_likelihood_params : optional + If `True`, the trial wise likelihood parameters that are computed + on route to getting the log likelihood are kept in the `idata` object. + Defaults to False. See also the method `add_likelihood_parameters_to_idata`. + + Returns + ------- + az.InferenceData | None + InferenceData or None + """ + if self._inference_obj is None and idata is None: + raise ValueError( + "Neither has the model been sampled yet nor" + + " an idata object has been provided." + ) + + if idata is None: + if self._inference_obj is None: + raise ValueError( + "The model has not been sampled yet. " + + "Please provide an idata object." + ) + else: + idata = self._inference_obj + + # Actual likelihood computation + idata = _compute_log_likelihood(self.model, idata, data, inplace) + + # clean up posterior: + if not keep_likelihood_params: + self._clean_posterior_group(idata=idata) + + if inplace: + return None + else: + return idata + + def add_likelihood_parameters_to_idata( + self, + idata: az.InferenceData | None = None, + inplace: bool = False, + ) -> az.InferenceData | None: + """Add likelihood parameters to the InferenceData object. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object returned by HSSM.sample(). + inplace : bool + If True, the likelihood parameters are added to idata in-place. Otherwise, + a copy of idata with the likelihood parameters added is returned. + Defaults to False. + + Returns + ------- + az.InferenceData | None + InferenceData or None + """ + if idata is None: + if self._inference_obj is None: + raise ValueError("No idata provided and model not yet sampled!") + else: + idata = self.model._compute_likelihood_params( # pylint: disable=protected-access + deepcopy(self._inference_obj) + if not inplace + else self._inference_obj + ) + else: + idata = self.model._compute_likelihood_params( # pylint: disable=protected-access + deepcopy(idata) if not inplace else idata + ) + return idata + + def sample_posterior_predictive( + self, + idata: az.InferenceData | None = None, + data: pd.DataFrame | None = None, + inplace: bool = True, + include_group_specific: bool = True, + kind: Literal["response", "response_params"] = "response", + draws: int | float | list[int] | np.ndarray | None = None, + safe_mode: bool = True, + ) -> az.InferenceData | None: + """Perform posterior predictive sampling from the HSSM model. + + Parameters + ---------- + idata : optional + The `InferenceData` object returned by `HSSM.sample()`. If not provided, + the `InferenceData` from the last time `sample()` is called will be used. + data : optional + An optional data frame with values for the predictors that are used to + obtain out-of-sample predictions. If omitted, the original dataset is used. + inplace : optional + If `True` will modify idata in-place and append a `posterior_predictive` + group to `idata`. Otherwise, it will return a copy of idata with the + predictions added, by default True. + include_group_specific : optional + If `True` will make predictions including the group specific effects. + Otherwise, predictions are made with common effects only (i.e. group- + specific are set to zero), by default True. + kind: optional + Indicates the type of prediction required. Can be `"response_params"` or + `"response"`. The first returns draws from the posterior distribution of the + likelihood parameters, while the latter returns the draws from the posterior + predictive distribution (i.e. the posterior probability distribution for a + new observation) in addition to the posterior distribution. Defaults to + "response_params". + draws: optional + The number of samples to draw from the posterior predictive distribution + from each chain. + When it's an integer >= 1, the number of samples to be extracted from the + `draw` dimension. If this integer is larger than the number of posterior + samples in each chain, all posterior samples will be used + in posterior predictive sampling. When a float between 0 and 1, the + proportion of samples from the draw dimension from each chain to be used in + posterior predictive sampling.. If this proportion is very + small, at least one sample will be used. When None, all posterior samples + will be used. Defaults to None. + safe_mode: bool + If True, the function will split the draws into chunks of 10 to avoid memory + issues. Defaults to True. + + Raises + ------ + ValueError + If the model has not been sampled yet and idata is not provided. + + Returns + ------- + az.InferenceData | None + InferenceData or None + """ + if idata is None: + if self._inference_obj is None: + raise ValueError( + "The model has not been sampled yet. " + + "Please either provide an idata object or sample the model first." + ) + idata = self._inference_obj + _logger.info( + "idata=None, we use the traces assigned to the HSSM object as idata." + ) + + if idata is not None: + if "posterior_predictive" in idata.groups(): + del idata["posterior_predictive"] + _logger.warning( + "pre-existing posterior_predictive group deleted from idata. \n" + ) + + if self._check_extra_fields(data): + self._update_extra_fields(data) + + if isinstance(draws, np.ndarray): + draws = draws.astype(int) + elif isinstance(draws, list): + draws = np.array(draws).astype(int) + elif isinstance(draws, int | float): + draws = np.arange(int(draws)) + elif draws is None: + draws = idata["posterior"].draw.values + else: + raise ValueError( + "draws must be an integer, " + "a list of integers, or a numpy array." + ) + + assert isinstance(draws, np.ndarray) + + # Make a copy of idata, set the `posterior` group to be a random sub-sample + # of the original (draw dimension gets sub-sampled) + + idata_copy = idata.copy() + + if (draws.shape != idata["posterior"].draw.values.shape) or ( + (draws.shape == idata["posterior"].draw.values.shape) + and not np.allclose(draws, idata["posterior"].draw.values) + ): + # Reassign posterior to sub-sampled version + setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws)) + + if kind == "response": + # If we run kind == 'response' we actually run the observation RV + if safe_mode: + # safe mode splits the draws into chunks of 10 to avoid + # memory issues (TODO: Figure out the source of memory issues) + split_draws = _split_array( + idata_copy["posterior"].draw.values, divisor=10 + ) + + posterior_predictive_list = [] + for samples_tmp in split_draws: + tmp_posterior = idata["posterior"].sel(draw=samples_tmp) + setattr(idata_copy, "posterior", tmp_posterior) + self.model.predict( + idata_copy, kind, data, True, include_group_specific + ) + posterior_predictive_list.append(idata_copy["posterior_predictive"]) + + if inplace: + idata.add_groups( + posterior_predictive=xr.concat( + posterior_predictive_list, dim="draw" + ) + ) + # for inplace, we don't return anything + return None + else: + # Reassign original posterior to idata_copy + setattr(idata_copy, "posterior", idata["posterior"]) + # Add new posterior predictive group to idata_copy + del idata_copy["posterior_predictive"] + idata_copy.add_groups( + posterior_predictive=xr.concat( + posterior_predictive_list, dim="draw" + ) + ) + return idata_copy + else: + if inplace: + # If not safe-mode + # We call .predict() directly without any + # chunking of data. + + # .predict() is called on the copy of idata + # since we still subsampled (or assigned) the draws + self.model.predict( + idata_copy, kind, data, True, include_group_specific + ) + + # posterior predictive group added to idata + idata.add_groups( + posterior_predictive=idata_copy["posterior_predictive"] + ) + # don't return anything if inplace + return None + else: + # Not safe mode and not inplace + # Function acts as very thin wrapper around + # .predict(). It just operates on the + # idata_copy object + return self.model.predict( + idata_copy, kind, data, False, include_group_specific + ) + elif kind == "response_params": + # If kind == 'response_params', we don't need to run the RV directly, + # there shouldn't really be any significant memory issues here, + # we can simply ignore settings, since the computational overhead + # should be very small --> nudges user towards good outputs. + _logger.warning( + "The kind argument is set to 'mean', but 'draws' argument " + + "is not None: The draws argument will be ignored!" + ) + return self.model.predict( + idata, kind, data, inplace, include_group_specific + ) + else: + raise ValueError("`kind` must be either 'response' or 'response_params'.") + + def plot_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid: + """Produce a posterior predictive plot. + + Equivalent to calling `hssm.plotting.plot_predictive()` with the + model. Please see that function for + [full documentation][hssm.plotting.plot_predictive]. + + Returns + ------- + mpl.axes.Axes | sns.FacetGrid + The matplotlib axis or seaborn FacetGrid object containing the plot. + """ + return plotting.plot_predictive(self, **kwargs) + + def plot_quantile_probability(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid: + """Produce a quantile probability plot. + + Equivalent to calling `hssm.plotting.plot_quantile_probability()` with the + model. Please see that function for + [full documentation][hssm.plotting.plot_quantile_probability]. + + Returns + ------- + mpl.axes.Axes | sns.FacetGrid + The matplotlib axis or seaborn FacetGrid object containing the plot. + """ + return plotting.plot_quantile_probability(self, **kwargs) + + def predict(self, **kwargs) -> az.InferenceData: + """Generate samples from the predictive distribution.""" + return self.model.predict(**kwargs) + + def sample_do( + self, params: dict[str, Any], draws: int = 100, return_model=False, **kwargs + ) -> az.InferenceData | tuple[az.InferenceData, pm.Model]: + """Generate samples from the predictive distribution using the `do-operator`.""" + do_model = do(self.pymc_model, params) + do_idata = pm.sample_prior_predictive(model=do_model, draws=draws, **kwargs) + + # clean up `rt,response_mean` to `v` + do_idata = self._drop_parent_str_from_idata(idata=do_idata) + + # rename otherwise inconsistent dims and coords + if "rt,response_extra_dim_0" in do_idata["prior_predictive"].dims: + setattr( + do_idata, + "prior_predictive", + do_idata["prior_predictive"].rename_dims( + {"rt,response_extra_dim_0": "rt,response_dim"} + ), + ) + if "rt,response_extra_dim_0" in do_idata["prior_predictive"].coords: + setattr( + do_idata, + "prior_predictive", + do_idata["prior_predictive"].rename_vars( + name_dict={"rt,response_extra_dim_0": "rt,response_dim"} + ), + ) + + if return_model: + return do_idata, do_model + return do_idata + + def sample_prior_predictive( + self, + draws: int = 500, + var_names: str | list[str] | None = None, + omit_offsets: bool = True, + random_seed: np.random.Generator | None = None, + ) -> az.InferenceData: + """Generate samples from the prior predictive distribution. + + Parameters + ---------- + draws + Number of draws to sample from the prior predictive distribution. Defaults + to 500. + var_names + A list of names of variables for which to compute the prior predictive + distribution. Defaults to ``None`` which means both observed and unobserved + RVs. + omit_offsets + Whether to omit offset terms. Defaults to ``True``. + random_seed + Seed for the random number generator. + + Returns + ------- + az.InferenceData + ``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and + ``observed_data``. + """ + prior_predictive = self.model.prior_predictive( + draws, var_names, omit_offsets, random_seed + ) + + # AF-COMMENT: Not sure if necessary to include the + # mean prior here (which adds deterministics that + # could be recomputed elsewhere) + prior_predictive.add_groups(posterior=prior_predictive.prior) + # Bambi >= 0.17 renamed kind="mean" to kind="response_params". + self.model.predict(prior_predictive, kind="response_params", inplace=True) + + # clean + setattr(prior_predictive, "prior", prior_predictive["posterior"]) + del prior_predictive["posterior"] + + if self._inference_obj is None: + self._inference_obj = prior_predictive + else: + self._inference_obj.extend(prior_predictive) + + # clean up `rt,response_mean` to `v` + idata = self._drop_parent_str_from_idata(idata=self._inference_obj) + + # rename otherwise inconsistent dims and coords + if "rt,response_extra_dim_0" in idata["prior_predictive"].dims: + setattr( + idata, + "prior_predictive", + idata["prior_predictive"].rename_dims( + {"rt,response_extra_dim_0": "rt,response_dim"} + ), + ) + if "rt,response_extra_dim_0" in idata["prior_predictive"].coords: + setattr( + idata, + "prior_predictive", + idata["prior_predictive"].rename_vars( + name_dict={"rt,response_extra_dim_0": "rt,response_dim"} + ), + ) + + # Update self._inference_obj to match the cleaned idata + self._inference_obj = idata + return deepcopy(self._inference_obj) + + @property + def pymc_model(self) -> pm.Model: + """Provide access to the PyMC model. + + Returns + ------- + pm.Model + The PyMC model built by bambi + """ + return self.model.backend.model + + def set_alias(self, aliases: dict[str, str | dict]): + """Set parameter aliases. + + Sets the aliases according to the dictionary passed to it and rebuild the + model. + + Parameters + ---------- + aliases + A dict specifying the parameter names being aliased and the aliases. + """ + self.model.set_alias(aliases) + self.model.build() + + @property + def response_c(self) -> str: + """Return the response variable names in c() format.""" + if self.response is None: + return "c()" + if len(self.response) == 1: + return self.response[0] + return f"c({', '.join(self.response)})" + + @property + def response_str(self) -> str: + """Return the response variable names in string format.""" + if self.response is None: + return "" + if len(self.response) == 1: + return self.response[0] + return ",".join(self.response) + + # NOTE: can't annotate return type because the graphviz dependency is optional + def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"): + """Produce a graphviz Digraph from a built HSSM model. + + Requires graphviz, which may be installed most easily with `conda install -c + conda-forge python-graphviz`. Alternatively, you may install the `graphviz` + binaries yourself, and then `pip install graphviz` to get the python bindings. + See http://graphviz.readthedocs.io/en/stable/manual.html for more information. + + Parameters + ---------- + formatting + One of `"plain"` or `"plain_with_params"`. Defaults to `"plain"`. + name + Name of the figure to save. Defaults to `None`, no figure is saved. + figsize + Maximum width and height of figure in inches. Defaults to `None`, the + figure size is set automatically. If defined and the drawing is larger than + the given size, the drawing is uniformly scaled down so that it fits within + the given size. Only works if `name` is not `None`. + dpi + Point per inch of the figure to save. + Defaults to 300. Only works if `name` is not `None`. + fmt + Format of the figure to save. + Defaults to `"png"`. Only works if `name` is not `None`. + + Returns + ------- + graphviz.Graph + The graph + """ + graph = self.model.graph(formatting, name, figsize, dpi, fmt) + + parent_param = self._parent_param + if parent_param.is_regression: + return graph + + # Modify the graph + # 1. Remove all nodes and edges related to `{parent}_mean`: + graph.body = [ + item for item in graph.body if f"{parent_param.name}_mean" not in item + ] + # 2. Add a new edge from parent to response + graph.edge(parent_param.name, self.response_str) + + return graph + + def compile_logp(self, keep_transformed: bool = False, **kwargs): + """Compile the log probability function for the model. + + Parameters + ---------- + keep_transformed : bool, optional + If True, keeps the transformed variables in the compiled function. + If False, removes value transforms before compilation. + Defaults to False. + **kwargs + Additional keyword arguments passed to PyMC's compile_logp: + - vars: List of variables. Defaults to None (all variables). + - jacobian: Whether to include log(|det(dP/dQ)|) term for + transformed variables. Defaults to True. + - sum: Whether to sum all terms instead of returning a vector. + Defaults to True. + + Returns + ------- + callable + A compiled function that computes the model log probability. + """ + if keep_transformed: + return self.pymc_model.compile_logp( + vars=kwargs.get("vars", None), + jacobian=kwargs.get("jacobian", True), + sum=kwargs.get("sum", True), + ) + else: + new_model = pm.model.transform.conditioning.remove_value_transforms( + self.pymc_model + ) + return new_model.compile_logp( + vars=kwargs.get("vars", None), + jacobian=kwargs.get("jacobian", True), + sum=kwargs.get("sum", True), + ) + + def plot_trace( + self, + data: az.InferenceData | None = None, + include_deterministic: bool = False, + tight_layout: bool = True, + **kwargs, + ) -> None: + """Generate trace plot with ArviZ but with additional convenience features. + + This is a simple wrapper for the az.plot_trace() function. By default, it + filters out the deterministic values from the plot. Please see the + [arviz documentation] + (https://arviz-devs.github.io/arviz/api/generated/arviz.plot_trace.html) + for additional parameters that can be specified. + + Parameters + ---------- + data : optional + An ArviZ InferenceData object. If None, the traces stored in the model will + be used. + include_deterministic : optional + Whether to include deterministic variables in the plot. Defaults to False. + Note that if include deterministic is set to False and and `var_names` is + provided, the `var_names` provided will be modified to also exclude the + deterministic values. If this is not desirable, set + `include deterministic` to True. + tight_layout : optional + Whether to call plt.tight_layout() after plotting. Defaults to True. + """ + data = data or self.traces + if not isinstance(data, az.InferenceData): + raise TypeError("data must be an InferenceData object.") + + if not include_deterministic: + var_names = list( + set([var.name for var in self.pymc_model.free_RVs]).intersection( + set(list(data["posterior"].data_vars.keys())) + ) + ) + # var_names = self._get_deterministic_var_names(data) + if var_names: + if "var_names" in kwargs: + if isinstance(kwargs["var_names"], str): + if kwargs["var_names"] not in var_names: + var_names.append(kwargs["var_names"]) + kwargs["var_names"] = var_names + elif isinstance(kwargs["var_names"], list): + kwargs["var_names"] = list( + set(var_names) | set(kwargs["var_names"]) + ) + elif kwargs["var_names"] is None: + kwargs["var_names"] = var_names + else: + raise ValueError( + "`var_names` must be a string, a list of strings, or None." + ) + else: + kwargs["var_names"] = var_names + az.plot_trace(data, **kwargs) + + if tight_layout: + plt.tight_layout() + + def summary( + self, + data: az.InferenceData | None = None, + include_deterministic: bool = False, + **kwargs, + ) -> pd.DataFrame | xr.Dataset: + """Produce a summary table with ArviZ but with additional convenience features. + + This is a simple wrapper for the az.summary() function. By default, it + filters out the deterministic values from the plot. Please see the + [arviz documentation] + (https://arviz-devs.github.io/arviz/api/generated/arviz.summary.html) + for additional parameters that can be specified. + + Parameters + ---------- + data + An ArviZ InferenceData object. If None, the traces stored in the model will + be used. + include_deterministic : optional + Whether to include deterministic variables in the plot. Defaults to False. + Note that if include_deterministic is set to False and and `var_names` is + provided, the `var_names` provided will be modified to also exclude the + deterministic values. If this is not desirable, set + `include_deterministic` to True. + + Returns + ------- + pd.DataFrame | xr.Dataset + A pandas DataFrame or xarray Dataset containing the summary statistics. + """ + data = data or self.traces + if not isinstance(data, az.InferenceData): + raise TypeError("data must be an InferenceData object.") + + if not include_deterministic: + var_names = list( + set([var.name for var in self.pymc_model.free_RVs]).intersection( + set(list(data["posterior"].data_vars.keys())) + ) + ) + # var_names = self._get_deterministic_var_names(data) + if var_names: + kwargs["var_names"] = list(set(var_names + kwargs.get("var_names", []))) + return az.summary(data, **kwargs) + + def initial_point(self, transformed: bool = False) -> dict[str, np.ndarray]: + """Compute the initial point of the model. + + This is a slightly altered version of pm.initial_point.initial_point(). + + Parameters + ---------- + transformed : bool, optional + If True, return the initial point in transformed space. + + Returns + ------- + dict + A dictionary containing the initial point of the model parameters. + """ + fn = pm.initial_point.make_initial_point_fn( + model=self.pymc_model, return_transformed=transformed + ) + return pm.model.Point(fn(None), model=self.pymc_model) + + def restore_traces( + self, traces: az.InferenceData | pm.Approximation | str | PathLike + ) -> None: + """Restore traces from an InferenceData object or a .netcdf file. + + Parameters + ---------- + traces + An InferenceData object or a path to a file containing the traces. + """ + if isinstance(traces, pm.Approximation): + self._inference_obj_vi = traces + return + + if isinstance(traces, (str, PathLike)): + traces = az.from_netcdf(traces) + self._inference_obj = cast("az.InferenceData", traces) + + def restore_vi_traces( + self, traces: az.InferenceData | pm.Approximation | str | PathLike + ) -> None: + """Restore VI traces from an InferenceData object or a .netcdf file. + + Parameters + ---------- + traces + An InferenceData object or a path to a file containing the VI traces. + """ + if isinstance(traces, pm.Approximation): + self._inference_obj_vi = traces + return + + if isinstance(traces, (str, PathLike)): + traces = az.from_netcdf(traces) + self._inference_obj_vi = cast("az.InferenceData", traces) + + def save_model( + self, + model_name: str | None = None, + allow_absolute_base_path: bool = False, + base_path: str | Path = "hssm_models", + save_idata_only: bool = False, + ) -> None: + """Save a HSSM model instance and its inference results to disk. + + Parameters + ---------- + model_name : str | None + Name to use for the saved model files. + If None, will use model.model_name with timestamp + allow_absolute_base_path : bool + Whether to allow absolute paths for base_path. + Defaults to False for safety. + base_path : str | Path + Base directory to save model files in. + Must be relative path if allow_absolute_base_path=False. + Defaults to "hssm_models". + save_idata_only : bool + If True, only saves inference data (traces), not the model pickle. + Defaults to False (saves both model and traces). + + Raises + ------ + ValueError + If base_path is absolute and allow_absolute_base_path=False + """ + # Convert to Path object for cross-platform compatibility + base_path = Path(base_path) + + # Check if base_path is absolute (works on all platforms) + if not allow_absolute_base_path and base_path.is_absolute(): + raise ValueError( + "base_path must be a relative path if allow_absolute_base_path is False" + ) + + if model_name is None: + # Get date string format as suffix to model name + timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + model_name = f"{self.model_name}_{timestamp}" + + # Sanitize model_name and construct full path + model_name = model_name.replace(" ", "_") + model_path = Path(base_path).joinpath(model_name) + model_path.mkdir(parents=True, exist_ok=True) + + # Save model to pickle file + if not save_idata_only: + with open(model_path.joinpath("model.pkl"), "wb") as f: + cpickle.dump(self, f) + + # Save traces to netcdf file + if self._inference_obj is not None: + az.to_netcdf(self._inference_obj, model_path.joinpath("traces.nc")) + + # Save vi_traces to netcdf file + if self._inference_obj_vi is not None: + az.to_netcdf(self._inference_obj_vi, model_path.joinpath("vi_traces.nc")) + + @classmethod + def load_model( + cls, path: Union[str, Path] + ) -> Union["HSSMBase", dict[str, Optional[az.InferenceData]]]: + """Load a HSSM model instance and its inference results from disk. + + Parameters + ---------- + path : str | Path + Path to the model directory or model.pkl file. If a directory is provided, + will look for model.pkl, traces.nc and vi_traces.nc files within it. + + Returns + ------- + HSSMBase or dict[str, az.InferenceData | None] + The loaded model instance (with inference results attached if available), + or a dictionary of traces-only InferenceData objects when no model.pkl is + found. + """ + # Convert path to Path object + path = Path(path) + + # If path points to a file, assume it's model.pkl + if path.is_file(): + model_dir = path.parent + model_path = path + else: + # Path points to directory + model_dir = path + model_path = model_dir.joinpath("model.pkl") + + # check if model_dir exists + if not model_dir.exists(): + raise FileNotFoundError(f"Model directory {model_dir} does not exist.") + + # check if model.pkl exists raise logging information if not + if not model_path.exists(): + _logger.info( + f"model.pkl file does not exist in {model_dir}. " + "Attempting to load traces only." + ) + if (not model_dir.joinpath("traces.nc").exists()) and ( + not model_dir.joinpath("vi_traces.nc").exists() + ): + raise FileNotFoundError(f"No traces found in {model_dir}.") + else: + idata_dict = cls.load_model_idata(model_dir) + return idata_dict + else: + # Load model from pickle file + with open(model_path, "rb") as f: + model = cpickle.load(f) + + # Load traces if they exist + traces_path = model_dir.joinpath("traces.nc") + if traces_path.exists(): + model.restore_traces(traces_path) + + # Load VI traces if they exist + vi_traces_path = model_dir.joinpath("vi_traces.nc") + if vi_traces_path.exists(): + model.restore_vi_traces(vi_traces_path) + return model + + @classmethod + def load_model_idata(cls, path: str | Path) -> dict[str, az.InferenceData | None]: + """Load the traces from a model directory. + + Parameters + ---------- + path : str | Path + Path to the model directory containing traces.nc and/or vi_traces.nc files. + + Returns + ------- + dict[str, az.InferenceData | None] + A dictionary with keys "idata_mcmc" and "idata_vi" containing the traces + from the model directory. If the traces do not exist, the corresponding + value will be None. + """ + idata_dict: dict[str, az.InferenceData | None] = {} + model_dir = Path(path) + # check if path exists + if not model_dir.exists(): + raise FileNotFoundError(f"Model directory {model_dir} does not exist.") + + # check if traces.nc exists + traces_path = model_dir.joinpath("traces.nc") + if not traces_path.exists(): + _logger.warning(f"traces.nc file does not exist in {model_dir}.") + idata_dict["idata_mcmc"] = None + else: + idata_dict["idata_mcmc"] = az.from_netcdf(traces_path) + + # check if vi_traces.nc exists + vi_traces_path = model_dir.joinpath("vi_traces.nc") + if not vi_traces_path.exists(): + _logger.warning(f"vi_traces.nc file does not exist in {model_dir}.") + idata_dict["idata_vi"] = None + else: + idata_dict["idata_vi"] = az.from_netcdf(vi_traces_path) + + return idata_dict + + def __getstate__(self): + """Get the state of the model for pickling. + + This method is called when pickling the model. + It returns a dictionary containing the constructor + arguments needed to recreate the model instance. + + Returns + ------- + dict + A dictionary containing the constructor arguments + under the key 'constructor_args'. + """ + # Provide a clear error when the initialization snapshot is missing or + # empty. This makes the contract explicit and avoids an AttributeError + # that is easy to miss for subclasses that forget to capture init args. + if not hasattr(self, "_init_args") or not self._init_args: + raise RuntimeError( + "Model state missing initialization snapshot; ensure subclasses " + "call _store_init_args(locals(), kwargs) early in __init__" + ) + + state = {"constructor_args": self._init_args} + return state + + def __setstate__(self, state): + """Set the state of the model when unpickling. + + This method is called when unpickling the model. It creates a new instance + using the constructor arguments stored in the state dictionary, + and copies its attributes to the current instance. + + Parameters + ---------- + state : dict + A dictionary containing the constructor arguments under the key + 'constructor_args'. + """ + new_instance = self.__class__(**state["constructor_args"]) + self.__dict__ = new_instance.__dict__ + + def __repr__(self) -> str: + """Create a representation of the model.""" + output = [ + "Hierarchical Sequential Sampling Model", + f"Model: {self.model_name}\n", + f"Response variable: {self.response_str}", + f"Likelihood: {self.loglik_kind}", + f"Observations: {len(self.data)}\n", + "Parameters:\n", + ] + + for param in self.params.values(): + if param.name == "p_outlier": + continue + output.append(f"{param.name}:") + + component = self.model.components[param.name] + + # Regression case: + if param.is_regression: + assert isinstance(component, DistributionalComponent) + output.append(f" Formula: {param.formula}") + output.append(" Priors:") + intercept_term = component.intercept_term + if intercept_term is not None: + output.append(_print_prior(intercept_term)) + for _, common_term in component.common_terms.items(): + output.append(_print_prior(common_term)) + for _, group_specific_term in component.group_specific_terms.items(): + output.append(_print_prior(group_specific_term)) + output.append(f" Link: {param.link}") + # None regression case + else: + if param.prior is None: + prior = ( + component.intercept_term.prior + if param.is_parent + else component.prior + ) + else: + prior = param.prior + output.append(f" Prior: {prior}") + output.append(f" Explicit bounds: {param.bounds}") + output.append( + " (ignored due to link function)" + if self.link_settings is not None + else "" + ) + + # TODO: Handle p_outlier regression correctly here. + if self.p_outlier is not None: + output.append("") + output.append(f"Lapse probability: {self.p_outlier.prior}") + output.append(f"Lapse distribution: {self.lapse}") + + return "\n".join(output) + + def __str__(self) -> str: + """Create a string representation of the model.""" + return self.__repr__() + + @property + def traces(self) -> az.InferenceData | pm.Approximation: + """Return the trace of the model after sampling. + + Raises + ------ + ValueError + If the model has not been sampled yet. + + Returns + ------- + az.InferenceData | pm.Approximation + The trace of the model after the last call to `sample()`. + """ + if not self._inference_obj: + raise ValueError("Please sample the model first.") + + return self._inference_obj + + @property + def vi_idata(self) -> az.InferenceData: + """Return the variational inference approximation object. + + Raises + ------ + ValueError + If the model has not been sampled yet. + + Returns + ------- + az.InferenceData + The variational inference approximation object. + """ + if not self._inference_obj_vi: + raise ValueError( + "Please run variational inference first, " + "no variational posterior attached." + ) + + return self._inference_obj_vi + + @property + def vi_approx(self) -> pm.Approximation: + """Return the variational inference approximation object. + + Raises + ------ + ValueError + If the model has not been sampled yet. + + Returns + ------- + pm.Approximation + The variational inference approximation object. + """ + if not self._vi_approx: + raise ValueError( + "Please run variational inference first, " + "no variational approximation attached." + ) + + return self._vi_approx + + @property + def map(self) -> dict: + """Return the MAP estimates of the model parameters. + + Raises + ------ + ValueError + If the model has not been sampled yet. + + Returns + ------- + dict + A dictionary containing the MAP estimates of the model parameters. + """ + if not self._map_dict: + raise ValueError("Please compute map first.") + + return self._map_dict + + @property + def initvals(self) -> dict: + """Return the initial values of the model parameters for sampling. + + Returns + ------- + dict + A dictionary containing the initial values of the model parameters. + This dict serves as the default for initial values, and can be passed + directly to the `.sample()` function. + """ + if self._initvals == {}: + self._initvals = self.initial_point() + return self._initvals + + def _check_lapse(self, lapse): + """Determine if p_outlier and lapse is specified correctly.""" + # Basically, avoid situations where only one of them is specified. + if self.has_lapse and lapse is None: + raise ValueError( + "You have specified `p_outlier`. Please also specify `lapse`." + ) + if lapse is not None and not self.has_lapse: + _logger.warning( + "You have specified the `lapse` argument to include a lapse " + + "distribution, but `p_outlier` is set to either 0 or None. " + + "Your lapse distribution will be ignored." + ) + if "p_outlier" in self.list_params and self.list_params[-1] != "p_outlier": + raise ValueError( + "Please do not include 'p_outlier' in `list_params`. " + + "We automatically append it to `list_params` when `p_outlier` " + + "parameter is not None" + ) + + def _get_deterministic_var_names(self, idata) -> list[str]: + """Filter out the deterministic variables in var_names.""" + var_names = [ + f"~{param_name}" + for param_name, param in self.params.items() + if (param.is_regression) + ] + + if f"{self._parent}_mean" in idata["posterior"].data_vars: + var_names.append(f"~{self._parent}_mean") + + # Parent parameters (always regression implicitly) + # which don't have a formula attached + # should be dropped from var_names, since the actual + # parent name shows up as a regression. + if f"{self._parent}" in idata["posterior"].data_vars: + if self.params[self._parent].formula is None: + # Drop from var_names + var_names = [var for var in var_names if var != f"~{self._parent}"] + + return var_names + + def _drop_parent_str_from_idata( + self, idata: az.InferenceData | None + ) -> az.InferenceData: + """Drop the parent_str variable from an InferenceData object. + + Parameters + ---------- + idata + The InferenceData object to be modified. + + Returns + ------- + xr.Dataset + The modified InferenceData object. + """ + if idata is None: + raise ValueError("Please provide an InferenceData object.") + else: + for group in idata.groups(): + if ("rt,response_mean" in idata[group].data_vars) and ( + self._parent not in idata[group].data_vars + ): + setattr( + idata, + group, + idata[group].rename({"rt,response_mean": self._parent}), + ) + return idata + + def _postprocess_initvals_deterministic( + self, initval_settings: dict = INITVAL_SETTINGS + ) -> None: + """Set initial values for subset of parameters.""" + self._initvals = self.initial_point() + # Consider case where link functions are set to 'log_logit' + # or 'None' + if self.link_settings not in ["log_logit", None]: + _logger.info( + "Not preprocessing initial values, " + + "because none of the two standard link settings are chosen!" + ) + return None + + # Set initial values for particular parameters + for name_, starting_value in self.pymc_model.initial_point().items(): + # strip name of `_log__` and `_interval__` suffixes + name_tmp = name_.replace("_log__", "").replace("_interval__", "") + + # We need to check if the parameter is actually backed by + # a regression. + + # If not, we don't actually apply a link function to it as per default. + # Therefore we need to apply the initial value strategy corresponding + # to 'None' link function. + + # If the user actively supplies a link function, the user + # should also have supplied an initial value insofar it matters. + + if self.params[self._get_prefix(name_tmp)].is_regression: + param_link_setting = self.link_settings + else: + param_link_setting = None + if name_tmp in initval_settings[param_link_setting].keys(): + if self._check_if_initval_user_supplied(name_tmp): + _logger.info( + "User supplied initial value detected for %s, \n" + " skipping overwrite with default value.", + name_tmp, + ) + continue + + # Apply specific settings from initval_settings dictionary + dtype = self._initvals[name_tmp].dtype + self._initvals[name_tmp] = np.array( + initval_settings[param_link_setting][name_tmp] + ).astype(dtype) + + def _get_prefix(self, name_str: str) -> str: + """Resolve parameter prefix, handling underscore-containing RL param names. + + The base-class implementation splits ``name_str`` on the first ``_`` and + returns that single token (e.g. ``"rl_alpha_Intercept" → "rl"``), which + breaks for RL parameters whose names contain underscores. It also uses a + substring check (``"p_outlier" in name_str``) for the lapse parameter, + which would misfire for any parameter whose name merely *contains* that + substring. + + This override replaces both heuristics with a single longest-prefix-first + token search: split on ``_``, then try joining 1…N tokens (longest first) + until a candidate is found in ``self.params``. This is both correct for + multi-token RL param names and collision-free for ``p_outlier``. + """ + if "_" in name_str: + parts = name_str.split("_") + for i in range(len(parts), 0, -1): + candidate = "_".join(parts[:i]) + if hasattr(self, "params") and candidate in self.params: + return candidate + return name_str + + def _check_if_initval_user_supplied( + self, + name_str: str, + return_value: bool = False, + ) -> bool | float | int | np.ndarray | dict[str, Any] | None: + """Check if initial value is user-supplied.""" + # The function assumes that the name_str is either raw parameter name + # or `paramname_Intercept`, because we only really provide special default + # initial values for those types of parameters + + # `p_outlier` is the only basic parameter floating around that has + # an underscore in it's name. + # We need to handle it separately. (Renaming might be better...) + if "_" in name_str: + if "p_outlier" not in name_str: + name_str_prefix = name_str.split("_")[0] + # name_str_suffix = "".join(name_str.split("_")[1:]) + name_str_suffix = name_str[len(name_str_prefix + "_") :] + else: + name_str_prefix = "p_outlier" + if name_str == "p_outlier": + name_str_suffix = "" + else: + # name_str_suffix = "".join(name_str.split("_")[2:]) + name_str_suffix = name_str[len("p_outlier_") :] + else: + name_str_prefix = name_str + name_str_suffix = "" + + tmp_param = name_str_prefix + if tmp_param == self._parent: + # If the parameter was parent it is automatically treated as a + # regression. + if not name_str_suffix: + # No suffix --> Intercept + if isinstance(prior_tmp := self.params[tmp_param].prior, dict): + args_tmp = getattr(prior_tmp["Intercept"], "args") + if return_value: + return args_tmp.get("initval", None) + else: + return "initval" in args_tmp + else: + if return_value: + return None + return False + else: + # If the parameter has a suffix --> use it + if isinstance(prior_tmp := self.params[tmp_param].prior, dict): + args_tmp = getattr(prior_tmp[name_str_suffix], "args") + if return_value: + return args_tmp.get("initval", None) + else: + return "initval" in args_tmp + else: + if return_value: + return None + else: + return False + else: + # If the parameter is not a parent, it is treated as a regression + # only when actively specified as such. + if not name_str_suffix: + # If no suffix --> treat as basic parameter. + if isinstance(self.params[tmp_param].prior, float) or isinstance( + self.params[tmp_param].prior, np.ndarray + ): + if return_value: + return self.params[tmp_param].prior + else: + return True + elif isinstance(self.params[tmp_param].prior, bmb.Prior): + args_tmp = getattr(self.params[tmp_param].prior, "args") + if "initval" in args_tmp: + if return_value: + return args_tmp["initval"] + else: + return True + else: + if return_value: + return None + else: + return False + else: + if return_value: + return None + else: + return False + else: + # If suffix --> treat as regression and use suffix + if isinstance(prior_tmp := self.params[tmp_param].prior, dict): + args_tmp = getattr(prior_tmp[name_str_suffix], "args") + if return_value: + return args_tmp.get("initval", None) + else: + return "initval" in args_tmp + else: + if return_value: + return None + else: + return False + + def _jitter_initvals( + self, jitter_epsilon: float = 0.01, vector_only: bool = False + ) -> None: + """Apply controlled jitter to initial values.""" + if vector_only: + self.__jitter_initvals_vector_only(jitter_epsilon) + else: + self.__jitter_initvals_all(jitter_epsilon) + + def __jitter_initvals_vector_only(self, jitter_epsilon: float) -> None: + # Note: Calling our initial point function here + # --> operate on untransformed variables + initial_point_dict = self.initvals + for name_, starting_value in initial_point_dict.items(): + name_tmp = name_.replace("_log__", "").replace("_interval__", "") + if starting_value.ndim != 0 and starting_value.shape[0] != 1: + starting_value_tmp = starting_value + np.random.uniform( + -jitter_epsilon, jitter_epsilon, starting_value.shape + ).astype(np.float32) + + # Note: self._initvals shouldn't be None when this is called + dtype = self._initvals[name_tmp].dtype + self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype) + + def __jitter_initvals_all(self, jitter_epsilon: float) -> None: + # Note: Calling our initial point function here + # --> operate on untransformed variables + initial_point_dict = self.initvals + # initial_point_dict = self.pymc_model.initial_point() + for name_, starting_value in initial_point_dict.items(): + name_tmp = name_.replace("_log__", "").replace("_interval__", "") + starting_value_tmp = starting_value + np.random.uniform( + -jitter_epsilon, jitter_epsilon, starting_value.shape + ).astype(np.float32) + + dtype = self.initvals[name_tmp].dtype + self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype) diff --git a/src/hssm/config.py b/src/hssm/config.py index 9b9e0b12..f223b859 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -20,27 +20,17 @@ if TYPE_CHECKING: from pytensor.tensor.random.op import RandomVariable -# ====== Centralized RLSSM defaults ===== +import logging + +from ssms.config import model_config as ssms_model_config + +_logger = logging.getLogger("hssm") + + +# ====== Centralized SSM defaults ===== DEFAULT_SSM_OBSERVED_DATA = ["rt", "response"] -DEFAULT_RLSSM_OBSERVED_DATA = ["rt", "response"] DEFAULT_SSM_CHOICES = (0, 1) -RLSSM_REQUIRED_FIELDS = ( - "model_name", - "description", - "list_params", - "bounds", - "params_default", - "data", - "choices", - "decision_process", - "learning_process", - "response", - "decision_process_loglik_kind", - "learning_process_loglik_kind", - "extra_fields", -) - ParamSpec = Union[float, dict[str, Any], Prior, None] @@ -68,6 +58,9 @@ class BaseModelConfig(ABC): # Additional data requirements extra_fields: list[str] | None = None + # Random variable (simulator) for posterior predictive sampling + rv: Any | None = None + @abstractmethod def validate(self) -> None: """Validate configuration. Must be implemented by subclasses.""" @@ -78,6 +71,21 @@ def get_defaults(self, param: str) -> Any: """Get default values for a parameter. Must be implemented by subclasses.""" ... + @property + def n_params(self) -> int | None: + """Return the number of parameters.""" + return len(self.list_params) if self.list_params else None + + @property + def n_extra_fields(self) -> int | None: + """Return the number of extra fields.""" + return len(self.extra_fields) if self.extra_fields else None + + @property + def is_choice_only(self) -> bool: + """Return whether the model is choice-only (no RT).""" + return self.response is not None and len(self.response) == 1 + @dataclass class Config(BaseModelConfig): @@ -179,7 +187,7 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=loglik_kind, - response=DEFAULT_RLSSM_OBSERVED_DATA, + response=DEFAULT_SSM_OBSERVED_DATA, ) def update_loglik(self, loglik: Any | None) -> None: @@ -200,7 +208,7 @@ def update_choices(self, choices: tuple[int, ...] | None) -> None: Parameters ---------- - choices : tuple[int, ...] + choices : tuple[int, ...] | None A tuple of choices. """ if choices is None: @@ -217,7 +225,7 @@ def update_config(self, user_config: ModelConfig) -> None: User specified ModelConfig used update self. """ if user_config.response is not None: - self.response = user_config.response + self.response = list(user_config.response) # type: ignore[assignment] if user_config.list_params is not None: self.list_params = user_config.list_params if user_config.choices is not None: @@ -260,208 +268,62 @@ def get_defaults( """ return self.default_priors.get(param), self.bounds.get(param) - @property - def is_choice_only(self) -> bool: - """Check if the model is a choice-only model.""" - # Treat both None and an empty list as invalid configurations. - if not self.response: - raise ValueError( - "Please provide at least one `response` column in the configuration." - ) - return len(self.response) == 1 - - -@dataclass -class RLSSMConfig(BaseModelConfig): - """Config for reinforcement learning + sequential sampling models. - - This configuration class is designed for models that combine reinforcement - learning processes with sequential sampling decision models (RLSSM). - """ - - decision_process_loglik_kind: str = field(kw_only=True) - learning_process_loglik_kind: str = field(kw_only=True) - params_default: list[float] = field(kw_only=True) - decision_process: str | ModelConfig = field(kw_only=True) - learning_process: dict[str, Any] = field(kw_only=True) - - def __post_init__(self): - """Set default loglik_kind for RLSSM models if not provided.""" - if self.loglik_kind is None: - self.loglik_kind = "approx_differentiable" - - @property - def n_params(self) -> int | None: - """Return the number of parameters.""" - return len(self.list_params) if self.list_params else None - - @property - def n_extra_fields(self) -> int | None: - """Return the number of extra fields.""" - return len(self.extra_fields) if self.extra_fields else None - @classmethod - def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]): - """ - Create RLSSMConfig from a configuration dictionary. - - Parameters - ---------- - model_name : str - The name of the RLSSM model. - config_dict : dict[str, Any] - Dictionary containing model configuration. Expected keys: - - description: Model description (optional) - - list_params: List of parameter names (required) - - extra_fields: List of extra field names from data (required) - - params_default: Default parameter values (required) - - bounds: Parameter bounds (required) - - response: Response column names (required) - - choices: Valid choice values (required) - - decision_process: Decision process specification (required) - - learning_process: Learning process functions (required) - - decision_process_loglik_kind: Likelihood kind for decision process - (required) - - learning_process_loglik_kind: Likelihood kind for learning process - (required) - - Returns - ------- - RLSSMConfig - Configured RLSSM model configuration object. - """ - # Check for required fields and raise explicit errors if missing - for field_name in RLSSM_REQUIRED_FIELDS: - if field_name not in config_dict or config_dict[field_name] is None: - raise ValueError(f"{field_name} must be provided in config_dict") - - return cls( - model_name=model_name, - description=config_dict.get("description"), - list_params=config_dict["list_params"], - extra_fields=config_dict.get("extra_fields"), - params_default=config_dict["params_default"], - decision_process=config_dict["decision_process"], - learning_process=config_dict["learning_process"], - bounds=config_dict.get("bounds", {}), - response=config_dict["response"], - choices=config_dict["choices"], - decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], - learning_process_loglik_kind=config_dict["learning_process_loglik_kind"], - ) - - def validate(self) -> None: - """Validate RLSSM configuration. - - Raises - ------ - ValueError - If required fields are missing or inconsistent. + def _build_model_config( + cls, + model: SupportedModels | str, + loglik_kind: LoglikKind | None, + model_config: ModelConfig | dict | None, + choices: list[int] | tuple[int, ...] | None, + loglik: Any = None, + ) -> Config: + """Build and return a validated Config for standard HSSM models. + + Resolves defaults, normalizes dict/ModelConfig overrides, applies + choices and loglik precedence rules, then validates before returning. """ - if self.response is None: - raise ValueError("Please provide `response` columns in the configuration.") - if self.list_params is None: - raise ValueError("Please provide `list_params` in the configuration.") - if self.choices is None: - raise ValueError("Please provide `choices` in the configuration.") - if self.decision_process is None: - raise ValueError("Please specify a `decision_process`.") + config = cls.from_defaults(model, loglik_kind) - # Validate parameter defaults consistency - if self.params_default and self.list_params: - if len(self.params_default) != len(self.list_params): - raise ValueError( - f"params_default length ({len(self.params_default)}) doesn't " - f"match list_params length ({len(self.list_params)})" - ) - - def get_defaults( - self, param: str - ) -> tuple[float | None, tuple[float, float] | None]: - """Return default value and bounds for a parameter. + if model_config is not None: + final_config = _normalize_model_config_with_choices(model_config, choices) + config.update_config(final_config) - Parameters - ---------- - param - The name of the parameter. + # No model_config provided: apply `choices` when appropriate. + # If caller passed a SupportedModels string, ignore explicit `choices`. + if ( + model in get_args(SupportedModels) + and choices is not None + and model_config is None + ): + _logger.info( + "Model string is in SupportedModels. Ignoring choices arguments." + ) - Returns - ------- - tuple - A tuple of (default_value, bounds) where: - - default_value is a float or None if not found - - bounds is a tuple (lower, upper) or None if not found - """ - # Try to find the parameter in list_params and get its default value - default_val = None - if self.list_params is not None: - try: - param_idx = self.list_params.index(param) - if self.params_default and param_idx < len(self.params_default): - default_val = self.params_default[param_idx] - except ValueError: - # Parameter not in list_params - pass - - return default_val, self.bounds.get(param) - - def to_config(self) -> Config: - """Convert to standard Config for compatibility with HSSM. - - This method transforms the RLSSM configuration into a standard Config - object that can be used with the existing HSSM infrastructure. - - Returns - ------- - Config - A Config object with RLSSM parameters mapped to standard format. - - Notes - ----- - The transformation converts params_default list to default_priors dict, - mapping parameter names to their default values. - """ - # Validate parameter defaults consistency before conversion - if self.params_default and self.list_params: - if len(self.params_default) != len(self.list_params): - raise ValueError( - f"params_default length ({len(self.params_default)}) doesn't " - f"match list_params length ({len(self.list_params)}). " - "This would result in silent data loss during conversion." + # If model is not a supported built-in, prefer explicit choices or + # fall back to ssms-simulators lookup when available. + if model not in get_args(SupportedModels): + if choices is not None: + config.update_choices(choices) + elif model in ssms_model_config: + config.update_choices(ssms_model_config[model]["choices"]) + _logger.info( + "choices argument passed as None, " + "but found %s in ssms-simulators. " + "Using choices, from ssm-simulators configs: %s", + model, + ssms_model_config[model]["choices"], ) - # Transform params_default list to default_priors dict - default_priors = ( - { - param: default - for param, default in zip(self.list_params, self.params_default) - } - if self.list_params and self.params_default - else {} - ) - - return Config( - model_name=self.model_name, - loglik_kind=self.loglik_kind, - response=self.response, - choices=self.choices, - list_params=self.list_params, - description=self.description, - bounds=self.bounds, - default_priors=cast( - "dict[str, float | dict[str, Any] | Any | None]", default_priors - ), - extra_fields=self.extra_fields, - backend=self.backend or "jax", # RLSSM typically uses JAX - loglik=self.loglik, - ) + config.update_loglik(loglik) + config.validate() + return config @dataclass class ModelConfig: """Representation for model_config provided by the user.""" - response: list[str] | None = None + response: tuple[str, ...] | None = None list_params: list[str] | None = None choices: tuple[int, ...] | None = None default_priors: dict[str, ParamSpec] = field(default_factory=dict) @@ -469,3 +331,46 @@ class ModelConfig: backend: Literal["jax", "pytensor"] | None = None rv: RandomVariable | None = None extra_fields: list[str] | None = None + + +def _normalize_model_config_with_choices( + model_config: "ModelConfig" | dict[str, Any], + choices: list[int] | tuple[int, ...] | None, +) -> "ModelConfig": + """Normalize a user-supplied model_config and apply choices. + + Returns a fresh :class:`ModelConfig` instance and does not mutate the + caller's objects. If both ``model_config`` and ``choices`` are provided + and ``model_config`` already contains ``choices``, the value from + ``model_config`` wins (and a log entry is emitted). + """ + # Normalize input to a mutable dict so we can coerce and avoid mutating + # the caller's objects. Build a fresh ModelConfig from that dict. + if isinstance(model_config, ModelConfig): + mc: dict[str, Any] = { + k: getattr(model_config, k) for k in model_config.__dataclass_fields__ + } + else: + mc = model_config.copy() + + # Coerce any existing choices on the input to a tuple for immutability + if mc.get("choices") is not None: + mc["choices"] = tuple(mc["choices"]) + + # If caller didn't provide an explicit `choices` argument, return the + # normalized ModelConfig built from the input (fresh instance). + if choices is None: + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) + + # Caller provided choices; prefer the one embedded in model_config if + # present, otherwise apply the provided value (coerced to tuple). + if mc.get("choices") is not None: + _logger.info( + "choices list provided in both model_config and " + "as an argument directly. Using the one provided in " + "model_config. We recommend providing choices in model_config." + ) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) + + mc["choices"] = tuple(choices) + return ModelConfig(**{k: v for k, v in mc.items() if v is not None}) diff --git a/src/hssm/data_validator.py b/src/hssm/data_validator.py index 01a6373e..4f6871c2 100644 --- a/src/hssm/data_validator.py +++ b/src/hssm/data_validator.py @@ -6,8 +6,6 @@ import numpy as np import pandas as pd -from hssm.defaults import MissingDataNetwork - _logger = logging.getLogger("hssm") @@ -103,55 +101,6 @@ def _post_check_data_sanity(self): # remaining check on missing data # which are coming AFTER the data validation # in the HSSM class, into this function? - def _handle_missing_data_and_deadline(self): - """Handle missing data and deadline.""" - if not self.missing_data and not self.deadline: - # In the case of choice only model, we don't need to do anything with the - # data. - if self.is_choice_only: - return - # In the case where missing_data is set to False, we need to drop the - # cases where rt = na_value - if pd.isna(self.missing_data_value): - na_dropped = self.data.dropna(subset=["rt"]) - else: - na_dropped = self.data.loc[ - self.data["rt"] != self.missing_data_value, : - ] - - if len(na_dropped) != len(self.data): - warnings.warn( - "`missing_data` is set to False, " - + "but you have missing data in your dataset. " - + "Missing data will be dropped.", - stacklevel=2, - ) - self.data = na_dropped - - elif self.missing_data and not self.deadline: - # In the case where missing_data is set to True, we need to replace the - # missing data with a specified na_value - - # Create a shallow copy to avoid modifying the original dataframe - if pd.isna(self.missing_data_value): - self.data["rt"] = self.data["rt"].fillna(-999.0) - else: - self.data["rt"] = self.data["rt"].replace( - self.missing_data_value, -999.0 - ) - - else: # deadline = True - if self.deadline_name not in self.data.columns: - raise ValueError( - "You have specified that your data has deadline, but " - + f"`{self.deadline_name}` is not found in your dataset." - ) - else: - self.data.loc[:, "rt"] = np.where( - self.data["rt"] < self.data[self.deadline_name], - self.data["rt"], - -999.0, - ) def _update_extra_fields(self, new_data: pd.DataFrame | None = None): """Update the extra fields data in self.model_distribution. @@ -174,45 +123,6 @@ def _update_extra_fields(self, new_data: pd.DataFrame | None = None): new_data[field].values for field in self.extra_fields ] - @staticmethod - def _set_missing_data_and_deadline( - missing_data: bool, deadline: bool, data: pd.DataFrame - ) -> MissingDataNetwork: - """Set missing data and deadline.""" - network = MissingDataNetwork.NONE - if not missing_data: - return network - if missing_data and not deadline: - network = MissingDataNetwork.CPN - elif missing_data and deadline: - network = MissingDataNetwork.OPN - # AF-TODO: GONOGO case not yet correctly implemented - # else: - # # TODO: This won't behave as expected yet, GONOGO needs to be split - # # into a deadline case and a non-deadline case. - # network = MissingDataNetwork.GONOGO - - if np.all(data["rt"] == -999.0): - if network in [MissingDataNetwork.CPN, MissingDataNetwork.OPN]: - # AF-TODO: I think we should allow invalid-only datasets. - raise ValueError( - "`missing_data` is set to True, but you have no valid data in your " - "dataset." - ) - # AF-TODO: This one needs refinement for GONOGO case - # elif network == MissingDataNetwork.OPN: - # raise ValueError( - # "`deadline` is set to True and `missing_data` is set to True, " - # "but ." - # ) - # else: - # raise ValueError( - # "`missing_data` and `deadline` are both set to True, - # "but you have " - # "no missing data and/or no rts exceeding the deadline." - # ) - return network - def _validate_choices(self): """ Ensure that `choices` is provided (not None). diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 2bf01d63..1e43582d 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -6,59 +6,40 @@ This file defines the entry class HSSM. """ -import datetime import logging -import typing from copy import deepcopy -from inspect import isclass, signature +from inspect import isclass from os import PathLike -from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union, cast, get_args +from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import cast as typing_cast -import arviz as az import bambi as bmb -import cloudpickle as cpickle -import matplotlib as mpl -import matplotlib.pyplot as plt import numpy as np import pandas as pd import pymc as pm -import pytensor -import seaborn as sns -import xarray as xr -from bambi.model_components import DistributionalComponent -from bambi.transformations import transformations_namespace -from pymc.model.transform.conditioning import do -from ssms.config import model_config as ssms_model_config from hssm._types import LoglikKind, SupportedModels -from hssm.data_validator import DataValidatorMixin from hssm.defaults import ( INITVAL_JITTER_SETTINGS, - INITVAL_SETTINGS, MissingDataNetwork, missing_data_networks_suffix, ) from hssm.distribution_utils import ( assemble_callables, make_distribution, - make_family, make_hssm_rv, make_likelihood_callable, make_missing_data_callable, ) from hssm.utils import ( - _compute_log_likelihood, - _get_alias_dict, - _print_prior, _rearrange_data, - _split_array, ) -from . import plotting +from .base import HSSMBase from .config import Config, ModelConfig -from .param import Params -from .param import UserParam as Param + +if TYPE_CHECKING: + from pytensor.graph.op import Op _logger = logging.getLogger("hssm") @@ -98,7 +79,7 @@ def __get__(self, instance, owner): # noqa: D105 return self.fget(owner) -class HSSM(DataValidatorMixin): +class HSSM(HSSMBase): """The basic Hierarchical Sequential Sampling Model (HSSM) class. Parameters @@ -124,9 +105,12 @@ class HSSM(DataValidatorMixin): model. If left unspecified, defaults will be used for all parameter specifications. Defaults to None. model_config : optional - A dictionary containing the model configuration information. If None is - provided, defaults will be used if there are any. Defaults to None. - Fields for this `dict` are usually: + A :class:`~hssm.config.BaseModelConfig` / :class:`~hssm.config.Config` + instance or a ``dict`` with model configuration information. The + constructor accepts a typed ``ModelConfig`` or a plain ``dict``; when a + ``dict`` is provided the library will build a typed :class:`Config` + via the factory function. If ``None`` is provided, defaults will be + used where available. Fields for this config are usually: - `"list_params"`: a list of parameters indicating the parameters of the model. The order in which the parameters are specified in this list is important. @@ -276,1753 +260,51 @@ def __init__( data: pd.DataFrame, model: SupportedModels | str = "ddm", choices: list[int] | None = None, - include: list[dict[str, Any] | Param] | None = None, + include: list[dict[str, Any] | Any] | None = None, model_config: ModelConfig | dict | None = None, loglik: ( - str | PathLike | Callable | pytensor.graph.Op | type[pm.Distribution] | None + str | PathLike | Callable | pm.Distribution | type[pm.Distribution] | None ) = None, loglik_kind: LoglikKind | None = None, p_outlier: float | dict | bmb.Prior | None = 0.05, - lapse: float | dict | bmb.Prior | None = None, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), global_formula: str | None = None, link_settings: Literal["log_logit"] | None = None, prior_settings: Literal["safe"] | None = "safe", extra_namespace: dict[str, Any] | None = None, missing_data: bool | float = False, deadline: bool | str = False, - loglik_missing_data: ( - str | PathLike | Callable | pytensor.graph.Op | None - ) = None, + loglik_missing_data: (str | PathLike | Callable | None) = None, process_initvals: bool = True, initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], - **kwargs, - ): - # Attach arguments to the instance - # so that we can easily define some - # methods that need to access these - # arguments (context: pickling / save - load). - - # Define a dict with all call arguments: - self._init_args = { - k: v for k, v in locals().items() if k not in ["self", "kwargs"] - } - if kwargs: - self._init_args.update(kwargs) - - self.data = data.copy() - self._inference_obj: az.InferenceData | None = None - self._initvals: dict[str, Any] = {} - self.initval_jitter = initval_jitter - self._inference_obj_vi: pm.Approximation | None = None - self._vi_approx = None - self._map_dict = None - self.global_formula = global_formula - - self.link_settings = link_settings - self.prior_settings = prior_settings - - self.missing_data_value = -999.0 - - additional_namespace = transformations_namespace.copy() - if extra_namespace is not None: - additional_namespace.update(extra_namespace) - self.additional_namespace = additional_namespace - - # Construct a model_config from defaults - self.model_config = Config.from_defaults(model, loglik_kind) - # Update defaults with user-provided config, if any - if model_config is not None: - if isinstance(model_config, dict): - if "choices" not in model_config: - if choices is not None: - model_config["choices"] = tuple(choices) - else: - if choices is not None: - _logger.info( - "choices list provided in both model_config and " - "as an argument directly." - " Using the one provided in model_config. \n" - "We recommend providing choices in model_config." - ) - elif isinstance(model_config, ModelConfig): - if model_config.choices is None: - if choices is not None: - model_config.choices = tuple(choices) - else: - if choices is not None: - _logger.info( - "choices list provided in both model_config and " - "as an argument directly." - " Using the one provided in model_config. \n" - "We recommend providing choices in model_config." - ) - - self.model_config.update_config( - model_config - if isinstance(model_config, ModelConfig) - else ModelConfig(**model_config) # also serves as dict validation - ) - else: - # Model config is not provided, but at this point was constructed from - # defaults. - if model not in typing.get_args(SupportedModels): - # TODO: ideally use self.supported_models above but mypy doesn't like it - if choices is not None: - self.model_config.update_choices(choices) - elif model in ssms_model_config: - self.model_config.update_choices( - ssms_model_config[model]["choices"] - ) - _logger.info( - "choices argument passed as None, " - "but found %s in ssms-simulators. " - "Using choices, from ssm-simulators configs: %s", - model, - ssms_model_config[model]["choices"], - ) - else: - # Model config already constructed from defaults, and model string is - # in SupportedModels. So we are guaranteed that choices are in - # self.model_config already. - - if choices is not None: - _logger.info( - "Model string is in SupportedModels." - " Ignoring choices arguments." - ) - - # Update loglik with user-provided value - self.model_config.update_loglik(loglik) - # Ensure that all required fields are valid - self.model_config.validate() - - # Set up shortcuts so old code will work - # TODO: add to HSSMBase - self.response = self.model_config.response[:] - self.list_params = self.model_config.list_params - self.choices = self.model_config.choices - self.model_name = self.model_config.model_name - self.loglik = self.model_config.loglik - self.loglik_kind = self.model_config.loglik_kind - self.extra_fields = self.model_config.extra_fields - - # TODO: add to HSSMBase - self.response = cast("list[str]", self.response) - self.is_choice_only: bool = self.model_config.is_choice_only - - if self.choices is None: - raise ValueError( - "`choices` must be provided either in `model_config` or as an argument." - ) - - self.n_choices = len(self.choices) - - self._validate_choices() - self._pre_check_data_sanity() - - # Process missing data setting - # AF-TODO: Could be a function in data validator? - # TODO: Move to the MissingDataMixin class when we have it - if self.is_choice_only and missing_data is not False: - raise ValueError("Choice-only models cannot have missing data.") - - if not self.is_choice_only: - if isinstance(missing_data, float): - if not ((self.data.rt == missing_data).any()): - raise ValueError( - f"missing_data argument is provided as a float {missing_data}, " - f"However, you have no RTs of {missing_data} in your dataset!" - ) - else: - self.missing_data = True - self.missing_data_value = missing_data - elif isinstance(missing_data, bool): - if missing_data and (not (self.data.rt == -999.0).any()): - raise ValueError( - "missing_data argument is provided as True, " - " so RTs of -999.0 are treated as missing. \n" - "However, you have no RTs of -999.0 in your dataset!" - ) - elif (not missing_data) and (self.data.rt == -999.0).any(): - # self.missing_data = True - raise ValueError( - "Missing data provided as False. \n" - "However, you have RTs of -999.0 in your dataset!" - ) - else: - self.missing_data = missing_data - else: - raise ValueError( - "missing_data argument must be a bool or a float! \n" - f"You provided: {type(missing_data)}" - ) - else: - self.missing_data = False - - if isinstance(deadline, str): - self.deadline = True - self.deadline_name = deadline - else: - self.deadline = deadline - self.deadline_name = "deadline" - - if ( - not self.missing_data and not self.deadline - ) and loglik_missing_data is not None: - raise ValueError( - "You have specified a loglik_missing_data function, but you have not " - + "set the missing_data or deadline flag to True." - ) - self.loglik_missing_data = loglik_missing_data + **kwargs: Any, + ) -> None: + # ===== save/load serialisation ===== + self._init_args = self._store_init_args(locals(), kwargs) - # Update data based on missing_data and deadline - self._handle_missing_data_and_deadline() - # Set self.missing_data_network based on `missing_data` and `deadline` - self.missing_data_network = self._set_missing_data_and_deadline( - self.missing_data, self.deadline, self.data + # Build typed Config via factory + config = Config._build_model_config( + model, loglik_kind, model_config, choices, loglik ) - if self.deadline: - # self.response is a tuple (from Config); use concatenation. - self.response.append(self.deadline_name) - - # Process lapse distribution - self.has_lapse = p_outlier is not None and p_outlier != 0 - self._check_lapse(lapse) - - # Process all parameters - self.params = Params.from_user_specs( - model=self, - include=[] if include is None else include, - kwargs=kwargs, + super().__init__( + data=data, + model_config=config, + include=include, p_outlier=p_outlier, - ) - - self._parent = self.params.parent - self._parent_param = self.params.parent_param - - self._validate_fixed_vectors() - self.formula, self.priors, self.link = self.params.parse_bambi(model=self) - - # For parameters that have a regression backend, apply bounds at the likelihood - # level to ensure that the samples that are out of bounds - # are discarded (replaced with a large negative value). - self.bounds = { - name: param.bounds - for name, param in self.params.items() - if param.is_regression and param.bounds is not None - } - - # Set p_outlier and lapse - self.p_outlier = self.params.get("p_outlier") - - self._post_check_data_sanity() - - self.model_distribution = self._make_model_distribution() - - self.family = make_family( - self.model_distribution, - self.list_params, - self.link, - self._parent, - ) - - self.model = bmb.Model( - self.formula, - data=self.data, - family=self.family, - priors=self.priors, # center_predictors=False - extra_namespace=self.additional_namespace, + lapse=lapse, + global_formula=global_formula, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, + initval_jitter=initval_jitter, **kwargs, ) - self._aliases = _get_alias_dict( - self.model, self._parent_param, self.response_c, self.response_str - ) - self.set_alias(self._aliases) - self.model.build() - - # Bambi >= 0.17 declares dims=("__obs__",) for intercept-only - # deterministics that actually have shape (1,). This causes an - # xarray CoordinateValidationError during pm.sample() when ArviZ - # tries to create a DataArray with mismatched dimension sizes. - # Fix by removing the dims declaration for these deterministics. - self._fix_scalar_deterministic_dims() - - if process_initvals: - self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS) - if self.initval_jitter > 0: - self._jitter_initvals( - jitter_epsilon=self.initval_jitter, - vector_only=True, - ) - - # Make sure we reset rvs_to_initial_values --> Only None's - # Otherwise PyMC barks at us when asking to compute likelihoods - self.pymc_model.rvs_to_initial_values.update( - {key_: None for key_ in self.pymc_model.rvs_to_initial_values.keys()} - ) - _logger.info("Model initialized successfully.") - - def _fix_scalar_deterministic_dims(self) -> None: - """Fix dims metadata for scalar deterministics. - - Bambi >= 0.17 returns shape ``(1,)`` for intercept-only - deterministics but still declares ``dims=("__obs__",)``. This causes - an xarray ``CoordinateValidationError`` during ``pm.sample()`` because - the ``__obs__`` coordinate has ``n_obs`` entries. Removing the dims - declaration for these variables lets ArviZ handle them as - un-dimensioned arrays, avoiding the conflict. - """ - n_obs = len(self.data) - dims_dict = self.pymc_model.named_vars_to_dims - for det in self.pymc_model.deterministics: - if det.name not in dims_dict: - continue - dims = dims_dict[det.name] - if "__obs__" in dims: - # Check static shape: if it doesn't match n_obs, remove dims - try: - shape_0 = det.type.shape[0] - except (IndexError, TypeError): - continue - if shape_0 is not None and shape_0 != n_obs: - del dims_dict[det.name] - - def _validate_fixed_vectors(self) -> None: - """Validate that fixed-vector parameters have the correct length. - - Fixed-vector parameters (``prior=np.ndarray``) bypass Bambi's formula - system entirely --- they are passed as a scalar ``0.0`` placeholder to - Bambi, and the real vector is substituted inside - ``HSSMDistribution.logp()`` (see ``dist.py``). Because this - substitution is invisible to Bambi, we must validate the vector length - against ``len(self.data)`` up front to catch shape mismatches early. - """ - for name, param in self.params.items(): - if isinstance(param.prior, np.ndarray): - if len(param.prior) != len(self.data): - raise ValueError( - f"Fixed vector for parameter '{name}' has length " - f"{len(param.prior)}, but data has {len(self.data)} rows." - ) - - @classproperty - def supported_models(cls) -> tuple[SupportedModels, ...]: - """Get a tuple of all supported models. - - Returns - ------- - tuple[SupportedModels, ...] - A tuple containing all supported model names. - """ - return get_args(SupportedModels) - - @classmethod - def _store_init_args(cls, *args, **kwargs): - """Store initialization arguments using signature binding.""" - sig = signature(cls.__init__) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - return {k: v for k, v in bound_args.arguments.items() if k != "self"} - - def find_MAP(self, **kwargs): - """Perform Maximum A Posteriori estimation. - - Returns - ------- - dict - A dictionary containing the MAP estimates of the model parameters. - """ - self._map_dict = pm.find_MAP(model=self.pymc_model, **kwargs) - return self._map_dict - - def sample( - self, - sampler: Literal["pymc", "numpyro", "blackjax", "nutpie", "laplace"] - | None = None, - init: str | None = None, - initvals: str | dict | None = None, - include_response_params: bool = False, - **kwargs, - ) -> az.InferenceData | pm.Approximation: - """Perform sampling using the `fit` method via bambi.Model. - - Parameters - ---------- - sampler: optional - The sampler to use. Can be one of "pymc", "numpyro", - "blackjax", "nutpie", or "laplace". If using `blackbox` likelihoods, - this cannot be "numpyro", "blackjax", or "nutpie". By default it is None, - and sampler will automatically be chosen: when the model uses the - `approx_differentiable` likelihood, and `jax` backend, "numpyro" will - be used. Otherwise, "pymc" (the default PyMC NUTS sampler) will be used. - - Note that the old sampler names such as "mcmc", "nuts_numpyro", - "nuts_blackjax" will be deprecated and removed in future releases. A warning - will be raised if any of these old names are used. - init: optional - Initialization method to use for the sampler. If any of the NUTS samplers - is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`. - initvals: optional - Pass initial values to the sampler. This can be a dictionary of initial - values for parameters of the model, or a string "map" to use initialization - at the MAP estimate. If "map" is used, the MAP estimate will be computed if - not already attached to the base class from prior call to 'find_MAP'. - include_response_params: optional - Include parameters of the response distribution in the output. These usually - take more space than other parameters as there's one of them per - observation. Defaults to False. - kwargs - Other arguments passed to bmb.Model.fit(). Please see [here] - (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit) - for full documentation. - - Returns - ------- - az.InferenceData | pm.Approximation - A reference to the `model.traces` object, which stores the traces of the - last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData` - instance if `sampler` is `"pymc"` (default), `"numpyro"`, - `"blackjax"` or "`laplace". - """ - # If initvals are None (default) - # we skip processing initvals here. - if sampler in _new_sampler_mapping: - _logger.warning( - f"Sampler '{sampler}' is deprecated. " - "Please use the new sampler names: " - "'pymc', 'numpyro', 'blackjax', 'nutpie', or 'laplace'." - ) - sampler = _new_sampler_mapping[sampler] # type: ignore - - if sampler == "vi": - raise ValueError( - "VI is not supported via the sample() method. " - "Please use the vi() method instead." - ) - - if initvals is not None: - if isinstance(initvals, dict): - kwargs["initvals"] = initvals - else: - if isinstance(initvals, str): - if initvals == "map": - if self._map_dict is None: - _logger.info( - "initvals='map' but no map" - "estimate precomputed. \n" - "Running map estimation first..." - ) - self.find_MAP() - kwargs["initvals"] = self._map_dict - else: - kwargs["initvals"] = self._map_dict - else: - raise ValueError( - "initvals argument must be a dictionary or 'map'" - " to use the MAP estimate." - ) - else: - kwargs["initvals"] = self._initvals - _logger.info("Using default initvals. \n") - - if sampler is None: - if ( - self.loglik_kind == "approx_differentiable" - and self.model_config.backend == "jax" - ): - sampler = "numpyro" - else: - sampler = "pymc" - - if self.loglik_kind == "blackbox": - if sampler in ["blackjax", "numpyro", "nutpie"]: - raise ValueError( - f"{sampler} sampler does not work with blackbox likelihoods." - ) - - if "step" not in kwargs: - kwargs |= {"step": pm.Slice(model=self.pymc_model)} - - if ( - self.loglik_kind == "approx_differentiable" - and self.model_config.backend == "jax" - and sampler == "pymc" - and kwargs.get("cores", None) != 1 - ): - _logger.warning( - "Parallel sampling might not work with `jax` backend and the PyMC NUTS " - + "sampler on some platforms. Please consider using `numpyro`, " - + "`blackjax`, or `nutpie` sampler if that is a problem." - ) - - if self._check_extra_fields(): - self._update_extra_fields() - - if init is None: - if sampler in ["pymc", "numpyro", "blackjax", "nutpie"]: - init = "adapt_diag" - else: - init = "auto" - - # If sampler is finally `numpyro` make sure - # the jitter argument is set to False - if sampler == "numpyro": - if "nuts_sampler_kwargs" in kwargs: - if kwargs["nuts_sampler_kwargs"].get("jitter"): - _logger.warning( - "The jitter argument is set to True. " - + "This argument is not supported " - + "by the numpyro backend. " - + "The jitter argument will be set to False." - ) - kwargs["nuts_sampler_kwargs"]["jitter"] = False - else: - kwargs["nuts_sampler_kwargs"] = {"jitter": False} - - if sampler != "pymc" and "step" in kwargs: - raise ValueError( - "`step` samplers (enabled by the `step` argument) are only supported " - "by the `pymc` sampler." - ) - - if self._inference_obj is not None: - _logger.warning( - "The model has already been sampled. Overwriting the previous " - + "inference object. Any previous reference to the inference object " - + "will still point to the old object." - ) - - # Define whether likelihood should be computed - compute_likelihood = True - if "idata_kwargs" in kwargs: - if "log_likelihood" in kwargs["idata_kwargs"]: - compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True) - - omit_offsets = kwargs.pop("omit_offsets", False) - self._inference_obj = self.model.fit( - inference_method=sampler, - init=init, - include_response_params=include_response_params, - omit_offsets=omit_offsets, - **kwargs, - ) - - # Separate out log likelihood computation - if compute_likelihood: - self.log_likelihood(self._inference_obj, inplace=True) - - # Subset data vars in posterior - self._clean_posterior_group(idata=self._inference_obj) - return self.traces - - def vi( - self, - method: str = "advi", - niter: int = 10000, - draws: int = 1000, - return_idata: bool = True, - ignore_mcmc_start_point_defaults=False, - **vi_kwargs, - ) -> pm.Approximation | az.InferenceData: - """Perform Variational Inference. - - Parameters - ---------- - niter : int - The number of iterations to run the VI algorithm. Defaults to 3000. - method : str - The method to use for VI. Can be one of "advi" or "fullrank_advi", "svgd", - "asvgd".Defaults to "advi". - draws : int - The number of samples to draw from the posterior distribution. - Defaults to 1000. - return_idata : bool - If True, returns an InferenceData object. Otherwise, returns the - approximation object directly. Defaults to True. - - Returns - ------- - pm.Approximation or az.InferenceData: The mean field approximation object. - """ - if self.loglik_kind == "analytical": - _logger.warning( - "VI is not recommended for the analytical likelihood," - " since gradients can be brittle." - ) - elif self.loglik_kind == "blackbox": - raise ValueError( - "VI is not supported for blackbox likelihoods, " - " since likelihood gradients are needed!" - ) - - if ("start" not in vi_kwargs) and not ignore_mcmc_start_point_defaults: - _logger.info("Using MCMC starting point defaults.") - vi_kwargs["start"] = self._initvals - - # Run variational inference directly from pymc model - with self.pymc_model: - self._vi_approx = pm.fit(n=niter, method=method, **vi_kwargs) - - # Sample from the approximate posterior - if self._vi_approx is not None: - self._inference_obj_vi = self._vi_approx.sample(draws) - - # Post-processing - self._clean_posterior_group(idata=self._inference_obj_vi) - - # Return the InferenceData object if return_idata is True - if return_idata: - return self._inference_obj_vi - # Otherwise return the appromation object directly - return self.vi_approx - - def _clean_posterior_group(self, idata: az.InferenceData | None = None): - """Clean up the posterior group of the InferenceData object. - - Parameters - ---------- - idata : az.InferenceData - The InferenceData object to clean up. If None, the last InferenceData object - will be used. - """ - # # Logic behind which variables to keep: - # # We essentially want to get rid of - # # all the trial-wise variables. - - # # We drop all distributional components, IF they are deterministics - # # (in which case they will be trial wise systematically) - # # and we keep distributional components, IF they are - # # basic random-variabels (in which case they should never - # # appear trial-wise). - if idata is None: - raise ValueError( - "The InferenceData object is None. Cannot clean up the posterior group." - ) - elif not hasattr(idata, "posterior"): - raise ValueError( - "The InferenceData object does not have a posterior group. " - + "Cannot clean up the posterior group." - ) - - vars_to_keep = set(idata["posterior"].data_vars.keys()).difference( - set( - key_ - for key_ in self.model.distributional_components.keys() - if key_ in [var_.name for var_ in self.pymc_model.deterministics] - ) - ) - vars_to_keep_clean = [ - var_ - for var_ in vars_to_keep - if isinstance(var_, str) and "_mean" not in var_ - ] - - setattr( - idata, - "posterior", - idata["posterior"][vars_to_keep_clean], - ) - - def log_likelihood( - self, - idata: az.InferenceData | None = None, - data: pd.DataFrame | None = None, - inplace: bool = True, - keep_likelihood_params: bool = False, - ) -> az.InferenceData | None: - """Compute the log likelihood of the model. - - Parameters - ---------- - idata : optional - The `InferenceData` object returned by `HSSM.sample()`. If not provided, - data : optional - A pandas DataFrame with values for the predictors that are used to obtain - out-of-sample predictions. If omitted, the original dataset is used. - inplace : optional - If `True` will modify idata in-place and append a `log_likelihood` group to - `idata`. Otherwise, it will return a copy of idata with the predictions - added, by default True. - keep_likelihood_params : optional - If `True`, the trial wise likelihood parameters that are computed - on route to getting the log likelihood are kept in the `idata` object. - Defaults to False. See also the method `add_likelihood_parameters_to_idata`. - - Returns - ------- - az.InferenceData | None - InferenceData or None - """ - if self._inference_obj is None and idata is None: - raise ValueError( - "Neither has the model been sampled yet nor" - + " an idata object has been provided." - ) - - if idata is None: - if self._inference_obj is None: - raise ValueError( - "The model has not been sampled yet. " - + "Please provide an idata object." - ) - else: - idata = self._inference_obj - - # Actual likelihood computation - idata = _compute_log_likelihood(self.model, idata, data, inplace) - - # clean up posterior: - if not keep_likelihood_params: - self._clean_posterior_group(idata=idata) - - if inplace: - return None - else: - return idata - - def add_likelihood_parameters_to_idata( - self, - idata: az.InferenceData | None = None, - inplace: bool = False, - ) -> az.InferenceData | None: - """Add likelihood parameters to the InferenceData object. - - Parameters - ---------- - idata : az.InferenceData - The InferenceData object returned by HSSM.sample(). - inplace : bool - If True, the likelihood parameters are added to idata in-place. Otherwise, - a copy of idata with the likelihood parameters added is returned. - Defaults to False. - - Returns - ------- - az.InferenceData | None - InferenceData or None - """ - if idata is None: - if self._inference_obj is None: - raise ValueError("No idata provided and model not yet sampled!") - else: - idata = self.model._compute_likelihood_params( # pylint: disable=protected-access - deepcopy(self._inference_obj) - if not inplace - else self._inference_obj - ) - else: - idata = self.model._compute_likelihood_params( # pylint: disable=protected-access - deepcopy(idata) if not inplace else idata - ) - return idata - - def sample_posterior_predictive( - self, - idata: az.InferenceData | None = None, - data: pd.DataFrame | None = None, - inplace: bool = True, - include_group_specific: bool = True, - kind: Literal["response", "response_params"] = "response", - draws: int | float | list[int] | np.ndarray | None = None, - safe_mode: bool = True, - ) -> az.InferenceData | None: - """Perform posterior predictive sampling from the HSSM model. - - Parameters - ---------- - idata : optional - The `InferenceData` object returned by `HSSM.sample()`. If not provided, - the `InferenceData` from the last time `sample()` is called will be used. - data : optional - An optional data frame with values for the predictors that are used to - obtain out-of-sample predictions. If omitted, the original dataset is used. - inplace : optional - If `True` will modify idata in-place and append a `posterior_predictive` - group to `idata`. Otherwise, it will return a copy of idata with the - predictions added, by default True. - include_group_specific : optional - If `True` will make predictions including the group specific effects. - Otherwise, predictions are made with common effects only (i.e. group- - specific are set to zero), by default True. - kind: optional - Indicates the type of prediction required. Can be `"response_params"` or - `"response"`. The first returns draws from the posterior distribution of the - likelihood parameters, while the latter returns the draws from the posterior - predictive distribution (i.e. the posterior probability distribution for a - new observation) in addition to the posterior distribution. Defaults to - "response_params". - draws: optional - The number of samples to draw from the posterior predictive distribution - from each chain. - When it's an integer >= 1, the number of samples to be extracted from the - `draw` dimension. If this integer is larger than the number of posterior - samples in each chain, all posterior samples will be used - in posterior predictive sampling. When a float between 0 and 1, the - proportion of samples from the draw dimension from each chain to be used in - posterior predictive sampling.. If this proportion is very - small, at least one sample will be used. When None, all posterior samples - will be used. Defaults to None. - safe_mode: bool - If True, the function will split the draws into chunks of 10 to avoid memory - issues. Defaults to True. - - Raises - ------ - ValueError - If the model has not been sampled yet and idata is not provided. - - Returns - ------- - az.InferenceData | None - InferenceData or None - """ - if idata is None: - if self._inference_obj is None: - raise ValueError( - "The model has not been sampled yet. " - + "Please either provide an idata object or sample the model first." - ) - idata = self._inference_obj - _logger.info( - "idata=None, we use the traces assigned to the HSSM object as idata." - ) - - if idata is not None: - if "posterior_predictive" in idata.groups(): - del idata["posterior_predictive"] - _logger.warning( - "pre-existing posterior_predictive group deleted from idata. \n" - ) - - if self._check_extra_fields(data): - self._update_extra_fields(data) - - if isinstance(draws, np.ndarray): - draws = draws.astype(int) - elif isinstance(draws, list): - draws = np.array(draws).astype(int) - elif isinstance(draws, int | float): - draws = np.arange(int(draws)) - elif draws is None: - draws = idata["posterior"].draw.values - else: - raise ValueError( - "draws must be an integer, " + "a list of integers, or a numpy array." - ) - - assert isinstance(draws, np.ndarray) - - # Make a copy of idata, set the `posterior` group to be a random sub-sample - # of the original (draw dimension gets sub-sampled) - - idata_copy = idata.copy() - - if (draws.shape != idata["posterior"].draw.values.shape) or ( - (draws.shape == idata["posterior"].draw.values.shape) - and not np.allclose(draws, idata["posterior"].draw.values) - ): - # Reassign posterior to sub-sampled version - setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws)) - - if kind == "response": - # If we run kind == 'response' we actually run the observation RV - if safe_mode: - # safe mode splits the draws into chunks of 10 to avoid - # memory issues (TODO: Figure out the source of memory issues) - split_draws = _split_array( - idata_copy["posterior"].draw.values, divisor=10 - ) - - posterior_predictive_list = [] - for samples_tmp in split_draws: - tmp_posterior = idata["posterior"].sel(draw=samples_tmp) - setattr(idata_copy, "posterior", tmp_posterior) - self.model.predict( - idata_copy, kind, data, True, include_group_specific - ) - posterior_predictive_list.append(idata_copy["posterior_predictive"]) - - if inplace: - idata.add_groups( - posterior_predictive=xr.concat( - posterior_predictive_list, dim="draw" - ) - ) - # for inplace, we don't return anything - return None - else: - # Reassign original posterior to idata_copy - setattr(idata_copy, "posterior", idata["posterior"]) - # Add new posterior predictive group to idata_copy - del idata_copy["posterior_predictive"] - idata_copy.add_groups( - posterior_predictive=xr.concat( - posterior_predictive_list, dim="draw" - ) - ) - return idata_copy - else: - if inplace: - # If not safe-mode - # We call .predict() directly without any - # chunking of data. - - # .predict() is called on the copy of idata - # since we still subsampled (or assigned) the draws - self.model.predict( - idata_copy, kind, data, True, include_group_specific - ) - - # posterior predictive group added to idata - idata.add_groups( - posterior_predictive=idata_copy["posterior_predictive"] - ) - # don't return anything if inplace - return None - else: - # Not safe mode and not inplace - # Function acts as very thin wrapper around - # .predict(). It just operates on the - # idata_copy object - return self.model.predict( - idata_copy, kind, data, False, include_group_specific - ) - elif kind == "response_params": - # If kind == 'response_params', we don't need to run the RV directly, - # there shouldn't really be any significant memory issues here, - # we can simply ignore settings, since the computational overhead - # should be very small --> nudges user towards good outputs. - _logger.warning( - "The kind argument is set to 'mean', but 'draws' argument " - + "is not None: The draws argument will be ignored!" - ) - return self.model.predict( - idata, kind, data, inplace, include_group_specific - ) - else: - raise ValueError("`kind` must be either 'response' or 'response_params'.") - - def plot_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid: - """Produce a posterior predictive plot. - - Equivalent to calling `hssm.plotting.plot_predictive()` with the - model. Please see that function for - [full documentation][hssm.plotting.plot_predictive]. - - Returns - ------- - mpl.axes.Axes | sns.FacetGrid - The matplotlib axis or seaborn FacetGrid object containing the plot. - """ - return plotting.plot_predictive(self, **kwargs) - - def plot_quantile_probability(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid: - """Produce a quantile probability plot. - - Equivalent to calling `hssm.plotting.plot_quantile_probability()` with the - model. Please see that function for - [full documentation][hssm.plotting.plot_quantile_probability]. - - Returns - ------- - mpl.axes.Axes | sns.FacetGrid - The matplotlib axis or seaborn FacetGrid object containing the plot. - """ - return plotting.plot_quantile_probability(self, **kwargs) - - def predict(self, **kwargs) -> az.InferenceData: - """Generate samples from the predictive distribution.""" - return self.model.predict(**kwargs) - - def sample_do( - self, params: dict[str, Any], draws: int = 100, return_model=False, **kwargs - ) -> az.InferenceData | tuple[az.InferenceData, pm.Model]: - """Generate samples from the predictive distribution using the `do-operator`.""" - do_model = do(self.pymc_model, params) - do_idata = pm.sample_prior_predictive(model=do_model, draws=draws, **kwargs) - - # clean up `rt,response_mean` to `v` - do_idata = self._drop_parent_str_from_idata(idata=do_idata) - - # rename otherwise inconsistent dims and coords - if "rt,response_extra_dim_0" in do_idata["prior_predictive"].dims: - setattr( - do_idata, - "prior_predictive", - do_idata["prior_predictive"].rename_dims( - {"rt,response_extra_dim_0": "rt,response_dim"} - ), - ) - if "rt,response_extra_dim_0" in do_idata["prior_predictive"].coords: - setattr( - do_idata, - "prior_predictive", - do_idata["prior_predictive"].rename_vars( - name_dict={"rt,response_extra_dim_0": "rt,response_dim"} - ), - ) - - if return_model: - return do_idata, do_model - return do_idata - - def sample_prior_predictive( - self, - draws: int = 500, - var_names: str | list[str] | None = None, - omit_offsets: bool = True, - random_seed: np.random.Generator | None = None, - ) -> az.InferenceData: - """Generate samples from the prior predictive distribution. - - Parameters - ---------- - draws - Number of draws to sample from the prior predictive distribution. Defaults - to 500. - var_names - A list of names of variables for which to compute the prior predictive - distribution. Defaults to ``None`` which means both observed and unobserved - RVs. - omit_offsets - Whether to omit offset terms. Defaults to ``True``. - random_seed - Seed for the random number generator. - - Returns - ------- - az.InferenceData - ``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and - ``observed_data``. - """ - prior_predictive = self.model.prior_predictive( - draws, var_names, omit_offsets, random_seed - ) - - # AF-COMMENT: Not sure if necessary to include the - # mean prior here (which adds deterministics that - # could be recomputed elsewhere) - prior_predictive.add_groups(posterior=prior_predictive.prior) - # Bambi >= 0.17 renamed kind="mean" to kind="response_params". - self.model.predict(prior_predictive, kind="response_params", inplace=True) - - # clean - setattr(prior_predictive, "prior", prior_predictive["posterior"]) - del prior_predictive["posterior"] - - if self._inference_obj is None: - self._inference_obj = prior_predictive - else: - self._inference_obj.extend(prior_predictive) - - # clean up `rt,response_mean` to `v` - idata = self._drop_parent_str_from_idata(idata=self._inference_obj) - - # rename otherwise inconsistent dims and coords - if "rt,response_extra_dim_0" in idata["prior_predictive"].dims: - setattr( - idata, - "prior_predictive", - idata["prior_predictive"].rename_dims( - {"rt,response_extra_dim_0": "rt,response_dim"} - ), - ) - if "rt,response_extra_dim_0" in idata["prior_predictive"].coords: - setattr( - idata, - "prior_predictive", - idata["prior_predictive"].rename_vars( - name_dict={"rt,response_extra_dim_0": "rt,response_dim"} - ), - ) - - # Update self._inference_obj to match the cleaned idata - self._inference_obj = idata - return deepcopy(self._inference_obj) - - @property - def pymc_model(self) -> pm.Model: - """Provide access to the PyMC model. - - Returns - ------- - pm.Model - The PyMC model built by bambi - """ - return self.model.backend.model - - def set_alias(self, aliases: dict[str, str | dict]): - """Set parameter aliases. - - Sets the aliases according to the dictionary passed to it and rebuild the - model. - - Parameters - ---------- - aliases - A dict specifying the parameter names being aliased and the aliases. - """ - self.model.set_alias(aliases) - self.model.build() - - # TODO: update this to HSSMBase - @property - def response_c(self) -> str: - """Return the response variable names in c() format. - - New in 0.2.12: when model is choice-only and has deadline, the response - is not in the form of c(...). - """ - if self.response is None: - raise ValueError("Response is not defined.") - if self.is_choice_only and not self.deadline: - return self.response[0] - return f"c({', '.join(self.response)})" - - @property - def response_str(self) -> str: - """Return the response variable names in string format.""" - if self.response is None: - return "" - return ",".join(self.response) - - # NOTE: can't annotate return type because the graphviz dependency is optional - def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"): - """Produce a graphviz Digraph from a built HSSM model. - - Requires graphviz, which may be installed most easily with `conda install -c - conda-forge python-graphviz`. Alternatively, you may install the `graphviz` - binaries yourself, and then `pip install graphviz` to get the python bindings. - See http://graphviz.readthedocs.io/en/stable/manual.html for more information. - - Parameters - ---------- - formatting - One of `"plain"` or `"plain_with_params"`. Defaults to `"plain"`. - name - Name of the figure to save. Defaults to `None`, no figure is saved. - figsize - Maximum width and height of figure in inches. Defaults to `None`, the - figure size is set automatically. If defined and the drawing is larger than - the given size, the drawing is uniformly scaled down so that it fits within - the given size. Only works if `name` is not `None`. - dpi - Point per inch of the figure to save. - Defaults to 300. Only works if `name` is not `None`. - fmt - Format of the figure to save. - Defaults to `"png"`. Only works if `name` is not `None`. - - Returns - ------- - graphviz.Graph - The graph - """ - graph = self.model.graph(formatting, name, figsize, dpi, fmt) - - parent_param = self._parent_param - if parent_param.is_regression: - return graph - - # Modify the graph - # 1. Remove all nodes and edges related to `{parent}_mean`: - graph.body = [ - item for item in graph.body if f"{parent_param.name}_mean" not in item - ] - # 2. Add a new edge from parent to response - graph.edge(parent_param.name, self.response_str) - - return graph - - def compile_logp(self, keep_transformed: bool = False, **kwargs): - """Compile the log probability function for the model. - - Parameters - ---------- - keep_transformed : bool, optional - If True, keeps the transformed variables in the compiled function. - If False, removes value transforms before compilation. - Defaults to False. - **kwargs - Additional keyword arguments passed to PyMC's compile_logp: - - vars: List of variables. Defaults to None (all variables). - - jacobian: Whether to include log(|det(dP/dQ)|) term for - transformed variables. Defaults to True. - - sum: Whether to sum all terms instead of returning a vector. - Defaults to True. - - Returns - ------- - callable - A compiled function that computes the model log probability. - """ - if keep_transformed: - return self.pymc_model.compile_logp( - vars=kwargs.get("vars", None), - jacobian=kwargs.get("jacobian", True), - sum=kwargs.get("sum", True), - ) - else: - new_model = pm.model.transform.conditioning.remove_value_transforms( - self.pymc_model - ) - return new_model.compile_logp( - vars=kwargs.get("vars", None), - jacobian=kwargs.get("jacobian", True), - sum=kwargs.get("sum", True), - ) - - def plot_trace( - self, - data: az.InferenceData | None = None, - include_deterministic: bool = False, - tight_layout: bool = True, - **kwargs, - ) -> None: - """Generate trace plot with ArviZ but with additional convenience features. - - This is a simple wrapper for the az.plot_trace() function. By default, it - filters out the deterministic values from the plot. Please see the - [arviz documentation] - (https://arviz-devs.github.io/arviz/api/generated/arviz.plot_trace.html) - for additional parameters that can be specified. - - Parameters - ---------- - data : optional - An ArviZ InferenceData object. If None, the traces stored in the model will - be used. - include_deterministic : optional - Whether to include deterministic variables in the plot. Defaults to False. - Note that if include deterministic is set to False and and `var_names` is - provided, the `var_names` provided will be modified to also exclude the - deterministic values. If this is not desirable, set - `include deterministic` to True. - tight_layout : optional - Whether to call plt.tight_layout() after plotting. Defaults to True. - """ - data = data or self.traces - if not isinstance(data, az.InferenceData): - raise TypeError("data must be an InferenceData object.") - - if not include_deterministic: - var_names = list( - set([var.name for var in self.pymc_model.free_RVs]).intersection( - set(list(data["posterior"].data_vars.keys())) - ) - ) - # var_names = self._get_deterministic_var_names(data) - if var_names: - if "var_names" in kwargs: - if isinstance(kwargs["var_names"], str): - if kwargs["var_names"] not in var_names: - var_names.append(kwargs["var_names"]) - kwargs["var_names"] = var_names - elif isinstance(kwargs["var_names"], list): - kwargs["var_names"] = list( - set(var_names) | set(kwargs["var_names"]) - ) - elif kwargs["var_names"] is None: - kwargs["var_names"] = var_names - else: - raise ValueError( - "`var_names` must be a string, a list of strings, or None." - ) - else: - kwargs["var_names"] = var_names - az.plot_trace(data, **kwargs) - - if tight_layout: - plt.tight_layout() - - def summary( - self, - data: az.InferenceData | None = None, - include_deterministic: bool = False, - **kwargs, - ) -> pd.DataFrame | xr.Dataset: - """Produce a summary table with ArviZ but with additional convenience features. - - This is a simple wrapper for the az.summary() function. By default, it - filters out the deterministic values from the plot. Please see the - [arviz documentation] - (https://arviz-devs.github.io/arviz/api/generated/arviz.summary.html) - for additional parameters that can be specified. - - Parameters - ---------- - data - An ArviZ InferenceData object. If None, the traces stored in the model will - be used. - include_deterministic : optional - Whether to include deterministic variables in the plot. Defaults to False. - Note that if include_deterministic is set to False and and `var_names` is - provided, the `var_names` provided will be modified to also exclude the - deterministic values. If this is not desirable, set - `include_deterministic` to True. - - Returns - ------- - pd.DataFrame | xr.Dataset - A pandas DataFrame or xarray Dataset containing the summary statistics. - """ - data = data or self.traces - if not isinstance(data, az.InferenceData): - raise TypeError("data must be an InferenceData object.") - - if not include_deterministic: - var_names = list( - set([var.name for var in self.pymc_model.free_RVs]).intersection( - set(list(data["posterior"].data_vars.keys())) - ) - ) - # var_names = self._get_deterministic_var_names(data) - if var_names: - kwargs["var_names"] = list(set(var_names + kwargs.get("var_names", []))) - return az.summary(data, **kwargs) - - def initial_point(self, transformed: bool = False) -> dict[str, np.ndarray]: - """Compute the initial point of the model. - - This is a slightly altered version of pm.initial_point.initial_point(). - - Parameters - ---------- - transformed : bool, optional - If True, return the initial point in transformed space. - - Returns - ------- - dict - A dictionary containing the initial point of the model parameters. - """ - fn = pm.initial_point.make_initial_point_fn( - model=self.pymc_model, return_transformed=transformed - ) - return pm.model.Point(fn(None), model=self.pymc_model) - - def restore_traces( - self, traces: az.InferenceData | pm.Approximation | str | PathLike - ) -> None: - """Restore traces from an InferenceData object or a .netcdf file. - - Parameters - ---------- - traces - An InferenceData object or a path to a file containing the traces. - """ - if isinstance(traces, pm.Approximation): - self._inference_obj_vi = traces - return - - if isinstance(traces, (str, PathLike)): - traces = az.from_netcdf(traces) - self._inference_obj = cast("az.InferenceData", traces) - - def restore_vi_traces( - self, traces: az.InferenceData | pm.Approximation | str | PathLike - ) -> None: - """Restore VI traces from an InferenceData object or a .netcdf file. - - Parameters - ---------- - traces - An InferenceData object or a path to a file containing the VI traces. - """ - if isinstance(traces, pm.Approximation): - self._inference_obj_vi = traces - return - - if isinstance(traces, (str, PathLike)): - traces = az.from_netcdf(traces) - self._inference_obj_vi = cast("az.InferenceData", traces) - - def save_model( - self, - model_name: str | None = None, - allow_absolute_base_path: bool = False, - base_path: str | Path = "hssm_models", - save_idata_only: bool = False, - ) -> None: - """Save a HSSM model instance and its inference results to disk. - - Parameters - ---------- - model_name : str | None - Name to use for the saved model files. - If None, will use model.model_name with timestamp - allow_absolute_base_path : bool - Whether to allow absolute paths for base_path. - Defaults to False for safety. - base_path : str | Path - Base directory to save model files in. - Must be relative path if allow_absolute_base_path=False. - Defaults to "hssm_models". - save_idata_only : bool - If True, only saves inference data (traces), not the model pickle. - Defaults to False (saves both model and traces). - - Raises - ------ - ValueError - If base_path is absolute and allow_absolute_base_path=False - """ - # Convert to Path object for cross-platform compatibility - base_path = Path(base_path) - - # Check if base_path is absolute (works on all platforms) - if not allow_absolute_base_path and base_path.is_absolute(): - raise ValueError( - "base_path must be a relative path if allow_absolute_base_path is False" - ) - - if model_name is None: - # Get date string format as suffix to model name - timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") - model_name = f"{self.model_name}_{timestamp}" - - # Sanitize model_name and construct full path - model_name = model_name.replace(" ", "_") - model_path = base_path / model_name - model_path.mkdir(parents=True, exist_ok=True) - - # Save model to pickle file - if not save_idata_only: - with open(model_path.joinpath("model.pkl"), "wb") as f: - cpickle.dump(self, f) - - # Save traces to netcdf file - if self._inference_obj is not None: - az.to_netcdf(self._inference_obj, model_path.joinpath("traces.nc")) - - # Save vi_traces to netcdf file - if self._inference_obj_vi is not None: - az.to_netcdf(self._inference_obj_vi, model_path.joinpath("vi_traces.nc")) - - @classmethod - def load_model( - cls, path: Union[str, Path] - ) -> Union["HSSM", dict[str, Optional[az.InferenceData]]]: - """Load a HSSM model instance and its inference results from disk. - - Parameters - ---------- - path : str | Path - Path to the model directory or model.pkl file. If a directory is provided, - will look for model.pkl, traces.nc and vi_traces.nc files within it. - - Returns - ------- - HSSM - The loaded HSSM model instance with inference results attached if available. - """ - # Convert path to Path object - path = Path(path) - - # If path points to a file, assume it's model.pkl - if path.is_file(): - model_dir = path.parent - model_path = path - else: - # Path points to directory - model_dir = path - model_path = model_dir.joinpath("model.pkl") - - # check if model_dir exists - if not model_dir.exists(): - raise FileNotFoundError(f"Model directory {model_dir} does not exist.") - - # check if model.pkl exists raise logging information if not - if not model_path.exists(): - _logger.info( - f"model.pkl file does not exist in {model_dir}. " - "Attempting to load traces only." - ) - if (not model_dir.joinpath("traces.nc").exists()) and ( - not model_dir.joinpath("vi_traces.nc").exists() - ): - raise FileNotFoundError(f"No traces found in {model_dir}.") - else: - idata_dict = cls.load_model_idata(model_dir) - return idata_dict - else: - # Load model from pickle file - with open(model_path, "rb") as f: - model = cpickle.load(f) - - # Load traces if they exist - traces_path = model_dir.joinpath("traces.nc") - if traces_path.exists(): - model.restore_traces(traces_path) - - # Load VI traces if they exist - vi_traces_path = model_dir.joinpath("vi_traces.nc") - if vi_traces_path.exists(): - model.restore_vi_traces(vi_traces_path) - return model - - @classmethod - def load_model_idata(cls, path: str | Path) -> dict[str, az.InferenceData | None]: - """Load the traces from a model directory. - - Parameters - ---------- - path : str | Path - Path to the model directory containing traces.nc and/or vi_traces.nc files. - - Returns - ------- - dict[str, az.InferenceData | None] - A dictionary with keys "idata_mcmc" and "idata_vi" containing the traces - from the model directory. If the traces do not exist, the corresponding - value will be None. - """ - idata_dict: dict[str, az.InferenceData | None] = {} - model_dir = Path(path) - # check if path exists - if not model_dir.exists(): - raise FileNotFoundError(f"Model directory {model_dir} does not exist.") - - # check if traces.nc exists - traces_path = model_dir.joinpath("traces.nc") - if not traces_path.exists(): - _logger.warning(f"traces.nc file does not exist in {model_dir}.") - idata_dict["idata_mcmc"] = None - else: - idata_dict["idata_mcmc"] = az.from_netcdf(traces_path) - - # check if vi_traces.nc exists - vi_traces_path = model_dir.joinpath("vi_traces.nc") - if not vi_traces_path.exists(): - _logger.warning(f"vi_traces.nc file does not exist in {model_dir}.") - idata_dict["idata_vi"] = None - else: - idata_dict["idata_vi"] = az.from_netcdf(vi_traces_path) - - return idata_dict - - def __getstate__(self): - """Get the state of the model for pickling. - - This method is called when pickling the model. - It returns a dictionary containing the constructor - arguments needed to recreate the model instance. - - Returns - ------- - dict - A dictionary containing the constructor arguments - under the key 'constructor_args'. - """ - state = {"constructor_args": self._init_args} - return state - - def __setstate__(self, state): - """Set the state of the model when unpickling. - - This method is called when unpickling the model. It creates a new instance - of HSSM using the constructor arguments stored in the state dictionary, - and copies its attributes to the current instance. - - Parameters - ---------- - state : dict - A dictionary containing the constructor arguments under the key - 'constructor_args'. - """ - new_instance = HSSM(**state["constructor_args"]) - self.__dict__ = new_instance.__dict__ - - def __repr__(self) -> str: - """Create a representation of the model.""" - output = [ - "Hierarchical Sequential Sampling Model", - f"Model: {self.model_name}\n", - f"Response variable: {self.response_str}", - f"Likelihood: {self.loglik_kind}", - f"Observations: {len(self.data)}\n", - "Parameters:\n", - ] - - for param in self.params.values(): - if param.name == "p_outlier": - continue - output.append(f"{param.name}:") - - component = self.model.components[param.name] - - # Regression case: - if param.is_regression: - assert isinstance(component, DistributionalComponent) - output.append(f" Formula: {param.formula}") - output.append(" Priors:") - intercept_term = component.intercept_term - if intercept_term is not None: - output.append(_print_prior(intercept_term)) - for _, common_term in component.common_terms.items(): - output.append(_print_prior(common_term)) - for _, group_specific_term in component.group_specific_terms.items(): - output.append(_print_prior(group_specific_term)) - output.append(f" Link: {param.link}") - # None regression case - else: - if param.prior is None: - prior = ( - component.intercept_term.prior - if param.is_parent - else component.prior - ) - else: - prior = param.prior - output.append(f" Prior: {prior}") - output.append(f" Explicit bounds: {param.bounds}") - output.append( - " (ignored due to link function)" - if self.link_settings is not None - else "" - ) - - # TODO: Handle p_outlier regression correctly here. - if self.p_outlier is not None: - output.append("") - output.append(f"Lapse probability: {self.p_outlier.prior}") - output.append(f"Lapse distribution: {self.lapse}") - - return "\n".join(output) - - def __str__(self) -> str: - """Create a string representation of the model.""" - return self.__repr__() - - @property - def traces(self) -> az.InferenceData | pm.Approximation: - """Return the trace of the model after sampling. - - Raises - ------ - ValueError - If the model has not been sampled yet. - - Returns - ------- - az.InferenceData | pm.Approximation - The trace of the model after the last call to `sample()`. - """ - if not self._inference_obj: - raise ValueError("Please sample the model first.") - - return self._inference_obj - - @property - def vi_idata(self) -> az.InferenceData: - """Return the variational inference approximation object. - - Raises - ------ - ValueError - If the model has not been sampled yet. - - Returns - ------- - az.InferenceData - The variational inference approximation object. - """ - if not self._inference_obj_vi: - raise ValueError( - "Please run variational inference first, " - "no variational posterior attached." - ) - - return self._inference_obj_vi - - @property - def vi_approx(self) -> pm.Approximation: - """Return the variational inference approximation object. - - Raises - ------ - ValueError - If the model has not been sampled yet. - - Returns - ------- - pm.Approximation - The variational inference approximation object. - """ - if not self._vi_approx: - raise ValueError( - "Please run variational inference first, " - "no variational approximation attached." - ) - - return self._vi_approx - - @property - def map(self) -> dict: - """Return the MAP estimates of the model parameters. - - Raises - ------ - ValueError - If the model has not been sampled yet. - - Returns - ------- - dict - A dictionary containing the MAP estimates of the model parameters. - """ - if not self._map_dict: - raise ValueError("Please compute map first.") - - return self._map_dict - - @property - def initvals(self) -> dict: - """Return the initial values of the model parameters for sampling. - - Returns - ------- - dict - A dictionary containing the initial values of the model parameters. - This dict serves as the default for initial values, and can be passed - directly to the `.sample()` function. - """ - if self._initvals == {}: - self._initvals = self.initial_point() - return self._initvals - - def _check_lapse(self, lapse): - """Determine if p_outlier and lapse is specified correctly.""" - # Basically, avoid situations where only one of them is specified. - if self.list_params[-1] != "p_outlier": - if "p_outlier" in self.list_params: - raise ValueError( - "Please do not include 'p_outlier' in `list_params`. " - + "We automatically append it to `list_params` when `p_outlier` " - + "parameter is not None" - ) - if self.has_lapse: - if lapse is None: - if self.is_choice_only: - self.lapse = 1 / self.n_choices - else: - self.lapse = bmb.Prior("Uniform", lower=0.0, upper=20.0) - else: - self.lapse = lapse - - self.list_params.append("p_outlier") - return - - if lapse is not None: - _logger.warning( - "You have specified the `lapse` argument to include a lapse " - + "distribution, but `p_outlier` is set to either 0 or None. " - + "Your lapse distribution will be ignored." - ) - self.lapse = None - def _make_model_distribution(self) -> type[pm.Distribution]: """Make a pm.Distribution for the model.""" ### Logic for different types of likelihoods: @@ -2039,6 +321,15 @@ def _make_model_distribution(self) -> type[pm.Distribution]: if isclass(self.loglik) and issubclass(self.loglik, pm.Distribution): return self.loglik + # At this point, loglik should not be a type[Distribution] and should be set + + assert self.loglik is not None, "loglik should be set" + assert self.loglik_kind is not None, "loglik_kind should be set" + assert not (isclass(self.loglik) and issubclass(self.loglik, pm.Distribution)) + loglik_callable = typing_cast( + "Op | Callable[..., Any] | PathLike | str", self.loglik + ) + # params_is_trialwise_base: one entry per model param (excluding # p_outlier). Used for graph-level broadcasting in logp() and # make_distribution, where dist_params does not include extra_fields. @@ -2059,20 +350,20 @@ def _make_model_distribution(self) -> type[pm.Distribution]: if self.loglik_kind == "approx_differentiable": if self.model_config.backend == "jax": likelihood_callable = make_likelihood_callable( - loglik=self.loglik, + loglik=loglik_callable, loglik_kind="approx_differentiable", backend="jax", params_is_reg=params_is_trialwise, ) else: likelihood_callable = make_likelihood_callable( - loglik=self.loglik, + loglik=loglik_callable, loglik_kind="approx_differentiable", backend=self.model_config.backend, ) else: likelihood_callable = make_likelihood_callable( - loglik=self.loglik, + loglik=loglik_callable, loglik_kind=self.loglik_kind, backend=self.model_config.backend, ) @@ -2145,7 +436,7 @@ def dummy_simulator_func(*args, **kwargs): self.model_config.rv = make_hssm_rv( dummy_simulator_func, - list_params=self.list_params, + list_params=self.list_params or [], lapse=self.lapse, is_choice_only=True, ) @@ -2159,10 +450,15 @@ def dummy_simulator_func(*args, **kwargs): if isinstance(param.prior, np.ndarray) } + # Use the typed `model_config` attributes directly + _list_params = self.model_config.list_params + assert _list_params is not None, "list_params should be set" # for type checker + rv_name = getattr(self.model_config, "rv", None) or self.model_config.model_name + return make_distribution( - rv=self.model_config.rv or self.model_name, + rv=rv_name, loglik=self.loglik, - list_params=self.list_params, + list_params=_list_params, bounds=self.bounds, lapse=self.lapse, extra_fields=( @@ -2175,255 +471,3 @@ def dummy_simulator_func(*args, **kwargs): # TODO: add to HSSMBase is_choice_only=self.is_choice_only, ) - - def _get_deterministic_var_names(self, idata) -> list[str]: - """Filter out the deterministic variables in var_names.""" - var_names = [ - f"~{param_name}" - for param_name, param in self.params.items() - if (param.is_regression) - ] - - if f"{self._parent}_mean" in idata["posterior"].data_vars: - var_names.append(f"~{self._parent}_mean") - - # Parent parameters (always regression implicitly) - # which don't have a formula attached - # should be dropped from var_names, since the actual - # parent name shows up as a regression. - if f"{self._parent}" in idata["posterior"].data_vars: - if self.params[self._parent].formula is None: - # Drop from var_names - var_names = [var for var in var_names if var != f"~{self._parent}"] - - return var_names - - def _drop_parent_str_from_idata( - self, idata: az.InferenceData | None - ) -> az.InferenceData: - """Drop the parent_str variable from an InferenceData object. - - Parameters - ---------- - idata - The InferenceData object to be modified. - - Returns - ------- - xr.Dataset - The modified InferenceData object. - """ - if idata is None: - raise ValueError("Please provide an InferenceData object.") - else: - for group in idata.groups(): - if ("rt,response_mean" in idata[group].data_vars) and ( - self._parent not in idata[group].data_vars - ): - setattr( - idata, - group, - idata[group].rename({"rt,response_mean": self._parent}), - ) - return idata - - def _postprocess_initvals_deterministic( - self, initval_settings: dict = INITVAL_SETTINGS - ) -> None: - """Set initial values for subset of parameters.""" - self._initvals = self.initial_point() - # Consider case where link functions are set to 'log_logit' - # or 'None' - if self.link_settings not in ["log_logit", None]: - _logger.info( - "Not preprocessing initial values, " - + "because none of the two standard link settings are chosen!" - ) - return None - - # Set initial values for particular parameters - for name_, starting_value in self.pymc_model.initial_point().items(): - # strip name of `_log__` and `_interval__` suffixes - name_tmp = name_.replace("_log__", "").replace("_interval__", "") - - # We need to check if the parameter is actually backed by - # a regression. - - # If not, we don't actually apply a link function to it as per default. - # Therefore we need to apply the initial value strategy corresponding - # to 'None' link function. - - # If the user actively supplies a link function, the user - # should also have supplied an initial value insofar it matters. - - if self.params[self._get_prefix(name_tmp)].is_regression: - param_link_setting = self.link_settings - else: - param_link_setting = None - if name_tmp in initval_settings[param_link_setting].keys(): - if self._check_if_initval_user_supplied(name_tmp): - _logger.info( - "User supplied initial value detected for %s, \n" - " skipping overwrite with default value.", - name_tmp, - ) - continue - - # Apply specific settings from initval_settings dictionary - dtype = self._initvals[name_tmp].dtype - self._initvals[name_tmp] = np.array( - initval_settings[param_link_setting][name_tmp] - ).astype(dtype) - - def _get_prefix(self, name_str: str) -> str: - """Get parameters wise link setting function from parameter prefix.""" - # `p_outlier` is the only basic parameter floating around that has - # an underscore in it's name. - # We need to handle it separately. (Renaming might be better...) - if "_" in name_str: - if "p_outlier" not in name_str: - name_str_prefix = name_str.split("_")[0] - else: - name_str_prefix = "p_outlier" - else: - name_str_prefix = name_str - return name_str_prefix - - def _check_if_initval_user_supplied( - self, - name_str: str, - return_value: bool = False, - ) -> bool | float | int | np.ndarray | dict[str, Any] | None: - """Check if initial value is user-supplied.""" - # The function assumes that the name_str is either raw parameter name - # or `paramname_Intercept`, because we only really provide special default - # initial values for those types of parameters - - # `p_outlier` is the only basic parameter floating around that has - # an underscore in it's name. - # We need to handle it separately. (Renaming might be better...) - if "_" in name_str: - if "p_outlier" not in name_str: - name_str_prefix = name_str.split("_")[0] - # name_str_suffix = "".join(name_str.split("_")[1:]) - name_str_suffix = name_str[len(name_str_prefix + "_") :] - else: - name_str_prefix = "p_outlier" - if name_str == "p_outlier": - name_str_suffix = "" - else: - # name_str_suffix = "".join(name_str.split("_")[2:]) - name_str_suffix = name_str[len("p_outlier_") :] - else: - name_str_prefix = name_str - name_str_suffix = "" - - tmp_param = name_str_prefix - if tmp_param == self._parent: - # If the parameter was parent it is automatically treated as a - # regression. - if not name_str_suffix: - # No suffix --> Intercept - if isinstance(prior_tmp := self.params[tmp_param].prior, dict): - args_tmp = getattr(prior_tmp["Intercept"], "args") - if return_value: - return args_tmp.get("initval", None) - else: - return "initval" in args_tmp - else: - if return_value: - return None - return False - else: - # If the parameter has a suffix --> use it - if isinstance(prior_tmp := self.params[tmp_param].prior, dict): - args_tmp = getattr(prior_tmp[name_str_suffix], "args") - if return_value: - return args_tmp.get("initval", None) - else: - return "initval" in args_tmp - else: - if return_value: - return None - else: - return False - else: - # If the parameter is not a parent, it is treated as a regression - # only when actively specified as such. - if not name_str_suffix: - # If no suffix --> treat as basic parameter. - if isinstance(self.params[tmp_param].prior, float) or isinstance( - self.params[tmp_param].prior, np.ndarray - ): - if return_value: - return self.params[tmp_param].prior - else: - return True - elif isinstance(self.params[tmp_param].prior, bmb.Prior): - args_tmp = getattr(self.params[tmp_param].prior, "args") - if "initval" in args_tmp: - if return_value: - return args_tmp["initval"] - else: - return True - else: - if return_value: - return None - else: - return False - else: - if return_value: - return None - else: - return False - else: - # If suffix --> treat as regression and use suffix - if isinstance(prior_tmp := self.params[tmp_param].prior, dict): - args_tmp = getattr(prior_tmp[name_str_suffix], "args") - if return_value: - return args_tmp.get("initval", None) - else: - return "initval" in args_tmp - else: - if return_value: - return None - else: - return False - - def _jitter_initvals( - self, jitter_epsilon: float = 0.01, vector_only: bool = False - ) -> None: - """Apply controlled jitter to initial values.""" - if vector_only: - self.__jitter_initvals_vector_only(jitter_epsilon) - else: - self.__jitter_initvals_all(jitter_epsilon) - - def __jitter_initvals_vector_only(self, jitter_epsilon: float) -> None: - # Note: Calling our initial point function here - # --> operate on untransformed variables - initial_point_dict = self.initvals - for name_, starting_value in initial_point_dict.items(): - name_tmp = name_.replace("_log__", "").replace("_interval__", "") - if starting_value.ndim != 0 and starting_value.shape[0] != 1: - starting_value_tmp = starting_value + np.random.uniform( - -jitter_epsilon, jitter_epsilon, starting_value.shape - ).astype(np.float32) - - # Note: self._initvals shouldn't be None when this is called - dtype = self._initvals[name_tmp].dtype - self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype) - - def __jitter_initvals_all(self, jitter_epsilon: float) -> None: - # Note: Calling our initial point function here - # --> operate on untransformed variables - initial_point_dict = self.initvals - # initial_point_dict = self.pymc_model.initial_point() - for name_, starting_value in initial_point_dict.items(): - name_tmp = name_.replace("_log__", "").replace("_interval__", "") - starting_value_tmp = starting_value + np.random.uniform( - -jitter_epsilon, jitter_epsilon, starting_value.shape - ).astype(np.float32) - - dtype = self.initvals[name_tmp].dtype - self._initvals[name_tmp] = np.array(starting_value_tmp).astype(dtype) diff --git a/src/hssm/missing_data_mixin.py b/src/hssm/missing_data_mixin.py new file mode 100644 index 00000000..5c7012d3 --- /dev/null +++ b/src/hssm/missing_data_mixin.py @@ -0,0 +1,200 @@ +"""Mixin module for handling missing data and deadline logic in HSSM models.""" + +import numpy as np +import pandas as pd + +from hssm.defaults import MissingDataNetwork # noqa: F401 + + +class MissingDataMixin: + """Mixin for handling missing data and deadline logic in HSSM models. + + Parameters + ---------- + missing_data : optional + Specifies whether the model should handle missing data. Can be a `bool` + or a `float`. If `False`, and if the `rt` column contains -999.0, the + model will drop those rows and produce a warning. If `True`, the model + will treat -999.0 as missing data. If a `float` is provided, it will be + treated as the missing data value. Defaults to `False`. + deadline : optional + Specifies whether the model should handle deadline data. Can be a `bool` + or a `str`. If `False`, the model will not act even if a deadline column + is provided. If `True`, the model will treat the `deadline` column as + deadline data. If a `str` is provided, it is treated as the name of the + deadline column. Defaults to `False`. + loglik_missing_data : optional + A likelihood function for missing data. See the `loglik` parameter for + details. If not provided, a default likelihood is used. Required only if + either `missing_data` or `deadline` is not `False`. + """ + + def _handle_missing_data_and_deadline(self): + """Handle missing data and deadline. + + Originally from DataValidatorMixin. Handles dropping, replacing, or masking + missing data and deadline values in self.data based on the current settings. + """ + import warnings + + if not self.missing_data and not self.deadline: + # In the case of choice only model, we don't need to do anything with the + # data. + if self.is_choice_only: + return + # In the case where missing_data is set to False, we need to drop the + # cases where rt = na_value + if pd.isna(self.missing_data_value): + na_dropped = self.data.dropna(subset=["rt"]) + else: + na_dropped = self.data.loc[ + self.data["rt"] != self.missing_data_value, : + ] + + if len(na_dropped) != len(self.data): + warnings.warn( + "`missing_data` is set to False, " + + "but you have missing data in your dataset. " + + "Missing data will be dropped.", + stacklevel=2, + ) + self.data = na_dropped + + elif self.missing_data and not self.deadline: + # In the case where missing_data is set to True, we need to replace the + # missing data with a specified na_value + + # Create a shallow copy to avoid modifying the original dataframe + if pd.isna(self.missing_data_value): + self.data["rt"] = self.data["rt"].fillna(-999.0) + else: + self.data["rt"] = self.data["rt"].replace( + self.missing_data_value, -999.0 + ) + + else: # deadline = True + if self.deadline_name not in self.data.columns: + raise ValueError( + "You have specified that your data has deadline, but " + + f"`{self.deadline_name}` is not found in your dataset." + ) + else: + self.data.loc[:, "rt"] = np.where( + self.data["rt"] < self.data[self.deadline_name], + self.data["rt"], + -999.0, + ) + + @staticmethod + def _set_missing_data_and_deadline( + missing_data: bool, deadline: bool, data: pd.DataFrame + ) -> MissingDataNetwork: + """Set missing data and deadline.""" + if not missing_data: + return MissingDataNetwork.NONE + network = MissingDataNetwork.OPN if deadline else MissingDataNetwork.CPN + # AF-TODO: GONOGO case not yet correctly implemented + # else: + # # TODO: This won't behave as expected yet, GONOGO needs to be split + # # into a deadline case and a non-deadline case. + # network = MissingDataNetwork.GONOGO + + if np.all(data["rt"] == -999.0): + # AF-TODO: I think we should allow invalid-only datasets. + raise ValueError( + "`missing_data` is set to True, but you have no valid data in your " + "dataset." + ) + # AF-TODO: This one needs refinement for GONOGO case + # elif network == MissingDataNetwork.OPN: + # raise ValueError( + # "`deadline` is set to True and `missing_data` is set to True, " + # "but ." + # ) + # else: + # raise ValueError( + # "`missing_data` and `deadline` are both set to True, + # "but you have " + # "no missing data and/or no rts exceeding the deadline." + # ) + return network + + def _process_missing_data_and_deadline( + self, missing_data: float | bool, deadline: bool | str, loglik_missing_data + ): + """ + Process missing data and deadline logic for the model's data. + + This method sets up missing data and deadline handling for the model. + It updates self.missing_data, self.missing_data_value, self.deadline, + self.deadline_name, and self.loglik_missing_data based on the arguments. + It also modifies self.data in-place to drop or replace missing/deadline + values as appropriate, and sets self.missing_data_network. + + Parameters + ---------- + missing_data : float or bool + If True, treat -999.0 as missing data. If a float, use that value + as the missing data marker. If False, drop missing data rows. + deadline : bool or str + If True, use the 'deadline' column for deadline logic. If a str, + use that column name. If False, ignore deadline logic. + loglik_missing_data : callable or None + Optional custom likelihood function for missing data. If not None, + must be used only when missing_data or deadline is True. + """ + if isinstance(missing_data, float): + if not ((self.data.rt == missing_data).any()): + raise ValueError( + f"missing_data argument is provided as a float {missing_data}, " + f"However, you have no RTs of {missing_data} in your dataset!" + ) + else: + self.missing_data = True + self.missing_data_value = missing_data + elif isinstance(missing_data, bool): + if missing_data and (not (self.data.rt == -999.0).any()): + raise ValueError( + "missing_data argument is provided as True, " + " so RTs of -999.0 are treated as missing. \n" + "However, you have no RTs of -999.0 in your dataset!" + ) + elif (not missing_data) and (self.data.rt == -999.0).any(): + raise ValueError( + "Missing data provided as False. \n" + "However, you have RTs of -999.0 in your dataset!" + ) + else: + self.missing_data = missing_data + else: + raise ValueError( + "missing_data argument must be a bool or a float! \n" + f"You provided: {type(missing_data)}" + ) + + if isinstance(deadline, str): + self.deadline = True + self.deadline_name = deadline + else: + self.deadline = deadline + self.deadline_name = "deadline" + + if ( + not self.missing_data and not self.deadline + ) and loglik_missing_data is not None: + raise ValueError( + "You have specified a loglik_missing_data function, but you have not " + "set the missing_data or deadline flag to True." + ) + self.loglik_missing_data = loglik_missing_data + + # Update data based on missing_data and deadline + self._handle_missing_data_and_deadline() + # Set self.missing_data_network based on `missing_data` and `deadline` + self.missing_data_network = self._set_missing_data_and_deadline( + self.missing_data, self.deadline, self.data + ) + + if self.deadline and self.response is not None: # type: ignore[attr-defined] + if self.deadline_name not in self.response: # type: ignore[attr-defined] + self.response.append(self.deadline_name) # type: ignore[attr-defined] diff --git a/src/hssm/param/param.py b/src/hssm/param/param.py index 4ea8493f..3d1897dc 100644 --- a/src/hssm/param/param.py +++ b/src/hssm/param/param.py @@ -157,7 +157,7 @@ def is_trialwise(self) -> bool: def fill_defaults( self, - prior: dict[str, Any] | None = None, + prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None, bounds: tuple[float, float] | None = None, **kwargs, ) -> None: diff --git a/src/hssm/param/params.py b/src/hssm/param/params.py index f542c803..d90de35f 100644 --- a/src/hssm/param/params.py +++ b/src/hssm/param/params.py @@ -213,6 +213,7 @@ def collect_user_params( user_param = UserParam.from_dict(param) if isinstance(param, dict) else param if user_param.name is None: raise ValueError("Parameter name must be specified.") + assert model.list_params is not None, "list_params should be set" if user_param.name not in model.list_params: raise ValueError( f"Parameter {user_param.name} not found in list_params." @@ -229,6 +230,7 @@ def collect_user_params( # If any of the keys is found in `list_params` it is a parameter specification. # We add the parameter specification to `user_params` and remove it from # `kwargs` + assert model.list_params is not None, "list_params should be set" for param_name in model.list_params: # Update user_params only if param_name is in kwargs # and not already in user_params @@ -272,6 +274,7 @@ def make_params(model: HSSM, user_params: dict[str, UserParam]) -> dict[str, Par and model.loglik_kind != "approx_differentiable" ) + assert model.list_params is not None, "list_params should be set" for name in model.list_params: if name in user_params: param = make_param_from_user_param(model, name, user_params[name]) diff --git a/src/hssm/param/regression_param.py b/src/hssm/param/regression_param.py index c28d7c8b..64ac1816 100644 --- a/src/hssm/param/regression_param.py +++ b/src/hssm/param/regression_param.py @@ -111,7 +111,7 @@ def from_defaults( def fill_defaults( self, - prior: dict[str, Any] | None = None, + prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None, bounds: tuple[float, float] | None = None, **kwargs, ) -> None: diff --git a/src/hssm/param/simple_param.py b/src/hssm/param/simple_param.py index 70d00104..a71528ac 100644 --- a/src/hssm/param/simple_param.py +++ b/src/hssm/param/simple_param.py @@ -111,7 +111,7 @@ def from_user_param(cls, user_param: UserParam) -> "SimpleParam": def fill_defaults( self, - prior: dict[str, Any] | None = None, + prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None, bounds: tuple[float, float] | None = None, **kwargs, ) -> None: @@ -201,14 +201,17 @@ class DefaultParam(SimpleParam): def __init__( self, name: str, - prior: float | np.ndarray | dict[str, Any] | bmb.Prior, - bounds: tuple[float, float], + prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None, + bounds: tuple[float, float] | None, ) -> None: super().__init__(name, prior=prior, bounds=bounds) @classmethod def from_defaults( - cls, name: str, prior: dict[str, Any], bounds: tuple[int, int] + cls, + name: str, + prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None, + bounds: tuple[float, float] | None, ) -> "DefaultParam": """Create a DefaultParam object from default values. @@ -248,7 +251,7 @@ def process_prior(self) -> None: def fill_defaults( self, - prior: dict[str, Any] | None = None, + prior: float | np.ndarray | dict[str, Any] | bmb.Prior | None = None, bounds: tuple[float, float] | None = None, **kwargs, ) -> None: diff --git a/src/hssm/param/utils.py b/src/hssm/param/utils.py index 96965f27..6d630f67 100644 --- a/src/hssm/param/utils.py +++ b/src/hssm/param/utils.py @@ -26,7 +26,7 @@ def _make_default_prior(bounds: tuple[float, float] | None) -> bmb.Prior: A bmb.Prior object representing the default prior for the provided bounds. """ if bounds is None: - raise ValueError("Parameter unspecified.") + raise ValueError("Bounds parameter unspecified.") lower, upper = bounds if np.isinf(lower) and np.isinf(upper): prior = bmb.Prior("Normal", mu=0.0, sigma=2.0) diff --git a/src/hssm/rl/__init__.py b/src/hssm/rl/__init__.py new file mode 100644 index 00000000..64e17bc4 --- /dev/null +++ b/src/hssm/rl/__init__.py @@ -0,0 +1,27 @@ +"""Reinforcement-learning extensions for HSSM. + +This subpackage groups components that integrate reinforcement-learning +learning rules with sequential-sampling decision models (SSMs). + +Public API (import from ``hssm.rl``): + +- ``RLSSM``: the RL + SSM model class implemented in :mod:`hssm.rl.rlssm`. +- ``RLSSMConfig``: the config class for RL + SSM models, implemented in + :mod:`hssm.rl.config`. +- ``validate_balanced_panel``: panel-balance utility in :mod:`hssm.rl.utils`. + +RL likelihood builders live in :mod:`hssm.rl.likelihoods.builder` and include +helpers such as :func:`~hssm.rl.likelihoods.builder.make_rl_logp_func` and +:func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`. + +""" + +from .config import RLSSMConfig +from .rlssm import RLSSM +from .utils import validate_balanced_panel + +__all__ = [ + "RLSSM", + "RLSSMConfig", + "validate_balanced_panel", +] diff --git a/src/hssm/rl/config.py b/src/hssm/rl/config.py new file mode 100644 index 00000000..17ec2ea1 --- /dev/null +++ b/src/hssm/rl/config.py @@ -0,0 +1,219 @@ +"""RL-specific configuration classes. + +This module houses `RLSSMConfig` which was previously defined in +`hssm.config`. It is intentionally lightweight and re-uses +`BaseModelConfig` from :mod:`hssm.config` to avoid duplicating core +behaviour. +""" + +from __future__ import annotations + +import logging +from dataclasses import MISSING, dataclass, field, fields +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .._types import LoglikKind, SupportedModels + from ..config import ModelConfig + +from ..config import DEFAULT_SSM_CHOICES, DEFAULT_SSM_OBSERVED_DATA, BaseModelConfig + +_logger = logging.getLogger("hssm") + + +@dataclass +class RLSSMConfig(BaseModelConfig): + """Config for reinforcement learning + sequential sampling models. + + Extends :class:`BaseModelConfig` with the fields required by the RLSSM + likelihood pipeline. The key extra fields are: + + - ``ssm_logp_func``: the annotated JAX SSM log-likelihood function (see + below) whose ``computed`` dict drives per-parameter RL computations. + - ``learning_process``: a mapping that declares *how* each computed + parameter is specified (see below). + - ``decision_process``: the name (string) or :class:`ModelConfig` instance + that identifies the SSM decision process (e.g. ``"ddm"``, ``"angle"``). + - ``decision_process_loglik_kind`` / ``learning_process_kind``: string + tags that record which kind of likelihood and which kind of learning rule + are used (e.g. ``"approx_differentiable"`` / ``"blackbox"``). + + ssm_logp_func: + A JAX function decorated with ``@annotate_function``. It must carry: + + - ``.inputs`` — ordered list of all parameter names the function + expects (e.g. ``["v", "a", "z", "t", "theta", "rt", "response"]``). + - ``.outputs`` — list of output names (e.g. ``["logp"]``). + - ``.computed`` — dict mapping each *computed* parameter name to the + annotated function that produces it. For example:: + + {"v": compute_v_annotated} + + where ``compute_v_annotated`` is itself decorated with + ``@annotate_function`` and carries ``.inputs`` / ``.outputs``. + + ``make_rl_logp_op`` inspects ``ssm_logp_func.computed`` to resolve + which parameters come from data / sampled posteriors and which must + be computed by the RL learning rule at each gradient step. + + learning_process: + A dict keyed by the name of each computed parameter (matching the keys + in ``ssm_logp_func.computed``). Values record how that parameter is + specified. The dict is intentionally permissive — current supported + value forms are: + + - **callable** — an annotated function (or plain function) used to + compute the parameter. The actual computation at runtime is driven + by ``ssm_logp_func.computed``; this entry serves as declarative + documentation and for config serialisation / round-trip:: + + learning_process = {"v": compute_v_annotated} + + - **string** — a symbolic identifier for declarative / YAML-based + configs that can be resolved to a callable by the caller:: + + learning_process = {"v": "subject_wise_function"} + + An empty dict ``{}`` is valid when the SSM has no computed parameters + (i.e. ``ssm_logp_func.computed`` is also empty). + + .. note:: + The dict is *not* directly consumed by ``make_rl_logp_op``. + The actual compute functions used at runtime come from + ``ssm_logp_func.computed``. ``learning_process`` therefore acts + as a config-level record of intent and is useful for inspection, + serialisation, and future higher-level tooling. + """ + + decision_process_loglik_kind: str = field(kw_only=True) + learning_process_kind: str = field(kw_only=True) + params_default: list[float] = field(kw_only=True) + decision_process: str | "ModelConfig" = field(kw_only=True) + learning_process: dict[str, Any] = field(kw_only=True) + ssm_logp_func: Any = field(default=None, kw_only=True) + + def __post_init__(self): # noqa: D105 + if self.loglik_kind is None: + self.loglik_kind = "approx_differentiable" + _logger.debug( + "RLSSMConfig: loglik_kind not specified; " + "defaulting to 'approx_differentiable'." + ) + + @classmethod + def from_defaults( # noqa: D102 + cls, model_name: "SupportedModels" | str, loglik_kind: "LoglikKind" | None + ) -> "RLSSMConfig": + raise NotImplementedError( + "RLSSMConfig does not support from_defaults(). " + "Use RLSSMConfig.from_rlssm_dict() or the constructor directly." + ) + + @classmethod + def from_rlssm_dict(cls, config_dict: dict[str, Any]) -> "RLSSMConfig": # noqa: D102 + # Derive required fields from the dataclass itself: a field is required + # iff it has no default and no default_factory. This keeps the dataclass + # as the single source of truth — no separate required-key list needed. + field_exceptions = ("loglik", "loglik_kind", "backend") + required_fields = [ + f.name + for f in fields(cls) + if f.name not in field_exceptions + and f.default is MISSING + and f.default_factory is MISSING # type: ignore[misc] + ] + for field_name in required_fields: + if field_name not in config_dict or config_dict[field_name] is None: + raise ValueError(f"{field_name} must be provided in config_dict") + + # ssm_logp_func has a dataclass default of None but is required in practice. + if config_dict.get("ssm_logp_func") is None: + raise ValueError("ssm_logp_func must be provided in config_dict") + + init_kwargs = dict( + model_name=config_dict["model_name"], + description=config_dict.get("description"), + list_params=config_dict.get("list_params"), + extra_fields=config_dict.get("extra_fields"), + params_default=config_dict["params_default"], + decision_process=config_dict["decision_process"], + learning_process=config_dict["learning_process"], + ssm_logp_func=config_dict.get("ssm_logp_func"), + bounds=config_dict.get("bounds", {}), + decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], + learning_process_kind=config_dict["learning_process_kind"], + ) + + def _get_or_warn(key: str, default: Any) -> None: + if key not in config_dict: + _logger.warning( + "'%s' not specified in the RLSSM config; using default value: %r.", + key, + default, + ) + init_kwargs[key] = config_dict.get(key, default) + + _get_or_warn("response", DEFAULT_SSM_OBSERVED_DATA) + _get_or_warn("choices", DEFAULT_SSM_CHOICES) + + return cls(**init_kwargs) + + def validate(self) -> None: # noqa: D102 + if self.response is None: + raise ValueError("Please provide `response` columns in the configuration.") + if self.list_params is None: + raise ValueError("Please provide `list_params` in the configuration.") + if self.choices is None: + raise ValueError("Please provide `choices` in the configuration.") + if self.decision_process is None: + raise ValueError("Please specify a `decision_process`.") + + logpfunc = self.ssm_logp_func + if logpfunc is None: + raise ValueError( + "Please provide `ssm_logp_func`: the fully annotated JAX SSM " + "log-likelihood function required by `make_rl_logp_op`." + ) + if not callable(logpfunc): + raise ValueError( + f"`ssm_logp_func` must be a callable, but got {type(logpfunc)!r}." + ) + missing_attrs = [ + attr + for attr in ("inputs", "outputs", "computed") + if not hasattr(logpfunc, attr) + ] + if missing_attrs: + raise ValueError( + "`ssm_logp_func` must be decorated with `@annotate_function` " + "so that it carries the attributes required by `make_rl_logp_op`. " + f"Missing attribute(s): {missing_attrs}. " + ) + + if not isinstance(logpfunc.computed, dict) or not all( + callable(v) for v in logpfunc.computed.values() + ): + raise ValueError( + "`ssm_logp_func.computed` must be a dictionary with callable values." + ) + + if self.params_default and self.list_params: + if len(self.params_default) != len(self.list_params): + raise ValueError( + f"params_default length ({len(self.params_default)}) doesn't " + f"match list_params length ({len(self.list_params)})" + ) + + if self.list_params: + missing_bounds = [p for p in self.list_params if p not in self.bounds] + if missing_bounds: + raise ValueError( + f"Missing bounds for parameter(s): {missing_bounds}. " + "Every parameter in `list_params` must have a corresponding " + "entry in `bounds`." + ) + + def get_defaults( # noqa: D102 + self, param: str + ) -> tuple[float | None, tuple[float, float] | None]: + return None, self.bounds.get(param) diff --git a/src/hssm/rl/rlssm.py b/src/hssm/rl/rlssm.py new file mode 100644 index 00000000..180cd12c --- /dev/null +++ b/src/hssm/rl/rlssm.py @@ -0,0 +1,264 @@ +"""RLSSM: Reinforcement Learning Sequential Sampling Model. + +This module defines the :class:`RLSSM` class, a subclass of :class:`HSSMBase` +for models that couple a reinforcement learning (RL) learning process with a +sequential sampling decision model (SSM). + +The key difference from :class:`HSSM` is the likelihood: + - ``HSSM`` wraps an analytical / ONNX / blackbox callable via + :func:`~hssm.distribution_utils.make_likelihood_callable`. + - ``RLSSM`` builds a differentiable pytensor ``Op`` directly from an + :class:`~hssm.rl.likelihoods.builder.AnnotatedFunction` via + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op`, which internally + handles the RL learning rule and per-participant trial structure. + This Op is then passed straight to + :func:`~hssm.distribution_utils.make_distribution`, bypassing the + standard ``loglik`` / ``loglik_kind`` wrapping pipeline. +""" + +from dataclasses import replace +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import bambi as bmb +import pandas as pd +import pymc as pm + +if TYPE_CHECKING: + from pytensor.graph import Op + + +from hssm.defaults import ( + INITVAL_JITTER_SETTINGS, +) +from hssm.distribution_utils import make_distribution +from hssm.rl.likelihoods.builder import make_rl_logp_op +from hssm.rl.utils import validate_balanced_panel + +from ..base import HSSMBase +from .config import RLSSMConfig + + +class RLSSM(HSSMBase): + """Reinforcement Learning Sequential Sampling Model. + + Combines a reinforcement learning (RL) process with a sequential sampling + model (SSM) inside a single differentiable likelihood. The RL component + computes trial-wise intermediate parameters (e.g., drift rates) from the + learning history, which are then fed into the SSM log-likelihood. + + The likelihood is built via + :func:`~hssm.rl.likelihoods.builder.make_rl_logp_op` from the annotated + SSM function stored in *model_config.ssm_logp_func*. This produces a + differentiable pytensor ``Op`` that is passed directly to + :func:`~hssm.distribution_utils.make_distribution`, superseding the + ``loglik`` / ``loglik_kind`` dispatching used by :class:`~hssm.hssm.HSSM`. + + Parameters + ---------- + data : pd.DataFrame + Trial-level data. Must contain at least the response columns + specified in *model_config* (typically ``"rt"`` and ``"response"``), + a participant identifier column (default ``"participant_id"``), and + any extra fields listed in *model_config.extra_fields*. + The data **must** form a balanced panel: every participant must have + the same number of trials. + model_config : RLSSMConfig + Full configuration for the RLSSM model. Must have ``ssm_logp_func`` + set to the annotated JAX SSM log-likelihood function. + participant_col : str, optional + Name of the column that uniquely identifies participants. + Used to infer ``n_participants`` and ``n_trials`` from *data*. + Defaults to ``"participant_id"``. + include : list, optional + Parameter specifications forwarded to :class:`~hssm.base.HSSMBase`. + p_outlier : float | dict | bmb.Prior | None, optional + Lapse probability specification. Defaults to ``0.05``. + lapse : dict | bmb.Prior | None, optional + Lapse distribution. Defaults to ``Uniform(0, 20)``. + link_settings : Literal["log_logit"] | None, optional + Link-function preset. Defaults to ``None``. + prior_settings : Literal["safe"] | None, optional + Prior preset. Defaults to ``"safe"``. + extra_namespace : dict | None, optional + Extra variables for formula evaluation. Defaults to ``None``. + missing_data : bool | float, optional + Whether to handle missing RT data coded as ``-999.0``. + Defaults to ``False``. + deadline : bool | str, optional + Whether to handle deadline data. Defaults to ``False``. + loglik_missing_data : Callable | None, optional + Custom likelihood for missing observations. Defaults to ``None``. + process_initvals : bool, optional + Whether to post-process initial values. Defaults to ``True``. + initval_jitter : float, optional + Jitter magnitude for initial values. + Defaults to :data:`~hssm.defaults.INITVAL_JITTER_SETTINGS` epsilon. + **kwargs + Additional keyword arguments forwarded to :class:`bmb.Model`. + + Attributes + ---------- + model_config : RLSSMConfig + The RLSSM configuration object (stored as ``self.model_config`` on + :class:`~hssm.base.HSSMBase` with the built ``loglik`` Op injected). + n_participants : int + Number of participants inferred from *data*. + n_trials : int + Number of trials per participant inferred from *data*. + """ + + def __init__( + self, + data: pd.DataFrame, + model_config: RLSSMConfig, + participant_col: str = "participant_id", + include: list[dict[str, Any] | Any] | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict[str, Any] | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: Callable | None = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs: Any, + ) -> None: + # ===== save/load serialisation ===== + self._init_args = self._store_init_args(locals(), kwargs) + + # Validate config (ensures ssm_logp_func is present, etc.) + model_config.validate() + + # RLSSM reshapes rows into (n_participants, n_trials, ...) by position, + # so _rearrange_data (which moves missing/deadline rows to the front) + # would scramble per-participant trial sequences and corrupt RL dynamics. + # Raise early so the user gets a clear message before model construction. + if missing_data is not False: + raise NotImplementedError( + "RLSSM currently does not support `missing_data` handling. " + "The RL log-likelihood Op relies on strict row order to recover " + "per-participant trial sequences; rearranging rows for missing RT " + "values would corrupt the RL learning dynamics. " + "Please remove missing trials from the data before passing it to RLSSM." + ) + if deadline is not False: + raise NotImplementedError( + "RLSSM currently does not support `deadline` handling. " + "The RL log-likelihood Op relies on strict row order to recover " + "per-participant trial sequences; rearranging rows for deadline " + "trials would corrupt the RL learning dynamics. Please remove " + "deadline trials from the data before passing it to RLSSM." + ) + + # Infer panel structure and validate balance BEFORE calling super so any + # error surfaces before the expensive model-build steps. + n_participants, n_trials = validate_balanced_panel(data, participant_col) + + # Store RL-specific state on self BEFORE super().__init__() so that + # _make_model_distribution() (called from super) can access them. + self.config = model_config + self.n_participants = n_participants + self.n_trials = n_trials + + # Build the differentiable pytensor Op from the annotated SSM function. + # This Op supersedes the loglik/loglik_kind workflow: it is stored on + # rlssm_config.loglik so that HSSMBase can access it uniformly via + # self.model_config.loglik, without any Config conversion. + # + # Fresh list() copies are passed to make_rl_logp_op so the closure inside + # captures its own isolated list objects. HSSMBase will later append + # "p_outlier" to self.list_params, and that mutation must NOT be visible + # to the Op's _validate_args_length check at sampling time. + loglik_op = make_rl_logp_op( + ssm_logp_func=model_config.ssm_logp_func, + n_participants=n_participants, + n_trials=n_trials, + data_cols=list(model_config.response), # type: ignore[arg-type] + list_params=list(model_config.list_params), # type: ignore[arg-type] + extra_fields=list(model_config.extra_fields or []), + ) + + # Build a new RLSSMConfig with the Op and backend injected, leaving + # the caller's object unmodified (dataclasses.replace creates a shallow + # copy with only the specified fields overridden). + # + # backend is hardcoded to "jax" because the entire RLSSM likelihood + # stack is JAX-only. See ssm_logp_func, make_rl_logp_op, and + # _make_model_distribution for details. + model_config = replace(model_config, loglik=loglik_op, backend="jax") + + super().__init__( + data=data, + model_config=model_config, + include=include, + p_outlier=p_outlier, + lapse=lapse, + link_settings=link_settings, + prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, + deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, + initval_jitter=initval_jitter, + **kwargs, + ) + + def _make_model_distribution(self) -> type[pm.Distribution]: + """Build a pm.Distribution using the pre-built RL log-likelihood Op. + + Unlike :meth:`HSSM._make_model_distribution`, this method does not go + through :func:`~hssm.distribution_utils.make_likelihood_callable`. + Instead it uses ``self.loglik`` directly — the differentiable pytensor + ``Op`` built in :meth:`__init__` from + ``self.model_config.ssm_logp_func``. + + The Op already handles: + - The RL learning rule (computing trial-wise intermediate parameters). + - The per-participant / per-trial data reshaping. + - Gradient computation via its embedded VJP. + + Missing-data network assembly (OPN / CPN) is not yet supported for + RLSSM and ``missing_data`` / ``deadline`` are rejected in ``__init__`` + before this method is ever reached. + """ + # Use self.list_params (managed by HSSMBase, includes p_outlier when + # has_lapse=True) rather than self.model_config.list_params (the original + # config list, never mutated by HSSMBase). + list_params = self.list_params + assert list_params is not None, "list_params must be set" + assert isinstance(list_params, list), ( + "list_params must be a list" + ) # for type checker + + # Every RLSSM distribution parameter is trialwise (the Op receives one + # value per trial). p_outlier is excluded to match the contract of + # make_distribution, which strips p_outlier before indexing this list. + params_is_trialwise = [True for name in list_params if name != "p_outlier"] + + extra_fields = self.model_config.extra_fields or [] + extra_fields_data = ( + None + if not extra_fields + else [self.data[field].to_numpy(copy=True) for field in extra_fields] + ) + + # The differentiable pytensor Op was stored on model_config.loglik during + # __init__; ensure it's present and cast for typing. + assert self.model_config.loglik is not None, "model_config.loglik must be set" + loglik_op = cast("Callable[..., Any] | Op", self.model_config.loglik) + + # RLSSMConfig carries no `rv` field; use model_name as the rv identifier. + rv_name = self.model_config.model_name + + return make_distribution( + rv=rv_name, + loglik=loglik_op, + list_params=list_params, + bounds=self.bounds, + lapse=self.lapse, + extra_fields=extra_fields_data, + params_is_trialwise=params_is_trialwise, + ) diff --git a/src/hssm/rl/utils.py b/src/hssm/rl/utils.py new file mode 100644 index 00000000..d35f08a9 --- /dev/null +++ b/src/hssm/rl/utils.py @@ -0,0 +1,70 @@ +"""Utility functions for reinforcement learning + SSM models.""" + +import pandas as pd + + +def validate_balanced_panel( + data: pd.DataFrame, + participant_col: str = "participant_id", +) -> tuple[int, int]: + """Validate that data forms a balanced panel and return its shape. + + A balanced panel requires every participant to have exactly the same number + of trials (rows in *data*). + + Parameters + ---------- + data : pd.DataFrame + The DataFrame to validate. + participant_col : str, optional + Name of the column that identifies participants. + Defaults to ``"participant_id"``. + + Returns + ------- + tuple[int, int] + ``(n_participants, n_trials)`` where *n_trials* is the number of rows + per participant. + + Raises + ------ + ValueError + If *participant_col* is not present in *data*, or if the panel is + unbalanced (participants have different trial counts). + """ + if participant_col not in data.columns: + raise ValueError( + f"Column '{participant_col}' not found in data. " + "Please provide the correct participant column name via " + "`participant_col`." + ) + + n_null = data[participant_col].isna().sum() + if n_null > 0: + raise ValueError( + f"Column '{participant_col}' contains {n_null} NaN value(s). " + "All rows must have a valid participant identifier." + ) + + counts = data.groupby(participant_col).size() + if counts.nunique() != 1: + raise ValueError( + "Data must form balanced panels: all participants must have the " + f"same number of trials. Observed trial counts: {dict(counts)}" + ) + + # Check that each participant's rows are contiguous (not interleaved). + # make_rl_logp_op reshapes data as (n_participants, n_trials, ...) by row + # order, so interleaved rows would silently mix subjects and corrupt the + # RL learning dynamics. + n_runs = int((data[participant_col] != data[participant_col].shift()).sum()) + if n_runs != len(counts): + raise ValueError( + "Data rows must be contiguous per participant. " + "The RL likelihood reshapes data by row position; interleaved " + "participant rows will corrupt per-participant trial sequences. " + "Please sort the data by participant before passing it to RLSSM " + f"(e.g. data.sort_values('{participant_col}'))." + ) + + return int(len(counts)), int(counts.iloc[0]) diff --git a/tests/param/test_default_param.py b/tests/param/test_default_param.py index 2a90be3f..3f86ebaa 100644 --- a/tests/param/test_default_param.py +++ b/tests/param/test_default_param.py @@ -4,6 +4,7 @@ from hssm import Prior from hssm.param.simple_param import DefaultParam +from hssm.param.utils import _make_default_prior def test_from_defaults(): @@ -40,3 +41,7 @@ def test_make_default_prior(bounds, prior): assert param.prior.name == prior.pop("name") for key, value in prior.items(): assert param.prior.args[key] == value + + +def test_make_default_prior_no_bounds(): + pytest.raises(ValueError, _make_default_prior, None) diff --git a/tests/test_config.py b/tests/test_config.py index 77d8284b..5685ca2b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,8 +1,11 @@ -import numpy as np +import logging + import pytest +import numpy as np -import hssm from hssm.config import Config, ModelConfig +import hssm + hssm.set_floatX("float32") @@ -87,3 +90,60 @@ def test_update_config(): assert v_prior.name == "Normal" assert v_bounds == (-np.inf, np.inf) + + +class TestConfigBuildModelConfigExtraLogic: + def test_build_model_config_dict_with_choices_conflict(self, caplog): + # model 'ddm' has defaults in hssm.defaults; use a minimal dict override + model_config = { + "response": ("rt", "response"), + "list_params": ["v", "a"], + "choices": (0, 1), + } + # provide a different choices argument — should log that model_config wins + with caplog.at_level(logging.INFO): + cfg = Config._build_model_config("ddm", None, model_config, choices=[1, 0]) + + assert isinstance(cfg, Config) + assert "choices list provided in both model_config" in caplog.text + + def test_build_model_config_modelconfig_adds_choices(self): + # Create a ModelConfig without choices and pass choices argument + mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) + cfg = Config._build_model_config("ddm", None, mc, choices=(0, 1)) + # choices should be applied to resulting Config + assert cfg.choices == (0, 1) + + def test_build_model_config_uses_ssms_model_config(self, monkeypatch): + # High-level view of the test: ensures that when a model name is not in the built-in + # SupportedModels and no choices argument is passed, _build_model_config will consult + # the external ssms_model_config registry and use its defaults (here, the choices tuple). + # The monkeypatch fixture isolates the change and will be undone after the test. + + # Simulate an external ssms_model_config entry for a model not in SupportedModels + fake_model = "external_ssm" + fake_choices = (2, 3) + + # Monkeypatch the ssms_model_config mapping in the module + import hssm.config as cfgmod + + # Emulate an external package registering defaults for external_ssm. + # Ensures `_build_model_config` will consult `ssms_model_config` + # when the model name isn't in SupportedModels. + monkeypatch.setitem( + cfgmod.ssms_model_config, fake_model, {"choices": fake_choices} + ) + + # Build config with model not in SupportedModels and no choices arg. + # Provide a minimal ModelConfig and a dummy `loglik` so + # `Config.validate()` runs (loglik is required) while still + # exercising the ssms-simulators choices fallback. + mc = ModelConfig(response=("rt", "response"), list_params=["v"], choices=None) + result = Config._build_model_config( + fake_model, + "analytical", + mc, + choices=None, + loglik=(lambda *a, **k: None), # required so Config.validate() passes + ) + assert result.choices == fake_choices diff --git a/tests/test_data_validator.py b/tests/test_data_validator.py index af176a8e..e3715129 100644 --- a/tests/test_data_validator.py +++ b/tests/test_data_validator.py @@ -152,29 +152,6 @@ def test_post_check_data_sanity_valid(base_data): dv_instance_no_missing._post_check_data_sanity() -def test_handle_missing_data_and_deadline_deadline_column_missing(base_data): - # Should raise ValueError if deadline is True but deadline_name column is missing - data = base_data.drop(columns=["deadline"]) - dv = DataValidatorTester( - data=data, - deadline=True, - ) - with pytest.raises(ValueError, match="`deadline` is not found in your dataset"): - dv._handle_missing_data_and_deadline() - - -def test_handle_missing_data_and_deadline_deadline_applied(base_data): - # Should set rt to -999.0 where rt >= deadline - base_data.loc[0, "rt"] = 2.0 # Exceeds deadline - dv = DataValidatorTester( - data=base_data, - deadline=True, - ) - dv._handle_missing_data_and_deadline() - assert dv.data.loc[0, "rt"] == -999.0 - assert all(dv.data.loc[1:, "rt"] < dv.data.loc[1:, "deadline"]) - - def test_update_extra_fields(monkeypatch): # Create a DataValidatorTester with extra_fields data = pd.DataFrame( @@ -207,61 +184,6 @@ class DummyModelDist: assert (dv.model_distribution.extra_fields[i] == data[field].values).all() -def test_set_missing_data_and_deadline(): - # No missing data and no deadline - data = pd.DataFrame({"rt": [0.5, 0.7]}) - assert ( - DataValidatorMixin._set_missing_data_and_deadline(False, False, data) - == MissingDataNetwork.NONE - ) - # Missing data but no deadline - data = pd.DataFrame({"rt": [0.5, -999.0]}) - assert ( - DataValidatorMixin._set_missing_data_and_deadline(True, False, data) - == MissingDataNetwork.CPN - ) - assert ( - DataValidatorMixin._set_missing_data_and_deadline(True, True, data) - == MissingDataNetwork.OPN - ) - # AF-TODO: I think GONOGO as a network category can go, - # but needs a little more thought, out of scope for PR, - # during which this was commented out. - # assert ( - # DataValidatorMixin._set_missing_data_and_deadline(True, True, data) - # == MissingDataNetwork.GONOGO - # ) - - -def test_set_missing_data_and_deadline_all_missing(): - data = pd.DataFrame({"rt": [-999.0, -999.0]}) - # cpn - with pytest.raises( - ValueError, - match="`missing_data` is set to True, but you have no valid data in your " - "dataset.", - ): - DataValidatorMixin._set_missing_data_and_deadline(True, False, data) - - # opn - with pytest.raises( - ValueError, - match="`missing_data` is set to True, but you have no valid data in your " - "dataset.", - ): - DataValidatorMixin._set_missing_data_and_deadline(True, True, data) - - # AF-TODO: GONOGO case not yet correctly implemented - # gonogo - # data = pd.DataFrame({"rt": [-999.0, -999.0]}) - # with pytest.raises( - # ValueError, - # match="`missing_data` is set to True, but you have no valid data in your " - # + "dataset.", - # ): - # DataValidatorMixin._set_missing_data_and_deadline(True, True, data) - - def test_validate_choices(): # ====== Valid choices ===== dv = DataValidatorTester( diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 2af50779..c3be947b 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -86,21 +86,19 @@ def test_custom_model(data_ddm): with pytest.raises( ValueError, match="When using a custom model, please provide a `loglik_kind.`" ): - model = HSSM(data=data_ddm, model="custom") + HSSM(data=data_ddm, model="custom") - with pytest.raises(ValueError, match="Please provide `list_params`"): - model = HSSM(data=data_ddm, model="custom", loglik_kind="analytical") + with pytest.raises(ValueError, match=r"^Please provide `list_params`"): + HSSM(data=data_ddm, model="custom", loglik_kind="analytical") - with pytest.raises(ValueError, match="Please provide `list_params`"): - model = HSSM( - data=data_ddm, model="custom", loglik=DDM, loglik_kind="analytical" - ) + with pytest.raises(ValueError, match=r"^Please provide `list_params`"): + HSSM(data=data_ddm, model="custom", loglik=DDM, loglik_kind="analytical") with pytest.raises( ValueError, - match="Please provide `list_params`", + match=r"^Please provide `list_params`", ): - model = HSSM( + HSSM( data=data_ddm, model="custom", loglik=DDM, @@ -125,9 +123,9 @@ def test_custom_model(data_ddm): loglik_kind="analytical", ) - assert model.model_name == "custom" - assert model.loglik_kind == "analytical" - assert model.list_params == ["v", "a", "z", "t", "p_outlier"] + assert model.model_config.model_name == "custom" + assert model.model_config.loglik_kind == "analytical" + assert model.model_config.list_params == ["v", "a", "z", "t", "p_outlier"] @pytest.mark.slow @@ -165,12 +163,12 @@ def test_sample_prior_predictive(data_ddm_reg): model_regression = HSSM( data=data_ddm_reg, include=[dict(name="v", formula="v ~ 1 + x")] ) - prior_predictive_3 = model_regression.sample_prior_predictive(draws=10) + model_regression.sample_prior_predictive(draws=10) model_regression_a = HSSM( data=data_ddm_reg, include=[dict(name="a", formula="a ~ 1 + x")] ) - prior_predictive_4 = model_regression_a.sample_prior_predictive(draws=10) + model_regression_a.sample_prior_predictive(draws=10) model_regression_multi = HSSM( data=data_ddm_reg, @@ -179,7 +177,7 @@ def test_sample_prior_predictive(data_ddm_reg): dict(name="a", formula="a ~ 1 + y"), ], ) - prior_predictive_5 = model_regression_multi.sample_prior_predictive(draws=10) + model_regression_multi.sample_prior_predictive(draws=10) data_ddm_reg.loc[:, "subject_id"] = np.arange(10) @@ -190,9 +188,7 @@ def test_sample_prior_predictive(data_ddm_reg): dict(name="a", formula="a ~ (1|subject_id) + y"), ], ) - prior_predictive_6 = model_regression_random_effect.sample_prior_predictive( - draws=10 - ) + model_regression_random_effect.sample_prior_predictive(draws=10) @pytest.mark.slow diff --git a/tests/test_missing_data_mixin.py b/tests/test_missing_data_mixin.py new file mode 100644 index 00000000..a5476a52 --- /dev/null +++ b/tests/test_missing_data_mixin.py @@ -0,0 +1,221 @@ +import pytest +import pandas as pd + +from hssm.missing_data_mixin import MissingDataMixin +from hssm.defaults import MissingDataNetwork + + +class DummyModel(MissingDataMixin): + """ + Dummy model for testing MissingDataMixin. + + This class provides stub implementations of methods that the mixin expects + to exist on the consuming class. These stubs allow us to verify, via mocks/spies, + that the mixin calls them as part of its logic. This is a common pattern for + testing mixins: the dummy class provides the required interface, and the test + checks the mixin's interaction with it. + """ + + def __init__(self, data): + self.data = data + self.response = ["response"] + self.missing_data_value = -999.0 + self.missing_data = False + self.deadline = False + self.is_choice_only = False + + +# region ===== Fixtures ===== +@pytest.fixture +def basic_data(): + return pd.DataFrame({"rt": [1.0, 2.0, -999.0], "response": [1, -1, 1]}) + + +@pytest.fixture +def dummy_model(basic_data): + return DummyModel(basic_data) + + +@pytest.fixture +def dummy_model_with_deadline(basic_data): + data = basic_data.assign(deadline=[2.0, 2.0, 2.0]) + return DummyModel(data) + + +# Indirect fixture dispatcher for parameterized model selection +@pytest.fixture +def model(request): + return request.getfixturevalue(request.param) + + +# endregion + + +class TestProcessMissingDataAndDeadline: + @pytest.mark.parametrize( + "model, deadline", + [ + ("dummy_model", False), + ("dummy_model_with_deadline", True), + ("dummy_model_with_deadline", "deadline"), + ], + indirect=["model"], + ) + def test_missing_data_false_raises_valueerror(self, model, deadline): + """ + Should raise ValueError if missing_data=False and -999.0 is present in rt column. + Covers all cases where deadline is False, True, or a string. + """ + with pytest.raises(ValueError, match="Missing data provided as False"): + model._process_missing_data_and_deadline( + missing_data=False, + deadline=deadline, + loglik_missing_data=None, + ) + + +# --- 2. Additional tests for new features and edge cases in MissingDataMixin --- +class TestMissingDataMixinNew: + def test_set_missing_data_network_set(self, dummy_model): + # missing_data True, deadline False + dummy_model._process_missing_data_and_deadline( + missing_data=True, deadline=False, loglik_missing_data=None + ) + assert dummy_model.missing_data_network == MissingDataNetwork.CPN + + # missing_data True, deadline True + dummy_model.data["deadline"] = 2.0 + dummy_model._process_missing_data_and_deadline( + missing_data=True, deadline=True, loglik_missing_data=None + ) + assert dummy_model.missing_data_network == MissingDataNetwork.OPN + + # missing_data False, deadline False (should raise ValueError due to -999.0 present) + with pytest.raises(ValueError, match="Missing data provided as False"): + dummy_model._process_missing_data_and_deadline( + missing_data=False, deadline=False, loglik_missing_data=None + ) + + def test_response_appended_with_deadline_name(self, dummy_model): + # Should append deadline_name to response if not present + dummy_model.data["deadline"] = 2.0 + dummy_model.response = ["response"] + dummy_model._process_missing_data_and_deadline( + missing_data=True, deadline="deadline", loglik_missing_data=None + ) + assert "deadline" in dummy_model.response + + def test_error_on_missing_data_false_with_missing(self, dummy_model): + # Should raise ValueError if missing_data is False and -999.0 is present + with pytest.raises(ValueError, match="Missing data provided as False"): + dummy_model._process_missing_data_and_deadline( + missing_data=False, deadline=False, loglik_missing_data=None + ) + + def test_missing_data_true_retains_missing_marker(self, dummy_model): + # Should retain -999.0 as missing marker if missing_data is True + dummy_model._process_missing_data_and_deadline( + missing_data=True, deadline=False, loglik_missing_data=None + ) + assert -999.0 in dummy_model.data.rt.values + + def test_deadline_sets_rt_to_missing_marker(self, dummy_model): + # Should set rt to -999.0 if above deadline + # Set up so that the second RT is above its deadline + dummy_model.data["rt"] = [1.0, 3.0, -999.0] # 3.0 > 2.5 + dummy_model.data["deadline"] = [1.5, 2.5, 2.5] + dummy_model._process_missing_data_and_deadline( + missing_data=True, deadline="deadline", loglik_missing_data=None + ) + # The first row rt=1.0 < 1.5, so not -999.0; second should be -999.0; third is already -999.0 + assert dummy_model.data.rt.iloc[0] == 1.0 + assert dummy_model.data.rt.iloc[1] == -999.0 + assert dummy_model.data.rt.iloc[2] == -999.0 + + def test_loglik_missing_data_error(self, dummy_model): + # Should raise if loglik_missing_data is set but both missing_data and deadline are False + dummy_model.data.rt = [1.0, 2.0, 3.0] # No -999.0 present + with pytest.raises( + ValueError, + match="loglik_missing_data function, but you have not set the missing_data or deadline flag to True", + ): + dummy_model._process_missing_data_and_deadline( + missing_data=False, deadline=False, loglik_missing_data=lambda x: x + ) + + def test_process_missing_data_and_deadline_updates_attributes(self, dummy_model): + """ + Test that _process_missing_data_and_deadline updates missing_data, deadline, deadline_name, and loglik_missing_data. + """ + + # Set up a custom loglik function + def custom_loglik(x): + return x + + # Add a custom_deadline column to the data to satisfy the mixin's requirements + dummy_model.data["custom_deadline"] = 2.0 + # Call with missing_data True, deadline as string, and custom loglik + dummy_model._process_missing_data_and_deadline( + missing_data=True, + deadline="custom_deadline", + loglik_missing_data=custom_loglik, + ) + assert dummy_model.missing_data is True + assert dummy_model.deadline is True + assert dummy_model.deadline_name == "custom_deadline" + assert dummy_model.loglik_missing_data is custom_loglik + + def test_missing_data_value_custom(self, dummy_model): + custom_missing = -123.0 + # Add a row with custom missing value + dummy_model.data.loc[len(dummy_model.data)] = [custom_missing, 1] + dummy_model._process_missing_data_and_deadline( + missing_data=custom_missing, + deadline=False, + loglik_missing_data=None, + ) + assert dummy_model.missing_data is True + assert dummy_model.missing_data_value == custom_missing + # After processing, custom missing values are replaced with -999.0 + assert (dummy_model.data.rt == -999.0).any() + + def test_deadline_column_added_once(self, dummy_model, basic_data): + # Add a deadline_col to the data + data = basic_data.assign(deadline_col=[2.0, 2.0, 2.0]) + dummy_model.data = data + # Add deadline_col to response already + dummy_model.response.append("deadline_col") + # Should raise ValueError due to -999.0 in rt when missing_data=False + with pytest.raises(ValueError, match="Missing data provided as False"): + dummy_model._process_missing_data_and_deadline( + missing_data=False, + deadline="deadline_col", + loglik_missing_data=None, + ) + + def test_missing_data_and_deadline_together(self, dummy_model_with_deadline): + # Should set both flags + dummy_model_with_deadline._process_missing_data_and_deadline( + missing_data=True, + deadline=True, + loglik_missing_data=None, + ) + assert dummy_model_with_deadline.missing_data is True + assert dummy_model_with_deadline.deadline is True + assert dummy_model_with_deadline.deadline_name == "deadline" + + def test_handle_missing_data_and_deadline_direct(self, dummy_model): + """ + Directly test the _handle_missing_data_and_deadline method for coverage. + """ + # Call with no arguments, as expected by the mixin stub + dummy_model._handle_missing_data_and_deadline() + + def test_set_missing_data_and_deadline_edge_case(self, dummy_model): + all_missing = pd.DataFrame({"rt": [-999.0]}) + with pytest.raises(ValueError, match="no valid data in your dataset"): + dummy_model._set_missing_data_and_deadline( + missing_data=True, + deadline=False, + data=all_missing, + ) diff --git a/tests/test_rl_utils.py b/tests/test_rl_utils.py new file mode 100644 index 00000000..554fabd1 --- /dev/null +++ b/tests/test_rl_utils.py @@ -0,0 +1,107 @@ +"""Tests for hssm.rl.utils — validate_balanced_panel.""" + +import numpy as np +import pandas as pd +import pytest + +from hssm.rl.utils import validate_balanced_panel + + +def _make_panel( + n_participants: int, n_trials: int, participant_col: str = "participant_id" +) -> pd.DataFrame: + """Return a perfectly balanced, contiguous panel DataFrame.""" + ids = np.repeat(range(n_participants), n_trials) + return pd.DataFrame( + {participant_col: ids, "rt": np.random.rand(n_participants * n_trials)} + ) + + +class TestValidateBalancedPanelHappyPath: + def test_returns_correct_shape(self) -> None: + """Returns (n_participants, n_trials) for a balanced panel.""" + df = _make_panel(5, 20) + n_p, n_t = validate_balanced_panel(df) + assert n_p == 5 + assert n_t == 20 + + def test_single_participant(self) -> None: + """Single-participant panel is trivially balanced.""" + df = _make_panel(1, 10) + n_p, n_t = validate_balanced_panel(df) + assert n_p == 1 + assert n_t == 10 + + def test_custom_participant_col(self) -> None: + """Works when participant column has a non-default name.""" + df = _make_panel(3, 8, participant_col="subj_id") + n_p, n_t = validate_balanced_panel(df, participant_col="subj_id") + assert n_p == 3 + assert n_t == 8 + + +class TestValidateBalancedPanelMissingColumn: + def test_missing_participant_col_raises(self) -> None: + """Raises ValueError when participant_col is absent from the DataFrame.""" + df = pd.DataFrame({"rt": [0.5, 0.6]}) + with pytest.raises(ValueError, match="not found in data"): + validate_balanced_panel(df) + + def test_wrong_participant_col_name_raises(self) -> None: + """Raises ValueError when a wrong column name is supplied.""" + df = _make_panel(2, 5) + with pytest.raises(ValueError, match="not found in data"): + validate_balanced_panel(df, participant_col="subject") + + +class TestValidateBalancedPanelNaN: + def test_nan_participant_id_raises(self) -> None: + """Raises ValueError when participant_col contains NaN.""" + df = _make_panel(3, 4) + df.loc[df.index[0], "participant_id"] = float("nan") + with pytest.raises(ValueError, match="NaN"): + validate_balanced_panel(df) + + +class TestValidateBalancedPanelUnbalanced: + def test_unbalanced_panel_raises(self) -> None: + """Raises ValueError when participants have different trial counts.""" + df = _make_panel(3, 5) + unbalanced = df.iloc[:-1].copy() # drop one row → participant 2 has 4 trials + with pytest.raises(ValueError, match="balanced panels"): + validate_balanced_panel(unbalanced) + + def test_one_participant_fewer_trials_raises(self) -> None: + """Raises ValueError when one participant has fewer trials than others.""" + df = _make_panel(3, 5) + # Drop last 2 rows of participant 2 so counts differ. + mask = ~( + (df["participant_id"] == 2) + & (df.index >= df[df["participant_id"] == 2].index[-2]) + ) + unbalanced = df[mask].copy() + with pytest.raises(ValueError, match="balanced panels"): + validate_balanced_panel(unbalanced) + + +class TestValidateBalancedPanelContiguity: + def test_interleaved_participants_raises(self) -> None: + """Raises ValueError when participants' rows are interleaved (not contiguous). + + The RL likelihood reshapes data as (n_participants, n_trials, ...) by row + position, so interleaved rows would silently corrupt trial sequences. + """ + # Build interleaved data: [0, 1, 2, 0, 1, 2, ...] (3 participants × 4 trials) + ids = np.tile([0, 1, 2], 4) + df = pd.DataFrame({"participant_id": ids, "rt": np.random.rand(12)}) + with pytest.raises(ValueError, match="contiguous"): + validate_balanced_panel(df) + + def test_sorted_data_passes(self) -> None: + """Sorting an interleaved panel by participant_id makes it valid.""" + ids = np.tile([0, 1, 2], 4) + df = pd.DataFrame({"participant_id": ids, "rt": np.random.rand(12)}) + df_sorted = df.sort_values("participant_id").reset_index(drop=True) + n_p, n_t = validate_balanced_panel(df_sorted) + assert n_p == 3 + assert n_t == 4 diff --git a/tests/test_rlssm.py b/tests/test_rlssm.py new file mode 100644 index 00000000..973061c7 --- /dev/null +++ b/tests/test_rlssm.py @@ -0,0 +1,329 @@ +"""Tests for the RLSSM class. + +Mirrors the structure of tests/test_hssm.py, covering initialisation, +config validation, param keys, balanced-panel enforcement, the no-lapse +variant, bambi / PyMC model construction, and a sampling smoke test. +""" + +from collections.abc import Generator +from pathlib import Path + +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytensor +import pytest + +import hssm +from hssm.rl import RLSSM, RLSSMConfig +from hssm.rl.likelihoods.two_armed_bandit import compute_v_subject_wise +from hssm.utils import annotate_function + +# --------------------------------------------------------------------------- +# Module-level annotated helpers (shared by all tests) +# --------------------------------------------------------------------------- + +# Annotate the RL learning function: maps +# (rl_alpha, scaler, response, feedback) -> v +_compute_v_annotated = annotate_function( + inputs=["rl_alpha", "scaler", "response", "feedback"], + outputs=["v"], +)(compute_v_subject_wise) + + +# Annotated SSM log-likelihood function (simplified for testing). +# It receives a 2-D lan_matrix whose columns correspond to +# [v, a, z, t, theta, rt, response] +# and returns per-trial log-probabilities of shape (n_total_trials,). +@annotate_function( + inputs=["v", "a", "z", "t", "theta", "rt", "response"], + outputs=["logp"], + computed={"v": _compute_v_annotated}, +) +def _dummy_ssm_logp(lan_matrix: jnp.ndarray) -> jnp.ndarray: + """Return per-trial log-probabilities (column-sum); structural tests only.""" + # Return 1D (N,) — PyTensor declares the Op output as pt.vector(), so + # gradients arrive as (N,). A (N,1) return causes a VJP shape mismatch. + return jnp.sum(lan_matrix, axis=1) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module", autouse=True) +def _set_floatx_float32() -> Generator[None, None, None]: + """Ensure float32 is used for this module's tests, then restore previous setting.""" + prev_floatx = pytensor.config.floatX + hssm.set_floatX("float32", update_jax=True) + try: + yield + finally: + hssm.set_floatX(prev_floatx, update_jax=True) + + +@pytest.fixture(scope="module") +def rldm_data() -> pd.DataFrame: + """Load the RLDM fixture dataset (balanced panel).""" + raw = np.load( + Path(__file__).parent / "fixtures" / "rldm_data.npy", allow_pickle=True + ).item() + return pd.DataFrame(raw["data"]) + + +@pytest.fixture(scope="module") +def rlssm_config() -> RLSSMConfig: + """Minimal but valid RLSSMConfig for the RLDM fixture dataset.""" + return RLSSMConfig( + model_name="rldm_test", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=["rl_alpha", "scaler", "a", "theta", "t", "z"], + params_default=[0.1, 1.0, 1.0, 0.0, 0.3, 0.5], + bounds={ + "rl_alpha": (0.0, 1.0), + "scaler": (0.0, 10.0), + "a": (0.1, 3.0), + "theta": (-0.1, 0.1), + "t": (0.001, 1.0), + "z": (0.1, 0.9), + }, + learning_process={"v": _compute_v_annotated}, + response=["rt", "response"], + choices=[0, 1], + extra_fields=["feedback"], + ssm_logp_func=_dummy_ssm_logp, + ) + + +# --------------------------------------------------------------------------- +# Initialisation & config-validation tests +# --------------------------------------------------------------------------- + + +def test_rlssm_init(rldm_data, rlssm_config) -> None: + """Basic RLSSM initialisation should succeed and return an RLSSM instance.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert isinstance(model, RLSSM) + assert model.model_config.model_name == "rldm_test" + + +def test_rlssm_panel_attrs(rldm_data, rlssm_config) -> None: + """n_participants and n_trials should match the fixture data structure.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + + n_participants = rldm_data["participant_id"].nunique() + n_trials = len(rldm_data) // n_participants + + assert model.n_participants == n_participants + assert model.n_trials == n_trials + + +def test_rlssm_params_keys(rldm_data, rlssm_config) -> None: + """model.params should contain exactly list_params + p_outlier.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + expected = set(rlssm_config.list_params) | {"p_outlier"} + assert set(model.params.keys()) == expected + + +def test_rlssm_unbalanced_raises(rldm_data, rlssm_config) -> None: + """Dropping one row should make the panel unbalanced → ValueError.""" + unbalanced = rldm_data.iloc[:-1].copy() + with pytest.raises(ValueError, match="balanced panels"): + RLSSM(data=unbalanced, model_config=rlssm_config) + + +def test_rlssm_nan_participant_id_raises(rldm_data, rlssm_config) -> None: + """NaN in participant_id column should raise ValueError before groupby silently drops rows.""" + nan_data = rldm_data.copy() + nan_data.loc[nan_data.index[0], "participant_id"] = float("nan") + with pytest.raises(ValueError, match="NaN"): + RLSSM(data=nan_data, model_config=rlssm_config) + + +def test_rlssm_missing_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: + """RLSSMConfig without ssm_logp_func should raise ValueError on init.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + # ssm_logp_func intentionally omitted → defaults to None + ) + with pytest.raises(ValueError, match="ssm_logp_func"): + RLSSM(data=rldm_data, model_config=bad_config) + + +def test_rlssm_unannotated_ssm_logp_func_raises(rldm_data, rlssm_config) -> None: + """A plain callable without @annotate_function attrs should raise ValueError.""" + bad_config = RLSSMConfig( + model_name="rldm_bad", + loglik_kind="approx_differentiable", + decision_process="angle", + decision_process_loglik_kind="approx_differentiable", + learning_process_kind="blackbox", + list_params=rlssm_config.list_params, + params_default=rlssm_config.params_default, + bounds=rlssm_config.bounds, + learning_process=rlssm_config.learning_process, + response=list(rlssm_config.response), + choices=list(rlssm_config.choices), + extra_fields=list(rlssm_config.extra_fields), + ssm_logp_func=lambda x: x, # callable but no .inputs/.outputs/.computed + ) + with pytest.raises(ValueError, match="annotate_function"): + RLSSM(data=rldm_data, model_config=bad_config) + + +def test_rlssm_missing_data_raises(rldm_data, rlssm_config) -> None: + """Passing missing_data!=False should raise NotImplementedError with 'missing_data' in msg.""" + with pytest.raises(NotImplementedError, match="missing_data"): + RLSSM(data=rldm_data, model_config=rlssm_config, missing_data=True) + + +def test_rlssm_deadline_raises(rldm_data, rlssm_config) -> None: + """Passing deadline!=False should raise NotImplementedError with 'deadline' in msg.""" + with pytest.raises(NotImplementedError, match="deadline"): + RLSSM(data=rldm_data, model_config=rlssm_config, deadline=True) + + +# --------------------------------------------------------------------------- +# Model-structure tests +# --------------------------------------------------------------------------- + + +def test_rlssm_params_is_trialwise_aligned(rldm_data, rlssm_config) -> None: + """params_is_trialwise must align with list_params (same length, p_outlier=False).""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model_config.list_params is not None + params_is_trialwise = [ + name != "p_outlier" for name in model.model_config.list_params + ] + assert len(params_is_trialwise) == len(model.model_config.list_params) + for name, is_tw in zip(model.model_config.list_params, params_is_trialwise): + if name == "p_outlier": + assert not is_tw, "p_outlier must be non-trialwise" + else: + assert is_tw, f"{name} must be trialwise" + + +def test_rlssm_get_prefix(rldm_data, rlssm_config) -> None: + """_get_prefix must use token-based matching, not substring search. + + - 'rl_alpha_Intercept' → 'rl_alpha' (underscore-containing RL param) + - 'p_outlier_log__' → 'p_outlier' (lapse param via token loop, not substring) + - 'a_Intercept' → 'a' (single-token standard param) + """ + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model._get_prefix("rl_alpha_Intercept") == "rl_alpha" + assert model._get_prefix("p_outlier_log__") == "p_outlier" + assert model._get_prefix("p_outlier") == "p_outlier" + assert model._get_prefix("a_Intercept") == "a" + # Fallback: not in params + assert model._get_prefix("unknown_param") == "unknown_param" + + +def test_rlssm_no_lapse(rldm_data, rlssm_config) -> None: + """Setting p_outlier=None should remove p_outlier from params.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config, p_outlier=None) + assert "p_outlier" not in model.params + + +def test_rlssm_model_built(rldm_data, rlssm_config) -> None: + """The bambi model should be built and the computed param 'v' absent from params.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.model is not None + # rl_alpha is a free (sampled) parameter + assert "rl_alpha" in model.params + # v is computed inside the Op; it must NOT appear as a free parameter + assert "v" not in model.params + + +def test_rlssm_extra_fields_are_copies(rldm_data, rlssm_config) -> None: + """extra_fields passed to make_distribution must be independent numpy copies. + + to_numpy(copy=True) should return a new buffer; if it returned a view, + in-place mutations of the DataFrame would silently corrupt the distribution. + """ + from unittest.mock import patch + + from hssm.distribution_utils import make_distribution as real_make_distribution + + model = RLSSM(data=rldm_data, model_config=rlssm_config) + captured: dict = {} + + def capturing_make_distribution(*args, **kwargs): + captured["extra_fields"] = kwargs.get("extra_fields") + return real_make_distribution(*args, **kwargs) + + with patch( + "hssm.rl.rlssm.make_distribution", side_effect=capturing_make_distribution + ): + model._make_model_distribution() + + assert captured.get("extra_fields") is not None + for field_name, arr in zip(rlssm_config.extra_fields, captured["extra_fields"]): + original = model.data[field_name].to_numpy() + assert not np.shares_memory(arr, original), ( + f"extra_fields['{field_name}'] shares memory with the DataFrame — " + "it is a view, not a copy" + ) + + +def test_rlssm_pymc_model(rldm_data, rlssm_config) -> None: + """pymc_model should be accessible after model construction.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + assert model.pymc_model is not None + + +# --------------------------------------------------------------------------- +# Slow sampling smoke test +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +def test_rlssm_sample_smoke(rldm_data, rlssm_config) -> None: + """Minimal sampling run should return an InferenceData object.""" + model = RLSSM(data=rldm_data, model_config=rlssm_config) + trace = model.sample( + draws=4, tune=50, chains=1, cores=1, sampler="numpyro", target_accept=0.9 + ) + assert trace is not None + + +def test_rlssm_pickle_round_trip( + rldm_data: pd.DataFrame, rlssm_config: RLSSMConfig +) -> None: + """cloudpickle round-trip must reconstruct an equivalent RLSSM. + + Verifies that __getstate__ / __setstate__ survive serialisation: + - The reconstructed object is a fresh RLSSM (not the same instance). + - n_participants and n_trials are preserved. + - list_params (including p_outlier) are preserved. + - model_config.model_name is preserved. + - model.model (bambi model) is rebuilt, confirming full re-initialisation. + """ + import cloudpickle + + model = RLSSM(data=rldm_data, model_config=rlssm_config) + blob = cloudpickle.dumps(model) + restored = cloudpickle.loads(blob) + + assert restored is not model + assert isinstance(restored, RLSSM) + assert restored.n_participants == model.n_participants + assert restored.n_trials == model.n_trials + assert restored.list_params == model.list_params + assert restored.model_config.model_name == model.model_config.model_name + assert restored.model is not None diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py index cd80d9d4..879ae3e7 100644 --- a/tests/test_rlssm_config.py +++ b/tests/test_rlssm_config.py @@ -1,8 +1,14 @@ import pytest import hssm -from hssm.config import Config, RLSSMConfig -from hssm.config import ModelConfig +from hssm.config import ( + DEFAULT_SSM_CHOICES, + DEFAULT_SSM_OBSERVED_DATA, + Config, + ModelConfig, +) +from hssm.rl import RLSSMConfig +from hssm.utils import annotate_function # Define constants for repeated data structures DEFAULT_RESPONSE = ("rt", "response") @@ -16,6 +22,13 @@ } +# Module-level annotated dummy used wherever from_rlssm_dict needs a valid +# ssm_logp_func but the test is not about ssm_logp_func itself. +@annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={}) +def _module_dummy_ssm_logp(x): + return x + + # Helper function to create a config dictionary def create_config_dict( model_name, @@ -28,7 +41,8 @@ def create_config_dict( learning_process={}, decision_process="ddm", decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", + ssm_logp_func=_module_dummy_ssm_logp, ): return dict( model_name=model_name, @@ -43,7 +57,8 @@ def create_config_dict( learning_process=learning_process, decision_process=decision_process, decision_process_loglik_kind=decision_process_loglik_kind, - learning_process_loglik_kind=learning_process_loglik_kind, + learning_process_kind=learning_process_kind, + ssm_logp_func=ssm_logp_func, data={}, ) @@ -51,17 +66,23 @@ def create_config_dict( # region fixtures and helpers @pytest.fixture def valid_rlssmconfig_kwargs(): + @annotate_function(inputs=["v", "rt", "response"], outputs=["logp"], computed={}) + def _dummy_ssm_logp_func(x): + return x + return dict( model_name="test_model", list_params=["alpha", "beta"], params_default=[0.5, 0.3], + bounds={"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, decision_process="ddm", response=["rt", "response"], choices=[0, 1], extra_fields=["feedback"], decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", learning_process={}, + ssm_logp_func=_dummy_ssm_logp_func, ) @@ -149,7 +170,7 @@ def test_from_rlssm_dict_cases( expected_choices, expected_learning_process, ): - config = RLSSMConfig.from_rlssm_dict(model_name, config_dict) + config = RLSSMConfig.from_rlssm_dict(config_dict) assert config.model_name == expected_model_name assert config.params_default == expected_params_default assert config.bounds == expected_bounds @@ -164,8 +185,9 @@ class TestRLSSMConfigValidation: [ ("response", None, "Please provide `response` columns"), ("list_params", None, "Please provide `list_params"), - ("choices", None, "Please provide `choices"), - ("decision_process", None, "Please specify a `decision_process"), + ("choices", None, "Please provide `choices`"), + ("decision_process", None, "Please specify a `decision_process`"), + ("ssm_logp_func", None, "Please provide `ssm_logp_func`"), ], ) def test_validate_missing_fields( @@ -184,7 +206,7 @@ def test_validate_missing_fields( "params_default", "decision_process", "decision_process_loglik_kind", - "learning_process_loglik_kind", + "learning_process_kind", "learning_process", ], ) @@ -201,17 +223,12 @@ def test_validate_success(self, valid_rlssmconfig_kwargs): config = RLSSMConfig(**valid_rlssmconfig_kwargs) config.validate() - def test_validate_params_default_mismatch(self): + def test_validate_params_default_mismatch(self, valid_rlssmconfig_kwargs): config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, + **{ + **valid_rlssmconfig_kwargs, + "params_default": [0.5], # length 1, but list_params has 2 entries + } ) with pytest.raises( ValueError, @@ -219,21 +236,84 @@ def test_validate_params_default_mismatch(self): ): config.validate() + def test_validate_ssm_logp_func_not_callable(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + config.ssm_logp_func = "not_a_callable" + with pytest.raises(ValueError, match="must be a callable"): + config.validate() + + def test_validate_ssm_logp_func_missing_annotations(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + # Replace with a plain callable that lacks @annotate_function attributes + config.ssm_logp_func = lambda x: x + with pytest.raises( + ValueError, match="must be decorated with `@annotate_function`" + ): + config.validate() + + def test_validate_ssm_logp_func_computed_not_callable( + self, valid_rlssmconfig_kwargs + ): + """`computed` exists but contains non-callable values -> error.""" + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + # Inject a computed mapping with a non-callable value to trigger the + # specific validation branch. + config.ssm_logp_func.computed = {"x": "not_callable"} + with pytest.raises( + ValueError, + match=r"`ssm_logp_func.computed` must be a dictionary with callable values\.", + ): + config.validate() + + def test_validate_missing_bounds_for_param(self, valid_rlssmconfig_kwargs): + """validate() should raise early when a param has no bounds entry.""" + kwargs = {**valid_rlssmconfig_kwargs, "bounds": {}} # strip all bounds + config = RLSSMConfig(**kwargs) + with pytest.raises(ValueError, match="Missing bounds for parameter"): + config.validate() + + def test_from_defaults_raises(self): + """RLSSMConfig.from_defaults() must raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="from_defaults"): + RLSSMConfig.from_defaults("ddm", None) + class TestRLSSMConfigDefaults: @pytest.mark.parametrize( "list_params, params_default, bounds, param, expected_default, expected_bounds", [ + # params_default stores initialisation values, NOT priors. + # get_defaults always returns None for the prior so that + # prior_settings="safe" can assign priors from bounds. + # + # Case 1: queried param is present in bounds → bound returned. ( ["alpha", "beta", "gamma"], [0.5, 0.3, 0.2], - {"beta": (0.0, 1.0)}, + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0), "gamma": (0.0, 1.0)}, "beta", - 0.3, + None, + (0.0, 1.0), + ), + # Case 2: queried param is NOT in list_params and NOT in bounds + # (e.g. a typo or an extra lookup) → both None. + ( + ["alpha", "beta"], + [0.5, 0.3], + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, + "gamma", + None, + None, + ), + # Case 3: params_default may be empty; param in bounds → bound returned. + ( + ["alpha", "beta"], + [], + {"alpha": (0.0, 1.0), "beta": (0.0, 1.0)}, + "alpha", + None, (0.0, 1.0), ), - (["alpha", "beta"], [0.5, 0.3], {"alpha": (0.0, 1.0)}, "gamma", None, None), - (["alpha", "beta"], [], {"alpha": (0.0, 1.0)}, "alpha", None, (0.0, 1.0)), ], ) def test_get_defaults_cases( @@ -254,7 +334,7 @@ def test_get_defaults_cases( response=["rt", "response"], choices=[0, 1], decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", learning_process={}, ) default_val, bounds_val = config.get_defaults(param) @@ -262,145 +342,6 @@ def test_get_defaults_cases( assert bounds_val == expected_bounds -class TestRLSSMConfigConversion: - @pytest.mark.parametrize( - "list_params, params_default, backend, expected_backend, expected_default_priors, raises", - [ - ( - ["alpha", "beta", "v", "a"], - [0.5, 0.3, 1.0, 1.5], - "jax", - "jax", - {"alpha": 0.5, "beta": 0.3, "v": 1.0, "a": 1.5}, - None, - ), - (["alpha"], [0.5], None, "jax", {"alpha": 0.5}, None), - (["alpha", "beta"], [], None, "jax", {}, None), - (["alpha", "beta", "gamma"], [0.5, 0.3], None, None, None, ValueError), - ], - ) - def test_to_config_cases( - self, - list_params, - params_default, - backend, - expected_backend, - expected_default_priors, - raises, - ): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=list_params, - params_default=params_default, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - backend=backend, - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - if raises: - with pytest.raises(raises): - rlssm_config.to_config() - else: - config = rlssm_config.to_config() - assert config.backend == expected_backend - assert config.default_priors == expected_default_priors - - def test_to_config(self): - rlssm_config = RLSSMConfig( - model_name="rlwm", - description="RLWM model", - list_params=["alpha", "beta", "v", "a"], - params_default=[0.5, 0.3, 1.0, 1.5], - bounds={ - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - }, - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - extra_fields=["feedback"], - backend="jax", - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = rlssm_config.to_config() - assert isinstance(config, Config) - assert config.model_name == "rlwm" - assert config.description == "RLWM model" - assert config.list_params == ["alpha", "beta", "v", "a"] - assert config.response == ["rt", "response"] - assert config.choices == [0, 1] - assert config.extra_fields == ["feedback"] - assert config.backend == "jax" - assert config.loglik_kind == "approx_differentiable" - assert config.bounds == { - "alpha": (0.0, 1.0), - "beta": (0.0, 1.0), - "v": (-3.0, 3.0), - "a": (0.3, 2.5), - } - assert config.default_priors == { - "alpha": 0.5, - "beta": 0.3, - "v": 1.0, - "a": 1.5, - } - - def test_to_config_defaults_backend(self): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha"], - params_default=[0.5], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = rlssm_config.to_config() - assert config.backend == "jax" - - def test_to_config_no_defaults(self): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta"], - params_default=[], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - config = rlssm_config.to_config() - assert config.default_priors == {} - - def test_to_config_mismatched_defaults_length(self): - rlssm_config = RLSSMConfig( - model_name="test_model", - list_params=["alpha", "beta", "gamma"], - params_default=[0.5, 0.3], - decision_process="ddm", - response=["rt", "response"], - choices=[0, 1], - decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", - learning_process={}, - ) - with pytest.raises( - ValueError, - match=r"params_default length \(2\) doesn't match list_params length \(3\)", - ): - rlssm_config.to_config() - - class TestRLSSMConfigLearningProcess: def test_learning_process(self): config = RLSSMConfig( @@ -412,7 +353,7 @@ def test_learning_process(self): choices=[0, 1], learning_process={"v": v_func, "a": a_func}, decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", ) assert "v" in config.learning_process assert "a" in config.learning_process @@ -429,7 +370,7 @@ def test_immutable_defaults(self): choices=[0, 1], learning_process={"v": v_func}, decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", ) config2 = RLSSMConfig( model_name="model2", @@ -440,7 +381,7 @@ def test_immutable_defaults(self): choices=[0, 1], learning_process={"a": a_func}, decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", ) config1.learning_process["v"] = "function1" assert "v" not in config2.learning_process @@ -457,18 +398,40 @@ def test_from_rlssm_dict_missing_required(self): "params_default": [0.0], "decision_process": "ddm", "learning_process": {}, - "learning_process_loglik_kind": "blackbox", + "learning_process_kind": "blackbox", "response": ["rt", "response"], "choices": [0, 1], "description": "desc", "bounds": {}, "data": {}, "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, } with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): - RLSSMConfig.from_rlssm_dict("test_model", config_dict) + RLSSMConfig.from_rlssm_dict(config_dict) + + def test_from_rlssm_dict_missing_ssm_logp_func(self): + # Should raise ValueError at construction time if ssm_logp_func is missing + config_dict = { + "model_name": "test_model", + "name": "test_model", + "list_params": ["alpha"], + "params_default": [0.0], + "decision_process": "ddm", + "learning_process": {}, + "learning_process_kind": "blackbox", + "decision_process_loglik_kind": "analytical", + "response": ["rt", "response"], + "choices": [0, 1], + "description": "desc", + "bounds": {}, + "data": {}, + "extra_fields": [], + } + with pytest.raises(ValueError, match="ssm_logp_func must be provided"): + RLSSMConfig.from_rlssm_dict(config_dict) def test_missing_decision_process_loglik_kind(self): with pytest.raises(TypeError): @@ -488,15 +451,16 @@ def test_missing_decision_process_loglik_kind(self): "data": {}, "decision_process": "ddm", "learning_process": {}, - "learning_process_loglik_kind": "blackbox", + "learning_process_kind": "blackbox", "response": ["rt", "response"], "choices": [0, 1], "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, } with pytest.raises( ValueError, match="decision_process_loglik_kind must be provided" ): - RLSSMConfig.from_rlssm_dict("test_model", config_dict) + RLSSMConfig.from_rlssm_dict(config_dict) def test_with_modelconfig_decision_process(self): decision_config = ModelConfig( @@ -512,6 +476,57 @@ def test_with_modelconfig_decision_process(self): response=["rt", "response"], choices=[0, 1], decision_process_loglik_kind="analytical", - learning_process_loglik_kind="blackbox", + learning_process_kind="blackbox", learning_process={}, ) + + +class TestRLSSMConfigDefaultWarnings: + """Warnings are emitted when 'response' or 'choices' are missing from config_dict.""" + + @pytest.fixture + def _base_config_dict(self): + return { + "model_name": "test_model", + "list_params": ["alpha"], + "params_default": [0.0], + "bounds": {"alpha": (0.0, 1.0)}, + "decision_process": "ddm", + "learning_process": {}, + "learning_process_kind": "blackbox", + "decision_process_loglik_kind": "analytical", + "extra_fields": [], + "ssm_logp_func": _module_dummy_ssm_logp, + } + + def test_warns_when_response_missing(self, _base_config_dict, caplog): + _base_config_dict["choices"] = (0, 1) + # 'response' deliberately omitted + with caplog.at_level("WARNING", logger="hssm"): + config = RLSSMConfig.from_rlssm_dict(_base_config_dict) + assert any("'response' not specified" in m for m in caplog.messages) + assert config.response == list(DEFAULT_SSM_OBSERVED_DATA) + + def test_warns_when_choices_missing(self, _base_config_dict, caplog): + _base_config_dict["response"] = ["rt", "response"] + # 'choices' deliberately omitted + with caplog.at_level("WARNING", logger="hssm"): + config = RLSSMConfig.from_rlssm_dict(_base_config_dict) + assert any("'choices' not specified" in m for m in caplog.messages) + assert config.choices == DEFAULT_SSM_CHOICES + + def test_warns_when_both_missing(self, _base_config_dict, caplog): + # Both 'response' and 'choices' omitted + with caplog.at_level("WARNING", logger="hssm"): + config = RLSSMConfig.from_rlssm_dict(_base_config_dict) + assert any("'response' not specified" in m for m in caplog.messages) + assert any("'choices' not specified" in m for m in caplog.messages) + assert config.response == list(DEFAULT_SSM_OBSERVED_DATA) + assert config.choices == DEFAULT_SSM_CHOICES + + def test_no_warning_when_both_provided(self, _base_config_dict, caplog): + _base_config_dict["response"] = ["rt", "response"] + _base_config_dict["choices"] = (0, 1) + with caplog.at_level("WARNING", logger="hssm"): + RLSSMConfig.from_rlssm_dict(_base_config_dict) + assert not any("not specified" in m for m in caplog.messages) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index b7b48d57..60244cf7 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -10,7 +10,9 @@ def compare_hssm_class_attributes(model_a, model_b): b = np.array([type(v) for k, v in model_b._init_args.items()]) assert (a == b).all(), "Init arg types not the same" assert (model_a.data).equals(model_b.data), "Data not the same" - assert model_a.model_name == model_b.model_name, "Model name not the same" + assert model_a.model_config.model_name == model_b.model_config.model_name, ( + "Model name not the same" + ) assert model_a.pymc_model._repr_latex_() == model_b.pymc_model._repr_latex_(), ( "Latex representation of model not the same" ) From 09bd6bc387611080cf9dbf779d8987da1c1e190b Mon Sep 17 00:00:00 2001 From: AndrewZhang599 Date: Tue, 5 May 2026 16:44:13 -0400 Subject: [PATCH 4/4] updates to plan after RLSSM changes --- addm_andrew_dev /addm_hssm.md | 267 ++++++++++++++++++++++++++-------- 1 file changed, 207 insertions(+), 60 deletions(-) diff --git a/addm_andrew_dev /addm_hssm.md b/addm_andrew_dev /addm_hssm.md index 405a2a84..f4eee15f 100644 --- a/addm_andrew_dev /addm_hssm.md +++ b/addm_andrew_dev /addm_hssm.md @@ -15,11 +15,20 @@ model.sample() where `addm_trial_df` contains standard columns (`rt`, `response`) plus **aDDM-specific per-trial arrays** (item values, fixation onsets, fixation counts, first-fixation flag). The aDDM needs per-trial covariates that are *not* themselves sampled parameters — exactly the pattern RLSSM already solves in HSSM. We therefore follow the RLSSM design so that aDDM lives alongside it rather than carving a new architectural lane. -The intended outcome is a working `model="addm"` path inside HSSM that (a) validates aDDM-specific trial data, (b) composes the vendored JAX FPT likelihood with sampled parameters `{eta, kappa, sigma, a, b, x0, t}` (non-decision time optional), (c) exposes the standard HSSM hierarchical regression and sampling machinery, and (d) ships with a tutorial notebook and unit tests. +The intended outcome is a working `aDDM(...)` class (and matching `aDDMConfig`) inside HSSM that (a) validates aDDM-specific trial data, (b) composes the vendored JAX FPT likelihood with sampled parameters `{eta, kappa, sigma, a, b, x0, t}` (non-decision time optional), (c) exposes the standard HSSM hierarchical regression and sampling machinery, and (d) ships with a tutorial notebook and unit tests. -## Design choice: config pattern, not subclass +## Design choice: subclass pattern (matching the new RLSSM architecture) -The plan creates an **`aDDMConfig`** dataclass plus a small submodule — no new `aDDM(HSSM)` subclass is introduced. (If the user prefers an explicit subclass for API discoverability, a thin `class aDDM(HSSM)` wrapper can be added on top.) +> **Architecture update (post-rebase, commit `b4228d1b` — "HSSM base + RLSSM classes" #893).** The HSSM base was refactored: `HSSMBase` is now an abstract base ([src/hssm/base.py](data/azhang/HSSM/src/hssm/base.py)) and both `HSSM` ([src/hssm/hssm.py](data/azhang/HSSM/src/hssm/hssm.py)) and `RLSSM` ([src/hssm/rl/rlssm.py](data/azhang/HSSM/src/hssm/rl/rlssm.py)) are concrete subclasses. `RLSSMConfig` was moved out of `hssm.config` into [src/hssm/rl/config.py](data/azhang/HSSM/src/hssm/rl/config.py) and **no longer exposes a `to_config()` method** — `HSSMBase` accepts any `BaseModelConfig` directly, and the family-specific subclass is responsible for building its own likelihood `Op` and stamping it onto the config via `dataclasses.replace(...)` before calling `super().__init__()`. The new `RLSSMConfig` also adds an `ssm_logp_func` field (an `@annotate_function`-decorated JAX function) and renamed `learning_process_loglik_kind` → `learning_process_kind`. + +Given this new architecture, the plan creates **both**: + +1. An `aDDM(HSSMBase)` concrete subclass — a peer of `HSSM` and `RLSSM`, exported from `hssm/__init__.py` as `hssm.aDDM`. +2. An `aDDMConfig(BaseModelConfig)` dataclass living in `src/hssm/addm/config.py`, peer of `RLSSMConfig` in `src/hssm/rl/config.py`. + +The `aDDM` subclass handles (i) validating aDDM-specific data shape, (ii) building the differentiable PyTensor `Op` from the vendored JAX likelihood, (iii) stamping that `Op` onto the config (via `replace(loglik=op, backend="jax")`), and (iv) overriding `_make_model_distribution` to bypass the standard `loglik_kind` dispatching — exactly mirroring `RLSSM.__init__` and `RLSSM._make_model_distribution`. + +(There is **no longer** a `to_config()` round-trip; that method existed only in the pre-rebase architecture and has been removed from `RLSSMConfig` as well.) --- @@ -57,61 +66,175 @@ The plan creates an **`aDDMConfig`** dataclass plus a small submodule — no new ### Step 2 — Define the aDDM submodule layout -Create a new package under `src/hssm/addm/`, mirroring `src/hssm/rl/`: +Create a new package under `src/hssm/addm/`, mirroring the **post-rebase** `src/hssm/rl/` layout (`config.py` and `rlssm.py` at the package root, plus a `likelihoods/` subpackage): ``` src/hssm/addm/ - __init__.py + __init__.py # exports aDDM, aDDMConfig + config.py # aDDMConfig — peer of hssm.rl.config.RLSSMConfig + addm.py # aDDM(HSSMBase) — peer of hssm.rl.rlssm.RLSSM + attention_process.py # pluggable fixation/attention models (default: standard_alternating) + utils.py # validate_addm_panel(...) — peer of hssm.rl.utils.validate_balanced_panel likelihoods/ __init__.py - builder.py # make_addm_logp_func / make_addm_logp_op - addm_jax.py # thin wrapper that imports from .jax and applies the attention process - jax/ # vendored from efficient_fpt_jax (Step 1) + builder.py # make_addm_logp_func / make_addm_logp_op + addm_jax.py # thin wrapper that imports from .jax and applies the attention process + jax/ # vendored from efficient_fpt_jax (Step 1) __init__.py multi_stage.py single_stage.py utils.py - attention_process.py # pluggable fixation/attention models ``` -**Rationale:** aDDM is conceptually a two-stage model (attention process → SSM likelihood) just like RLSSM (learning process → SSM likelihood). Reusing the folder layout makes the parallel obvious to future maintainers. The vendored JAX code lives in its own `jax/` subdirectory so it stays clearly identifiable as upstream-derived, isolated from HSSM-original code in `builder.py` and `addm_jax.py`. +**Rationale:** aDDM is conceptually a two-stage model (attention process → SSM likelihood) just like RLSSM (learning process → SSM likelihood). After the rebase, RLSSM is a real `HSSMBase` subclass with its own subpackage; the aDDM subpackage now mirrors that exactly — same files (`config.py`, `.py`, `utils.py`, `likelihoods/builder.py`), same conventions. The vendored JAX code lives in its own `jax/` subdirectory so it stays clearly identifiable as upstream-derived, isolated from HSSM-original code in `builder.py` and `addm_jax.py`. + +### Step 3 — Add `aDDMConfig` dataclass in `src/hssm/addm/config.py` -### Step 3 — Add `aDDMConfig` dataclass in `config.py` +**Critical file:** [src/hssm/addm/config.py](data/azhang/HSSM/src/hssm/addm/config.py) (new file) — peer of [src/hssm/rl/config.py](data/azhang/HSSM/src/hssm/rl/config.py). -**Critical file:** [src/hssm/config.py](data/azhang/HSSM/src/hssm/config.py) — add a new dataclass beneath `RLSSMConfig` (around line 457). +> **Note:** `aDDMConfig` is **not** added to `src/hssm/config.py`. After the rebase, family-specific configs live in their own subpackages (`hssm.rl.config.RLSSMConfig`, `hssm.addm.config.aDDMConfig`), with only `BaseModelConfig`, `Config`, and `ModelConfig` remaining in `hssm.config`. ```python +from dataclasses import dataclass, field +from typing import Any, Callable +from ..config import BaseModelConfig, DEFAULT_SSM_OBSERVED_DATA, DEFAULT_SSM_CHOICES + @dataclass class aDDMConfig(BaseModelConfig): """Config for the attentional DDM.""" - model_name: str = "addm" - list_params: list[str] = field( - default_factory=lambda: ["eta", "kappa", "sigma", "a", "b", "x0"] - ) - params_default: list[float] = field( - default_factory=lambda: [0.3, 1.0, 1.0, 2.0, 0.0, 0.0] + + # Required (kw_only) fields — pattern borrowed from RLSSMConfig + params_default: list[float] = field(kw_only=True) + attention_process: str | Callable | dict[str, Any] = field( + kw_only=True, default="standard_alternating" ) - response: list[str] = field(default_factory=lambda: ["rt", "response"]) - choices: tuple[int, ...] = (-1, 1) - # trial-level covariates consumed by the attention process: + + # aDDM-specific extra-field column names (defaultable) + # These are *data* columns, not sampled parameters. extra_fields: list[str] | None = field( default_factory=lambda: ["r1", "r2", "sacc_array", "d", "flag"] ) - bounds: dict[str, tuple[float, float]] = field(default_factory=dict) - loglik_kind: str = "approx_differentiable" - attention_process: str | Callable = "standard_alternating" - description: str | None = "Attentional Drift Diffusion Model" - def to_config(self) -> Config: ... + def __post_init__(self): + if self.loglik_kind is None: + self.loglik_kind = "approx_differentiable" + + @classmethod + def from_defaults(cls, model_name, loglik_kind): + raise NotImplementedError( + "aDDMConfig does not support from_defaults(). " + "Use the aDDM(...) constructor directly, or pass an aDDMConfig " + "instance built explicitly." + ) + + def validate(self) -> None: + # Mirror RLSSMConfig.validate: required fields, params_default vs + # list_params length parity, every list_params entry has bounds, etc. + ... + + def get_defaults(self, param): + return None, self.bounds.get(param) +``` + +**Key design decisions (post-rebase):** + +- **No `to_config()` method.** The new architecture has `HSSMBase` accept any `BaseModelConfig`; family-specific subclasses build the `loglik` `Op` themselves and stamp it onto the config via `dataclasses.replace(...)`. `RLSSMConfig` no longer has `to_config()` and neither will `aDDMConfig`. +- **No `from_defaults` registration.** Like `RLSSMConfig`, `aDDMConfig` raises `NotImplementedError` from `from_defaults`. Users construct it explicitly (or via a `from_addm_dict` classmethod, optional). Therefore **`aDDM` is *not* registered through the `default_model_config` / `register_model` pipeline** that `HSSM(model="ddm", ...)` uses — instead, users instantiate `hssm.aDDM(...)` directly. +- `extra_fields` defaults to the five aDDM-specific columns the JAX likelihood needs; these flow through the existing extra-fields machinery (data validator → `Op` `*args`) the same way they do for RLSSM. +- `attention_process` is a pluggable hook (default `"standard_alternating"`) that maps `(r1, r2, flag, eta, kappa) → mu_array_padded` per trial. This mirrors `RLSSMConfig.learning_process` semantically (declarative documentation; the actual callable is resolved by the builder). +- `list_params` covers the sampled parameters: `["eta", "kappa", "sigma", "a", "b", "x0"]`. Non-decision time `t` is deliberately omitted initially. +- `bounds` is required (every `list_params` entry must have an entry), matching `RLSSMConfig.validate`'s post-rebase strictness. + +**Reuses:** `BaseModelConfig` ([src/hssm/config.py:48](data/azhang/HSSM/src/hssm/config.py#L48)), defaults `DEFAULT_SSM_OBSERVED_DATA` / `DEFAULT_SSM_CHOICES` ([src/hssm/config.py:24-26](data/azhang/HSSM/src/hssm/config.py#L24-L26)). + +### Step 3b — Add `aDDM` subclass in `src/hssm/addm/addm.py` + +**Critical file:** [src/hssm/addm/addm.py](data/azhang/HSSM/src/hssm/addm/addm.py) (new file) — peer of [src/hssm/rl/rlssm.py](data/azhang/HSSM/src/hssm/rl/rlssm.py). + +This step is **new in the post-rebase plan.** Mirror `RLSSM` (264 lines): + +```python +from dataclasses import replace +from typing import Any, Callable, Literal, cast +import bambi as bmb +import pandas as pd +import pymc as pm + +from hssm.distribution_utils import make_distribution +from hssm.defaults import INITVAL_JITTER_SETTINGS +from ..base import HSSMBase +from .config import aDDMConfig +from .likelihoods.builder import make_addm_logp_op +from .utils import validate_addm_panel + + +class aDDM(HSSMBase): + def __init__( + self, + data: pd.DataFrame, + model_config: aDDMConfig, + include: list | None = None, + p_outlier: float | dict | bmb.Prior | None = 0.05, + lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0), + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = "safe", + extra_namespace: dict | None = None, + missing_data: bool | float = False, + deadline: bool | str = False, + loglik_missing_data: Callable | None = None, + process_initvals: bool = True, + initval_jitter: float = INITVAL_JITTER_SETTINGS["jitter_epsilon"], + **kwargs, + ): + self._init_args = self._store_init_args(locals(), kwargs) + model_config.validate() + + # Same row-order argument as RLSSM: per-trial sacc_array shape and the + # JAX vmap rely on strict 1:1 row→trial correspondence; reordering + # missing/deadline rows would break the alignment. + if missing_data is not False or deadline is not False: + raise NotImplementedError( + "aDDM does not support `missing_data` or `deadline` handling..." + ) + + validate_addm_panel(data, model_config.extra_fields) + + loglik_op = make_addm_logp_op( + attention_process=model_config.attention_process, + data_cols=list(model_config.response), + list_params=list(model_config.list_params), + extra_fields=list(model_config.extra_fields or []), + ) + + # Stamp the Op + backend onto a fresh config copy (do NOT mutate the + # caller's config). Identical pattern to RLSSM.__init__. + model_config = replace(model_config, loglik=loglik_op, backend="jax") + + super().__init__( + data=data, model_config=model_config, include=include, + p_outlier=p_outlier, lapse=lapse, + link_settings=link_settings, prior_settings=prior_settings, + extra_namespace=extra_namespace, + missing_data=missing_data, deadline=deadline, + loglik_missing_data=loglik_missing_data, + process_initvals=process_initvals, initval_jitter=initval_jitter, + **kwargs, + ) + + def _make_model_distribution(self) -> type[pm.Distribution]: + # Mirror RLSSM._make_model_distribution: bypass loglik_kind dispatching + # and feed the pre-built Op directly into make_distribution. + ... ``` **Key design decisions:** -- `extra_fields` defaults to the five aDDM-specific columns that the JAX likelihood needs. These are **not** sampled parameters — they come from the data. -- `attention_process` is a pluggable hook (default `"standard_alternating"`) that maps `(r1, r2, flag, eta, kappa) → mu_array_padded` per trial. This mirrors `RLSSMConfig.learning_process`. -- `list_params` covers the sampled parameters. Non-decision time `t` is deliberately omitted initially; it can be added later via shifted RTs. -- `.to_config()` builds a standard HSSM `Config` pointing at the new likelihood op from Step 4, so downstream `HSSM.__init__` behavior is unchanged. -**Reuses:** `BaseModelConfig` (config.py:48), `Config.from_defaults` registration flow (config.py:96–145), `register_model` (register.py:16–60). +- **`aDDM` is a peer of `HSSM` and `RLSSM`**, all three inheriting from `HSSMBase`. It is exported from `hssm/__init__.py` as `hssm.aDDM`. +- **`_make_model_distribution` is overridden** (same as `RLSSM`) because the aDDM `Op` already encapsulates the attention process + per-trial vmap; the standard `make_likelihood_callable` dispatch on `loglik_kind` should be bypassed. +- **No `participant_col`-style panel reshape** — unlike RLSSM (which reshapes rows into `(n_participants, n_trials, ...)` because the RL learning rule is per-subject), aDDM's likelihood is per-trial. The vmap inside the JAX `Op` handles the trial dimension; participants flow through bambi/HSSM hierarchical regression as usual. +- **`missing_data` / `deadline` are rejected up front**, same as RLSSM, because rearranging rows would break the strict trial→`sacc_array`-row correspondence. + +**Reuses:** `HSSMBase` ([src/hssm/base.py:92](data/azhang/HSSM/src/hssm/base.py#L92)), `make_distribution` ([src/hssm/distribution_utils](data/azhang/HSSM/src/hssm/distribution_utils)), `INITVAL_JITTER_SETTINGS` (defaults.py). ### Step 4 — Build the likelihood op in `addm/likelihoods/builder.py` @@ -149,15 +272,20 @@ mu2 = kappa * (eta * r1 - r2) This reproduces the logic in [efficient-fpt addm.py:_build_mu_data_padded](data/azhang/efficient-fpt/src/efficient_fpt/addm.py) but in JAX for autodiff. Expose it via a registry so future variants (e.g., bias, drift offsets) can be registered by name — the same way `RLSSMConfig.learning_process` accepts either a string or dict. -### Step 6 — Register `"addm"` as a built-in model +### Step 6 — Export `aDDM` and `aDDMConfig` from the top-level package **Critical files:** -- [src/hssm/modelconfig/](data/azhang/HSSM/src/hssm/modelconfig/) — add `addm_config.py` in the same style as the existing per-model configs (e.g., `ddm_config.py`). -- [src/hssm/defaults.py](data/azhang/HSSM/src/hssm/defaults.py) — register `"addm"` in the default model list so `hssm.HSSM(model="addm", ...)` works out of the box. +- `src/hssm/addm/__init__.py` (new) — re-export `aDDM` and `aDDMConfig`, mirroring [src/hssm/rl/__init__.py](data/azhang/HSSM/src/hssm/rl/__init__.py). +- [src/hssm/__init__.py](data/azhang/HSSM/src/hssm/__init__.py) — add `from .addm import aDDM` (line 22 already has `from .rl import RLSSM`) and add `"aDDM"` to `__all__`. -`addm_config.py` returns a dict with `response`, `list_params`, `choices`, `description`, and a `likelihoods` sub-dict keyed `"approx_differentiable"` whose `loglik` points to `make_addm_logp_op(...)` from Step 4 and `extra_fields=["r1","r2","sacc_array","d","flag"]`. +> **Architecture update:** `aDDM` follows the **`RLSSM` registration pattern, not the `HSSM(model=...)` pattern**. Like `RLSSM`, it is *not* added to `default_model_config` and *not* registered through `register_model`. Users instantiate it directly: +> ```python +> import hssm +> model = hssm.aDDM(data=addm_trial_df, model_config=cfg, ...) +> ``` +> No `src/hssm/modelconfig/addm_config.py` is created — that directory holds dicts consumed by `register_model` for the analytical/ONNX-based `HSSM(model="ddm", ...)` flow, which doesn't apply to subclass-based families like RLSSM and aDDM. -**Reuses:** `register_model` (register.py:16–60) — already handles the registration flow; we just need to pass the right dict. +**Reuses:** `hssm/__init__.py` re-export pattern (see existing `RLSSM` import at line 22). ### Step 7 — Data validation for aDDM-specific columns @@ -174,12 +302,21 @@ Minimally invasive: put the hook in `_post_check_data_sanity` and no-op for non- ### Step 8 — Tests -New file: `tests/test_addm_config.py`, patterned after [tests/test_rlssm_config.py](data/azhang/HSSM/tests/test_rlssm_config.py): +Create a `tests/addm/` directory mirroring [tests/rl/](data/azhang/HSSM/tests/rl/): + +- `tests/addm/test_addm_config.py` — patterned after [tests/test_rlssm_config.py](data/azhang/HSSM/tests/test_rlssm_config.py). +- `tests/addm/test_addm.py` — patterned after [tests/test_rlssm.py](data/azhang/HSSM/tests/test_rlssm.py). +- `tests/addm/test_addm_builder_output_shape.py` — patterned after [tests/test_rl_builder_output_shape.py](data/azhang/HSSM/tests/test_rl_builder_output_shape.py). +- `tests/addm/test_addm_likelihood.py` — patterned after [tests/test_rldm_likelihood.py](data/azhang/HSSM/tests/test_rldm_likelihood.py). + +Test classes: + +1. `TestaDDMConfigCreation` — build `aDDMConfig`, assert defaults, assert `validate()` raises on missing/inconsistent fields. +2. `TestaDDMConfigFromDefaults` — assert `aDDMConfig.from_defaults(...)` raises `NotImplementedError` (mirrors RLSSMConfig). +3. `TestaDDMLikelihood` — tiny synthetic dataset (10 trials), confirm `logp` is finite, gradient w.r.t. each parameter is finite, matches a direct call to the vendored `get_addm_fptd_jax_fast`. +4. `TestaDDMEndToEnd` — 200-trial synthetic dataset, `hssm.aDDM(data=..., model_config=...)` builds, a single MCMC draw succeeds (smoke test, `draws=5, tune=5`). -1. `TestaDDMConfigCreation` — build `aDDMConfig`, assert defaults. -2. `TestaDDMConfigConversion` — `.to_config()` round-trip. -3. `TestaDDMLikelihood` — tiny synthetic dataset (10 trials), confirm `logp` is finite, gradient w.r.t. each parameter is finite, matches a direct call to `get_addm_fptd_jax_fast`. -4. `TestaDDMEndToEnd` — 200-trial synthetic dataset, `hssm.HSSM(model="addm", ...)` builds, a single MCMC draw succeeds (smoke test, `draws=5, tune=5`). +> Removed: the obsolete `TestaDDMConfigConversion` (`.to_config()` round-trip) — there is no `to_config()` method in the post-rebase architecture. **Reuse:** test fixtures from `tests/conftest.py`. @@ -187,7 +324,7 @@ New file: `tests/test_addm_config.py`, patterned after [tests/test_rlssm_config. Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tutorials/rlssm_tutorial.ipynb`: - Load/simulate a small aDDM dataset (reuse `simulate_addm` from efficient-fpt example6). -- Build the HSSM model with `model="addm"`. +- Build the model with `hssm.aDDM(data=..., model_config=aDDMConfig(...))` — **not** `hssm.HSSM(model="addm", ...)`, since aDDM follows the RLSSM subclass pattern. - Add a hierarchical regression on `eta` (e.g., by participant) to showcase why using HSSM buys more than raw efficient-fpt. - Run `model.sample()` and plot posteriors via `arviz`. @@ -200,14 +337,19 @@ Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tut ## Files to be created -**HSSM-original code:** -- `src/hssm/addm/__init__.py` -- `src/hssm/addm/attention_process.py` +**HSSM-original code (mirrors `src/hssm/rl/` layout post-rebase):** +- `src/hssm/addm/__init__.py` — re-exports `aDDM`, `aDDMConfig` +- `src/hssm/addm/config.py` — `aDDMConfig` dataclass *(peer of `hssm/rl/config.py`)* +- `src/hssm/addm/addm.py` — `aDDM(HSSMBase)` class *(peer of `hssm/rl/rlssm.py`)* +- `src/hssm/addm/utils.py` — `validate_addm_panel` *(peer of `hssm/rl/utils.py`)* +- `src/hssm/addm/attention_process.py` — `standard_alternating` and registry - `src/hssm/addm/likelihoods/__init__.py` -- `src/hssm/addm/likelihoods/builder.py` -- `src/hssm/addm/likelihoods/addm_jax.py` -- `src/hssm/modelconfig/addm_config.py` -- `tests/test_addm_config.py` +- `src/hssm/addm/likelihoods/builder.py` — `make_addm_logp_func`, `make_addm_logp_op` +- `src/hssm/addm/likelihoods/addm_jax.py` — thin wrapper composing attention process + vendored JAX likelihood +- `tests/addm/test_addm_config.py` +- `tests/addm/test_addm.py` +- `tests/addm/test_addm_builder_output_shape.py` +- `tests/addm/test_addm_likelihood.py` - `docs/tutorials/addm_tutorial.ipynb` **Vendored from efficient-fpt (verbatim copies, kept in their own subpackage):** @@ -219,12 +361,14 @@ Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tut ## Files to be modified -- `src/hssm/config.py` — add `aDDMConfig` dataclass. -- `src/hssm/defaults.py` — register `"addm"` in the default model list. +- `src/hssm/__init__.py` — `from .addm import aDDM` and add `"aDDM"` to `__all__` (mirrors the existing `RLSSM` import on line 22). - `src/hssm/data_validator.py` — add aDDM column-shape validation hook. - `README.md`, `mkdocs.yml` — mention the new model. -(`pyproject.toml` is **not** modified — no new dependencies are introduced. JAX is already a core dependency.) +> **Removed from "files to be modified":** +> - ~~`src/hssm/config.py`~~ — `aDDMConfig` lives in its own subpackage at `src/hssm/addm/config.py`, not in the central `config.py`. (RLSSMConfig was moved out of `hssm.config` in the rebase for the same reason.) +> - ~~`src/hssm/defaults.py`~~ — aDDM is **not** registered in `default_model_config`; users instantiate `hssm.aDDM(...)` directly, same as `hssm.RLSSM(...)`. +> - ~~`pyproject.toml`~~ — no new dependencies (JAX is already core). ## Key functions/utilities to reuse (no re-implementation) @@ -232,9 +376,11 @@ Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tut |---|---| | JAX FPT likelihood | `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` *(vendored)* | | Safe padding of saccade arrays | `hssm.addm.likelihoods.jax.pad_sacc_array_safely` *(vendored)* | -| Likelihood op wrapping pattern | [hssm/rl/likelihoods/builder.py](data/azhang/HSSM/src/hssm/rl/likelihoods/builder.py) | -| Config → standard Config conversion | [config.RLSSMConfig.to_config](data/azhang/HSSM/src/hssm/config.py#L408) | -| Model registration | [register.register_model](data/azhang/HSSM/src/hssm/register.py#L16) | +| Likelihood `Op` builder pattern | [hssm/rl/likelihoods/builder.py](data/azhang/HSSM/src/hssm/rl/likelihoods/builder.py) | +| Subclass `__init__` / `_make_model_distribution` pattern | [hssm/rl/rlssm.py](data/azhang/HSSM/src/hssm/rl/rlssm.py) | +| Family-specific config dataclass pattern | [hssm/rl/config.RLSSMConfig](data/azhang/HSSM/src/hssm/rl/config.py) | +| Abstract base for HSSM model classes | [hssm/base.HSSMBase](data/azhang/HSSM/src/hssm/base.py#L92) | +| `make_distribution` (consumes pre-built loglik `Op`) | `hssm.distribution_utils.make_distribution` | | Extra-fields propagation into logp | [data_validator.DataValidatorMixin._update_extra_fields](data/azhang/HSSM/src/hssm/data_validator.py#L156) | | Param bound enforcement | [distribution_utils.dist.apply_param_bounds_to_loglik](data/azhang/HSSM/src/hssm/distribution_utils/dist.py#L40) | @@ -244,15 +390,16 @@ Create `docs/tutorials/addm_tutorial.ipynb` mirroring the structure of `docs/tut End-to-end checks, in order: -1. **Unit**: `pytest tests/test_addm_config.py -v` — all four test classes pass, including finite-gradient check. -2. **Likelihood parity**: in `test_addm_config.py::TestaDDMLikelihood`, assert HSSM's wrapped op returns the same value (to 1e-6) as a direct call to the vendored `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` on a 10-trial fixture. This confirms the HSSM extra-fields/op-wrapping plumbing does not corrupt the underlying JAX computation. (A separate, off-CI sanity script may also compare against an installed `efficient-fpt` checkout to detect drift between the vendored copy and upstream.) -3. **Smoke sample**: `hssm.HSSM(model="addm", data=synthetic_trials).sample(draws=5, tune=5)` completes without error and returns an `InferenceData`. -4. **Parameter recovery**: larger off-CI script (e.g., `tests/scripts/addm_recovery.py`) — simulate 1000 trials with known `(eta, kappa, a, b, x0, sigma)`, fit in HSSM, confirm posterior means within ~2σ of ground truth. Reuse the recovery setup from [efficient-fpt example8_empirical/parameter_recovery.ipynb](data/azhang/efficient-fpt/examples/example8_empirical). +1. **Unit**: `pytest tests/addm/ -v` — all test files pass, including finite-gradient check. +2. **Likelihood parity**: in `tests/addm/test_addm_likelihood.py`, assert HSSM's wrapped op returns the same value (to 1e-6) as a direct call to the vendored `hssm.addm.likelihoods.jax.get_addm_fptd_jax_fast` on a 10-trial fixture. This confirms the HSSM extra-fields/op-wrapping plumbing does not corrupt the underlying JAX computation. (A separate, off-CI sanity script may also compare against an installed `efficient-fpt` checkout to detect drift between the vendored copy and upstream.) +3. **Smoke sample**: `hssm.aDDM(data=synthetic_trials, model_config=cfg).sample(draws=5, tune=5)` completes without error and returns an `InferenceData`. +4. **Parameter recovery**: larger off-CI script (e.g., `tests/scripts/addm_recovery.py`) — simulate 1000 trials with known `(eta, kappa, a, b, x0, sigma)`, fit via `hssm.aDDM(...)`, confirm posterior means within ~2σ of ground truth. Reuse the recovery setup from [efficient-fpt example8_empirical/parameter_recovery.ipynb](data/azhang/efficient-fpt/examples/example8_empirical). 5. **Tutorial runs clean**: `jupyter nbconvert --execute docs/tutorials/addm_tutorial.ipynb` finishes without errors. 6. **Docs build**: `mkdocs build` succeeds with the new tutorial in nav. ## Open questions for the user -1. **Subclass vs config-only**: confirm the config-pattern approach (no `class aDDM(HSSM)`) is acceptable, or whether a thin `hssm.aDDM` convenience class is desired on top. +1. ~~**Subclass vs config-only**~~ — *resolved by the rebase.* The new architecture (`HSSMBase` + `RLSSM` subclass) settles this: aDDM follows the subclass pattern, with both `aDDM(HSSMBase)` and `aDDMConfig(BaseModelConfig)` introduced. 2. **Non-decision time `t`**: include in v1 as an additional sampled parameter (shift RTs), or defer? 3. **Attention-process extensibility**: is the default `standard_alternating` enough, or should v1 already expose user-pluggable attention processes (e.g., non-alternating fixation patterns)? +4. **Hierarchical regression target**: should the tutorial demonstrate hierarchical regression on `eta` (attention bias) by participant, or is a different parameter (e.g., `kappa`) more meaningful as a worked example?