#### PPO RL Algorithm with Natural Language Sequences

this RL update step uses the log probabilities, value function estimations, stepwise rewards, the prompt tokens and generated tokens 

below we will study whats going on in `ppo_trainer.compute_logits_vpred` and `ppo_trainer.loss`. As you can see below, thwn you run these together, the policy will be updated such that the rewards will have changed.

Lets first just quickly go thru all the steps from part 1 to get to the same log probabilities, value function estimations, stepwise rewards, the prompt tokens and generated tokens in order to continue from where we left off

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from minichatgpt.experiments.imdb import config, sent_kwargs
from minichatgpt import Lab
from minichatgpt.processdata.collators import imdb_dataloader_collator

# for the loss calculation
from minichatgpt.core import whiten, logprobs_from_logits

In [3]:
# For the sake of the speed of this demonstration, the batch_size is temporarily decreased from 256 to 4
batch_size = 4
config.batch_size = batch_size
config.forward_batch_size = batch_size//2
config.seed

0

In [4]:
lab = Lab(config)
dataset = lab.build_dataset(dataset_name="imdb",input_min_text_length=2,input_max_text_length=8)

Found cached dataset imdb (/Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2bd6a5d7d39a840d.arrow
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2ecf25d24c93f132.arrow


In [5]:
new_policy, old_policy, tokenizer = lab.init_policies_tokenizer()
lab.set_generation_config(do_sample=True,output_min_length=4,output_max_length=16,pad_token_id=tokenizer.eos_token_id)
ppo_trainer = lab.init_ppo_trainer(
    config, 
    new_policy,old_policy, 
    tokenizer, 
    dataset, dataloader_collator=imdb_dataloader_collator,
)
reward_model = lab.init_reward_model()

In [6]:
for batch_step, batch in enumerate(ppo_trainer.dataloader):
    
    queries = batch['input_ids']
    
    #### Get response from gpt2
    responses = []
    for query in queries:
        gen_len = lab.output_length_sampler()
        lab.generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **lab.generation_kwargs)
        responses.append(response.squeeze()[-gen_len:])

    batch['response'] = [tokenizer.decode(r.squeeze()) for r in responses]

    #### Compute sentiment score
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = lab.reward_model(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
    break
    
queries, responses, scores = ppo_trainer._step_safety_checker(batch_size, queries, responses, rewards)

scores

[tensor(-1.1495), tensor(-1.8032), tensor(-0.8332), tensor(-1.9958)]

In [7]:
old_logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)

rewards, non_score_reward = ppo_trainer.compute_rewards(scores, old_logprobs, ref_logprobs)

print(rewards)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -1.1495]), tensor([-0.0000, -0.0000, -0.0000, -1.8032]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.8332]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -1.9958])]


In [8]:
idx = list(range(config.batch_size))

# train_minibatch() # line 419 ppo_trainier.py

for idx in range(config.batch_size):
    
    new_logprobs, vpred, logits = ppo_trainer.compute_logits_vpred(
        model_input = torch.cat([queries[idx],responses[idx]]).unsqueeze(0), 
        query = queries[idx].unsqueeze(0), 
        response = responses[idx].unsqueeze(0), 
        rewards = rewards[idx].unsqueeze(0),
    )
    
    loss_p, loss_v, train_stats = ppo_trainer.loss(
        old_logprobs[idx].unsqueeze(0),
        values[idx].unsqueeze(0),
        rewards[idx].unsqueeze(0),
        logits,
        vpred,
        new_logprobs,
    )
    
    loss = loss_p + loss_v
    
    ppo_trainer.optimizer.zero_grad()
    ppo_trainer.accelerator.backward(loss)
    ppo_trainer.optimizer.step()
    
    break
    
train_stats

