## Background

### Masking Inputs

One of the most popular features of [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is setting the following configuration value:

```{.yaml filename="config.yml"}
train_on_inputs: false
```

If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset) such as `alpaca` or `chatml`, axolotl knows what is an input (i.e. human) vs. an output (i.e. the assistant) and masks the input labels so that your model can focus on the predicting the outputs only.

### You may not want prompt templates

However, there are many situations where you don't want to use one of these formats or templates (I usually don't!).  This is because they can:

- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can quickly become footguns if you don't include them correctly at inference time.
- Enforce a _chat_ interface when you do not want one.  Sometimes you just want to fine tune a model to a very specific task and do NOT want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.

### The `input_output` format

You can construct your prompts without a template by using the `input_output` format, by setting `type: input_output` in your configuration file like this:

```{.yaml filename="config.yml"}
train_on_inputs: false # Mask segments of your data
datasets:
  - path: output.jsonl
    type: input_output  # use template free prompt construction
```

Unlike `type: completion`, which is also template-free, `type: input_output` allows you to mask segments of your text. More details on how this works is described below.

## Usage

This is how you can use the `input_output` format:

### 1. Prepare Data

In [None]:
#| code-overflow: wrap
#| include: false
import json

data = {"segments": [
            {"label": True, "text": "<s>Hello\n"},
            {"label": True, "text": "hi there!. "},
            {"label": False, "text": "goodbye "},
            {"label": True, "text": "farewell</s>"}]
}

with open("output.jsonl", "w") as file:
    for _ in range(500):
        file.write(json.dumps(data) + "\n")

To use the `input_output` format, collect your data in the following format into a jsonl file (below is the first row from the file `output.jsonl` pretty-printed):

In [None]:
! head -n1 output.jsonl | python -m json.tool

{
    "segments": [
        {
            "label": true,
            "text": "<s>Hello\n"
        },
        {
            "label": true,
            "text": "hi there!. "
        },
        {
            "label": false,
            "text": "goodbye "
        },
        {
            "label": true,
            "text": "farewell</s>"
        }
    ]
}


Set `label:false` when you want to mask a segment of text so that the model isn't trained on it.  Some things to keep in mind:

1. **EOS, BOS, spaces, newlines etc. are entirely up to you.  Axolotl concatenates all the segments as-is.**  The tokenizer doesn't add anything additional.  Notice how I added spaces, newlines, `<s>` (BOS), and `</s>` (EOS) myself.
2. Make sure you check the materialized output to validate that the prompt is getting assembled how you like.

### 3. Use `type: input_output`

Let's materialize data with our `output.jsonl` file by setting `type: input_output` in our axolotl config:

In [None]:
%%writefile training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49

datasets:
  - path: output.jsonl
    type: input_output 
val_set_size: 0.1

sequence_len: 896
sample_packing: false

micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002

train_on_inputs: false
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Overwriting training_config.yaml


You can use the following command to materialize your data.  The `--debug` flag will print the tokens, along with the labels so you can verify that the correct items are being ignored:

In [None]:
#| filter_stream: WARNING DEBUG get_accelerator normalize_config load_tokenizer load_tokenized_prepared_datasets log load_datasets do_cli check_example_labels
! python -m axolotl.cli.preprocess training_config.yaml --debug

[2024-03-05 23:36:41,948] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
                                 dP            dP   dP 
                                 88            88   88 
      .d8888b. dP.  .dP .d8888b. 88 .d8888b. d8888P 88 
      88'  `88  `8bd8'  88'  `88 88 88'  `88   88   88 
      88.  .88  .d88b.  88.  .88 88 88.  .88   88   88 
      `88888P8 dP'  `dP `88888P' dP `88888P'   dP   dP 
                                                       
                                                       

[2024-03-05 23:36:43,224] [INFO] [axolotl.normalize_config:178] [PID:607731] [RANK:0] GPU memory usage baseline: 0.000GB (+0.498GB misc)[39m
[2024-03-05 23:36:43,725] [DEBUG] [axolotl.load_tokenizer:245] [PID:607731] [RANK:0] EOS: 2 / </s>[39m
[2024-03-05 23:36:43,725] [DEBUG] [axolotl.load_tokenizer:246] [PID:607731] [RANK:0] BOS: 1 / <s>[39m
[2024-03-05 23:36:43,725] [DEBUG] [axolotl.load_tokenizer:247] [PID:607731] [RANK:

If you look closely, axolotl prints this to help you debug prompt construction (b/c we used the `--debug` flag):

```{.md .code-overflow-wrap}
<s>(1, 1) Hello(22557, 22557) (13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```

The format is `decoded_token`(`label`, `token_id`), for example, `<s>(1, 1)` means that the token is `<s>`, the label is `1` and the token_id is `1`.  When the label is `-100` then that token is ignored for training. 


Here is another way to check the materialized output (that I personally like):

In [None]:
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml

directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
    cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')

In [None]:
row = ds[0]
print(tok.decode(row['input_ids']))

<s> Hello
 hi there!.  goodbye  farewell</s>


We can check that the right tokens are ingored by comparing the labels to each token:

In [None]:
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in 
              zip(row['input_ids'], row['labels'])])

Unnamed: 0,token,label,id
0,<s>,1,1
1,Hello,22557,22557
2,\n,13,13
3,hi,12014,12014
4,there,736,736
5,!,28808,28808
6,.,28723,28723
7,,28705,28705
8,good,-100,1179
9,bye,-100,17664


If we look at the input data, the above table seems correct! (The jsonl version is repeated below for reference):

In [None]:
! head -n1 output.jsonl | python -m json.tool

{
    "segments": [
        {
            "label": true,
            "text": "<s>Hello\n"
        },
        {
            "label": true,
            "text": "hi there!. "
        },
        {
            "label": false,
            "text": "goodbye "
        },
        {
            "label": true,
            "text": "farewell</s>"
        }
    ]
}


## Resources

1. [Pull request](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1346) that added this feature.
2. Axolotl [debugging guide](https://hamel.dev/blog/posts/axolotl/).
3. Axolotl prompt construction [notes](https://github.com/hamelsmu/hamel-site/issues/11).