Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Associate types of intervenors with types of staged models #17

Open
mlprt opened this issue Feb 23, 2024 · 1 comment
Open

Associate types of intervenors with types of staged models #17

mlprt opened this issue Feb 23, 2024 · 1 comment
Labels
enhancement New feature or request typing

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 23, 2024

Currently, to add a curl force field to a SimpleFeedback model, we need to write something like:

from feedbax.intervene import CurlField, add_intervenor

model_curl = add_intervenor(
    model, 
    CurlField.with_params(amplitude=-0.5),  # negative -> clockwise
    where=lambda m: m.step.mechanics,
)

However, the only part of model: SimpleFeedback to which it makes sense to add a CurlField, is model.step.mechanics. And if we were simulating a Mechanics instance directly instead of wrapping it in SimpleFeedback, it would make sense to add the intervention to model.step.

In principle Feedbax could recognize this, and not require that we specify it.

model_curl = add_intervenor(
    model, 
    CurlField.with_params(amplitude=-0.5),
)

A potential solution: for each type of intervenor, type it by (or assign it with) the subclass of AbstractStagedModel that it makes sense to add it to. When add_intervenor is called, we can automatically figure out where that model type lives in the tree, such as by using equinox.tree_at.

The outcome of this operation is ambiguous, in the case that the model PyTree contains multiple nodes of the same type of AbstractStagedModel, or in the case of general interventions like AddNoise whose associated type should be AbstractStagedModel itself. So there would also need to be some mechanism to determine when there are multiple instances to which the intervention could be added, and perhaps an argument to add_intervenor that would determine whether the intervention should be added to all of them, or an error should be raised if a disambiguating where has not been passed.

@mlprt mlprt added enhancement New feature or request typing labels Feb 23, 2024
@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

This seems related to #24, but I'm not sure exactly how.

  • A general intervenor like AddNoise may only expect that its input/output both be PyTree[Array]. Many different state types will have fields that satisfy this.
  • A more specific intervenor (e.g. CurlField) is particular to a type of state, e.g. MechanicsState. But what larger state PyTree is MechanicsState a field of? SimpleFeedbackState? We're still left with the problem of locating the MechanicsState inside the full PyTree of model states, or else expecting that the user will provide a where.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request typing
Projects
None yet
Development

No branches or pull requests

1 participant