# Training and evaluating multiple models in parallel

Often we need to train multiple copies of the same model in parallel. This is easy in Feedbax, thanks to automatic vectorization with `jax.vmap`.

!!! NOTE     
    TODO.
    Training model replicates is handled automatically if you're using a pre-built model-task pairing. If you want to train 3 replicates, pass `n_replicates=3` 

Let's see an example of training multiple replicates of the same model, that differ only in terms of their random initializations. We'll start with a function that builds our model, like the one we defined in an earlier [example](/feedbax/examples/1_train/#building-the-model-ourselves-using-core-feedbax)

TODO: just use `point_mass_nn`, and mention that we could also write our own `get_model` as in 1_train

In [27]:
import equinox as eqx
import jax

from feedbax.bodies import SimpleFeedback
from feedbax.iterate import Iterator
from feedbax.mechanics.mechanics import Mechanics
from feedbax.mechanics.plant import DirectForceInput
from feedbax.mechanics.skeleton.pointmass import PointMass
from feedbax.networks import SimpleStagedNetwork


def get_model(
    task,
    dt: float = 0.05, 
    mass: float = 1., 
    hidden_size: int = 50, 
    n_steps: int = 100, 
    *,  
    key=None,
):   
    key1, key2 = jax.random.split(key)

    plant = DirectForceInput(PointMass(mass=mass))
    mechanics = Mechanics(plant, dt)
    
    feedback_spec = dict(
        where=lambda state: (
            state.effector.pos,
            state.effector.vel,
        ),
        delay=0,
        noise_std=0.0,
    )
    
    # Determine the network input size automatically
    input_size = SimpleFeedback.get_nn_input_size(
        task, mechanics, feedback_spec
    )
    
    net = SimpleStagedNetwork(
        input_size,
        hidden_size,
        hidden_type=eqx.nn.GRUCell,
        out_size=plant.input_size,
        key=key1
    )
    
    body = SimpleFeedback(net, mechanics, feedback_spec=feedback_spec, key=key2)
    
    return Iterator(body, n_steps)

We'll also need a task to train our models on.

In [28]:
from feedbax.task import SimpleReaches

from feedbax.xabdeef.losses import simple_reach_loss


task = SimpleReaches(
    loss_func=simple_reach_loss(),
    workspace=((-1., -1.),  # ((x_min, y_min), (x_max, y_max))
               (1., 1.)), 
    n_steps=100,
)

We could build a single model, like we've done before.

In [29]:
key_init, key_train, key_eval = jax.random.split(jax.random.PRNGKey(0), 3)

model = get_model(task, key=key_init)

To build $N$ replicates of the model, we need to:

1) Get a different random key for each of the model replicates, so they'll be initialized differently.
2) Partially evaluate `get_model` so that it's only a function of the random key. In other words, we'll pass all the arguments that will be the same across all model replicates, leaving only the random key left to be passed.  
3) Use `jax.vmap` obtain a vectorized version of the partially evaluated `get_model`, then pass it all the keys at once, to obtain all the model replicates at once.

Feedbax provides the function `get_ensemble` to perform these steps. We could just import it---`from feedbax import get_ensemble`---but let's see what the source looks like, and rename a couple of variables to match the current situation.

In [30]:
from collections.abc import Callable
from functools import partial
import inspect

def get_ensemble(
    get_model: Callable, 
    n_replicates: int, 
    *args, 
    key: jax.Array,
    **kwargs,
):
    """Helper to vmap model generation over a set of random keys."""
    keys = jax.random.split(key, n_replicates)
    get_model_ = partial(get_model, *args, **kwargs)
    print(inspect.signature(get_model_).parameters)
    return eqx.filter_vmap(get_model_)(keys)

!!! NOTE    
    If you run into problems with `jax.vmap`, try using Equinox's `filter_vmap` as we've done above. It does the same thing, but a little more intelligently.
    
Now, let's get 5 model replicates.

In [36]:
n_replicates = 5

# task,
# dt: float = 0.05, 
# mass: float = 1., 
# hidden_size: int = 50, 
# n_steps: int = 100, 
# *,  
# key,

from feedbax.model import get_ensemble

models = get_ensemble(
    get_model, 
    n_replicates, 
    task, 
    0.05,
    1.,
    50,
    0.1,
    key=key_init,
)

TypeError: get_model() takes from 1 to 5 positional arguments but 6 were given

The replicated models are all stored in a single object. Consider the single model we constructed earlier:

In [37]:
model

Iterator(
  step=SimpleFeedback(
    net=SimpleStagedNetwork(
      out_size=2,
      hidden=GRUCell(
        weight_ih=f32[150,8],
        weight_hh=f32[150,50],
        bias=f32[150],
        bias_n=f32[50],
        input_size=8,
        hidden_size=50,
        use_bias=True
      ),
      hidden_size=50,
      hidden_noise_std=None,
      hidden_nonlinearity=None,
      encoder=None,
      encoding_size=None,
      readout=Linear(
        weight=f32[2,50],
        bias=f32[2],
        in_features=50,
        out_features=2,
        use_bias=True
      ),
      out_nonlinearity=<function <lambda>>,
      intervenors={'hidden': [], 'readout': []}
    ),
    mechanics=Mechanics(
      plant=DirectForceInput(
        skeleton=PointMass(mass=1.0),
        clip_states=True,
        intervenors={'clip_skeleton_state': []},
        muscle_model=None
      ),
      dt=0.05,
      solver=Euler(),
      intervenors={
        'convert_effector_force':
        [],
        'statics_step':
       

When we use `vmap` to construct multiple models at once, the result is essentially the same: a single object. However, each parameter in the [PyTree](/feedbax/examples/2_pytrees/) of the model now has a new leading dimension of size 5. 

This is very nice because it means 1) similar parameters are all stored together, and 2) overhead is kept to a minimum because we don't (say) construct a list with 5 model objects in it, each with their own copies of an identical set of methods.

In [None]:
models

To train all these models at once, we can pass them to a `TaskTrainer` instance along with the argument `ensembled=True`.

In [None]:
import optax

from feedbax.trainer import TaskTrainer


trainer = TaskTrainer(
    optimizer=optax.inject_hyperparams(optax.adam)(
        learning_rate=1e-2
    ),
    checkpointing=True,
)

# not training readout!
where_train = lambda model: (
    model.step.net.hidden.weight_hh, 
    model.step.net.hidden.weight_ih, 
    model.step.net.hidden.bias
)

models, train_history = trainer(
    task=task, 
    model=models,
    ensembled=True,
    n_batches=1000, 
    batch_size=250, 
    log_step=100,
    where_train=where_train,
    key=key_train,
)

And plot the mean and standard deviation of the loss term histories, across the replicates:

In [None]:
from feedbax.plot import plot_mean_losses

plot_mean_losses(train_history)

Evaluating an ensemble

In [None]:
states = task.eval_ensemble(models, n_replicates)

Indexing a single model out of the ensemble using `tree_get_idx`

Example of how to make multiple plots at once