# Saving and loading

We often want to save the parameters of a trained model, so that we can use it again later without needing to re-train. 

All Feedbax components—including, automatically and for free, any that you might write—are [PyTrees](feedbax/examples/pytrees): they are represented as tree-structured data. Equinox is able to [save](https://docs.kidger.site/equinox/examples/serialisation/) this data to a file. 

Feedbax provides some functions to make this slightly easier. However, you can also [learn](https://docs.kidger.site/equinox/api/serialisation/) to use the Equinox functions `tree_serialise_leaves` and `tree_deserialize_leaves`, if you prefer a different scheme for saving and loading.

Here's an example of how to use the functions provided by Feedbax. 

We'll start by writing a function that sets up the components we're going to want to save.

In [3]:
import jax

from feedbax.task import SimpleReaches

from feedbax.xabdeef.losses import simple_reach_loss
from feedbax.xabdeef.models import point_mass_nn

# The leading asterisk forces all the arguments to be passed as keyword arguments
def setup(*, workspace, n_steps, dt, hidden_size, key):

    task = SimpleReaches(
        loss_func=simple_reach_loss(), 
        workspace=workspace, 
        n_steps=n_steps
    )

    model = point_mass_nn(task, dt=dt, hidden_size=hidden_size, key=key)
    
    return task, model 

Use this function to get a task and model. We'll also want to save the parameters we use to set up the task and model, so let's store them together in a dictionary.

In [5]:
hyperparameters = dict(
    workspace=((-1., -1.),  # Workspace bounds ((x_min, y_min), (x_max, y_max)
               (1., 1.)),
    n_steps=100,  # Number of time steps per trial
    dt=0.05,  # Duration of a time step
    hidden_size=50,  # Number of units in the hidden layer of the controller
)

key_init, key_train = jax.random.split(jax.random.PRNGKey(0))

task, model = setup(**hyperparameters, key=key_init)

Now train the model to perform the task. We just do a short training run, since in this case we're interested in saving the model rather than whether it has converged on a solution or not.

In [6]:
import optax

from feedbax.trainer import TaskTrainer


trainer = TaskTrainer(
    optimizer=optax.adam(learning_rate=1e-2)
)

model, _ = trainer(
    task=task, 
    model=model, 
    n_batches=1000, 
    batch_size=250, 
    where_train=lambda model: model.step.net,
    key=key_train,
)

compile:   0%|          | 0/1 [00:00<?, ?it/s]

Training step compiled.
Validation step compiled.


train batch:   0%|          | 0/1000 [00:00<?, ?it/s]


Training iteration: 0
	training loss: 3.94e+01
	validation loss: 5.14e+00

Training iteration: 100
	training loss: 2.86e-02
	validation loss: 5.32e-03

Training iteration: 200
	training loss: 8.28e-03
	validation loss: 1.58e-03

Training iteration: 300
	training loss: 6.48e-03
	validation loss: 1.25e-03

Training iteration: 400
	training loss: 5.37e-03
	validation loss: 1.06e-03

Training iteration: 500
	training loss: 4.46e-03
	validation loss: 9.42e-04

Training iteration: 600
	training loss: 3.98e-03
	validation loss: 8.68e-04

Training iteration: 700
	training loss: 3.58e-03
	validation loss: 8.12e-04

Training iteration: 800
	training loss: 3.34e-03
	validation loss: 7.66e-04

Training iteration: 900
	training loss: 2.94e-03
	validation loss: 7.34e-04


In [None]:
from feedbax import save, load

model_path = save(
    (task, model),
    hyperparameters=hyperparameters, 
)

!!! Note    
   Python includes the module [`pickle`](https://docs.python.org/3/library/pickle.html), which can save and load entire Python objects without needing to specify, at the time of loading, how those objects were created. This seems convenient, but it is not good practice in general:
   
   - Upon loading, Python will automatically execute code found in a pickle file, in order to reconstruct the pickled objects. This is a security issue. If someone shares a pickled model with you, they (or an interloper) could insert harmful code into the pickle file, and you may not know it's there until you run it.
   - You probably still have to keep track of how the objects in the pickle were created, for your research to be reproducible in detail. It is better to do this explicitly.
   - Some of the components we use, such as `lambda` expressions, are not compatible with `pickle`.
   
   See the Equinox [documentation](https://docs.kidger.site/equinox/examples/serialisation/) for a similar discussion of these limitations.

In [None]:
model, task = load(model_path, setup)

Similar to Equinox.

Need a function that can construct your model.