# Mini-ChatGPT

This tutorial assumes the reader has some basic machine learning (ML) knowledge, natural language processing (NLP) knowledge and familiarity with transformer neural network large language models, aka foundation models, but is new to using reinforcement learning. If you know what a loss aka cost aka objective function is and how a token aka input_id represents a vector aka embedding, you should be good.

If not here is a really great place to start  https://lena-voita.github.io/nlp_course/language_modeling.html 

Special thanks to Leandro von Werra 's https://github.com/lvwerra/trl repo

diagram from https://openai.com/blog/chatgpt/ ChatGPT is an improved InstructGPT 

<img src='https://cdn.openai.com/chatgpt/draft-20221129c/ChatGPT_Diagram.svg' height=600 width=800>

#### Figure 1

### Step 1

Supervised Fine-Tuning (SFT) is the process of taking a model already pre-trained on a very large corpus of data, and re-training it carefully as not to erase the knowledge the model has gained from that pre-training [catastropic forgetting](https://github.com/clam004/intro_continual_learning), but rather slightly nudge the model to perform better in a more specific or narrow domain represented by a much smaller fine tuning dataset.

There are already plenty of resources on Supervised Fine-Tuning (SFT) aka fine-tuning already, some with impressively good explanations, mistakenly claiming to help you build a mini-chatgpt. I will not cover SFT, instead I will point you to some of my favorite resources. 

Here is an explanatory video by Huggingface of the Decoder, or autoregressive, transformers like GPT: https://youtu.be/d_ixlCubqQw

Here is a great blog explaining language modeling and the cross entropy loss function https://lena-voita.github.io/nlp_course/language_modeling.html , I would start here if you have done ML but not deep learning for NLP before. 

Fine-tuning GPT-like models: https://huggingface.co/course/chapter7/6?fw=pt and [Guide to fine-tuning Text Generation models: GPT-2, GPT-Neo and T5](https://towardsdatascience.com/guide-to-fine-tuning-text-generation-models-gpt-2-gpt-neo-and-t5-dc5de6b3bc5e)

ChatGPT likely was fine-tuned on a dataset of instructions and examples of following to
 
### Step 2

Fine-tuning usually means training for only a small number of epochs with a much smaller learning rate, but also using the cross entropy loss function used to do pre-training, but this way to learning is abit limited, can you tell why? 

<img src="https://lena-voita.github.io/resources/lectures/lang_models/neural/one_step_loss_intuition-min.png" height=500 width=700>

#### Figure 2

This example of using cross entropy shows how the training algorithm rewards or penalizes based on how large the logit (probability mass asigned to) for "cat" is, just this one way of completing the sequence, but we know that in reality when generating responses to instructions, there isnt one right ways to do it even in the cases when there is one right answer, there are many ways to be good and many ways to be bad. 

For the sake of keeping this tutorial light weight, fit on commonly available compute and intuitive, we have simplified the overall strategy in 2 major ways:

A. ChatGPT and InstructGPT, although they are very specially trained with reinforcement learning (RL), when you generalize, ChatGPT is just taking some prompt, aka input text (the phrase or instructions) and generating an output, aka continuing that text (providing a response or answer). It takes alot more compute memory to represent these long instructions and long answers in the neural network transformers, so instead we simplify the input text to the start of a movie review (first few words or subwords) and simplify the continuing text to the continuation of those first few words.

B. The reward modeling has been simplified as compared to InstructGPT https://openai.com/blog/instruction-following/ which is the model that ChatGPT is a scaled up and improved version of. 

Instead of training a reward model based on human rankings of output text, which was done to make the reward signal more stable or reliable, we use the more direct tactic of using the logits from a classifier (like BERT) as the reward signal, "positive meanss do more like this, negative means do less like this", where positive means the review is a referring to movie the reviewer thinks is good (a positive review). 

### Step 3

Given our limited time, I am going to focus on giving you what is hard to find. That is, good explainations of the parts that are usually glossed over yet important, or those parts that are usually explained in a much more jargony, field specific, technical or mathematical manner.

That means we will be focusing on step 3 in Figure 1. Namely, how reinforcement learning, something we have seen thus far in the popular media mostly applied to computer games, instead applied to natural language, or the generation of sequences of discrete tokens. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from minichatgpt.experiments.imdb import config, sent_kwargs
from minichatgpt import PPOTrainer, Lab
from minichatgpt.processdata.collators import imdb_dataloader_collator
from minichatgpt.trainer import PPOTrainer

In [3]:
# For the sake of the speed of this demonstration, the batch_size is temporarily decreased from 256 to 4
batch_size = 4
config.batch_size = batch_size
config.forward_batch_size = batch_size//2

config.seed

In [4]:
lab = Lab(config)

dataset = lab.build_dataset(dataset_name="imdb",input_min_text_length=2,input_max_text_length=8)

new_policy, old_policy, tokenizer = lab.init_policies_tokenizer()

lab.set_generation_config(do_sample=True,output_min_length=4,output_max_length=16,pad_token_id=tokenizer.eos_token_id)

ppo_trainer = lab.init_ppo_trainer(
    config, 
    new_policy, 
    old_policy, 
    tokenizer, 
    dataset, 
    dataloader_collator = imdb_dataloader_collator,
)

reward_model = lab.init_reward_model()

Found cached dataset imdb (/Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2bd6a5d7d39a840d.arrow
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2ecf25d24c93f132.arrow


#### A new prompt is sampled from the dataset

In the cells below, he text is our samples is represented by the token IDs in input_ids, and each token is either a word or a subword. Roughly speaking the average token represents a word or subword that is around 3 characters long, 'can', 'ed', 'con', 'tion' etc. Because of the code `input_min_text_length=2, input_max_text_length=8` above in `dataset = lab.build_dataset()`, you will find a random assortment of sample lengths in each batch, but they will be no less than 2 tokens short and no more than 8 tokens long. 

In reinforcement learning the agent acts in an environment, like a computer game, and takes actions, like hiting the jump button or moving to the left, which in the future may result in more or less reward, like points in a game. Here the environment is the random assortment of samples that the model can continue is many ways. The action here is how the model chooses to continue those samples, aka prompts.

In [5]:
for batch_step, batch in enumerate(ppo_trainer.dataloader):
    
    queries = batch['input_ids']
    
    break
    
print('the part of each batch are: ', batch.keys())
print('-'*50)
print('each batch has ', len(batch['query']), 'samples')
print('-'*50)
print('here are some samples examples ', batch['query'])
print('-'*50)
print('here are the token ids of those examples ', batch['input_ids'])

the part of each batch are:  dict_keys(['label', 'input_ids', 'query'])
--------------------------------------------------
each batch has  4 samples
--------------------------------------------------
here are some samples examples  ['I remember watching', "I couldn't hold back", 'A suspenseful thriller that', 'I like musicals']
--------------------------------------------------
here are the token ids of those examples  [tensor([  40, 3505, 4964]), tensor([  40, 3521,  470, 1745,  736]), tensor([   32, 43527,   913, 32251,   326]), tensor([   40,   588, 10530,    82])]


#### The policy generates an output

run the next 2 cells several times, notice that the initial query (the prompt) stays the same, but the predicted actions (the text continuation) keeps changing, thats because in our generation keyword arguments we have `do_sample=True`

In [6]:
#### Get response from gpt2
responses = []
for query in queries:
    gen_len = lab.output_length_sampler()
    lab.generation_kwargs["max_new_tokens"] = gen_len
    response = ppo_trainer.generate(query, **lab.generation_kwargs)
    responses.append(response.squeeze()[-gen_len:])
    
responses

[tensor([16089,   351,   337,     5,    38,    11,   290,   484,  6304]),
 tensor([ 379,  477,  287, 4964]),
 tensor([  468,   645,  7110,    11,  3805, 12209, 14586]),
 tensor([ 290, 8842,   11,  290,  314, 1842, 9281,   12,   35, 3383,  518,    7,
         8133, 4480,  257])]

In [7]:
batch['response'] = [tokenizer.decode(r.squeeze()) for r in responses]

concat_text = [batch['query'][i]+" >>> "+batch['response'][i] for i in range(len(batch['response']))]

concat_text

['I remember watching >>>  Raw with M&G, and they weren',
 "I couldn't hold back >>>  at all in watching",
 'A suspenseful thriller that >>>  has no plot, despite occasional grat',
 'I like musicals >>>  and fantasy, and I love Saint-Domingue(!)with a']

I used `>>>` to separate the prompt, left side, which is sampled from the dataset from the text generated by the policy, right side.

#### The reward model calculates a reward for the output

In [8]:
#### Compute sentiment score
texts = [q + r for q,r in zip(batch['query'], batch['response'])]
pipe_outputs = lab.reward_model(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

rewards



[tensor(1.3289), tensor(-1.8032), tensor(0.8134), tensor(2.2469)]

#### The reward is used to update the policy using PPO

PPO means proximal policy optimization. They are the reinforcement learning steps and they are all done by calling `stats = ppo_trainer.step(query_tensors, response_tensors, rewards)` in fact the entire Step 3 training loop is:

```python

for batch_step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    
    queries = batch['input_ids']
    
    #### Get response from gpt2
    responses = []
    for query in queries:
        gen_len = lab.output_length_sampler()
        lab.generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **lab.generation_kwargs)
        responses.append(response.squeeze()[-gen_len:])
        
    batch['response'] = [tokenizer.decode(r.squeeze()) for r in responses]
    
    #### Compute sentiment score
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = lab.reward_model(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
    
    #### Run PPO step 
    stats = ppo_trainer.step(queries, responses, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

```

But Im here to learn the details, so we will go through PPO as it applies to sequences of discrete tokens, step by step.

In [9]:
# check to make sure that each queries, responses and rewards is a list of tensors
queries, responses, scores = ppo_trainer._step_safety_checker(batch_size, queries, responses, rewards)

#### Policy model architecture 

To be clear on what happens next I find it helpful to remind myself of some of these details about our policy model. First, the policy is a causal language model with an added value head. We will talk about what the value head is later. 

```python

lab.init_policies_tokenizer()

    self.new_policy = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

```

There are already so many great lessons on the internet about the architectures of transformers, here is a great place to start https://jalammar.github.io/illustrated-gpt2/ so I am not going to explain them here before making reference to them.  

The GPT like models are repeating blocks that pass their hidden state to the next block. The final or last hidden state (aka activations) is used as the input to the language model head (lm_head) and the value head (v_head).

simply run `new_policy` in the next cell to see for yourself

In [10]:
new_policy

AutoModelForCausalLMWithValueHead(
  (pretrained_model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
        

#### Preventing new policy from moving too far from old policy 

Next we find the log probabilities for each token in our responses, notice that the length of each element in the logprobs matches its corresponding pair in responses. We calculate the log probabilities not because we are implementing the reinforce policy gradient loss function [Policy Gradient Theorem](https://jonathan-hui.medium.com/rl-policy-gradients-explained-9b13b688b146), but in order to calculate the [Kullback–Leibler divergence](https://machinelearningmastery.com/divergence-between-probability-distributions/) (KL divergence) between the old static policy and new dynamic policy. You see, in 2022 the current efforts in RL are still being directed to combat the high variance and sample inefficiency in training, in laymens terms, trying to keep the neural network from spining out of control, losing its mind, while being updated, part of this involves keeping track of how much the behavior of the neural network in training has drifted away from the initial behavior, we measure this drift using the KL divergence between the two models outputs. 

In [11]:
logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)
logprobs

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[tensor([-9.1356, -2.0816, -6.7180, -5.0547, -4.3260, -1.9985, -1.2266, -3.6060,
         -5.2811]),
 tensor([-5.0621, -0.4335, -4.2776, -3.4706]),
 tensor([-2.9841, -3.7663, -2.3399, -1.3100, -7.0527, -7.6810, -5.0248]),
 tensor([ -2.0197,  -5.9065,  -1.8272,  -1.7268,  -1.4677,  -2.2598, -11.4134,
          -2.5132,  -3.0124,  -0.3093,  -0.5850,  -6.7908,  -3.5179,  -7.8780,
          -2.0105])]

In [12]:
values

[tensor([-1.0001, -2.9678, -2.1181, -3.0465, -4.3964, -4.3594, -3.4076, -1.7643,
          0.8133]),
 tensor([-2.0027, -1.7066, -3.0060, -1.9641]),
 tensor([-2.2580, -1.7337, -1.8779, -0.6323, -2.5215, -0.7152, -1.6352]),
 tensor([-2.9713, -1.1697, -1.1668, -2.0873, -1.4201, -3.3242, -0.8569,  0.3901,
         -1.4176, -1.0883, -1.2186, -0.4966, -2.8074, -2.0383, -0.8759])]

#### The output logits are policy's actions

If the hidden states size is 768, for example if `new_policy.pretrained_model.lm_head` prints `Linear(in_features=768, out_features=50257, bias=False)`, then that means we are using our last hidden state to predicet logits for each of the 50257 tokens in our vocabulary.

As you can seen below, the outputs of the policy are  `logits, loss, value = model(input_tokens)`

```python
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
        .
        .
        .
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        **kwargs,
    ):
        r"""
        Applies a forward pass to the wrapped model and returns the logits of the value head.
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary.
            past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
                Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
                (see `past_key_values` input) to speed up sequential decoding.
            attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
                Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
            kwargs (`dict`, `optional`):
                Additional keyword arguments, that are passed to the wrapped model.
        """
        base_model_output = self.pretrained_model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            output_hidden_states=True,  # We force the model to output hidden states
            **kwargs,
        )

        last_hidden_state = base_model_output.hidden_states[-1]
        lm_logits = base_model_output.logits
        loss = base_model_output.loss

        value = self.v_head(last_hidden_state).squeeze(-1)

        return (lm_logits, loss, value)
```

Logits are basically a non-normalized score for how likely a particular token is to come next. After the logits are pushed through a softmax layer they can sort of be interpreted as probabilities (horizontal bars below).

<img src="https://lena-voita.github.io/resources/lectures/lang_models/neural/nn_lm_idea_linear-min.png" height=600 width=800>

#### Figure 3.

The label "|V| tokens of them" in the diagram above is sayiing that there is a logit for every token in our vocabulary

The shape of the logits tensor is:

(batch_size, token_length_of_longest_sample_in_batch, number_of_tokens_in_vocabulary)

In [13]:
input_ids = ppo_trainer.data_collator(batch["input_ids"])["input_ids"]
input_kwargs = {"input_ids": input_ids}
logits, loss, values = new_policy(**input_kwargs)
print(logits.shape)
logits[0]

torch.Size([4, 5, 50257])


tensor([[-15.1416, -15.4052, -18.6318,  ..., -22.9367, -20.7308, -14.9200],
        [-57.1975, -58.6069, -65.6370,  ..., -66.9992, -66.3596, -60.5881],
        [-43.4085, -43.9128, -47.9591,  ..., -52.3836, -53.5728, -44.3646],
        [-52.6918, -51.6320, -54.7046,  ..., -60.7820, -60.7672, -49.4324],
        [-50.9333, -50.0535, -52.9739,  ..., -59.3635, -59.0772, -48.5471]],
       grad_fn=<SelectBackward0>)

#### Batched_forward_pass 

Below we dissect the operations done in `ppo_trainer.batched_forward_pass()` 

In the below cell I want to prepare you for some potentially confusing:

Step 3. 

```python
# Step 3. Get the log probabilities of the generated (sampled from logits) tokens 
logp = F.log_softmax(logits[:, :-1, :], dim=2)
logprobs = torch.gather(logp, 2, input_ids[:, 1:].unsqueeze(2)).squeeze(-1)
```

The blue diagram below, Blue Figure, is a tranformer and shows several blue boxes that represents the vector representations for each input token in the sequence horizontally and at each layer vertically. These are equivalent to the `h`'s in the grey green pink diagram above, Figure 3. 

It leaves out a a few steps between the deepest, aka last, or top most blue box and the word that comes out ot it. Implicit in Blue Figure is that just like in grey green pink Figure 3, in order for a token or word to be outputted from the top, the `h`'s go thru a linear layer to form `logits` and a softmax layer to form a vector of probabilities the same length as the number of possible tokens. Once you have this vector of probabilities, you have several options, you could choose the token with highest probability (ie a greedy approach) or sample based on those probabilities such that you have a higher chance of sampling reasonable tokens over unreasonable ones, ie `'I saw a cat on a'-> {'mat' > 'desk' > 'sofa' >>> 'cloud'} ` , this is what we chose, to sample, when we indicated `do_sample=True` above, which is why everytime you run the `response = ppo_trainer.generate(query, **lab.generation_kwargs)` you get different continuations of the beginning prompt text. 

The line of code `logp = F.log_softmax(logits[:, :-1, :], dim=2)` pushes the logits through both a softmax and then a natural log layer so that what you get is log probabilities `logp`. But why are we only concerned with the  `logp`'s before the last one `:-1` in `logits[:, :-1, :]` ? and why are we only concerned about the input IDs after the first one `1:` in `logprobs = torch.gather(logp, 2, input_ids[:, 1:].unsqueeze(2)).squeeze(-1)` ?

<img src="https://coriva.eu.org/images/nlp/lamdamodel.png" height=400 width=600>

#### Blue Figure.

Blue Figure supposes that tokens are words, notice how the word at the top of one column of boxes shows up at the bottom of the boxes of the next column just to the right of it. Basically they are saying that whatever token is sampled from the logits (softmaxed logits) gets incorporated into the newest, rightmost, part of the elongating sequence, and fed back into the transformer to generate the next set of logits to sample from and so on and so on. So if `t` is the position in the sequence, the token at position `t` is sampled from the logits at position `t - 1`, one position behind it. And since the torch.gather function is basically using the tokens in `input_ids`, to pick out the log probability in `logp` that corresponds to that token, we feed the `gather` function the `logits` and `input_ids` with `input_ids` shifted one position forward. 

here is a toy code snippet to show you have gather works, it assumes the batch size is 2, so 2 sample sequences, each sequence is 3 tokens long, and the vocabulary size is only 4. Instead of probabilities or log probabilties, we are just assuming we are indexing out the logits using our `ids` and that those logits are convienently `[1, 2, 3, 4]` for each position in the sequence.

```python
logits = torch.tensor([[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]],[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]])
ids = torch.tensor([[0, 0, 1], [2, 1, 0]])
ids = ids.unsqueeze(2) # gather needs the dimension shape of logits and ids to be the same
print(ids.shape) # torch.Size([2, 3, 1])
selected_logits = torch.gather(logits, 2, ids)
print(selected_logits.shape) # torch.Size([2, 3, 1])
selected_logits = selected_logits.squeeze(-1) # we dont really need the last dimension, [2, 3, 1] -> [2, 3]
selected_logits 
```
result
```
tensor([[1, 1, 2],
        [3, 2, 1]])
```


In [14]:
import torch.nn.functional as F

all_logprobs = []
all_ref_logprobs = []
all_values = []

bs = lab.config.batch_size
fbs = lab.config.forward_batch_size

for i in range(int(bs / fbs)):
    
    query_batch = queries[i * fbs : (i + 1) * fbs]
    response_batch = responses[i * fbs : (i + 1) * fbs]
    
    sequence_tokens = [torch.cat([q, r]) for q, r in zip(query_batch, response_batch)]
    # Step 1. pads other sequences to the longest sample in the batch by 
    # adding tokenizer.pad_token (50256) to the right end of the sequence 
    input_ids = ppo_trainer.data_collator(sequence_tokens)["input_ids"]
    input_kwargs = {"input_ids": input_ids}
    
    print([tokenizer.decode(r.squeeze()) for r in input_ids])
    
    # Step 2. Forward Pass through the model to get the logits
    with torch.no_grad():
        logits, _, v = new_policy(**input_kwargs)
        ref_logits, _, _ = old_policy(**input_kwargs)
    
    # Step 3. Get the log probabilities of the generated (sampled from logits) tokens 
    logp = F.log_softmax(logits[:, :-1, :], dim=2)
    logprobs = torch.gather(logp, 2, input_ids[:, 1:].unsqueeze(2)).squeeze(-1)

    logp = F.log_softmax(ref_logits[:, :-1, :], dim=2)
    ref_logprobs = torch.gather(logp, 2, input_ids[:, 1:].unsqueeze(2)).squeeze(-1)

    # Step 4. use the lenghts of the non-padded sequences to cut out the log probs
    # from the padded tokens and 
    for j in range(lab.config.forward_batch_size):

        start = len(query_batch[j]) - 1
        end = len(query_batch[j]) + len(response_batch[j]) - 1

        if len(logprobs[j, start:end]) < 2:
            raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")

        all_values.append(v[j, start - 1 : end - 1])
        all_logprobs.append(logprobs[j, start:end])
        all_ref_logprobs.append(ref_logprobs[j, start:end])
        
    break # stop after one forward_batch
    
all_logprobs

['I remember watching Raw with M&G, and they weren', "I couldn't hold back at all in watching<|endoftext|><|endoftext|><|endoftext|>"]


[tensor([-9.1356, -2.0816, -6.7180, -5.0547, -4.3260, -1.9985, -1.2266, -3.6060,
         -5.2811]),
 tensor([-5.0621, -0.4335, -4.2776, -3.4706])]

#### Compute rewards

$ R(x,y) = r(x,y) - \beta \log \frac{\pi(y|x)}{\rho(y|x)} $

We are using the [PPO2 algorithm](https://github.com/openai/baselines/tree/master/baselines/ppo2) with a modified reward, the log probabilities are used to calculate the KL divergence penalty on the reward with expectation $\beta KL(\pi, \rho)$, this penalty is described in equation (2) in [Fine-Tuning Language Models from Human Preferences](https://arxiv.org/pdf/1909.08593.pdf) where x is the prompt, y is the generated sequence, $ \pi $ is the distribution of the new policy, $ \rho $ is the distribution of the old policy, aka reference model, aka pretrained transformer, $KL(\pi, \rho)$ is the KL divergence between $ \rho $ and $ \pi $ measured in *nats* and $ \beta $ is dynamically adjusted to acheive a target $KL(\pi, \rho)$ such as *6 nats*.

In [15]:
# next we use these end of episode (end of sentence) rewards
logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)
rewards, non_score_reward = ppo_trainer.compute_rewards(scores, logprobs, ref_logprobs)

rewards

[tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 1.3289]),
 tensor([-0.0000, -0.0000, -0.0000, -1.8032]),
 tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 0.8134]),
 tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 2.2469])]

In [16]:
# inside the compute_rewards function
# before any updates to the new policy, it is the same as th old policy, so logprob_ - ref_logprob_ = 0 

rewards_, non_score_rewards_ = [], []

for score_, logprob_, ref_logprob_ in zip(scores, logprobs, ref_logprobs):
    kl = logprob_ - ref_logprob_ # log[pi(y|x)/rho(y|x)]
    non_score_reward_ = -ppo_trainer.kl_ctl.value * kl 
    non_score_rewards_.append(non_score_reward_)
    reward_ = non_score_reward_.clone()
    reward_[-1] += score_ # the reward model score is added only to the last token in each sequence
    rewards_.append(reward_)
    print(reward_)

tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 1.3289])
tensor([-0.0000, -0.0000, -0.0000, -1.8032])
tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 0.8134])
tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 2.2469])
