# Control

I'm using this notebook to:

1. Explore what it takes to add a new control dataset.
2. Make notes that we can use for discussion and knowledge-share.

My plan, for now, is to copy into this notebook the bare minimum lines needed to run a training/evaluation step using a ControlTask.

Every time I hit a road block, I'll make note of the issue, try my best to get around it (not necessarily in the correct way), and rinse/repeat.

If you're scanning for places where I have questions, look for formatting ***? like this ?***.

In [42]:
# If you edit code while this notebook is running,
# you might need to reload the module for the changes
# to take effect. You could also resolve the issue
# by restarting the kernel. But reloading the module
# is quicker and less destructive. I'm adding the 
# reload at the top so you can just run this cell
# after making any breaking edits to the code.
import importlib
import gato.tasks.control_task
importlib.reload(gato.tasks.control_task)

import minari
import gymnasium as gym
from gato.tasks.control_task import ControlTask

## BabyAI Minari dataset

I grabbed a random dataset from the ones created with [the bot.py script in Santiago's baby-ai-dataset repo](https://github.com/snat-s/baby-ai-dataset/blob/master/scripts/bot.py#L64).

In [7]:
dataset_name = 'BabyAI-GoToOpen-v0'
dataset = minari.load_dataset(dataset_name)

In [14]:
dataset, dataset.spec.env_spec.id

(<minari.dataset.minari_dataset.MinariDataset at 0x7fdb736327a0>,
 'BabyAI-GoToOpen-v0')

In [20]:
env = gym.make(dataset.spec.env_spec)

# train.py

## Training arguments

I'm using [Namespace](https://docs.python.org/3/library/argparse.html#argparse.Namespace) to hack up an `args` object so that I don't have to go through the hassle of `parser = ArgumentParser; parser.add_argument(...); args = parser.parse(['foo', 'bar', 'baz', ...])`.

Browsing the code, the only `args` attribute that I see `ControlTask` access is `args.patch_size`, so that's the only one I'm bothering to add.

In [43]:
from argparse import Namespace

In [44]:
args = Namespace(patch_size=4)

In [45]:
context_len = 512
training_prompg_len_proportion=0.5
share_prompt_episodes=True
top_5_prompting=None

# ControlTask

`train.py` creates a `ControlTask`.

## Supported observation spaces

The first error I encounter is that the BabyAI environment has a Dict observation space.

***? What would it take to support a Dict (or any other) observation space ?***

Looking at how this is used, I see it conditionally:

- If it's a Box that has a shape of length 2 or 3
    - Adds image transforms
- Otherwise
    - Sets `obs_str` to `'continuous_obs'`, which eventually makes its way into `input_dict`, which eventually makes its way to `predict_control`, which doesn't get checked in `predict_control` but eventually makes its way to `tokenize_input_dicts`, and _that's_ where it gets checked.
        - It tokenizes the batch with the `continuous_obs_tokenizer`. 

(Is this a safe condition? Might there be non-image con)

### tokens_per_space

This is another place where the type of the observation space is checked.

Box: `space.shape[0]`
Discrete: `1`
Dict: ?

***? What should this? How is it used? ?***

`tokens_per_space` gets assigned to `action_tokens` [here](https://github.com/eihli/NEKO/blob/b66b48b88117307a442c43a7f4d8701706670144/gato/tasks/control_task.py#L74) and then used to calculate `tokens_per_timestamp` and is eventually used to create/manipulate the shape of the `input_dict` in [ControlTask.evaluate](https://github.com/eihli/NEKO/blob/b66b48b88117307a442c43a7f4d8701706670144/gato/tasks/control_task.py#L138).

### Hacking my way past errors.

I hard coded some arbitrary values to get past the assertions.

For example, I added an `isinstance(space, Dict): return 1` to `tokens_per_space` [here](https://github.com/eihli/NEKO/blob/b66b48b88117307a442c43a7f4d8701706670144/gato/tasks/control_task.py#L23).

I'm sure it's going to blow up due to a dict size mismatch or something. I just want to get to that point so I can see that error and maybe understand what the value _should_ be.

In [46]:
env.observation_space

Dict('direction': Discrete(4), 'image': Box(0, 255, (7, 7, 3), uint8), 'mission': MissionSpace(<function BabyAIMissionSpace._gen_mission at 0x7fdb727b09d0>, None))

In [47]:
control = ControlTask('playground_env_name', env, dataset, context_len, args)

# Back to train.py now that we have a ControlTask

Now we're going to need a lot more args.

In [52]:
from gato.policy.gato_policy import GatoPolicy

In [76]:
args.device = 'cpu'
args.embed_dim = 768
args.layers = 8
args.heads = 24
args.dropout = 0.1
args.mu = 100
args.M = 256
args.resid_mid_channels = 128
args.continuous_tokens = 1024
args.discrete_tokens = 1024
args.sequence_length = 1024
args.disable_patch_pos_encoding = False
args.disable_inner_pos_encoding = False
args.activation_fn = 'gelu'
args.pretrained_lm = None
args.flash = False
args.tokenizer_model_name = 'gpt2'
args.pad_seq = False

In [78]:
model = GatoPolicy(
    device=args.device,
    embed_dim=args.embed_dim,
    layers=args.layers,
    heads=args.heads,
    dropout=args.dropout,
    mu=args.mu,
    M=args.M,
    patch_size=args.patch_size,
    resid_mid_channels=args.resid_mid_channels,
    continuous_tokens=args.continuous_tokens,
    discrete_tokens=args.discrete_tokens,
    context_len=args.sequence_length,
    use_patch_pos_encoding=not args.disable_patch_pos_encoding,
    use_pos_encoding=not args.disable_inner_pos_encoding,
    activation_fn=args.activation_fn,
    pretrained_lm=args.pretrained_lm,
    flash=args.flash,
    tokenizer_model_name=args.tokenizer_model_name,
    pad_seq=args.pad_seq,
)

In [80]:
args.embed_dim = model.embed_dim

# Trainer

***? What if we wanted to skip the trainer? Could I just run `evaluate` on the ControlTask myself ?***

`trainer.train` runs `train_iteration`.

`train_iteration` runs `model.train()` (where `model` is an `nn.Module` ([docs](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)) and we don't extend its `train` method).

`GatoPolicy.forward` gets something called a "final_representation", which is the result of passing some token embeddings and a mask to the GPT2Model.

I'm getting a bit lost at this point. I'd love to have someone explain what's going on around this part of the code.

`forward` calls `tokenize_input_dicts(inputs)` [here](https://github.com/eihli/NEKO/blob/b66b48b88117307a442c43a7f4d8701706670144/gato/policy/gato_policy.py#L156). If your tracking `inputs`, which is probably an important variable to track`, then this line is probably important.

Continuing anyways...

We eventually call `predict_tokens` and return the `logits` (and conditionally the `loss`). [predict_token](https://github.com/eihli/NEKO/blob/b9facb61e7d48bf5f9fef9f4ec73b85b531e4aaf/gato/policy/gato_policy.py#L123) is an `nn.Linear`.

In [82]:
from gato.training.trainer import Trainer

In [None]:
trainer = Trainer(
    model,
    optimizer,
    scheduler,
    accelerator,
    tasks,
    exp_name,
    args
)

# Pause...

I just realized it would probably benefit me to explore what an existing _working_ dataset looks like. I'm going to take a break from this notebook to go do that. Maybe I'll do it below. Maybe I'll do it in a new notebook.