{'loss/policy': tensor(1.9206e-07, grad_fn=<MeanBackward0>),
 'loss/value': tensor(12.0091, grad_fn=<MulBackward0>),
 'loss/total': tensor(1.2009, grad_fn=<AddBackward0>),
 'policy/entropy': tensor(4.1144, grad_fn=<MeanBackward0>),
 'policy/approxkl': tensor(0., grad_fn=<MulBackward0>),
 'policy/policykl': tensor(0., grad_fn=<MeanBackward0>),
 'policy/clipfrac': tensor(0., dtype=torch.float64),
 'policy/advantages': tensor([[ 0.3414,  1.0291,  0.9201,  1.0319,  0.5275, -0.5273, -0.6148, -0.9588,
          -1.7491]]),
 'policy/advantages_mean': tensor(-1.9206e-07),
 'policy/ratio': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.]], grad_fn=<ExpBackward0>),
 'returns/mean': tensor(-0.4744),
 'returns/var': tensor(0.1508),
 'val/vpred': tensor(3.2707, grad_fn=<MeanBackward0>),
 'val/error': tensor(23.7453, grad_fn=<MeanBackward0>),
 'val/clipfrac': tensor(0.2222, dtype=torch.float64),
 'val/mean': tensor(2.3194),
 'val/var': tensor(0.5796)}

In [9]:
old_logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)

rewards, non_score_reward = ppo_trainer.compute_rewards(scores, old_logprobs, ref_logprobs)

print(rewards)

[tensor([-0.0373, -0.1523, -0.0710, -0.1819, -0.2028,  0.0330,  0.0162,  0.2848,
        -1.0883]), tensor([ 4.2279e-03,  1.4152e-03, -4.4575e-03, -1.8094e+00]), tensor([-9.5544e-04,  4.1749e-03,  1.2315e-03,  1.1430e-04, -1.5267e-03,
        -2.0228e-03, -8.3362e-01]), tensor([-2.6575e-03, -2.8696e-03,  1.8173e-03,  3.6266e-03, -8.3869e-04,
        -5.2579e-03,  9.5940e-04,  8.7895e-03,  2.5551e-03,  3.8439e-03,
         5.9714e-03, -1.1896e-02,  1.6894e-02, -4.4158e-03, -1.9859e+00])]


#### Advantage = Returns - Values

