Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelInput and passing intervenor parameters to submodels #12

Open
mlprt opened this issue Feb 19, 2024 · 0 comments
Open

ModelInput and passing intervenor parameters to submodels #12

mlprt opened this issue Feb 19, 2024 · 0 comments

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 19, 2024

All Feedbax models (subclasses of AbstractModel) have the signature (input: ModelInput, state: StateT, *, key: PRNGKeyArray) -> StateT. StateT is bound to AbstractState in AbstractModel. StateT is now bound to PyTree[Array] -- see #24.

In AbstractStagedModel subclasses (#19) we perform a sequence of state operations by passing subsets of input and state to model components, and using the return values to make out-of-place updates to state.

An AbstractModel is generally a PyTree containing other AbstractModel nodes; i.e. Feedbax models are hierarchical. Typically, the outermost node in the model PyTree is an instance of Iterator, which is essentially a loop over a single step of the model (e.g. a SimpleFeedback instance) where all of the actual state operations happen.

The input to the outermost model node is not selected from the input to another AbstractModel that contains it, because there is none. Instead, its input is provided by an instance of AbstractTask. This task information is any trial-by-trial data that is unconditional on the internal operations of the model. For example, a reaching task like SimpleReaches will provide the model with the goal position it is expected to reach to, and the model will ultimately forward this to the controller (neural network) component.

An issue arises when we need to schedule interventions on a task/model that already exists. Interventions may change on a trial-by-trial basis. Any systematic trial-by-trial variations are specified by an AbstractTask. In particular, if the parameters of an AbstractIntervenor are expected to change across trials, then an AbstractTask should provide those changing parameters as part of the input to the model. The model will then need to make sure that these parameters are matched up to the right instance of AbstractIntervenor.

Perhaps there is a way for schedule_intervenor to work with AbstractTask to structure the intervention parameters so that, at each level of the model, AbstractStagedModel.__call__ can be made to send them on to the right component, until they reach the component that contains the instance of AbstractIntervenor they pertain to. I have not figured out how to do this.

My current solution is, when an intervenor is scheduled with schedule_intervenor, to assign it a unique string label among all the intervenors aggregated over all levels of a model PyTree. Then, intervention parameters are included in input as a flat mapping from the unique labels, to parameters. This flat mapping is passed as-is down through the hierarchy of model components; every AbstractStagedModel sees the same mapping, and simply tries to match the unique labels of its own intervenors, with those in the mapping.

This is what ModelInput is for: it's an eqx.Module with two fields, input and intervene: input contains the usual task information which, once it reaches the outermost AbstractStagedModel in the model, is selectively passed on to certain component(s) depending on the definition of model_spec (again, typically it's all sent to the neural network). On the other hand, intervene contains the flat mapping of intervention parameters, and is passed on as-is.

So, in AbstractStagedModel.__call__ we see something like:

feedbax/feedbax/_staged.py

Lines 152 to 160 in 8f080c6

callable_ = stage.callable(self)
subinput = stage.where_input(input.input, state)
# TODO: What's a less hacky way of doing this?
# I was trying to avoid introducing additional parameters to `AbstractStagedModel.__call__`
if isinstance(callable_, AbstractModel):
callable_input = ModelInput(subinput, input.intervene)
else:
callable_input = subinput

Here, we:

  1. We need to pass a subset of the model inputs to the current stage: select subinput out of input.input—I haven't thought of a better name. Maybe input.task_input.
  2. If the component to be called is an AbstractModel, it accepts ModelInput and might contain interventions. Therefore we pass a reconstructed ModelInput with the same intervene value (i.e. the flat mapping), but with only subinput as input.

This seems pretty hacky to me and I'm not sure how it should be done better. I've considered adding another argument to the signature of AbstractModel, but that doesn't seem better. Also, I suppose I don't have to use ModelInput at all, and could just type input as a tuple.

@mlprt mlprt pinned this issue Feb 19, 2024
@mlprt mlprt changed the title Structure of model inputs, and intervenor specifications ModelInput and passing intervenor parameters to submodels Feb 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant