# 4. Training sampler

JAXMAPP provides some useful features for training sampler models from MAPP demonstrations.

## Overview

- `scripts/create_training_data.py` creates a collection of MAPP problem instances and their solutions using random sampler.
- `scripts/create_tfrecords.py` converts the data collection created above into the `tfrecord` format.
- `scripts/train.py` trains a sampler (for examle `CTRMNet` used in `CTRMSampler`).

In order to fully use the functions provided above, you need to be able to use hydra and jax to some extent. 
However, if you know how to handle tfrecord, learning itself can be done with other deep learning frameworks such as pyTorch.


## Dataset generation

```console
$ python scripts/create_training_data.py
```

By running `create_training_data.py` you can create a collection of MAPP problems with their solutions based on the config provided by `scripts/config/create_trainining_data.yaml`.
By default, MAPP problem instances with `num_agents == 30` and variable `max_speeds` and `rads` are randomly generated and solved via `RandomSampler`.
The complete definition of instances is described in `scripts/config/dataset/instance/hetero_fixednumagents.yaml`, which you can change to provide different problem collections for training.

Note that, since it takes a considerable amount of time to solve each problem, this script needs to be on a workstation with a reasonable CPU to full make use of `joblib.Parallel` feature.

## Conversion to tfrecord

```console
$ python scripts/create_tfrecord.py
```

With this script, you can pick out some data for training your sampler from created problem instances and convert them into tfrecord format.

Currently, the following data are stored:

- `current_pos`: current positions of all agents
- `previous_pos`: previous positions of all agents
- `next_pos`: next positions of all agents
- `goals`: goal positions of all agents
- `max_speeds`: agent's max speed
- `rads`: agent's sizes
- `occupancy`: occupancy map of the environment
- `cost_map`: cost-to-go maps for each agent obtained using the dijkstra algorithm


By modifying the script above you can add some more data necessary for your training, such as `instance.starts` and `instance.obs.sdf`.

## Training

```console
$ python scripts/train.py
```

This script trains a sampler model specified in `scripts/config/train.yaml` and `scripts/config/model/**.yaml`. By default the script will train `CTRMNet` used in `CTRMSampler`.
If your model is written in `JAX` and `flax` you'll be able to reuse this training script almost as-is.

Nevertheless, as long as your sampler can be trained in a supervised fashion, training loop should be basically simple like:

```python
for e in range(config.num_epochs):
        for split in ["train", "val"]:
            for batch in tqdm(datasets[split].as_numpy_iterator()):
                ...
                # single training step using batch
```