The term `lastgaelam` meanns last [generalized advantage estimator](https://arxiv.org/pdf/1506.02438.pdf) GAE lambda.
There are many resources already for explaining what the [Advantage](https://huggingface.co/blog/deep-rl-a2c) function is, so I will not go into it too much. In the equation below $A^{(k)}_{t}$ is the advantage that represents "Compared to the average reward, or expected reward, reward we should get from state $ s_{t} $  till the end, How much more or less did we get specifically as a result of taking the action we took at step t rather than all the other actions we could have taken, not including the actions we took after or before step t", so in a simple equation:

$$ Advantage = Returns - Values $$

`returns` is the total sum of rewards $ R(t) = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots + \gamma^T r_{T} $ where T is the total number of timesteps in the episode and gamma $ \gamma $ is the discount factor that when small << 1 emphasizes short term rewards over long term rewards and when $ \gamma $ = 1 weights longer term and short term equally.

`values` is the model's prediction of what the returns will be at any given timestep t and state $s_t$


In [10]:
print('lambda', ppo_trainer.config.lam, 'gamma', ppo_trainer.config.gamma)

lambda 0.95 gamma 1


#### Generalized Advantage Estimator (GAE)

$A^{(1)}_{t}$ does this by incorporating only the actual reward r_t we got immediately after taking action $ a_{t} $ in $ s_{t} $ + $\gamma V(s_{t+1})$ to estimate the rest of the future - $V(s_t)$, the expected total rewards averaged across all the action options at state s_t, good and bad. 

\begin{align}
\hat{A}_t^{(1)} &= r_t + \gamma V(s_{t+1}) - V(s_t) \\
\hat{A}_t^{(2)} &= r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) - V(s_t) \\
\cdots &= \cdots \\ 
\hat{A}_t^{(\infty)} &= r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots - V(s_t)
\end{align}

$A^{(2)}_{t}$ is similar to $A^{(1)}_{t}$, only we incorporate 2 steps of actual reward in the future then estimate the rest, and so on and so on.


lam, or lambda $ \lambda $ is the weight parameter, it is taught intuitively here in this other lesson about [ Exponentially Weighted Moving Average](https://medium.com/mlearning-ai/exponentially-weighted-average-5eed00181a09) (EWMA), only in this lesson it is called $ \beta $. Basically, the higher $ \lambda $ the more you are placing weight on values other than the most immediate one. 

<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*u3MIYRnLguhjvM0tr72wBA.png" width=600 height=400>

What you see is that lower the $ \beta $ is, the more noisy the signal. Thats because the lower the beta the less we are taking into account the more stable past values, instead changing the moving avareg alot based on the most recent volatile new piece of data. With higher beta we are weighing the past known and now static values more heavily, thereby inducing a smoother curve.

However, and im sorry for doing this, but with respect to GAE, the situation is reversed int both ways, from the example shown in the graph. So why did I show it to you? Well the example is easier to understand and the relationship is similar only reversed, and the relationship is harder to describe. But once you see that relationship, I think its easier to take the inverse of a relatshiption you do understand, than to explain the relationship is a more confusing setting. 

Whereas in typical times series EWMA, the most immediate data is the most recently data in the past and the other data is the data in the farther past, in GAE the most immediate data is the next reward and the other data is the rewards we are estimating we might get in the future, via the state value function V(s). The $ \lambda $ is therefore higher when you want to weight these far future estimates higher at the expense of the ones in your immediate future which are more certain. 

$\hat{A}_t^{GAE(\gamma,\lambda)}$  is the generalized advantage estimator. This [Blog on GAE](https://danieltakeshi.github.io/2017/04/02/notes-on-the-generalized-advantage-estimation-paper/) explains it well. The higher lambda is the more future steps (k's) you are taking into account

\begin{align}
\hat{A}_t^{GAE(\gamma,\lambda)} &= (1-\lambda)\Big(\hat{A}_{t}^{(1)} + \lambda \hat{A}_{t}^{(2)} + \lambda^2 \hat{A}_{t}^{(3)} + \cdots \Big) \\
&= (1-\lambda)\Big(\delta_t^V + \lambda(\delta_t^V + \gamma \delta_{t+1}^V) + \lambda^2(\delta_t^V + \gamma \delta_{t+1}^V + \gamma^2 \delta_{t+2}^V)+ \cdots \Big)  \\
&= (1-\lambda)\Big( \delta_t^V(1+\lambda+\lambda^2+\cdots) + \gamma\delta_{t+1}^V(\lambda+\lambda^2+\cdots) + \cdots \Big) \\
&= (1-\lambda)\left(\delta_t^V \frac{1}{1-\lambda} + \gamma \delta_{t+1}^V\frac{\lambda}{1-\lambda} + \cdots\right) \\
&= \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}^{V}
\end{align}

***
The tradeoff here is that the estimators $A^{(k)}_{t}$ with small k have low variance but high bias, whereas those with large k have low bias but high variance. Why?

I think of it based on the number of terms. With small k, we have fewer terms to sum over (which means low variance). However, the bias is relatively large because it does not make use of extra “exact” information with r_K for K > k

Here’s another way to think of it as emphasized in the paper: V(s_t)
is constant among the estimator class, so it does not affect the relative bias or variance among the estimators: differences arise entirely due to the k -step returns.
***

In RL and machine learning, we are calling this noise, the variance, as in the bias variance tradeoff.

I like to sum this up as "the L in lambda for for longtermism" and depending on the choices we make today, there are many variants of the future we could end up in, so larges L means more longterms and more variance. 

Basically like many tradeoffs there exists a point of balance for your particular problem. Like in the below example, you get bad learning not only when lambda is too high, but also when it is too low.

<img src="https://d3i71xaburhd42.cloudfront.net/ca11ba7b2991fe07b7a99b3a3aeba2486ed36261/9-Figure4-1.png">

Im going to rewrite the set of equations above to better mirror the code that we will implement below;

first lets add the lambda term

\begin{align}
\hat{A}_t^{(1)} &= \delta^{V}_{t} = r_t + \gamma V(s_{t+1}) - V(s_t) \\
\hat{A}_t^{(2)} &= \delta^{V}_{t} + (\gamma \lambda) \delta^{V}_{t+1} \\
\hat{A}_t^{(k)} &= \sum_{l=0}^{l=k} (\gamma \lambda)^l \delta_{t+l}^{V}
\end{align}

next, rewrite ${A}_t^{(T - t)}$ in terms of t + 1, so that we can calculate the GAE for each
step t in the sequence from the last T to the first 0:

\begin{align}
\delta^{V}_{t} &= r_t + \gamma V(s_{t+1}) - V(s_t) \\
\hat{A}_t^{(T - t)} &= \delta^{V}_{t} + (\gamma \lambda) \delta^{V}_{t+1} \\
\end{align}

The above two expressions correspond to line 2 and 3 below respectively

In [11]:
values[0]

tensor([2.3230, 1.6293, 1.4225, 1.2504, 1.7187, 2.5169, 2.5238, 2.6697, 3.2744])

In [12]:
# renaming our single sample to cleanly push this one sample through PPO for demonstration

old_logprobs_ = old_logprobs[idx].unsqueeze(0)
values_ = values[idx].unsqueeze(0)
rewards_ = rewards[idx].unsqueeze(0)
queries_ = queries[idx].unsqueeze(0)
responses_ = responses[idx].unsqueeze(0)
model_input_ = torch.cat([queries[idx],responses[idx]]).unsqueeze(0)

print(old_logprobs_, old_logprobs_.shape)
print(' ')
print(values_, values_.shape) # 0.2320
print(' ')
print(rewards_, rewards_.shape)
print(' ')
print(model_input_, model_input_.shape)

tensor([[-8.9493, -1.3201, -6.3630, -4.1454, -3.3118, -2.1634, -1.3075, -5.0301,
         -5.5869]]) torch.Size([1, 9])
 
tensor([[2.3230, 1.6293, 1.4225, 1.2504, 1.7187, 2.5169, 2.5238, 2.6697, 3.2744]]) torch.Size([1, 9])
 
tensor([[-0.0373, -0.1523, -0.0710, -0.1819, -0.2028,  0.0330,  0.0162,  0.2848,
         -1.0883]]) torch.Size([1, 9])
 
tensor([[   40,  3505,  4964, 16089,   351,   337,     5,    38,    11,   290,
           484,  6304]]) torch.Size([1, 12])


`ppo_trainer.compute_logits_vpred()` seems to do something very similar to `ppo_trainer.batched_forward_pass()` only it is for the new_policy only and returns the logits as well, the values are also shifted one position into the future

In [14]:
new_logprobs_, vpred_, logits_ = ppo_trainer.compute_logits_vpred(model_input_, queries_, responses_, rewards_)

print(new_logprobs_, new_logprobs_.shape)
print(' ')
print(vpred_, vpred_.shape)

tensor([[-8.9493, -1.3201, -6.3630, -4.1454, -3.3118, -2.1634, -1.3075, -5.0301,
         -5.5869]], grad_fn=<SliceBackward0>) torch.Size([1, 9])
 
tensor([[ 1.6293,  1.4225,  1.2504,  1.7187,  2.5169,  2.5238,  2.6697,  3.2744,
         10.4572]], grad_fn=<SliceBackward0>) torch.Size([1, 9])


In [14]:
gen_len = rewards_.shape[-1]
print('gen_len', gen_len)

lastgaelam = 0
advantages_reversed = []

# iterate backwards from last time step of episode, t = T -> 0 
for t in reversed(range(gen_len)):
    
    # 1. V(s_t+1) for all t except t = T
    nextvalues = values_[:, t + 1] if t < gen_len - 1 else 0.0  
    
    # 2. delta_t =  r_t + gamma*V(s_t+1) - V(s_t) 
    delta = rewards_[:, t] + ppo_trainer.config.gamma * nextvalues - values_[:, t]
    
    # 3. A_t = delta_t + gamma*lambda*delta_t+1
    lastgaelam = delta + ppo_trainer.config.gamma * ppo_trainer.config.lam * lastgaelam
    
    advantages_reversed.append(lastgaelam)
    
    print(t, nextvalues, lastgaelam)
    
print(' ')
print(advantages_reversed)
print(' ')

# reverse the reverse to make regular and concatenate
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

print(advantages)

8 0.0 tensor([2.8246])
7 tensor([-1.5782]) tensor([0.4757])
6 tensor([0.6362]) tensor([2.7972])
5 tensor([-1.7115]) tensor([2.1442])
4 tensor([-1.1644]) tensor([2.7566])
3 tensor([-1.4403]) tensor([4.6623])
2 tensor([-3.5342]) tensor([1.6699])
1 tensor([-0.6305]) tensor([4.3558])
0 tensor([-3.5340]) tensor([4.3599])
 
[tensor([2.8246]), tensor([0.4757]), tensor([2.7972]), tensor([2.1442]), tensor([2.7566]), tensor([4.6623]), tensor([1.6699]), tensor([4.3558]), tensor([4.3599])]
 
tensor([[4.3599, 4.3558, 1.6699, 4.6623, 2.7566, 2.1442, 2.7972, 0.4757, 2.8246]])


In [15]:
returns = advantages + values_
returns

tensor([[0.5143, 0.8218, 1.0394, 1.1281, 1.3163, 0.9797, 1.0857, 1.1119, 1.2464]])

In [16]:
# whitening simply subtracts the means and divides by the standard deviation to zero mean the data and 
# impose a std of 1

advantages = whiten(advantages)
advantages = advantages.detach()

advantages

tensor([[ 1.0581,  1.0552, -0.8836,  1.2764, -0.0992, -0.5413, -0.0699, -1.7456,
         -0.0501]])

In [19]:
# didnt we do this step already within the bacthed forwardf pass? Yes, we did, this is a redundant step
input_kwargs = {"input_ids": model_input_}
logits_, _, vpred_ = ppo_trainer.model(**input_kwargs)
print(logits_.shape)
logits_

torch.Size([1, 12, 50257])


tensor([[[ -14.9158,  -15.1586,  -18.4025,  ...,  -22.6887,  -20.5004,
           -14.6813],
         [ -56.8011,  -58.2155,  -65.1559,  ...,  -66.4373,  -65.7569,
           -60.1618],
         [ -42.7785,  -43.1388,  -47.1090,  ...,  -51.4935,  -52.7388,
           -43.6027],
         ...,
         [ -42.9758,  -43.9284,  -46.4706,  ...,  -50.7341,  -45.7978,
           -43.9313],
         [ -35.7480,  -36.7189,  -40.9218,  ...,  -45.1291,  -38.4662,
           -36.8048],
         [-252.0385, -250.1641, -257.2216,  ..., -275.8328, -284.1869,
          -251.5657]]], grad_fn=<UnsafeViewBackward0>)

In [23]:
# similar to why the logits and input_ids are one position off from each other
# the log probabilities are one position shifted relative to the value preditions
# because the value is a function of the actions already decided on
# so as explained in the previous section, if the token at position t
# was sampled at the logits at postion t - 1, then the

logprob_ = logprobs_from_logits(logits_[:, :-1, :], model_input_[:, 1:])
logprob_, vpred_ = logprob_[:, -gen_len:], vpred_[:, -gen_len - 1 : -1]

print(logprob_.shape, vpred_.shape)

torch.Size([1, 9]) torch.Size([1, 9])


$ L_{t}