# Models with stages

In this example, we'll consider the basic structure of Feedbax models, how that makes it possible to modify them with arbitrary interventions, and how to write models of your own that you can apply interventions to. 

Feedbax models are defined as Equinox `Module` objects, which may be composed of other `Module` objects, forming a nested structure of model components and their parameters. [Ref](/feedbax/examples/pytrees/#equinox)

After constructing a model, we use it by calling it like a function. This is possible because we define its `__call__` method. Giving an object a method with this specific name is the standard way of making the object behave like a function, in Python.

A lot can happen inside of `__call__`. Consider the following abridged definition of [`feedbax.bodies.SimpleFeedback`][feedbax.bodies.SimpleFeedback], which is a model of a single time step of a neural network sending a command to a mechanical model, based on sensory feedback it receives:

```python
class SimpleFeedbackState(eqx.Module):
    mechanics: MechanicsState
    network: NetworkState
    feedback: ChannelState


class SimpleFeedback(eqx.Module):
    """Model of one step around a feedback loop between a neural network 
    and a mechanical model.
    """
    net: eqx.Module  
    mechanics: Mechanics 
    feedback_channel: Channel
    where_feedback: Callable[[SimpleFeedbackState], PyTree] = \
        lambda state: state.mechanics.plant.skeleton
    
    def __call__(
        self, 
        input: PyTree[Array],  
        state: SimpleFeedbackState, 
        key: PRNGKeyArray,
    ) -> SimpleFeedbackState:
                
        key1, key2 = jr.split(key)
        
        feedback_state = self.feedback_channel(
            self.where_feedback(state),
            state.feedback,
            key1,
        )
        
        network_state = self.net(
            (input, feedback_state.output), 
            state.network, 
            key2
        )
        
        mechanics_state = self.mechanics(
            network_state.output, 
            state.mechanics
        )        
        
        return SimpleFeedbackState(
            mechanics=mechanics_state, 
            network=network_state,
            feedback=feedback_state,
        )
    
    # ...
    # We've omitted a bunch of other stuff from this definition!
```

There are several things to notice here:

1. `SimpleFeedback` is an Equinox `Module` subclass, and it's composed of other `Module` objects. `Mechanics` and `Channel` are also `Module`s defined in a similar way, with their own parameters and submodules. This is why we can refer to `model.mechanics.plant.input_size`, for example, when `model` is a `SimpleFeedback` object.
2. `__call__` takes a model state—a `SimpleFeedbackState`—and returns a new one. Like all Feedbax models, it also takes `input`, which is any input to the model in addition to its prior state. In general, the input to a model may be an arbitrary tree structure of data stored in JAX arrays (`PyTree[Array]`).
3. `__call__` contains several steps. Each step involves calling one of the components of the model, which takes its own part of the model state, and returns a new version of that state. Each component also takes some other information as input, in addition to its prior state.     

    === "`self.feedback_channel`"
        
        - takes `state.feedback` (a `ChannelState` object) as its prior state;
        - also takes as input `self.where_feedback(state)`, which is the part of `state` we want to store in the state of the feedback channel, to be retrieved on some later time step, depending on the delay;
        - returns an updated `Channel_State`, which we assign to `feedback_state`.
        
        Note that the default for `self.where_feedback` is `lambda state: state.mechanics.plant.skeleton`, which means that our sensory feedback consists of the full state of the skeleton—typically, the positions and velocities of some joints.
        
    === "`self.net`"
    
        - takes `state.network` (a `NetworkState` object) as prior state;
        - also takes as input `(input, feedback_state.output)`—note that `input` is the entire argument passed to `__call__` itself;
        - returns an updated `NetworkState`, which we assign to `network_state`.
        
        This is the only step in the model that receives the `input` that was passed to `SimpleFeedback` itself. This is because the input to the model is typically information the network needs to complete the task—say, the position of the goal it should reach to. The input to all of the other model steps is some other part of the model state.
    
    === "`self.mechanics`"

        - takes `state.mechanics` (a `MechanicsState` object) as its prior state;
        - also takes as input `network_state.output`, where `network_state` contains the updated `NetworkState` returned by `self.net`: `network_state.output` is the command we want to send to the mechanical model;
        - returns an updated `MechanicsState`, which we assign to `mechanics_state`.
        
4. After calling all the model steps and getting update substates, we use them to build a new `SimpleFeedbackState` to return. 

What if we want to interfere with the command the neural network generates, after we call `self.net` but before we call `self.mechanics`? We could write a new module with an extra component that operates on `NetworkState`, and call it at the right moment:

```py hl_lines="1 7 30 31"
class SimpleFeedbackPerturbNetworkOutput(eqx.Module):
    net: eqx.Module  
    mechanics: Mechanics 
    feedback_channel: Channel
    where_feedback: Callable[[SimpleFeedbackState], PyTree] = \
        lambda state: state.mechanics.plant.skeleton
    intervention: eqx.Module
    
    def __call__(
        self, 
        input: PyTree[Array],  
        state: SimpleFeedbackState, 
        key: PRNGKeyArray,
    ) -> SimpleFeedbackState:
                
        key1, key2 = jr.split(key)
        
        feedback_state = self.feedback_channel(
            self.where_feedback(state),
            state.feedback,
            key1,
        )
        
        network_state = self.net(
            (input, feedback_state.output), 
            state.network, 
            key2
        )
        
        # modifies `network_state.output` somehow
        network_state = self.intervention(network_state)
        
        mechanics_state = self.mechanics(
            network_state.output, 
            state.mechanics
        )        
        
        return SimpleFeedbackState(
            mechanics=mechanics_state, 
            network=network_state,
            feedback=feedback_state,
        )
```

It would be pretty inconvenient to have to do this every time we want to intervene a little. Once we have a model, it'd be nice to experiment on it quickly. And if we have a different model that is similar enough to `SimpleFeedback` that it could make sense to use the same kind of `NetworkState` intervention on it that we just used, we wouldn't want to have to manually rewrite that model, too. 

Thankfully we can do something about this. Start by noticing that each step in the `__call__` method of our original `SimpleFeedback`:

- is defined as a modification of some part of the model state—each operation we perform returns some part of `SimpleFeedbackState`;
- calls a model component in a consistent way: no matter if we're calling `self.feedback_channel`, `self.net`, or `self.mechanics`, our call always looks like `self.something(input_to_something, state_associated_with_something, key)`.

That means we can define each step in `__call__` with three pieces of information: 

1. What model component to call; e.g. `self.net`;
2. How to select the input to that model component, out of all the `input` and `state` given to `SimpleFeedback`;
3. How to select the state associated with (and modified by) that model component, out of the full `state` of `SimpleFeedback`.

!!! NOTE
    The `key` passed to each model component is no big deal. We just have to be sure to split up the `key` passed to `__call__`, so that each model component gets a different key.
    
OK, let's try to do that. We'll define an object called `ModelStage` which holds the three pieces of information required to define each model stage. Then we'll define a `model_spec` that defines all the stages of our model, in these terms.

```python
class ModelStage(eqx.Module):
    component: Callable
    where_input: Callable
    where_state: Callable


model_spec = dict({
    'update_feedback': ModelStage(
        # See explanation below for why we define this as a lambda!
        func=lambda self: self.feedback_channel,  
        where_input=lambda input, state: self.where_feedback(state),
        where_state=lambda state: state.feedback,  
    ),
    'net_step': ModelStage(
        func=lambda self: self.net,
        where_input=lambda input, state: (input, state.feedback.output),
        where_state=lambda state: state.net,                
    ),
    'mechanics_step': ModelStage(
        func=lambda self: self.mechanics,
        where_input=lambda input, state: state.net.output,
        where_state=lambda state: state.mechanics,
    ),
})       
```

!!! NOTE   
    Each of the entries in `ModelStage` is a function, which we define with `lambda`. 
    
    For `where_input` and `where_state`, this is similar to what we've seen in earlier examples. For example, given the `input` and `state` passed to `__call__`, `where_input` is a function that selects which parts will be passed to the component of the current model stage. 
    
    Why do we define `func` as `#!py lambda self: self.something` rather than just `#!py self.something`? It's to make sure that the reference to the component "stays fresh". If this doesn't make sense to you, don't worry about it at this point. Just remember that if you write your own staged models, you will need to write your `model_spec` this way.

In order to insert interventions at arbitrary points, here's what we'll do: 

1. define `model_spec` as an attribute of `SimpleFeedback`;
2. define `__call__` so that it calls each of the entries in `model_spec`, passing them their respective input and state, and using their return value to update the model state.    

    !!! Warning ""    
    
        Importantly, the way we define `__call__` will no longer allow our model stages to assign, or refer, to intermediate variables like `feedback_state`. This is why in the `model_spec` we just defined, the input to `self.net` includes `state.feedback.output`, where previously we had passed `feedback_state.output`.
        
        In our new `__call__`, we'll update `state` *immediately* after each stage, rather than assigning to intermediate variables and then finally constructing a new `SimpleFeedbackState`. Thus any changes one stage makes, can be passed to subsequent states by referring to `state` itself, like we do with `state.feedback.output`.
        
3. give `SimpleFeedback` a new attribute `intervenors`, where we can insert additional components that intervene on the model's state, *given the name of the model stage they should be applied before*. For example, if this attribute is set to `{'mechanics_step': some_intervention}` then `some_intervention` would be called *immediately before* `self.mechanics` is called.



```python
class SimpleFeedback(eqx.Module):
    net: eqx.Module  
    mechanics: Mechanics 
    feedback_channel: Channel
    where_feedback: Callable[[SimpleFeedbackState], PyTree] = \
        lambda state: state.mechanics.plant.skeleton
    intervenors: dict[str, eqx.Module]    
    
    @property
    def model_spec(self):
        return dict({
            'update_feedback': ModelStage(
                func=lambda self: self.feedback_channel,  
                where_input=lambda input, state: self.where_feedback(state),
                where_state=lambda state: state.feedback,  
            ),
            'net_step': ModelStage(
                func=lambda self: self.net,
                where_input=lambda input, state: (input, state.feedback.output),
                where_state=lambda state: state.net,                
            ),
            'mechanics_step': ModelStage(
                func=lambda self: self.mechanics,
                where_input=lambda input, state: state.net.output,
                where_state=lambda state: state.mechanics,
            ),
        })    
    
    def __call__(
        self, 
        input: PyTree[Array],  
        state: SimpleFeedbackState, 
        key: PRNGKeyArray,
    ) -> SimpleFeedbackState:
    
        # Get a different key for each stage of the model.
        keys = jr.split(key, len(self.model_spec))
        
        # Loop through the model stages, pairing them with their keys.
        for (label, stage), key_stage in zip(self.model_spec.items(), keys):
            
            # Loop through all intervenors assigned to this model stage.
            for intervenor in self.intervenors[label]:
                state = intervenor(state)
            
            # Get the updated part of the state associated with the stage
            new_component_state = stage.func(
                stage.where_input(input, state),
                stage.where_state(state),
                key_stage,
            )
            
            # Modify the full model state
            state = eqx.tree_at(
                stage.where_state,  # Part to modify
                state,  # What is modified (full state)
                new_component_state,  # New part to insert
            )
        
        return state
```

Our model is now structured so that it's possible to [insert interventions](/feedbax/examples/3_intervening/#adding-a-force-field) among its stages, without rewriting the whole thing each time!

The way we've defined `__call__` here is quite general. Actually, the real [`feedbax.bodies.SimpleFeedback`][feedbax.bodies.SimpleFeedback] doesn't define `__call__` itself, but inherits it from [`feedbax.AbstractStagedModel`][feedbax.AbstractStagedModel]. Each staged model, including `SimpleFeedback`, just has to define `model_spec` (and a couple of other smaller things).

Defining models as a sequence of named state operations has some additional advantages, beyond being able to insert interventions among the stages. For one, it makes it easy to [log the details](/feedbax/examples/debugging/#logging-details-of-model-execution) of our model stages as they are executed, which is useful for debugging.

## Pretty printing of model stages

Another advantage of staged models is that it's easy to print out a tree of operations, in the sequence they are performed by the model.

Feedbax provides the function [`pprint_model_spec`](feedbax.pprint_model_spec) for this purpose.

In [3]:
import jax
from feedbax import pprint_model_spec
from feedbax.xabdeef import point_mass_nn_simple_reaches

context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))

pprint_model_spec(context.model.step)

update_feedback: MultiModel
nn_step: SimpleStagedNetwork
  hidden: GRUCell
  readout: SimpleStagedNetwork._output
mechanics_step: Mechanics
  convert_effector_force: PointMass.update_state_given_effector_force
  statics_step: DirectForceInput
    clip_skeleton_state: DirectForceInput._clip_state
  dynamics_step: Mechanics._dynamics_step
  get_effector: PointMass.effector


Each line corresponds to a call to a model component. When the model component is also a staged model, the stages of that component are indented on the lines that follow.

For example, the `"nn_step"` stage of `SimpleFeedback` is a call to `self.net`, which in this case is a `SimpleStagedNetwork`, which is also a subclass of `AbstractStagedModel`, and in this case has a stage named `"hidden"` that calls `equinox.nn.GRUCell`, followed by a stage `"readout"` that calls the method `_output` of the `SimpleStagedNetwork` object.

## Writing a staged model

What components are needed.

Example: similar to `SimpleFeedback`, but with two neural networks with a `Channel` between them?

### Using simple functions as stages

i.e. that modify part of the state, but may not have associated state
a
Wrappers.

### Using non-staged components 

- Using existing neural networks as controllers 
    - And the downside (intervenors)