# From automatic to manual

In the first example we successfully trained a model. However, almost everything was pre-built and we didn't get to see how the model was put together. Let's rebuild the same model, but more explicitly.

First, we'll define the model parameters we might want to vary in this example. 

Of course, we'll also import JAX.

In [1]:
import jax 


mass = 1.0  # Mass of the point mass
n_steps = 100  # Number of time steps per trial
dt = 0.05  # Duration of a time step
feedback_delay_steps = 0  # Number of steps to delay the feedback
feedback_noise_std = 0.1  # Standard deviation of Gaussian noise added to sensory feedback
workspace = ((-1., -1.),  # Workspace bounds ((x_min, y_min), (x_max, y_max)
             (1., 1.))
hidden_size  = 50  # Number of units in the hidden layer of the controller

Now, define our random keys. We'll use `jax.random.split` to get several keys we can use for different purposes.

In [2]:
seed = 0 
key = jax.random.PRNGKey(seed)  # This is the parent key

# Split into three different keys for initialisation, training, and evaluation
key_init, key_train, key_eval = jax.random.split(key, 3)  

2024-02-06 18:50:16.778053: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_UNKNOWN: unknown error
CUDA backend failed to initialize: INTERNAL: no supported devices found for platform CUDA (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Manual pairing of a pre-built model with a task

Last time we saw that `feedbax.xabdeef` provides some pre-built pairings of tasks and models, which a single line of code can instantiate all at once, along with a `TaskTrainer`.

However, we can also manually choose a task to pair with a pre-built model, instead of pairing them automatically in the background. In this case, we'll use `point_mass_nn` to get the model, but define `SimpleReaches` explicitly.

In [1]:
from feedbax.task import SimpleReaches

from feedbax.xabdeef.losses import simple_reach_loss
from feedbax.xabdeef.models import point_mass_nn


task = SimpleReaches(
    loss_func=simple_reach_loss(),
    workspace=workspace, 
    n_steps=n_steps,
)

model = point_mass_nn(
    task,
    dt=dt,
    mass=mass,
    hidden_size=hidden_size, 
    n_steps=n_steps,
    feedback_delay_steps=feedback_delay_steps,
    feedback_noise_std=feedback_noise_std,
    key=key_init,
)

- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.
  _loop_backsolve.defvjp(_loop_backsolve_fwd, _loop_backsolve_bwd)


NameError: name 'workspace' is not defined

Note that:

- the task is given a *loss function*, `loss_func`, which determines how a model's performance on the task is scored. 
- `task` is passed to `point_mass_nn`, since the input shape of the neural network depends on the input data provided by the task for each trial.


We could use a `TaskTrainer` to train this `model` to perform this `task`, but let's go one level deeper first and define the model without using any functions from `feedbax.xabdeef`.

## Building the model ourselves using core Feedbax

Now we'll rebuild the exact same model, without using the pre-build function `point_mass_nn` that we imported from `feedbax.xabdeef`. We'll write our own function that lets us build copies of a model with a certain structure.

In [None]:
from feedbax.bodies import SimpleFeedback
from feedbax.iterate import SimpleIterator
from feedbax.mechanics.mechanics import Mechanics
from feedbax.mechanics.plant import SimplePlant
from feedbax.mechanics.skeleton.pointmass import PointMass
from feedbax.networks import SimpleNetwork


def get_model(
    task,
    dt: float = 0.05, 
    mass: float = 1., 
    hidden_size: int = 50, 
    n_steps: int = 100, 
    feedback_delay: int = 0,
    feedback_noise_std: float = 0.0,
    out_nonlinearity=lambda x: x,
    *,  # This asterisk forces us to pass `key` as a keyword argument, without needing to set a default value
    key,
):   
    key1, key2 = jax.random.split(key)

    plant = SimplePlant(PointMass(mass=mass))
    mechanics = Mechanics(plant, dt)
    
    feedback_spec = dict(
        where=lambda state: (
            state.effector.pos,
            state.effector.vel,
        ),
        delay=feedback_delay,
        noise_std=feedback_noise_std,
    )
    
    # Determine the network input size automatically
    input_size = SimpleFeedback.get_nn_input_size(
        task, mechanics, feedback_spec
    )
    
    net = SimpleNetwork(
        input_size,
        hidden_size,
        out_size=plant.input_size,
        noise_std=0.0,
        out_nonlinearity=out_nonlinearity, 
        key=key1
    )
    
    body = SimpleFeedback(net, mechanics, feedback_spec=feedback_spec, key=key2)
    
    return SimpleIterator(body, n_steps)

Actually, if you go look in `feedbax.xabdeef.models`, you'll see that our function here is more or less the same as `point_mass_nn`. Looking inside, we can explicitly see how the model is constructed.

First we construct a "plant", which contains all of the mathematics necessary to simulate our point mass. In this case we use a `SimplePlant`, which is a plant that has a "skeleton"---here, just the point mass---which is controlled directly, and not by forces generated by "muscles".

A `Mechanics` object is used to associate the plant dynamics with a differential equation solver, with time steps discretized to the duration we specify.

...     

And since all of this defines just a single loop through our simulated body, we wrap it all in a `SimpleIterator`

TODO


Using our function, we construct our model just like we did with `point_mass_nn`:

In [9]:
model = get_model(
    task,
    dt=dt,
    hidden_size=hidden_size,
    n_steps=n_steps,
    feedback_delay=feedback_delay_steps,
    feedback_noise_std=feedback_noise_std,
    key=key_init, 
)

SimpleIterator(
  step=SimpleFeedback(
    net=SimpleNetwork(
      out_size=2,
      hidden=GRUCell(
        weight_ih=f32[150,8],
        weight_hh=f32[150,50],
        bias=f32[150],
        bias_n=f32[50],
        input_size=8,
        hidden_size=50,
        use_bias=True
      ),
      hidden_size=50,
      noise_std=None,
      hidden_nonlinearity=None,
      encoder=None,
      encoding_size=None,
      readout=Linear(
        weight=f32[2,50],
        bias=f32[2],
        in_features=50,
        out_features=2,
        use_bias=True
      ),
      out_nonlinearity=<function <lambda>>,
      intervenors={'hidden': [], 'readout': []}
    ),
    mechanics=Mechanics(
      plant=SimplePlant(
        skeleton=PointMass(mass=1.0),
        clip_states=True,
        intervenors={'clip_skeleton_state': []},
        muscle_model=None
      ),
      dt=0.05,
      solver=Euler(),
      intervenors={
        'convert_effector_force':
        [],
        'statics_step':
        [],
       

Now's a good time to check out the structure of our model. Since all our models are Equinox modules, we can simply print them and the result is a nice tree structure.

In [10]:
model

SimpleIterator(
  step=SimpleFeedback(
    net=SimpleNetwork(
      out_size=2,
      hidden=GRUCell(
        weight_ih=f32[150,8],
        weight_hh=f32[150,50],
        bias=f32[150],
        bias_n=f32[50],
        input_size=8,
        hidden_size=50,
        use_bias=True
      ),
      hidden_size=50,
      noise_std=None,
      hidden_nonlinearity=None,
      encoder=None,
      encoding_size=None,
      readout=Linear(
        weight=f32[2,50],
        bias=f32[2],
        in_features=50,
        out_features=2,
        use_bias=True
      ),
      out_nonlinearity=<function <lambda>>,
      intervenors={'hidden': [], 'readout': []}
    ),
    mechanics=Mechanics(
      plant=SimplePlant(
        skeleton=PointMass(mass=1.0),
        clip_states=True,
        intervenors={'clip_skeleton_state': []},
        muscle_model=None
      ),
      dt=0.05,
      solver=Euler(),
      intervenors={
        'convert_effector_force':
        [],
        'statics_step':
        [],
       

This tree reflects how we defined our model. We can see that:

- The topmost level is a `SimpleIterator`, which is responsible for looping over a model step for `n_steps`.
- The model step is an instance of `SimpleFeedback`, which is composed of a `SimpleNetwork` which sends commands to, and receives feedback from, a `Mechanics` instance
- `SimpleNetwork` consists of a `GRUCell` containing arrays of weights, as well as a `Linear` readout with its own weights, but that `encoder=None` since by default it is constructed without a separate layer for encoding its inputs.
- `Mechanics` is using the default `Euler` solver, and usually this is sufficient, but in some cases it might be more accurate to use one of the higher-order solvers provided by `diffrax`.
- That each inner level of the model has a dictionary `intervenors` which contains a bunch of empty lists. Interventions will be addressed in a later example. 

Note that the tree does not show us *how* each part of the model performs its computations, or how it *calls* other parts of the model. It only shows which parts of the model *contain* other parts, or parameters. 

## Training the model

Now we construct a `TaskTrainer`. We can explicitly tell the trainer to use the Adam optimizer, or any other optimizer such as provided by `optax`.

In [7]:
import optax

from feedbax.trainer import TaskTrainer


trainer = TaskTrainer(
    optimizer=optax.adam(learning_rate=1e-2)
)

Now we can train our model to perform the task. 

We'll be explicit about which part of the model we want to be optimized: in this case, all the parameters inside of the `GRUCell` layer, and all of the parameters inside of the `Linear` readout layer. We use `lambda` to define a function, which is just a way of picking out certain parts of a `model` that is passed to that function.

In [None]:
where_train = lambda model: (
    model.step.net.hidden,
    model.step.net.readout,
)

model, train_history = trainer(
    task=task, 
    model=model,
    n_batches=2000, 
    batch_size=250, 
    log_step=200,
    where_train=where_train,
    key=key_train,
)

Note that since the `GRUCell` and `Linear` layers actually contain all of the JAX arrays within our `SimpleNetwork` that could even be trained in this case, we could have equivalently written

In [None]:
where_train = lambda model: model.step.net

and nothing would have changed. Since Feedbax is based on JAX/Equinox, we inherit a lot of flexibility in referring to subsets of models.

Use of `lambda` is common in Feedbax, so it's important to recall that because a lambda is merely a convenient way of defining a function in-line, we could also have written

In [None]:
def where_train(model):
    return model.step.net

before passing `where_train` to `trainer`, just as we did before.

## Another example

Show the code for torque control of the two-link arm. 