Skip to content

Plan for aDDM HSSM integration#953

Open
AndrewZhang599 wants to merge 4 commits into
mainfrom
addm-integration
Open

Plan for aDDM HSSM integration#953
AndrewZhang599 wants to merge 4 commits into
mainfrom
addm-integration

Conversation

@AndrewZhang599
Copy link
Copy Markdown
Collaborator

Pull request with initial plan (in Markdown) for aDDM integration (Claude-aided). The general idea is to first copy over the JAX-implemented likelihood algorithm and Gaussian-Legendre constants from the following repository: (https://github.com/RiverFlowsInYou98/efficient-fpt)

Then, mirror the RLSSM config setup, including parameters (reward values, order of drift arrays, saccade array, and dimension of saccade array) as extra_fields.

One thing I'm not entirely certain how to incorporate: including nondecision time in some form for the parameter list.

AlexanderFengler and others added 3 commits May 5, 2026 16:45
* Copy hssm.py to prepare for base class extraction

* Extract HSSM base class to hssmbase.py with refactorings and add tests

* refactor: extract init args logic to _get_init_args static method

* refactor: reorganize initialization of input data and configuration in HSSM class

* refactor: enhance comment clarity for model_config construction in HSSM class

* refactor: improve handling of user-provided model_config and choices in HSSM class

* refactor: implement model_config construction in a dedicated method

* refactor: remove slow marker from multiple test functions in test_hssmbase

* refactor: streamline model_config validation and enhance shortcut setup in HSSM class

* refactor: enhance type annotation for model_config and add validation for list_params in HSSM class

* refactor: replace DataValidator with DataValidatorMixin in HSSM and related tests

* refactor: remove unused import of bambi in test_hssmbase

* refactor: remove unused import of typing and simplify SupportedModels check in HSSM class

* refactor: simplify sample_prior_predictive calls in test_sample_prior_predictive

* refactor: correct typo in comments regarding inconsistent dimensions and coordinates

* refactor: remove unused variable assignments in test_sample_prior_predictive

* refactor: remove unused variable assignment in test_sample_prior_predictive

* refactor: remove redundant assignment in sample_prior_predictive test

* refactor: simplify HSSM instantiation in custom model tests

* refactor: implement parameter initialization in DataValidatorMixin

* refactor: assign HSSM instance to variable in test_custom_model

* refactor: enhance parameter initialization in DataValidatorMixin and add response handling in HSSM

* refactor: update parameter types in DataValidatorMixin constructor

* refactor: assign HSSM instance to variable in test_custom_model

* refactor: handle None response in response_c and response_str properties

* refactor: simplify docstring in DataValidatorMixin class

* refactor: remove unused variables in test_sample_prior_predictive

* fix: correct typo in classproperty docstring

* refactor: update condition to check for None in _update_extra_fields method

* refactor: remove unused initialization arguments and related method from HSSM class

* rename hssmbase.py to base.py

* refactor: rename HSSM class to HSSMBase for clarity and consistency

* refactor: replace HSSM with HSSMBase in test cases for consistency

* fix: update load_model and state restoration methods to reference HSSMBase instead of HSSM

* Make config a class variable

* refactor: migrate missing data tests from test_data_validator.py to test_missing_data_mixin.py

* test: add parameterized test for handling missing data as bool and float

* test: add warning handling for dropping rows when missing_data is False

* test: add error handling for invalid missing_data types in MissingDataMixin

* test: add tests for deadline handling in MissingDataMixin

* test: add additional tests for custom missing data handling and deadline logic in MissingDataMixin

* test: refactor tests in MissingDataMixin to use dummy_model fixture for consistency

* test: enhance DummyModel and fixtures for improved missing data and deadline handling

* feat: integrate MissingDataMixin into HSSM class for enhanced data handling

* refactor: move _handle_missing_data_and_deadline method missing data mixin

* feat: implement MissingDataMixin for comprehensive handling of missing data and deadlines

* feat: extend HSSMBase class with MissingDataMixin for improved data handling

* fix: resolve mypy type checking issues in MissingDataMixin for deadline handling

* test: mark test_sample_prior_predictive as expected to fail in CI

* fix: add missing newline for improved readability in test_hssmbase.py

* refactor: replace explicit choices validation with method call

* refactor: improve missing data handling and update tests for edge cases

* refactor: update tests for MissingDataMixin to handle missing data scenarios

* fix: add type ignore for choices length calculation in HSSMBase

* test: add comprehensive tests for MissingDataMixin's missing data handling

* refactor: streamline missing data and deadline handling using MissingDataMixin

* fix: remove uncessary check

* refactor: simplify network assignment logic in MissingDataMixin

* fix: remove unnecessary initialization of network in MissingDataMixin

* refactor: update test structure and improve parameterization in MissingDataMixin tests

* refactor: organize code sections with region markers in HSSMBase class

* refactor: add region markers for clarity in HSSMBase class methods

* feat: make HSSMBase an abstract class and define abstract method for model distribution

* feat: refactor HSSM class to inherit from HSSMBase and remove mixins

* fix: move data sanity check to the correct position in HSSMBase class

* Implement feature X to enhance user experience and fix bug Y in module Z

* test: remove obsolete test_hssmbase.py file

* refactor: clean up imports in hssm.py for better readability

* fix: update prior type hint in fill_defaults and from_defaults methods to include bmb.Prior

* fix: update fill_defaults method to include bmb.Prior type hint for prior parameter

* fix: add type ignore comments for model.list_params and DefaultParam.from_defaults parameters

* fix: update fill_defaults method to include bmb.Prior type hint for prior parameter

* fix: replace assertions with ValueError for loglik and list_params validation in HSSM class

* refactor: remove unused imports from base.py

* fix: update error message for missing list_params in HSSM initialization

* fix: add validation for loglik_kind in HSSM class initialization

* refactor: update comment style for clarity in _make_model_distribution method

* fix: handle None values for response and choices in HSSMBase initialization

* fix: streamline exception handling for missing list_params in HSSM initialization

* Restore init args so tests pass

* fix: update instance creation in HSSMBase to use class reference

* refactor: remove extra _set_missing_data_and_deadline method from DataValidatorMixin

* refactor: rename test class for clarity in missing data handling

* fix: update exception message regex for list_params validation in HSSM

* fix: improve error message for unspecified bounds in _make_default_prior function

* fix: ensure model_name is retrieved correctly in RLSSMConfig initialization

* fix: remove 'data' field from RLSSM_REQUIRED_FIELDS

* Use base in HSSM class

* Cast choices to list

* Fix response assertion in test_from_defaults to use list instead of tuple

* Refactor HSSM class to improve parameter handling in likelihood and distribution functions

* Update response assertions in test_from_defaults to use lists instead of tuples

* Restore hssm.py as in main

* Restore param

* Restore params

* Restore regression_params

* Restore simple param

* Restore test_hsmm

* Fix base for dimensionality problems

* Fix mypy bugs

* Remove duplicate comment regarding Bambi's kind parameter renaming

* Fix RLSSMConfig to require model_name in config_dict

* Update docstrings in HSSMBase for clarity on initial values and return types

* Fix line too long

* Add ssm_logp_func to RLSSMConfig and update validation tests

* Add RLSSM model and utilities for reinforcement learning integration

* Refactor RLSSM parameter handling and add custom prefix resolution for RL parameters

* Add tests for RLSSM class covering initialization, validation, and model structure

* Refactor loglik handling in RLSSM to improve type safety with casting

* Add NaN value check for participant column in validate_balanced_panel function

* Add validation for ssm_logp_func in RLSSMConfig to ensure it is callable and has required attributes

* Add exclude rules for ruff and mypy hooks to skip tests directory

* Add validation tests for ssm_logp_func in RLSSMConfig to ensure it is callable and properly annotated

* Add tests for NaN participant_id and unannotated ssm_logp_func in RLSSM

* Reject missing data and deadline handling in RLSSM initialization to preserve trial sequence integrity

* Add tests to validate error handling for missing data and deadline in RLSSM initialization

* Refactor path handling for loading RLDM fixture dataset in tests

* Add fixture to set floatX to float32 for module tests

* Ensure params_is_trialwise aligns with list_params in RLSSM initialization

* Clarify comments on default_priors in ModelConfig and remove unnecessary assertion for list_params

* Update RLSSM to use to_numpy(copy=True) for extra_fields and add test for independent copies

* Refactor parameter name resolution in RLSSM to handle underscores correctly and improve substring checks

* Add test for _get_prefix method in RLSSM to ensure token-based matching

* Refactor RLSSMConfig.from_rlssm_dict to remove model_name parameter and update tests accordingly

* Fix comment in test_rlssm.py to clarify output shape of log-likelihood function

* Update RLSSMConfig documentation to mark description as required

* Add ssm_logp_func to RLSSM_REQUIRED_FIELDS and update RLSSMConfig initialization

* Add dummy ssm_logp_func to tests and validate its presence in RLSSMConfig

* Remove unused logging import from rlssm.py

* Remove redundant exclude rule for ruff-format in pre-commit configuration

* Add to_model_config method to RLSSMConfig for ModelConfig conversion

* Refactor RLSSM to delegate ModelConfig construction to RLSSMConfig and simplify Op parameter handling

* Integrate Config and RLSSMConfig into HSSM and RLSSM classes for improved configuration handling

* Update choices type from list to tuple for consistency in BaseModelConfig and DataValidatorMixin

* Update choices type from list to tuple in test_constructor for consistency

* Add deprecation warnings for model_config attributes in HSSMBase

* Refactor HSSMBase to support BaseModelConfig and improve model_config handling

* Add model configuration building methods to BaseModelConfig and Config classes

* Refactor model configuration handling in HSSMBase and HSSM classes to delegate config building and improve attribute access

* Add properties to BaseModelConfig for parameter and extra field counts

* Refactor RLSSM attributes to use public naming convention for configuration and participant/trial counts

* Refactor test_rlssm_panel_attrs to use public attributes for participant and trial counts

* Refactor HSSMBase to streamline model configuration handling and update initialization parameters

* Refactor BaseModelConfig and RLSSMConfig by removing unused abstract methods and adding a new method for building validated Config instances

* Refactor HSSM class to remove Config inheritance and add initialization parameters for model configuration

* Refactor RLSSM class to remove RLSSMConfig inheritance and streamline model configuration handling

* Refactor Config and RLSSMConfig classes to use concrete types in method signatures

* Update Config class parameter types for choices to improve type safety

* Update choices method to accept a tuple for model_config.choices

* Add tests for model configuration handling and choices logic in Config

* Enhance HSSMBase initialization with safe default for constructor arguments and explicit error handling for missing snapshot

* Update model_config validation to check for non-null choices

* Refactor HSSM distribution method to use typed model_config attributes and avoid deprecated proxy properties

* Update test cases to use tuples for choices in model configuration

* Refactor RLSSM to utilize model_config for list_params and loglik, enhancing type safety and validation

* Fix typo in comment regarding model_config choices validation

* Refactor RLSSM tests to access model configuration attributes directly, ensuring consistency with updated model_config structure

* Update attribute comparison in compare_hssm_class_attributes to use model_config for model_name

* Update test assertions to access model configuration attributes directly

* Refactor model configuration normalization to streamline choices handling and improve logging

* Refactor choices handling in Config class to improve clarity and logging

* Refactor _normalize_model_config_with_choices to improve input handling and choices normalization

* Refactor likelihood callable construction to simplify logic and enhance clarity

* Refactor _make_model_distribution to utilize model_config for loglik and loglik_kind

* Fix formatting in HSSM class for consistency in likelihood callable parameters

* Fix formatting in HSSM class for consistency in likelihood callable parameters

* Refactor HSSM class to use typed model_config attributes directly and resolve loglik

* Restore make_model_dist in HSSM

* Remove deprecated properties and methods from HSSMBase class

* Enhance HSSMBase class to prevent overwriting _init_args if already set in subclasses and exclude additional internal names from locals() snapshots during re-instantiation.

* Clarify model_config parameter documentation in HSSMBase class to specify required fields and improve readability.

* Enhance HSSMBase class documentation to clarify filtering of internal names in parameter mapping for safe unpickling.

* Update model_config parameter documentation in HSSM class to support BaseModelConfig instance and clarify usage of dict for configuration.

* Add test to validate external model config fallback in _build_model_config

* Update sampling parameters in test_rlssm_sample_smoke for speed

* Add RLSSM quickstart notebook for model instantiation and sampling demonstration

* Add RLSSM Quickstart tutorial to navigation and plugins

* Remove redundant next steps and streamline summary in RLSSM quickstart notebook

* Refactor RLSSMConfig methods to simplify parameter handling and remove unused conversion tests

* Fix handling of list_params in HSSMBase to ensure proper conversion from None

* Refactor RLSSM to inject model configuration directly, removing unnecessary Config conversion

* Update TestRLSSMConfigDefaults to reflect None for default parameters instead of fixed values

* Refactor RLSSM to inject loglik and backend directly into a new RLSSMConfig instance, preserving the original configuration.

* Add validation for missing bounds in RLSSMConfig parameters

* Fix RLSSM to use model_config for ssm_logp_func and update test cases for default parameter bounds

* Enhance RLSSM tests to align params_is_trialwise with list_params and add pickle round-trip verification

* Add test to ensure RLSSMConfig.from_defaults raises NotImplementedError

* Clarify RLSSMConfig.from_defaults behavior and raise NotImplementedError for unsupported usage

* Inject JAX backend into RLSSMConfig during initialization

* Refactor RLSSM class to use model_config instead of rlssm_config for consistency

* Fix merge conflicts with base branch

* Remove commented out lines

* Remove RLSSMConfig import from __init__.py

* Reorganize import statements by moving RLSSMConfig import to the correct position

* Move RLSSMConfig import to the correct module in test files

* Update docstring in __init__.py and exports

* Remove RLSSMConfig class and its associated methods from config.py

* Move RLSSMConfig class hssm.rl module

* Refactor config.py to remove RLSSM-specific defaults and unify observed data constants

* Fix formatting of error messages in TestRLSSMConfigValidation for consistency

* Enhance validation in RLSSMConfig for ssm_logp_func attributes

* Add validation test for non-callable values in ssm_logp_func.computed

* Rename 'learning_process_loglik_kind' to 'learning_process_kind' in RLSSMConfig and related tests

* Simplify response and list_params assignment in HSSMBase by removing conditional checks

* Revert "Simplify response and list_params assignment in HSSMBase by removing conditional checks"

This reverts commit 7cf8bca.

* Refactor RLSSMConfig to dynamically retrieve required fields for validation

* Update RLSSMConfig to handle field exceptions in from_rlssm_dict method

* Fix import path for RLSSM and RLSSMConfig; correct learning_process_loglik_kind key in RLSSMConfig; update model instantiation parameter name

* Fix merge conflicts

* Fix instantiation of HSSMBase in __setstate__ method to use class reference

* Fix type hints

* Merge base branch cp-main-sb and add empty data check to validate_balanced_panel

Agent-Logs-Url: https://github.com/lnccbrown/HSSM/sessions/a86b0208-9ac2-480e-a585-320d6a0e1bbe

Co-authored-by: cpaniaguam <68481491+cpaniaguam@users.noreply.github.com>

* Fix ordering of empty-data check and use explicit None check in is_choice_only

Agent-Logs-Url: https://github.com/lnccbrown/HSSM/sessions/a86b0208-9ac2-480e-a585-320d6a0e1bbe

Co-authored-by: cpaniaguam <68481491+cpaniaguam@users.noreply.github.com>

* Update RLSSM to raise NotImplementedError for unsupported missing_data and deadline handling

* Revert "Fix ordering of empty-data check and use explicit None check in is_choice_only"

This reverts commit f30186b.

* Revert "Merge base branch cp-main-sb and add empty data check to validate_balanced_panel"

This reverts commit 38ffc8d.

* Revert "Merge remote-tracking branch 'origin/cp-main-sb' into rlssm-class-make-model-dist"

This reverts commit c696125, reversing
changes made to 1233cd7.

* Update tests to raise NotImplementedError for unsupported missing_data and deadline handling

Co-authored-by: Copilot <copilot@github.com>

* Update precommit

* Enhance test_rlssm_get_prefix to validate fallback for unknown parameters

* Move custom _get_prefix method to base

* Add validation for contiguous participant rows in validate_balanced_panel

Co-authored-by: Copilot <copilot@github.com>

* Update Config class to ignore choices when model_config is None

Co-authored-by: Copilot <copilot@github.com>

* Clarify trialwise parameter handling in RLSSM by updating p_outlier exclusion logic

Co-authored-by: Copilot <copilot@github.com>

* Refactor RLSSMConfig from_rlssm_dict method to derive required fields directly from the dataclass and improve validation for ssm_logp_func

Co-authored-by: Copilot <copilot@github.com>

* Add handling for choice-only models in MissingDataMixin and update test fixture

* Refactor is_choice_only assignment in HSSMBase to directly use model_config

Co-authored-by: Copilot <copilot@github.com>

* Enhance RLSSMConfig to log warnings for missing 'response' and 'choices' in config_dict

Co-authored-by: Copilot <copilot@github.com>

* Enhance RLSSMConfig docstring to detail fields for RLSSM likelihood pipeline

Co-authored-by: Copilot <copilot@github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <copilot@github.com>
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One high level comment:

I don't understand why so many files are touched here, seems kind of by accident.

Can you take out the plan .md file and maybe turn that into a clean PR that doesn't touch anything else?

I'll still leave my feedback here for now.

Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is already quite reasonable.

One more high level comment:
When you address the feedback also please make sure you re-organize the plan around commits for the next iteration. Steps / Stages is the right idea, but we should really force the plan to operate on clean commit boundaries.


The goal is to wire the aDDM into HSSM so that users can write:

```python
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, arguably, we would go for an ADDM (or even ASSM) class instead.
This would follow the RLSSM pattern.
@krishnbera we might want to keep RLASSM future capability in mind some of the decisions here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RLASSM might be better served if ASSM exists as a model under HSSM class? Ofc, internally ASSM will have a separate branch to handle ASSM-specific things. This would avoid bloating the HSSM package with various interoperability layers (RL <-> SSM, RL <-> ASSM, etc.).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And conceptually speaking, ASSM is just a decision process with a specific kind of covariate after all.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is not much conceptual difference from RLSSM, not adding a new class is probably the best idea, so the focus can just be on writing the decision process function


where `addm_trial_df` contains standard columns (`rt`, `response`) plus **aDDM-specific per-trial arrays** (item values, fixation onsets, fixation counts, first-fixation flag). The aDDM needs per-trial covariates that are *not* themselves sampled parameters — exactly the pattern RLSSM already solves in HSSM. We therefore follow the RLSSM design so that aDDM lives alongside it rather than carving a new architectural lane.

The intended outcome is a working `aDDM(...)` class (and matching `aDDMConfig`) inside HSSM that (a) validates aDDM-specific trial data, (b) composes the vendored JAX FPT likelihood with sampled parameters `{eta, kappa, sigma, a, b, x0, t}` (non-decision time optional), (c) exposes the standard HSSM hierarchical regression and sampling machinery, and (d) ships with a tutorial notebook and unit tests.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's more like it, seems to contradict the earlier statement?!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is confusing


Given this new architecture, the plan creates **both**:

1. An `aDDM(HSSMBase)` concrete subclass — a peer of `HSSM` and `RLSSM`, exported from `hssm/__init__.py` as `hssm.aDDM`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to name ADDM for camelcase. Slightly confusing why the LLM didn't suggest that out of the box.

High level comment:
Which model are you using here?

5. **Do not** vendor `efficient_fpt_jax/batch.py` (not used by the per-trial likelihood we wrap), nor anything from the `efficient_fpt` (Cython/NumPy) subpackage — that path is the simulator and is not part of inference.
6. **Do not** add `efficient-fpt` to `pyproject.toml`. HSSM already depends on `jax`/`jaxlib`, so the vendored code introduces no new dependencies.

**License/attribution:** efficient-fpt ships under a permissive license (see `efficient-fpt/LICENSE`); copy that license text into `src/hssm/addm/likelihoods/jax/LICENSE` (or `NOTICE`) so attribution travels with the code. If HSSM and efficient-fpt share authors/license, a brief `# Adapted from efficient-fpt (commit <sha>)` header is sufficient.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frankmj note this licensing stuff.
Didn't consider this when working on the addm paper code itself. Any advice on how to handle this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the relevant parts as a dependency would solve this(?). See https://github.com/lnccbrown/HSSM/pull/953/changes#r3210200016


**License/attribution:** efficient-fpt ships under a permissive license (see `efficient-fpt/LICENSE`); copy that license text into `src/hssm/addm/likelihoods/jax/LICENSE` (or `NOTICE`) so attribution travels with the code. If HSSM and efficient-fpt share authors/license, a brief `# Adapted from efficient-fpt (commit <sha>)` header is sufficient.

**Rationale:** efficient-fpt is not on PyPI; its Cython compile chain is heavyweight and irrelevant to HSSM (HSSM only needs the inference likelihood, not the simulator); and pinning to a remote git dep would couple HSSM's CI to an unstable upstream. Vendoring lets HSSM ship a frozen, audited copy that evolves on HSSM's release cadence.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndrewZhang599 here we might want to inject some context around parallel ssm-simulator work.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add this package (and the proposed lightweight dependency) to PyPI?


**Rationale:** efficient-fpt is not on PyPI; its Cython compile chain is heavyweight and irrelevant to HSSM (HSSM only needs the inference likelihood, not the simulator); and pinning to a remote git dep would couple HSSM's CI to an unstable upstream. Vendoring lets HSSM ship a frozen, audited copy that evolves on HSSM's release cadence.

**Drift management:** efficient-fpt continues to be the research home for the likelihood. When upstream changes, the vendored copy can be re-synced by re-copying the three files and rebuilding tests. The upstream-commit header in step 4 makes "what version are we on?" trivially answerable. (This is the same pattern HSSM already uses for `bayesflow` from a dev branch, just snapshotted instead of git-tracked.)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not necessary tbh


### Step 3 — Add `aDDMConfig` dataclass in `src/hssm/addm/config.py`

**Critical file:** [src/hssm/addm/config.py](data/azhang/HSSM/src/hssm/addm/config.py) (new file) — peer of [src/hssm/rl/config.py](data/azhang/HSSM/src/hssm/rl/config.py).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

genercal comment: These file-paths should not be personalized, so that the plan isn't machine specific.

**Key design decisions (post-rebase):**

- **No `to_config()` method.** The new architecture has `HSSMBase` accept any `BaseModelConfig`; family-specific subclasses build the `loglik` `Op` themselves and stamp it onto the config via `dataclasses.replace(...)`. `RLSSMConfig` no longer has `to_config()` and neither will `aDDMConfig`.
- **No `from_defaults` registration.** Like `RLSSMConfig`, `aDDMConfig` raises `NotImplementedError` from `from_defaults`. Users construct it explicitly (or via a `from_addm_dict` classmethod, optional). Therefore **`aDDM` is *not* registered through the `default_model_config` / `register_model` pipeline** that `HSSM(model="ddm", ...)` uses — instead, users instantiate `hssm.aDDM(...)` directly.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is something that we want to change about RLSSM class as well (@krishnbera).

In fact, iiuc, we DO want the ability to go .from_defaults() and have a pre-registered aDDM.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this registration/name-based referencing will be implemented soon for RLSSM. so it naturally extends here as well.

- **No `to_config()` method.** The new architecture has `HSSMBase` accept any `BaseModelConfig`; family-specific subclasses build the `loglik` `Op` themselves and stamp it onto the config via `dataclasses.replace(...)`. `RLSSMConfig` no longer has `to_config()` and neither will `aDDMConfig`.
- **No `from_defaults` registration.** Like `RLSSMConfig`, `aDDMConfig` raises `NotImplementedError` from `from_defaults`. Users construct it explicitly (or via a `from_addm_dict` classmethod, optional). Therefore **`aDDM` is *not* registered through the `default_model_config` / `register_model` pipeline** that `HSSM(model="ddm", ...)` uses — instead, users instantiate `hssm.aDDM(...)` directly.
- `extra_fields` defaults to the five aDDM-specific columns the JAX likelihood needs; these flow through the existing extra-fields machinery (data validator → `Op` `*args`) the same way they do for RLSSM.
- `attention_process` is a pluggable hook (default `"standard_alternating"`) that maps `(r1, r2, flag, eta, kappa) → mu_array_padded` per trial. This mirrors `RLSSMConfig.learning_process` semantically (declarative documentation; the actual callable is resolved by the builder).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is important, and relates to above.

- **`aDDM` is a peer of `HSSM` and `RLSSM`**, all three inheriting from `HSSMBase`. It is exported from `hssm/__init__.py` as `hssm.aDDM`.
- **`_make_model_distribution` is overridden** (same as `RLSSM`) because the aDDM `Op` already encapsulates the attention process + per-trial vmap; the standard `make_likelihood_callable` dispatch on `loglik_kind` should be bypassed.
- **No `participant_col`-style panel reshape** — unlike RLSSM (which reshapes rows into `(n_participants, n_trials, ...)` because the RL learning rule is per-subject), aDDM's likelihood is per-trial. The vmap inside the JAX `Op` handles the trial dimension; participants flow through bambi/HSSM hierarchical regression as usual.
- **`missing_data` / `deadline` are rejected up front**, same as RLSSM, because rearranging rows would break the strict trial→`sacc_array`-row correspondence.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as v0, eventually we need to develop a strategy around it.
We need to add an issue to address this later.


The goal is to wire the aDDM into HSSM so that users can write:

```python
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RLASSM might be better served if ASSM exists as a model under HSSM class? Ofc, internally ASSM will have a separate branch to handle ASSM-specific things. This would avoid bloating the HSSM package with various interoperability layers (RL <-> SSM, RL <-> ASSM, etc.).


The goal is to wire the aDDM into HSSM so that users can write:

```python
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And conceptually speaking, ASSM is just a decision process with a specific kind of covariate after all.

Given this new architecture, the plan creates **both**:

1. An `aDDM(HSSMBase)` concrete subclass — a peer of `HSSM` and `RLSSM`, exported from `hssm/__init__.py` as `hssm.aDDM`.
2. An `aDDMConfig(BaseModelConfig)` dataclass living in `src/hssm/addm/config.py`, peer of `RLSSMConfig` in `src/hssm/rl/config.py`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would a canonical aDDM config look like? Giving/generating an example here would help the coding agents/docs/human implementation.

"instance built explicitly."
)

