# Advanced examples

- Advanced PyTree stuff, like in my robustness notebooks
- Training multiple types of models with a single call... i.e. not an ensemble, but different pairings
- Tree methods provided by Feedbax, e.g. `tree_map_unzip`

Would be good to refer to this notebook as a source of uncertainty in my request for feedback. For example, are the tree methods I've written the best way to do the things I'm using them for?

In [None]:
task = SimpleReaches(
    loss_func=simple_reach_loss(n_steps),
    workspace=workspace, 
    n_steps=n_steps,
    eval_grid_n=2,
    eval_n_directions=8,
    eval_reach_length=0.5,    
)

In [None]:
from feedbax.tree import tree_unzip


models, tasks, trainers, labels = tree_unzip(
    {
        "control": (model, task, trainer),
        "control_hebb": (model, task, trainer_hebb),
        "train_curl": (model_train_curl, task_train_curl, trainer),
        "train_curl_hebb": (model_train_curl, task_train_curl, trainer_hebb),
    }
)

In [None]:
from feedbax.tree import tree_map_unzip


models, train_history = tree_map_unzip(
    lambda model, task, trainer: trainer(
        task=task, 
        model=model,
        n_batches=n_batches, 
        batch_size=batch_size, 
        log_step=n_batches // 4,
        where_train=where_train,
        key=key_train,
    ),
    models, tasks, trainers,
    is_leaf=lambda x: isinstance(x, eqx.Module),
)

- Scheduling interventions with multiple models simultaneously (we could also do multiple tasks)

In [None]:
task_test_curl, models_test_curl = schedule_intervenor(
    task_test,
    models, 
    CurlField.with_params(
        amplitude=lambda trial_spec, *, key: \
            test_curl_abs * jr.choice(key, jnp.array([-1, 1])),
        active=True,
    ),
    where=lambda model: model.step.mechanics,
    default_active=True,
)

Evaluating a bunch of 

In [None]:
key_eval, _ = jr.split(key_train)

losses_test, states = tree_map_unzip(
    lambda model, task: task.eval_with_losses(model, key=key_eval),
    models_test,
    tasks_test, 
    is_leaf=lambda x: isinstance(x, eqx.Module),
)

In [None]:
def eval_plot(states, task, cmap='viridis'):
    trial_specs, _ = task.validation_trials

    return plot_pos_vel_force_2D(
        states,
        step=task_test.eval_n_directions // 8,
        endpoints=(
            trial_specs.init['mechanics.effector'].pos, 
            trial_specs.goal.pos
        ),
        cmap=cmap,
    )

figs, axs = tree_map_unzip(
    eval_plot, 
    states, 
    tasks_test, 
    is_leaf=lambda x: isinstance(x, eqx.Module)
)

for label, fig in zip(labels_test, figs.values()):
    fig.suptitle(label)