### Understand SFT

I see in [speedrun.sh](https://github.com/karpathy/nanochat/blob/master/speedrun.sh) the next thing after midtraining is to do supervised finetuning via [chat_sft.py](https://github.com/karpathy/nanochat/blob/master/scripts/chat_sft.py). See what that's about.

The train dataset looks similar to from midtraining. Compare.

Midtraining:

```
train_dataset = TaskMixture([
    SmolTalk(split="train"), # 460K rows of general conversations
    MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
    GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
    CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
    CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
    SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
    SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
])
```

SFT:

```
train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"), # 2.3K rows
    ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
    GSM8K(subset="main", split="train"), # 8K rows
    SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
    CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
    SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
    SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
])
```

```
                            mid?            sft?
---------------------------------------------------
SmolTalk                    yes             no
MMLU                        yes             no
ARC-Easy                    no              yes
ARC-Challenge               no              yes
GSM8K                       yes             yes
identity                    yes             yes
SimpleSpelling              yes             yes
SpellingBee                 yes             yes

```

Does anything jump out about SmolTalk, MMLU, and ARC about why they would be included in one and not the other?

Look at challenge-26-understand-midtrain/midtrain-data-examples.ipynb. I don't have ARC in there. Make a new notebook in this challenge with that: `sft-data-examples.ipynb`

Not sure. From looking at very few examples, the SmolTalk and MMLU conversations have much more text than ARC. ARC is all multiple choice but so is MMLU.

Keep looking through script. Maybe it will become clear.

I see now we're finally going to get into using those masks, so the model won't be learning to generate user stuff or python output. (See notes under `Tokenizer.render_conversation()` in `challenge-26-understand-midtrain/understand-midtrain.ipynb`.)

And I see when we call `F.cross_entropy()` in GPT we pass `ignore_index=-1` so a -1 in our target list of tokens will mean to ignore that prediction when calculating loss. For example:

In [10]:
import torch
import torch.nn.functional as F
# expect these two to be the same
a = F.cross_entropy(torch.tensor([[0.1, 0.2],[0.3, 0.4]]), torch.tensor([0,-1], dtype=torch.long), ignore_index=-1)
b = F.cross_entropy(torch.tensor([[0.1, 0.2]]), torch.tensor([0], dtype=torch.long))
a,b

(tensor(0.7444), tensor(0.7444))

So repeating an example just like from `challenge-26-understand-midtrain/understand-midtrain.ipynb` but now using -1 and clearer fake token numbering, this is what we want:

```
ids to train on: 3  4  1  5  6  2  7
mask:            0  0  0  0  0  1  1

say 3 = bos
    4 = user_start
    1 = a token from the user
    5 = user_end
    6 = assistant_start
    2 = a token from assistant
    7 = assistant_end

Our "normal" target would be: 4  1  5  6  2  7  ?

However, we're only trying to learn how to predict 2 and 7 so we don't want to count the other predictions in our loss.

So our modified target is: -1 -1 -1 -1  2  7 -1 

Seeing it all together:

inputs:   3  4  1  5  6  2  7
targets: -1 -1 -1 -1  2  7 -1 
```

### sft_data_generator()

Hand copy sft_data_generator() to `my_chat_sft.py` and see what else besides masking, if anything, is different than in mid train.

In [9]:
import sys
sys.path.append('../my_nanochat')
from scripts.my_chat_sft import sft_data_generator, train_ds
from my_nanochat.my_tokenizer import get_tokenizer

In [6]:
inputs, targets = next(sft_data_generator(train_ds, batch_size=1))

In [7]:
inputs.shape, targets.shape

(torch.Size([1, 329]), torch.Size([1, 329]))

In [13]:
# expect one conversation (document) in the input, expect to see lots of -1s in the target
tokenizer = get_tokenizer()
tokenizer.decode(inputs[0].tolist())

'<|bos|><|user_start|>A woman is trying to decide whether it will be quicker to take an airplane or drive herself to a job interview. If she drives herself, the trip will take her 3 hours and 15 minutes.  If she takes an airplane, she will first need to drive 10 minutes to the airport, and then wait 20 minutes to board the plane.  After that, she will be on the airplane for one-third of the time it would have taken her to drive herself before landing in the destination city. Finally, it will take her an additional 10 minutes to get off the airplane and arrive at her interview site after the plane lands.  Given this information, how many minutes faster is it for her to take the airplane?<|user_end|><|assistant_start|>First, we must convert the driving time of 3 hours and 15 minutes to minutes. Since there are 60 minutes in an hour, driving takes a total of 3*60 + 15 = <|python_start|>3*60+15<|python_end|><|output_start|>195<|output_end|>195 minutes.\nNext, the woman will be on the airpl

In [14]:
targets[0]

tensor([   -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1, 

^ Those -1s in the 2nd ~half must be for things like <|output_start|>195<|output_end|> 

The very last token 65531 is not -1 like the example I came up with by hand and I'm not sure why.

First sanity check that the -1s are what I think. If so, 65533 should be <|python_end|>

In [16]:
tokenizer.decode([65533])

'<|python_end|>'

Now what's that last 65531?

In [17]:
tokenizer.decode([65531])

'<|assistant_end|>'

I get it. In my tiny hand example above I left the final id in the input, but in reality we remove it since we're not trying to predict from it and we have nothing to check it against. So I should have done this:

```
ids to train on: 3  4  1  5  6  2  7
mask:            0  0  0  0  0  1  1

inputs:          3  4  1  5  6  2
targets:        -1 -1 -1 -1  2  7 
```

So this seems different from mid training and base training in 2 ways:

- We ignore the loss on certain predictions (e.g. user tokens, python output)

- Each row of the batch is a complete conversation (doc). For example, if we have a batch of size 3 corresponding to 3 conversations, and the first is 4 tokens, the second is 5 tokens, and the third is 6 tokens, we'll get inputs like this (one less token for reason explained above):

```
T T T X X
T T T T X 
T T T T T 
```

Where T is a token from the conversation and X is whatever we pad with. We're willing to waste 3 tokens of processing.

But in mid training and base training, we just pack them all in and would move on to the next conversation (doc):

```
T T T T T T
T T T T T T
T T T T T T
```

In mid training, the batch size and sequence length is specified ahead of time and consistent. In SFT training, each batch will be the width of the longest conversation (doc). (How is the batch size picked? Not sure, maybe will see that later. Also, do we do anything to group conversations of similar lengths?)

Why does it work this way? This gets back to the question I had in `challenge-09-understand-model-input/understand-model-input.ipynb` where I was surprised we **didn't** line up by "something" at the start of each row and instead just packed the tokens in without regard to document boundary.

Maybe it just doesn't matter when we're learning more basic stuff and we don't want to waste any processing power. But now that we're fine-tuning, we need the model to see full conversations that include all of the user_start, user_end, assistant_start, assistant_end, etc. tokens. We don't want to cut any conversations off in the middle, and we don't want to train on conversations that start in the middle. Maybe (even less sure) it's also helpful to have the beginning with bos placed exactly at the start of the sequence for the positional embedding stuff to work a little better???

Just to sanity check I understand the padding, let's also do a batch size of 2.

In [19]:
inputs, targets = next(sft_data_generator(train_ds, batch_size=2))
inputs.shape, targets.shape

(torch.Size([2, 1107]), torch.Size([2, 1107]))

In [21]:
inputs

tensor([[65527, 65528,    65,  ..., 65531, 65531, 65531],
        [65527, 65528,  1708,  ...,   309,  1058,    46]], device='mps:0')

^ Yes, you can see the 2nd conversation is longer and so the end of the first one is padded.

In [23]:
tokenizer.decode(inputs[0].tolist())

'<|bos|><|user_start|>A woman is trying to decide whether it will be quicker to take an airplane or drive herself to a job interview. If she drives herself, the trip will take her 3 hours and 15 minutes.  If she takes an airplane, she will first need to drive 10 minutes to the airport, and then wait 20 minutes to board the plane.  After that, she will be on the airplane for one-third of the time it would have taken her to drive herself before landing in the destination city. Finally, it will take her an additional 10 minutes to get off the airplane and arrive at her interview site after the plane lands.  Given this information, how many minutes faster is it for her to take the airplane?<|user_end|><|assistant_start|>First, we must convert the driving time of 3 hours and 15 minutes to minutes. Since there are 60 minutes in an hour, driving takes a total of 3*60 + 15 = <|python_start|>3*60+15<|python_end|><|output_start|>195<|output_end|>195 minutes.\nNext, the woman will be on the airpl

In [24]:
tokenizer.decode(inputs[1].tolist())

"<|bos|><|user_start|>How does standard deviation measure variability or dispersion of a data set? Can you describe a data set that would have a low standard deviation and another that would have a high standard deviation?<|user_end|><|assistant_start|>Standard deviation measures the variability or dispersion of a data set by quantifying how spread out the individual data points are from the mean value. A low standard deviation indicates that the data points tend to be close to the mean, while a high standard deviation indicates that the data points are more spread out.\n\nA data set with a low standard deviation would be one where the values are relatively consistent and do not deviate much from the mean. For example, consider the scores of a group of students on a math test, where most students scored between 80 and 90, with a mean score of 85. In this case, the standard deviation would be low, perhaps around 2-3 points, indicating that the scores are clustered close to the mean.\n\n

^ wow, most of the tokens in the first row of the batch are wasted

Looking ahead in the code, I see that batch_size is in fact fixed for the whole run. That could mean that each step varies quite a bit in terms of number of tokens. It also means we need to worry about a conversation being too long, but maybe no conversation in the training is? (At least for training on something like H100?)

In fact now copying the rest of the code, I see he defaults device_batch_size to 4 with a comment "max to avoid OOM." Not sure if the compute time to adjust batch size per step and/or organize conversations in a way to maximize GPU use is too great to make it worthwhile, or it just wasn't done in this code. One way to do it would be to maintain a small buffer of tokenized conversations and each yield would choose conversations of similar length and return a batch with an appropriate but varying number of rows. However, I wonder if to do training efficiently we want a compiled model and compiled models expect or do best with a fixed batch size. Like if the batch keeps changing shape then it need to rejigger what operations get done to what memory.

Now seeing that he commented this line out: `# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs` makes me realize that even as currently implemented the batch sizes aren't fixed, the height is, but not the width, so maybe it will help to organize the conversations by similar length.

### Rest of code

Fill out the rest of `my_chat_sft.py` with a combination of copying and pasting from `my_mid_train.py` and hand copying from his chat_sft.py

From this comment `the number of "active" tokens of supervision seen` I'm now thinking supervised / supervision in SFT relates to masking, only certain tokens "supervise" the training or the "loss". Something like that.

While doing this noticed I left out updating the muon_momentum in mid_train. Maybe I should redo mid train before doing sft.

added code, try, though wonder if it will OOM on my mac

In [26]:
import os
os.environ["PYTHONPATH"] = "../my_nanochat"

In [31]:
!python -m scripts.my_chat_sft \
    --model_tag=d4 \
    --num_iterations=10 \
    --device_batch_size=1 \
    --target_examples_per_step=4 \
    --eval_every=5 \
    --eval_steps=10 \
    --eval_metrics_every=5 \
    --eval_metrics_max_problems=2

overriding model_tag = d4
overriding num_iterations = 10
overriding device_batch_size = 1
overriding target_examples_per_step = 4
overriding eval_every = 5
overriding eval_steps = 10
overriding eval_metrics_every = 5
overriding eval_metrics_max_problems = 2
user_config: {'run': 'dummy', 'source': 'mid', 'device_type': '', 'dtype': 'bfloat16', 'device_batch_size': 1, 'num_epochs': 1, 'num_iterations': 10, 'target_examples_per_step': 4, 'unembedding_lr': 0.004, 'embedding_lr': 0.2, 'matrix_lr': 0.02, 'weight_decay': 0.0, 'init_lr_frac': 0.02, 'eval_every': 5, 'eval_steps': 10, 'eval_metrics_every': 5, 'eval_metrics_max_problems': 2}
Autodetected device type: mps
loading the model from /Users/ericsilberstein/.cache/my_nanochat/mid_checkpoints/d4 with step 9
Building model with config: {'sequence_len': 128, 'vocab_size': 65536, 'n_layer': 4, 'n_head': 2, 'n_kv_head': 2, 'n_embd': 256}
Target examples per step: 4
Device batch size: 1
Examples per step is device_batch_size * ddp_world_size: 

^ it's failing beacuse the conversations are longer than 10 times the sequence length. Maybe just to test on my mac I should make something to skip conversations above a certain size.

In [34]:
!python -m scripts.my_chat_sft \
    --model_tag=d4 \
    --num_iterations=10 \
    --device_batch_size=1 \
    --target_examples_per_step=4 \
    --eval_every=5 \
    --eval_steps=10 \
    --eval_metrics_every=5 \
    --eval_metrics_max_problems=2 \
    --max_data_tokens=1280

overriding model_tag = d4
overriding num_iterations = 10
overriding device_batch_size = 1
overriding target_examples_per_step = 4
overriding eval_every = 5
overriding eval_steps = 10
overriding eval_metrics_every = 5
overriding eval_metrics_max_problems = 2
overriding max_data_tokens = 1280
user_config: {'run': 'dummy', 'source': 'mid', 'device_type': '', 'dtype': 'bfloat16', 'device_batch_size': 1, 'num_epochs': 1, 'num_iterations': 10, 'max_data_tokens': 1280, 'target_examples_per_step': 4, 'unembedding_lr': 0.004, 'embedding_lr': 0.2, 'matrix_lr': 0.02, 'weight_decay': 0.0, 'init_lr_frac': 0.02, 'eval_every': 5, 'eval_steps': 10, 'eval_metrics_every': 5, 'eval_metrics_max_problems': 2}
Autodetected device type: mps
loading the model from /Users/ericsilberstein/.cache/my_nanochat/mid_checkpoints/d4 with step 9
Building model with config: {'sequence_len': 128, 'vocab_size': 65536, 'n_layer': 4, 'n_head': 2, 'n_kv_head': 2, 'n_embd': 256}
Target examples per step: 4
Device batch size: 

^ ok, completed. Guess none of the 2 MMLU and 2 ARC conversations were too long.

Now try without num_iterations and cancel.

In [35]:
!python -m scripts.my_chat_sft \
    --model_tag=d4 \
    --device_batch_size=1 \
    --target_examples_per_step=4 \
    --eval_every=5 \
    --eval_steps=10 \
    --eval_metrics_every=5 \
    --eval_metrics_max_problems=2 \
    --max_data_tokens=1280

overriding model_tag = d4
overriding device_batch_size = 1
overriding target_examples_per_step = 4
overriding eval_every = 5
overriding eval_steps = 10
overriding eval_metrics_every = 5
overriding eval_metrics_max_problems = 2
overriding max_data_tokens = 1280
user_config: {'run': 'dummy', 'source': 'mid', 'device_type': '', 'dtype': 'bfloat16', 'device_batch_size': 1, 'num_epochs': 1, 'num_iterations': -1, 'max_data_tokens': 1280, 'target_examples_per_step': 4, 'unembedding_lr': 0.004, 'embedding_lr': 0.2, 'matrix_lr': 0.02, 'weight_decay': 0.0, 'init_lr_frac': 0.02, 'eval_every': 5, 'eval_steps': 10, 'eval_metrics_every': 5, 'eval_metrics_max_problems': 2}
Autodetected device type: mps
loading the model from /Users/ericsilberstein/.cache/my_nanochat/mid_checkpoints/d4 with step 9
Building model with config: {'sequence_len': 128, 'vocab_size': 65536, 'n_layer': 4, 'n_head': 2, 'n_kv_head': 2, 'n_embd': 256}
Target examples per step: 4
Device batch size: 1
Examples per step is device_b

Code added as part of this challenge:

- `my_chat_sft.py`