Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple feedback channels and MultiModel #30

Open
mlprt opened this issue Mar 4, 2024 · 0 comments
Open

Multiple feedback channels and MultiModel #30

mlprt opened this issue Mar 4, 2024 · 0 comments

Comments

@mlprt
Copy link
Owner

mlprt commented Mar 4, 2024

SimpleFeedback allows for multiple channels of feedback to the neural network, with different delays and noise. For example, one typical configuration is to feed back "proprioceptive" variables (e.g. arm joint configuration, or muscle states) at a short delay, and "visual" variables (e.g. position of the end of the arm) at a longer delay.

In particular, SimpleFeedback:

  1. has a field channels: PyTree[Channel]. At construction time, the user supplies a PyTree[ChannelSpec] or a container of mappings; this is used to construct a PyTree[Channel], which is used to construct a MultiModel . However, see: Remove feedbax.channel.ChannelSpec? #3.
  2. has a model stage "update_feedback":

    feedbax/feedbax/bodies.py

    Lines 175 to 183 in b73dfb8

    "update_feedback": Stage(
    callable=lambda self: self._feedback_module,
    where_input=lambda input, state: jax.tree_map(
    lambda spec: spec.where(state.mechanics),
    self._feedback_specs,
    is_leaf=lambda x: isinstance(x, ChannelSpec),
    ),
    where_state=lambda state: state.feedback,
    ),

MultiModel is a subclass of AbstractModel that has a field models: PyTree[AbstractModel], and expects to be passed state and input whose tree structures match that of models. When called, it maps the models, inputs, and states:

feedbax/feedbax/_model.py

Lines 129 to 146 in b73dfb8

def __call__(
self,
input: ModelInput,
state: PyTree[StateT, "T"],
key: PRNGKeyArray,
) -> StateT:
# TODO: This is hacky, because I want to pass intervenor stuff through entirely. See `staged`
return jax.tree_map(
lambda model, input_, state, key: model(
ModelInput(input_, input.intervene), state, key
),
self.models,
input.input,
state,
self._get_keys(key),
is_leaf=lambda x: isinstance(x, AbstractModel),
)

The PyTree structure of input.input matches models because of the tree_map performed in the definition of where_input for the "update_feedback" stage.

The structure of states matches too, because MultiModel (like all AbstractModel subclasses) provides an init method that returns a PyTree[ChannelState], and this is used to generate any initial state that is passed to the model.

Is there a better way to include a PyTree of similar components in a model, that are all executed as part of a single model stage? With the current approach, intervenors can be added to individual Channel objects, but it may be kind of inconvenient to refer to those objects (e.g. my_simple_feedback.channels.models['vision']).

I'm not sure how vmapping could be used here, as different channels can carry data of different shapes and dtypes.

The use of ModelInput in MultiModel is also not ideal, in particular because I think it makes sense for MultiModel to be a subclass of AbstractModel and not AbstractStagedModel; however, ModelInput is specifically used for carrying intervenor parameters along with other model inputs, and intervenors are associated with instances of AbstractStagedModel, and not with AbstractModel in general. See #12 for a more general discussion of ModelInput.

@mlprt mlprt added the structure label Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant