In [1]:
%load_ext autoreload
%autoreload 2

In [19]:
import torch

from minichatgpt.experiments.imdb import config, sent_kwargs
from minichatgpt import PPOTrainer, Lab

# for the loss calculation
from minichatgpt.core import whiten

In [3]:
# For the sake of the speed of this demonstration, the batch_size is temporarily decreased from 256 to 4
batch_size = 4
config.batch_size = batch_size
config.forward_batch_size = batch_size//2

In [4]:
lab = Lab(config)
dataset = lab.build_dataset(dataset_name="imdb",input_min_text_length=2,input_max_text_length=8)
new_policy, old_policy, tokenizer = lab.init_policies_tokenizer()
lab.set_generation_config(do_sample=True,output_min_length=4,output_max_length=16,pad_token_id=tokenizer.eos_token_id)
ppo_trainer = lab.init_ppo_trainer(config, new_policy, old_policy, tokenizer, dataset)
reward_model = lab.init_reward_model()

Found cached dataset imdb (/Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2bd6a5d7d39a840d.arrow
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2ecf25d24c93f132.arrow


In [5]:
for batch_step, batch in enumerate(ppo_trainer.dataloader):
    
    queries = batch['input_ids']
    
    #### Get response from gpt2
    responses = []
    for query in queries:
        gen_len = lab.output_length_sampler()
        lab.generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **lab.generation_kwargs)
        responses.append(response.squeeze()[-gen_len:])

    batch['response'] = [tokenizer.decode(r.squeeze()) for r in responses]

    #### Compute sentiment score
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = lab.reward_model(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
    break
    
queries, responses, scores = ppo_trainer._step_safety_checker(batch_size, queries, responses, rewards)



In [6]:
logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)
rewards, non_score_reward = ppo_trainer.compute_rewards(scores, logprobs, ref_logprobs)

print(rewards)

[tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 1.3289]), tensor([-0.0000, -0.0000, -0.0000, -1.8032]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 0.8134]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 2.2469])]


In [7]:
idx = list(range(config.batch_size))

# train_minibatch()

for idx in range(config.batch_size):
    
    loss_p, loss_v, train_stats = ppo_trainer.loss(
        logprobs[idx].unsqueeze(0),
        values[idx].unsqueeze(0),
        rewards[idx].unsqueeze(0),
        queries[idx].unsqueeze(0),
        responses[idx].unsqueeze(0),
        torch.cat([queries[idx],responses[idx]]).unsqueeze(0),
    )
    
    loss = loss_p + loss_v
    
    ppo_trainer.optimizer.zero_grad()
    ppo_trainer.accelerator.backward(loss)
    ppo_trainer.optimizer.step()
    
    break
    
train_stats

{'loss/policy': tensor(-8.9407e-08, grad_fn=<MeanBackward0>),
 'loss/value': tensor(3.3105, grad_fn=<MulBackward0>),
 'loss/total': tensor(0.3311, grad_fn=<AddBackward0>),
 'policy/entropy': tensor(4.1144, grad_fn=<MeanBackward0>),
 'policy/approxkl': tensor(0., grad_fn=<MulBackward0>),
 'policy/policykl': tensor(0., grad_fn=<MeanBackward0>),
 'policy/clipfrac': tensor(0., dtype=torch.float64),
 'policy/advantages': tensor([[-0.6389, -0.4488,  1.6807,  0.1677,  1.6868, -0.7205, -0.4856, -0.3320,
          -0.9094]]),
 'policy/advantages_mean': tensor(8.9407e-08),
 'policy/ratio': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.]], grad_fn=<ExpBackward0>),
 'returns/mean': tensor(1.8394),
 'returns/var': tensor(0.0901),
 'val/vpred': tensor(2.8901, grad_fn=<MeanBackward0>),
 'val/error': tensor(2.1865, grad_fn=<MeanBackward0>),
 'val/clipfrac': tensor(0.7778, dtype=torch.float64),
 'val/mean': tensor(4.1992),
 'val/var': tensor(1.1751)}

In [8]:
logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)
rewards, non_score_reward = ppo_trainer.compute_rewards(scores, logprobs, ref_logprobs)
print(rewards)

[tensor([ 0.0442,  0.1176, -0.1157, -0.0230, -0.3808,  0.0233,  0.0230,  0.0929,
         1.3803]), tensor([-2.1509e-03, -4.7445e-06, -5.3867e-03, -1.7988e+00]), tensor([-5.4193e-04,  2.0544e-03, -1.4941e-03,  7.2184e-04, -9.1734e-04,
         1.5621e-04,  8.1624e-01]), tensor([ 1.9403e-03,  1.2625e-03,  3.7740e-03,  6.0673e-03,  1.9946e-03,
         8.4896e-04, -1.9379e-03,  7.0903e-03,  6.0678e-03,  5.0527e-03,
         1.0809e-02, -6.8638e-03,  8.0862e-03,  1.0620e-02,  2.2539e+00])]


In [22]:
old_logprobs_ = logprobs[idx].unsqueeze(0)
values_ = values[idx].unsqueeze(0)
rewards_ = rewards[idx].unsqueeze(0)
queries_ = queries[idx].unsqueeze(0)
responses_ = responses[idx].unsqueeze(0)
model_input_ = torch.cat([queries[idx],responses[idx]]).unsqueeze(0)

In [20]:
print('lambda', ppo_trainer.config.lam)
print('gamma',  ppo_trainer.config.gamma)

lambda 0.95
gamma 1


#### Generalized advantage estimator

The term lastgaelam maens last [generalized advantage estimator](https://arxiv.org/pdf/1506.02438.pdf) GAE lambda.
There are many resources already for explaining what the [Adavantage](https://huggingface.co/blog/deep-rl-a2c) function is, so I will not go into it too much. In the equation below $A^{(k)}_{t}$ is the advantage that represents "Compared to the average reward, or expected reward, reward we should get from state $ s_{t} $  till the end, How much more or less did we get specifically as a result of taking the action we took at step t rather than all the other actions we could have taken, not including the actions we took after or before step t"

$A^{(1)}_{t}$ does this by incorporating only the actual reward r_t we got immediately after taking action $ a_{t} $ in $ s_{t} $ + $\gamma V(s_{t+1})$ to estimate the rest of the future - $V(s_t)$, the expected total rewards averaged across all the action options at state s_t, good and bad. 

\begin{align}
\hat{A}_t^{(1)} &= r_t + \gamma V(s_{t+1}) - V(s_t) \\
\hat{A}_t^{(2)} &= r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) - V(s_t) \\
\cdots &= \cdots \\ 
\hat{A}_t^{(\infty)} &= r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots - V(s_t)
\end{align}

$A^{(2)}_{t}$ is similar to $A^{(1)}_{t}$, only we incorporate 2 steps of actual reward in the future then estimate the rest, and so on and so on.


lam, or lambda $ \lambda $ is the weight parameter, it is taught intuitively here in this other lesson about [ Exponentially Weighted Moving Average](https://medium.com/mlearning-ai/exponentially-weighted-average-5eed00181a09) (EWMA), only in this lesson it is called $ \beta $. Basically, the higher $ \lambda $ the more you are placing weight on values other than the most immediate one. 

<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*u3MIYRnLguhjvM0tr72wBA.png" width=600 height=400>

What you see is that lower the $ \beta $ is, the more noisy the signal. Thats because the lower the beta the less we are taking into account the more stable past values, instead changing the moving avareg alot based on the most recent volatile new piece of data. With higher beta we are weighing the past known and now static values more heavily, thereby inducing a smoother curve.

However, and im sorry for doing this, but with respect to GAE, the situation is reversed int both ways, from the example shown in the graph. So why did I show it to you? Well the example is easier to understand and the relationship is similar only reversed, and the relationship is harder to describe. But once you see that relationship, I think its easier to take the inverse of a relatshiption you do understand, than to explain the relationship is a more confusing setting. 

Whereas in typical times series EWMA, the most immediate data is the most recently data in the past and the other data is the data in the farther past, in GAE the most immediate data is the next reward and the other data is the rewards we are estimating we might get in the future, via the state value function V(s). The $ \lambda $ is therefore higher when you want to weight these far future estimates higher at the expense of the ones in your immediate future which are more certain. 

$\hat{A}_t^{GAE(\gamma,\lambda)}$  is the generalized advantage estimator. This [Blog on GAE](https://danieltakeshi.github.io/2017/04/02/notes-on-the-generalized-advantage-estimation-paper/) explains it well. The higher lambda is the more future steps (k's) you are taking into account

\begin{align}
\hat{A}_t^{GAE(\gamma,\lambda)} &= (1-\lambda)\Big(\hat{A}_{t}^{(1)} + \lambda \hat{A}_{t}^{(2)} + \lambda^2 \hat{A}_{t}^{(3)} + \cdots \Big) \\
&= (1-\lambda)\Big(\delta_t^V + \lambda(\delta_t^V + \gamma \delta_{t+1}^V) + \lambda^2(\delta_t^V + \gamma \delta_{t+1}^V + \gamma^2 \delta_{t+2}^V)+ \cdots \Big)  \\
&= (1-\lambda)\Big( \delta_t^V(1+\lambda+\lambda^2+\cdots) + \gamma\delta_{t+1}^V(\lambda+\lambda^2+\cdots) + \cdots \Big) \\
&= (1-\lambda)\left(\delta_t^V \frac{1}{1-\lambda} + \gamma \delta_{t+1}^V\frac{\lambda}{1-\lambda} + \cdots\right) \\
&= \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}^{V}
\end{align}

***
The tradeoff here is that the estimators $A^{(k)}_{t}$ with small k have low variance but high bias, whereas those with large k have low bias but high variance. Why?

I think of it based on the number of terms. With small k, we have fewer terms to sum over (which means low variance). However, the bias is relatively large because it does not make use of extra “exact” information with r_K for K > k

Here’s another way to think of it as emphasized in the paper: V(s_t)
is constant among the estimator class, so it does not affect the relative bias or variance among the estimators: differences arise entirely due to the k -step returns.
***

In RL and machine learning, we are calling this noise, the variance, as in the bias variance tradeoff.

I like to sum this up as "the L in lambda for for longtermism" and depending on the choices we make today, there are many variants of the future we could end up in, so larges L means more longterms and more variance. 

Basically like many tradeoffs there exists a point of balance for your particular problem. Like in the below example, you get bad learning not only when lambda is too high, but also when it is too low.

<img src="https://d3i71xaburhd42.cloudfront.net/ca11ba7b2991fe07b7a99b3a3aeba2486ed36261/9-Figure4-1.png">

Im going to rewrite the set of equations above to better mirror the code that we will implement below;

first lets add the lambda term

\begin{align}
\hat{A}_t^{(1)} &= \delta^{V}_{t} = r_t + \gamma V(s_{t+1}) - V(s_t) \\
\hat{A}_t^{(2)} &= \delta^{V}_{t} + (\gamma \lambda) \delta^{V}_{t+1} \\
\hat{A}_t^{(k)} &= \sum_{l=0}^{l=k} (\gamma \lambda)^l \delta_{t+l}^{V}
\end{align}

next, rewrite ${A}_t^{(T - t)}$ in terms of t + 1, so that we can calculate the GAE for each
step t in the sequence from the last T to the first 0:

\begin{align}
\delta^{V}_{t} &= r_t + \gamma V(s_{t+1}) - V(s_t) \\
\hat{A}_t^{(T - t)} &= \delta^{V}_{t} + (\gamma \lambda) \delta^{V}_{t+1} \\
\end{align}


In [33]:
print(rewards_, rewards_.shape)


tensor([[ 0.0442,  0.1176, -0.1157, -0.0230, -0.3808,  0.0233,  0.0230,  0.0929,
          1.3803]]) torch.Size([1, 9])


In [40]:
print(values_, values_.shape)

tensor([[5.6171, 4.9112, 1.9119, 3.3269, 2.0643, 4.7600, 4.6683, 1.0787, 3.4501]]) torch.Size([1, 9])


In [38]:
# rewards_ and values_

gen_len = rewards_.shape[-1]

print('gen_len', gen_len)

gen_len 9


In [45]:
lastgaelam = 0
advantages_reversed = []

# iterate backwards from last time step of episode, t = T -> 0 
for t in reversed(range(gen_len)):
    
    # V(s_t+1) for all t except t = T
    nextvalues = values_[:, t + 1] if t < gen_len - 1 else 0.0  
    
    # A_1 =  r_t + gamma * V(s_t+1) - V(s_t) 
    delta = rewards_[:, t] + ppo_trainer.config.gamma * nextvalues - values_[:, t]
    
    # gamma * lambda * 
    lastgaelam = delta + ppo_trainer.config.gamma * ppo_trainer.config.lam * lastgaelam
    
    advantages_reversed.append(lastgaelam)
    
    print(t, nextvalues, lastgaelam)
    
  
print(' ')
print(advantages_reversed)
print(' ')

# reverse the reverse to make regular and concatenate
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

print(advantages)


8 0.0 tensor([-2.0698])
7 tensor([3.4501]) tensor([0.4980])
6 tensor([1.0787]) tensor([-3.0935])
5 tensor([4.6683]) tensor([-3.0072])
4 tensor([4.7600]) tensor([-0.5420])
3 tensor([2.0643]) tensor([-1.8004])
2 tensor([3.3269]) tensor([-0.4111])
1 tensor([1.9119]) tensor([-3.2723])
0 tensor([4.9112]) tensor([-3.7704])
 
[tensor([-2.0698]), tensor([0.4980]), tensor([-3.0935]), tensor([-3.0072]), tensor([-0.5420]), tensor([-1.8004]), tensor([-0.4111]), tensor([-3.2723]), tensor([-3.7704])]
 
tensor([[-3.7704, -3.2723, -0.4111, -1.8004, -0.5420, -3.0072, -3.0935,  0.4980,
         -2.0698]])


In [None]:
returns = advantages + values_

In [47]:
advantages.shape


torch.Size([1, 9])

In [29]:
returns

tensor([[1.8467, 1.6389, 1.5008, 1.5264, 1.5223, 1.7528, 1.5748, 1.5767, 1.3803]])

In [42]:
a = [1,2]
a

[1, 2]

In [43]:
a.append(3)
a

[1, 2, 3]