Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface to model objects, and Iterator #6

Open
mlprt opened this issue Feb 16, 2024 · 1 comment
Open

Interface to model objects, and Iterator #6

mlprt opened this issue Feb 16, 2024 · 1 comment

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 16, 2024

This issue is about converting single model steps to iterated models, and how this affects the model's PyTree structure, and references made to its components.

Generally, models (such as SimpleFeedback) are defined as a single iteration, and then wrapped in an Iterator object -- which is more or less a jax.lax loop.

However:

  • Almost the entire model PyTree is under model.step.*, from the user's perspective.
    • For example, when they pass a where_train when calling a TaskTrainer, they generally need to specify it like lambda model: model.step.net.
    • Similarly, whenever performing model surgery or the like, most references will be to model.step.*.
    • This differs from the structure of the state PyTree. For example, we have model.step.net but states.net. This is because Iterator adds a time dimension to the arrays in states.
  • In certain cases, we might have a model whose top level is not an Iterator, but which we will try to interact with using code that refers to model.step.
    • Currently, all AbstractModels provide a step property, which trivially returns self when the model is not an Iterator. AbstractIterator instead returns self._step, which is the field that the rest of the model PyTree is actually assigned to.
    • Should we be stricter about types, and (say) always assume that TaskTrainer is passed a model wrapped in Iterator?
  • In TaskTrainer._train_step we have to get initial states for the model, for all trials in a batch.
    • We start by vmapping model.init to obtain a default state. Since the input state to an iterated model is the same as to the model step, Iterator.init just returns self.step.init.
    • After _train_step obtains this default state, it modifies parts of it using state initialization data provided for the current batch of training trials (by the AbstractTask object). Then, it is necessary to ask the model to make sure that the state is internally consistent.
      • For example, the AbstractTask will typically give an initial position for the effector (e.g. arm endpoint). From the effector position we need to infer and update the mechanical configuration (e.g. joint angles). This is only necessary prior to the first time step for the trials in the batch, after which the states will be internally consistent by virtue of the model's operations. Thus we have a method AbstractStagedModel.state_consistency_update which is called once in _train_step.
      • So, do we add def _state_consistency_update(self): return self.step.state_consistency_update to Iterator similarly to what we've done with init? Currently, _train_step calls model._step.state_consistency_update.

I have considered modifying TaskTrainer to handle the model iteration over time, so the user does not explicitly instantiate an Iterator, and can refer to model.* instead of model.step.*. This would make sense in light of AbstractTask providing model inputs as trajectories over time, which Iterator indexes from using tree_take -- such that it does not make sense to use a non-iterated model with TaskTrainer. Should this change be adopted?

@mlprt
Copy link
Owner Author

mlprt commented Feb 26, 2024

This may be solved when #21 is.

If the model step is no longer a component of an AbstractIterator but is merely passed to its __call__ method, then the iterated model is described as a Tuple[AbstractIterator, AbstractModel] rather than as simply an AbstractIterator which is composed of an AbstractModel.

In that case, TaskTrainer would not need to implement model iteration over time, but could still compose an instance of AbstractIterator so that the user doesn't need to pass around Tuple[AbstractIterator, AbstractModel] when they train the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant