## Understand reinforcement learning

Look at [chat_rl.py](https://github.com/karpathy/nanochat/blob/master/scripts/chat_rl.py)

Well, his initial comment is:

```
Reinforcement learning on GSM8K via "GRPO".

I put GRPO in quotes because we actually end up with something a lot
simpler and more similar to just REINFORCE:

1) Delete trust region, so there is no KL regularization to a reference model
2) We are on policy, so there's no need for PPO ratio+clip.
3) We use GAPO style normalization that is token-level, not sequence-level.
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
```

I've never heard of GPRO, I have no idea what a trust region is, or what it means to be on policy, or GAPO.

I see from skimming ahead we're going to get into the task reward stuff I copied earlier.

I'll get back to looking up what some of those terms in the comment are. Start by hand copying `run_gsm8k_eval()` into this notebook to understand it. 

### run_gsmk8k_eval()

In [18]:
ddp_rank = 0
ddp_world_size = 1
device_batch_size = 1
def run_gsm8k_eval(task, tokenizer, engine,
                   max_examples=None,
                   num_samples=1,
                   max_completion_tokens=256,
                   temperature=0.0,
                   top_k=50):
    max_examples = min(max_examples, len(task)) if max_examples is not None else len(tasks)
    for idx in range(ddp_rank, max_examples, ddp_world_size):
        conversation = task[idx]
        tokens = tokenizer.render_for_completion(conversation)
        prefix_length = len(tokens)
        assert num_samples <= device_batch_size # he comments can add loop if not, won't be true on mac
        generated_token_sequences, masks = engine.generate_batch(tokens,
                                                                 num_samples=num_samples,
                                                                 max_tokens=max_completion_tokens,
                                                                 temperature=temperature,
                                                                 top_k=top_k)
        outcomes = []
        for sample_tokens in generated_token_sequences:
            generated_tokens = sample_tokens[prefix_length:]
            generated_text = tokenizer.decode(generated_tokens)
            is_correct = task.evaluate(conversation, generated_text)
            outcomes.append({
                'is_correct': is_correct
            })

        record = {
            'idx': idx,
            'outcomes': outcomes,
        }
        yield record

^ ok, this part is very similar to how things work in chat_eval, but instead of seeing if any sample is correct, it returns correct or not correct for each sample

as a quick test on my mac I'll have to set num_samples to 1 though (or I need to add a loop)

In [19]:
import sys
sys.path.append('../my_nanochat')
from my_nanochat.my_checkpoint_manager import load_model
from my_nanochat.my_common import compute_init, autodetect_device_type
from my_nanochat.my_engine import Engine
from my_tasks.my_gsm8k import MyGSM8K
device_type = autodetect_device_type() 
_, _, _, _, device = compute_init(device_type)
model, tokenizer, meta_data = load_model('sft', model_tag='d1', device=device, phase='eval')
engine = Engine(model, tokenizer)
task = MyGSM8K(subset="main", split="train")

Autodetected device type: mps
loading the model from /Users/ericsilberstein/.cache/my_nanochat/chatsft_checkpoints/d1 with step 9
Building model with config: {'sequence_len': 256, 'vocab_size': 65536, 'n_layer': 1, 'n_head': 1, 'n_kv_head': 1, 'n_embd': 64}


In [20]:
next(run_gsm8k_eval(tasks, tokenizer, engine, max_examples=2, num_samples=1))

{'idx': 0, 'outcomes': [{'is_correct': 0}]}

### get_batch()

See what that gives us.

In [21]:
import itertools

In [23]:
i = 0
for n in itertools.cycle([1,3,5]):
    print(n)
    i += 1
    if (i == 10):
        break

1
3
5
1
3
5
1
3
5
1


In [48]:
import sys
sys.path.append('../my_nanochat')
import torch
from my_nanochat.my_checkpoint_manager import load_model
from my_nanochat.my_common import compute_init, autodetect_device_type
from my_nanochat.my_engine import Engine
from my_tasks.my_gsm8k import MyGSM8K
device_type = autodetect_device_type() 
_, _, _, _, device = compute_init(device_type)
model, tokenizer, meta_data = load_model('sft', model_tag='d1', device=device, phase='eval')
engine = Engine(model, tokenizer)
train_task = MyGSM8K(subset="main", split="train")

@torch.no_grad()
def get_batch():
    assistant_end = tokenizer.encode_special("<|assistant_end|>")
    rank_indices = range(ddp_rank, len(train_task), ddp_world_size)
    for example_idx in itertools.cycle(rank_indices):
        conversation = train_task[example_idx]
        tokens = tokenizer.render_for_completion(conversation)
        prefix_len = len(tokens)

        model.eval() # this is pretty different, we're going to use the model in generating a batch
        generated_token_sequences = []
        masks = []
        num_sampling_steps = num_samples // device_batch_size
        for sampling_step in range(num_sampling_steps):
            seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
            with autocast_ctx:
                generated_token_sequences_batch, masks_batch = engine.generate_batch(
                    tokens,
                    num_samples=device_batch_size,
                    max_tokens=max_new_tokens,
                    temperature=temperature,
                    top_k=top_k,
                    seed=seed,
                )
            generated_token_sequences.extend(generated_token_sequences_batch)
            masks.extend(masks_batch)
    
        rewards = []
        for sample_tokens in generated_token_sequences:
            generated_tokens = sample_tokens[prefix_len:]
            generated_text = tokenizer.decode(generated_tokens)
            reward = train_task.reward(conversation, generated_text) # 1 or 0 right?
            rewards.append(reward)

        max_len = max(len(seq) for seq in generated_token_sequences)
        padded_generated_token_sequences = [seq + [assistant_end] * (max_len - len(seq)) for seq in generated_token_sequences]
        padded_masks = [mask + [0] * (max_len - len(mask)) for mask in masks]

        ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
        mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)

        inputs = ids[:, :-1]
        targets = ids[:, 1:].clone()
        targets[mask_ids[:, 1:] == 0] = -1
        rewards = torch.tensor(rewards, dtype=torch.float, device=device)

        mu = rewards.mean()
        advantages = rewards - mu

        yield generated_token_sequences, inputs, targets, rewards, advantages

Autodetected device type: mps
loading the model from /Users/ericsilberstein/.cache/my_nanochat/chatsft_checkpoints/d1 with step 9
Building model with config: {'sequence_len': 256, 'vocab_size': 65536, 'n_layer': 1, 'n_head': 1, 'n_kv_head': 1, 'n_embd': 64}


Before running it or thinking super carefully, the feel of it is that we're going to do something like judge loss by how much the model can predict whatever it predicted before but somehow weighted to prefer predictions that were correct. I can sort of see why this would work because the model already has the "ability" to generate that good prediction and the backprop here will reinforce weights to make it more likely. I don't understand why the inputs / targets though only have the part after the shared prefix. To take an extreme (and possibly inappropriate) example, if this was a different task where the user gave a multiple choice question and asked the assistant to respond only with A, B, C, or D, then how would this work?

But back to this GSM8K case, looking at `challenge-27-understand-chat-eval/chat-eval-data-examples.ipynb` to remember exactly what they look like...

oh, wait, the cutting out the prefix is only for calculating the reward for each sample, each row of the batch will contain the whole thing

Suppose our initial prompt is:

`<|user_start|>Mary Darrell and Allen's ages are in the ratio of 7:11. If their total age now is 162, calculate Allen's age 10 years from now.<|user_end|><|assistant_start|>`

We generate 16 samples resulting in these completions:

- ...a bunch of reasoning and calculating...### 109

- ...some other bunch of reasoning and calculating...### 7

- ...yet another bunch of reasoning and calculating...### 109

- 13 other samples

1 and 3 gave the right answer, 2 gave the wrong answer, and say the other 13 were wrong. There's no guarantee that 1 and 3 are right for the right reason, but the others are definitely wrong, so our goal is to adjust weights to make it more likely for the model to generate 1 and 3 and less likely to generate the other in future similar situations. We're reinforcing goodness. I think that's the idea.

And just as in SFT, we're only interested in learning to predict tokens the assistant is supposed to write, so not the user part of the conversation or python output.

But how will we adjust the loss? Going back to the example above, say the rewards are as follows:

- 1.0
- 0
- 1.0
- 0 for the 13 others

I guess there are lots of ways. For example we normalize the rewards and then multiply the average cross entropy loss of the row by the corresponding reward by some constant.

"Advantages" seems like a clue. Looking at those last two lines, for this example we'll end up with

```
rewards = [1,0,1,0...0]
mu = 2 / 16 = 0.125
advantages = [0.875, -0.125, 0.875, -0.125...-0.125]
```

So maybe in the actual loss calculation we use advantages and just add, or add it times a constant.

Try get_batch()

In [62]:
num_samples = 4
step = 1
max_new_tokens = 128
temperature = 1.0
top_k = 50
from contextlib import nullcontext
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
generated_token_sequences, inputs, targets, rewards, advantages = next(get_batch())

In [64]:
(len(generated_token_sequences),
 len(generated_token_sequences[0]),
 len(generated_token_sequences[1]),
 len(generated_token_sequences[2]),
 len(generated_token_sequences[3]))

(4, 90, 182, 109, 120)

In [65]:
inputs.shape

torch.Size([4, 181])

In [66]:
tokenizer.decode(inputs[0].tolist())

'<|bos|><|user_start|>Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|user_end|><|assistant_start|> not about of specific may,.\n is, about isTo bayonet not of of  to more of while and about will it what which, and on of.\n\n\n the,<|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_end|><|assistant_e

In [67]:
tokenizer.decode(inputs[1].tolist())

'<|bos|><|user_start|>Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?<|user_end|><|assistant_start|> also point\n a their The andTo teller, that are, one point is can is willTo expensive will it will and as be can not and to the you.s\n that the will to can be while can.\n will other which in the, the each but to thatHowever, on a the the can also be, about that if can on that. to a of may. is the about\n more of a that on also one of that but of it and each. would as while a and a the,, but also can and or as a be a\n and\n\nHowever, may not.\n\n would with'

In [72]:
targets[1]

tensor([   -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,   543,  1187,    10,   257,   472,   361,   288,
         2240, 63906,    44,   332,   345,    44,   550,  1187,   309,   400,
          309,   490,  2240,  5327,   490,   356,   490,   288,   343,   311,
          400,   434,   288,   287,   261,   348,    46,   115,    10,   332,
          261,   490,   287,   400,   311,  1095,   400,   307,   490,   534,
          491,   283,   261,    44,   261,   961,   540,   287,   332,  4238,
           44,   331,   257,   261,   261,   400,   543,   311,    44,   566,
          332,   711,   400,   331,   332,    46,   287,   257, 

In [73]:
# expect all 0 because no way my d1 model got any right
rewards

tensor([0., 0., 0., 0.], device='mps:0')

In [74]:
advantages

tensor([0., 0., 0., 0.], device='mps:0')

^ all more or less makes sense

### Loss calculation

Now how is loss actually calculated? Key code:

```
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)

# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()

# normalize by the number of valid tokens, number of passes, and examples_per_rank
num_valid = (targets >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)

# Note, there is no need to add PPO ratio+clip because we are on policy

# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
loss = -pg_obj
```

What is PG objective? Google: The objective of policy gradient (PG) in reinforcement learning is to maximize the expected cumulative reward by directly optimizing the agent's policy.

ok, so the basic idea is we flip the cross entropy loss so increase is good, multiply flipped per-token loss in each row by the advantage for the row (like 0.875 and -0.125 in my example above), add it all up, and flip again so decrease is good.

however, we also do some normalization in between. For example, if we processed a batch with tons of user tokens and only a few assistant tokens, the (absolute value of the) loss will probably be much smaller than a batch with tons of asssitant tokens because we're just adding everything together. This corrects for that and makes loss more like how we normally calculat it by taking the mean over the batch of tokens.

### Create my_chat_rl.py

Start hand copying the code.

What is this part of eval?

```
        for k in range(1, device_batch_size + 1):
            passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
```

It's just saying pass@1 if the first one is correct, pass@2 if any of the first two are correct, etc.

(This reminds me of something that either is confusing or I was confused about in earlier code: When does n-shot mean the number of examples given at the beginning of the prompt and when does n-shot mean how many chances you have to get a correct answer?)

In [84]:
B = 4
T = 5
fake_logp = torch.ones((B, T))
fake_advantages = torch.tensor([1,2,3,4])
fake_logp

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [85]:
fake_advantages.unsqueeze(-1)

tensor([[1],
        [2],
        [3],
        [4]])

In [86]:
fake_logp * fake_advantages.unsqueeze(-1)

tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.]])

What does this comment mean?

`# Note, there is no need to add PPO ratio+clip because we are on policy`

What exactly is a "rollout"? Maybe it's the idea of taking one example and "rolling it out" into many samples and then taking those sample and their rewards and "rolling that out" into the model via forward and back prop?

Try...might be too much of a pain to get it to run through on my mac (not even talking about it doing anything useful). Can see from things left out of his code that he didn't try.

In [87]:
import os
os.environ["PYTHONPATH"] = "../my_nanochat"

In [93]:
!python -m scripts.my_chat_rl \
    --model_tag=d1 \
    --source=sft \
    --device_batch_size=1 \
    --examples_per_step=4 \
    --num_samples=4 \
    --max_new_tokens=128 \
    --eval_examples=10 \
    --eval_every=5

overriding model_tag = d1
overriding source = sft
overriding device_batch_size = 1
overriding examples_per_step = 4
overriding num_samples = 4
overriding max_new_tokens = 128
overriding eval_examples = 10
overriding eval_every = 5
user_config: {'run': 'dummy', 'source': 'sft', 'dtype': 'bfloat16', 'device_type': '', 'device_batch_size': 1, 'examples_per_step': 4, 'num_samples': 4, 'max_new_tokens': 128, 'temperature': 1.0, 'top_k': 50, 'unembedding_lr': 0.004, 'embedding_lr': 0.2, 'matrix_lr': 0.02, 'weight_decay': 0.0, 'init_lr_frac': 0.05, 'num_epochs': 1, 'save_every': 60, 'eval_every': 5, 'eval_examples': 10}
Autodetected device type: mps
loading the model from /Users/ericsilberstein/.cache/my_nanochat/chatsft_checkpoints/d1 with step 9
Building model with config: {'sequence_len': 256, 'vocab_size': 65536, 'n_layer': 1, 'n_head': 1, 'n_kv_head': 1, 'n_embd': 64}
Calculated number of steps: 1868
Scaling the LR for the AdamW parameters proportional to 1/sqrt(64/768) = 3.4641016151377

Code added as part of this challenge:

- `my_chat_rl.py`