def validate(self) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can be explicit about the validation checks we need here.

**Key design decisions (post-rebase):**

- **No `to_config()` method.** The new architecture has `HSSMBase` accept any `BaseModelConfig`; family-specific subclasses build the `loglik` `Op` themselves and stamp it onto the config via `dataclasses.replace(...)`. `RLSSMConfig` no longer has `to_config()` and neither will `aDDMConfig`.
- **No `from_defaults` registration.** Like `RLSSMConfig`, `aDDMConfig` raises `NotImplementedError` from `from_defaults`. Users construct it explicitly (or via a `from_addm_dict` classmethod, optional). Therefore **`aDDM` is *not* registered through the `default_model_config` / `register_model` pipeline** that `HSSM(model="ddm", ...)` uses — instead, users instantiate `hssm.aDDM(...)` directly.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this registration/name-based referencing will be implemented soon for RLSSM. so it naturally extends here as well.

- `data[:, 0]` = rt, `data[:, 1]` = response
- `args` are sampled parameters in `list_params` order: `eta, kappa, sigma, a, b, x0`
- extra fields `r1, r2, sacc_array, d, flag` are appended to `args` by the HSSM extra-field machinery (exactly as RLSSM does; see [data_validator.py:156](data/azhang/HSSM/src/hssm/data_validator.py#L156)).
- Internally: call the attention process to build `mu_array_padded`, then call `get_addm_fptd_jax_fast(t=rt, d=d, mu_array=mu_array, sacc_array=sacc_array, sigma, a, b, x0)` vmapped over trials.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderFengler @AndrewZhang599
I am not too familiar with aDDM, but would be worthwhile to think about the fundamental dimension for vmap here. Is it possible to do better than trial dim?

data: pd.DataFrame,
model_config: aDDMConfig,
include: list | None = None,
p_outlier: float | dict | bmb.Prior | None = 0.05,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i assume outliers must be a critical aspect of aDDM (i believe there is a measurement component to it unlike the feedback component in RLSSM). are there basic validation checks we can make on this front? for v0, it is ok to ignore this aspect.


The attentional drift diffusion model (aDDM; Krajbich et al.) extends the standard DDM by modulating the drift rate based on **which option the subject is currently fixating**. A fast, differentiable JAX likelihood for the aDDM has been prototyped in the sibling repo [efficient-fpt](data/azhang/efficient-fpt) — specifically `get_addm_fptd_jax_fast` in [src/efficient_fpt_jax/multi_stage.py](data/azhang/efficient-fpt/src/efficient_fpt_jax/multi_stage.py).

**Integration approach: vendor, do not depend.** Rather than add `efficient-fpt` as a dependency (it is not on PyPI, it ships compiled Cython for the simulator path that HSSM does not need, and bringing it in pulls a heavy build chain into HSSM's install), we **copy the relevant pure-JAX modules into HSSM** and own them going forward. HSSM already depends on `jax`/`jaxlib`, so the vendored code adds zero new transitive dependencies. The simulator (Cython) and NumPy/CPU paths from efficient-fpt are *not* vendored; only the `efficient_fpt_jax` subpackage's likelihood code.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who owns/maintains the efficient-fpt package? I think it'd be sensible to split it into two parts, the simulators which are Cython dependent and the rest which is what's needed for this use case. This would avoid passing on the heavy Cython dependency to downstream projects unnecessarily. The new lightweight package could serve as a common dependency for both hssm and efficient_fpt.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prob won't keep efficient-fpt around (it's a repo attached to a methods paper), just extract the core likelihoods into HSSM.

The forward simulation part of efficient-fpt will go into ssm-simulators.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants