You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A major design motivation for Feedbax is the common use case where a researcher wants to intervene on an existing optimal control experiment. In this issue, I describe the approach I've taken to this problem, and my uncertainties about it.
Currently, Feedbax implements the following solution: models are defined by Equinox modules of type AbstractStagedModel. Each type of model is treated as a series of operations performed on a shared PyTree of states. All state operations (AKA stages) are defined in a consistent way, each as a collection of three things: 1) a model component to be called, 2) a function that selects the subset of model inputs/states to pass to the component, and 3) a function that selects the subset of the state that the component returns/updates.
To define a new staged model, we subclass AbstractStagedModel and implement the property model_spec, where those three things are defined for each of the model's stages. AbstractStagedModel implements __call__ itself, to perform the state operations defined in model_spec. For a more in depth description, see the documentation.
What kind of PyTree is model_spec?
Currently, model_spec is defined as property of type OrderedDict[str, ModelStage]. We use a mapping because it's nice for the stages to have names which can be referred to by the user. However, we cannot use a dict, because—while its entries maintain their insertion order since Python 3.7—its keys get sorted during PyTree flatten/uflatten operations. OrderedDict doesn't have the same problem.
[ModelStage] is an Equinox module whose fields describe the "three things" that define a stage. Using a module rather than (say) a tuple, makes model_spec a little more readable. However, there have been some typing issues with ModelState: Typing ModelStage #23.
Model state objects
AbstractStagedModel is generic, and each of its final subclasses has a type argument that's some final subclass of equinox.Module. This is the type of state PyTree operated on by the model. Different staged models may operate on the same type of state object.
A subclass of AbstractStagedModel may be composed of other types of AbstractStagedModel, in which case the state PyTrees associated with the higher-level model tend to be composites of the state PyTrees associated with the components.
To subclass AbstractStagedModel we also have to implement an init method which takes a key, and returns a default instance of the model's state PyTree. I refer to this as "default state" and not "initial state" to distinguish it from the state that has been updated (e.g. placing the arm at its starting position) at the beginning of a trial, based on the specifications provided by a task. See the documentation for a description of how these initial states are specified.
Having defined the model's computation as a series of state operations, the user can now insert interventions between the stages of an existing model, without needing to alter its source code. How?
All subclasses of AbstractStagedModel must include (#20) a field interventions: Mapping[str, Sequence[AbstractIntervenor]], which maps from the names of model stages, to one or more instances of AbstractIntervenor. By performing surgery on this field, we can modify an existing model with interventions. AbstractStagedModel.__call__ automatically interleaves the state operations defined in intervenors, with those in model_spec.
For more on what a subclass of AbstractIntervenor looks like, see the docs.
What issues might there be with this approach?
Writing model_spec instead of an imperative __call__ is probably a little confusing, at first.
Any model that we want to intervene on using an AbstractIntervenor, we have to (re)write as an AbstractStagedModel with a model_spec. All the states we expect the user might want to intervene on, must be included as fields in the respective state PyTree.
It is possible to use non-staged/vanilla Equinox modules as components, but they will essentially be black boxes that transform one part of their owner's state into another part, without including any of their own internal states in the composite PyTree of states. See Include any neural network as a component of a staged model #2.
It's not obvious how to store the history of the outputs of an intervenor. An extra field intervenors: dict[str, PyTree[Array]] could be added to the state PyTree of the model it belongs to, into which intervenor "states" could be inserted... however this might lead to issues with inconsistent PyTree structure.
Is there some other solution that would allow users to insert interventions into arbitrary points in a model, without needing to modify the model's source at intervention time? Perhaps there is a solution with hooks/callbacks that could work, especially if our models were stateful objects like they might be in PyTorch, and if we didn't need to pass around state PyTrees. But I'm not sure a solution like that is desirable in a JAX library, or what it would look like.
Returning now to the general design philosophy. Consider that in principle, there need only be a single, final class StagedModel that has a single, trivial model stage that on its own does nothing to the state. Any potential subclass of AbstractStagedModel we might want to build, could be replaced by a constructor that returns instances of this hypothetical StagedModel, but with an appropriate sequence of interventions inserted before each instance's single stage. That is, interventions and model stages both define operations on a model's state, and in principle they are interchangeable, though they are (currently) represented differently.
So, which state operations do we include in a model to begin with, and which do we leave to potentially be defined as interventions? That's an important tradeoff our approach leaves us with. I suspect there's no avoiding that problem—no free lunch. The people designing models will always need to rely on their domain expertise not to presume too much, or too little.
This discussion was converted from issue #19 on March 04, 2024 13:39.
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
Menu
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
A major design motivation for Feedbax is the common use case where a researcher wants to intervene on an existing optimal control experiment. In this issue, I describe the approach I've taken to this problem, and my uncertainties about it.
Currently, Feedbax implements the following solution: models are defined by Equinox modules of type
AbstractStagedModel
. Each type of model is treated as a series of operations performed on a shared PyTree of states. All state operations (AKA stages) are defined in a consistent way, each as a collection of three things: 1) a model component to be called, 2) a function that selects the subset of model inputs/states to pass to the component, and 3) a function that selects the subset of the state that the component returns/updates.To define a new staged model, we subclass
AbstractStagedModel
and implement the propertymodel_spec
, where those three things are defined for each of the model's stages.AbstractStagedModel
implements__call__
itself, to perform the state operations defined inmodel_spec
. For a more in depth description, see the documentation.What kind of PyTree is
model_spec
?Currently,
model_spec
is defined as property of typeOrderedDict[str, ModelStage]
. We use a mapping because it's nice for the stages to have names which can be referred to by the user. However, we cannot use adict
, because—while its entries maintain their insertion order since Python 3.7—its keys get sorted during PyTree flatten/uflatten operations.OrderedDict
doesn't have the same problem.[
ModelStage
] is an Equinox module whose fields describe the "three things" that define a stage. Using a module rather than (say) a tuple, makesmodel_spec
a little more readable. However, there have been some typing issues withModelState
: TypingModelStage
#23.Model state objects
AbstractStagedModel
is generic, and each of its final subclasses has a type argument that's some final subclass ofequinox.Module
. This is the type of state PyTree operated on by the model. Different staged models may operate on the same type of state object.A subclass of
AbstractStagedModel
may be composed of other types ofAbstractStagedModel
, in which case the state PyTrees associated with the higher-level model tend to be composites of the state PyTrees associated with the components.To subclass
AbstractStagedModel
we also have to implement aninit
method which takes a key, and returns a default instance of the model's state PyTree. I refer to this as "default state" and not "initial state" to distinguish it from the state that has been updated (e.g. placing the arm at its starting position) at the beginning of a trial, based on the specifications provided by a task. See the documentation for a description of how these initial states are specified.Having defined the model's computation as a series of state operations, the user can now insert interventions between the stages of an existing model, without needing to alter its source code. How?
All subclasses of
AbstractStagedModel
must include (#20) a fieldinterventions: Mapping[str, Sequence[AbstractIntervenor]]
, which maps from the names of model stages, to one or more instances ofAbstractIntervenor
. By performing surgery on this field, we can modify an existing model with interventions.AbstractStagedModel.__call__
automatically interleaves the state operations defined inintervenors
, with those inmodel_spec
.For more on what a subclass of
AbstractIntervenor
looks like, see the docs.What issues might there be with this approach?
model_spec
instead of an imperative__call__
is probably a little confusing, at first.AbstractIntervenor
, we have to (re)write as anAbstractStagedModel
with amodel_spec
. All the states we expect the user might want to intervene on, must be included as fields in the respective state PyTree.intervenors: dict[str, PyTree[Array]]
could be added to the state PyTree of the model it belongs to, into which intervenor "states" could be inserted... however this might lead to issues with inconsistent PyTree structure.Is there some other solution that would allow users to insert interventions into arbitrary points in a model, without needing to modify the model's source at intervention time? Perhaps there is a solution with hooks/callbacks that could work, especially if our models were stateful objects like they might be in PyTorch, and if we didn't need to pass around state PyTrees. But I'm not sure a solution like that is desirable in a JAX library, or what it would look like.
Returning now to the general design philosophy. Consider that in principle, there need only be a single, final class
StagedModel
that has a single, trivial model stage that on its own does nothing to the state. Any potential subclass ofAbstractStagedModel
we might want to build, could be replaced by a constructor that returns instances of this hypotheticalStagedModel
, but with an appropriate sequence of interventions inserted before each instance's single stage. That is, interventions and model stages both define operations on a model's state, and in principle they are interchangeable, though they are (currently) represented differently.So, which state operations do we include in a model to begin with, and which do we leave to potentially be defined as interventions? That's an important tradeoff our approach leaves us with. I suspect there's no avoiding that problem—no free lunch. The people designing models will always need to rely on their domain expertise not to presume too much, or too little.
Beta Was this translation helpful? Give feedback.
All reactions