In [31]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'
from IPython.display import Image

In [2]:
import torch
from transformers import GPT2Tokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

  torch.utils._pytree._register_pytree_node(
    PyTorch 2.0.1 with CUDA 1108 (you have 2.2.2+cu121)
    Python  3.10.13 (you have 3.10.13)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


[2024-05-25 18:28:16,128] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [35]:
# !export NCCL_P2P_DISABLE="1"

## overall

- https://github.com/huggingface/trl/blob/main/examples/hello_world.py
- OpenRLHF: https://github.com/OpenLLMAI/OpenRLHF
    - https://arxiv.org/abs/2405.11143
- PPO-penalty（PPO1）
  
    $$
    \begin{split}
    &J^{\theta'}_{PPO}=J^{\theta'}(\theta)-\beta KL(\theta,\theta'),\quad J^{\theta'}(\theta)=\mathbb E_{s_t,a_t\sim \pi_{\theta'}}\left[\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta'}(a_t|s_t)}A^{\theta'}(s_t,a_t)\right]\\
    &\mathcal{L}^{\text{PENALTY}}(\theta) = \mathbb{E}_t \left[ \hat{A}_t \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)} - \beta D_{KL} \left( \pi_{\theta_{\text{old}}}(\cdot | s_t) \parallel \pi_\theta(\cdot | s_t) \right) \right]
    \end{split}
    $$

- PPO-clip（PPO2）

    $$
    J_{PPO2}^{\theta^k}(\theta) \approx \sum_{(s_t, a_t)} \min \left( \frac{p_\theta(a_t | s_t)}{p_{\theta^k}(a_t | s_t)} A^{\theta^k}(s_t, a_t), \ 
    \text{clip} \left( \frac{p_\theta(a_t | s_t)}{p_{\theta^k}(a_t | s_t)}, 1 - \epsilon, 1 + \epsilon \right) A^{\theta^k}(s_t, a_t) \right)
    $$


In [4]:
Image(url='../../imgs/openrlhf.png', width=600)

## model vs. model_ref

In [6]:
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")

- `AutoModelForCausalLMWithValueHead`
    - `ValueHead`:  `self.summary = nn.Linear(hidden_size, 1)` (hidden_size => 1)
        - `value = self.v_head(last_hidden_state).squeeze(-1)`
- `model` vs. `model_ref`
    - `model`: $\pi_\theta$, `model_ref`: $\pi_{\theta_{old}}$（$\pi_{\text{sft}}$）
- `AdaptiveKLController`（https://arxiv.org/pdf/1909.08593）
    - 如下公式所示，$\pi_t, \rho$ 分别表示新旧策略，与 KL_target 的偏差在 -0.2 到 0.2 之间
    - log-space controller

    $$
    \log\beta_{t+1}=\log\beta_t+\log(1+K_\beta e_t)
    $$

In [32]:
Image(url='../../imgs/adap_kl_coef.png', width=500)

In [29]:
# model

In [7]:
model.pretrained_model.device

device(type='cpu')

In [8]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
print(tokenizer.pad_token, tokenizer.eos_token)
tokenizer.pad_token = tokenizer.eos_token

None <|endoftext|>


In [9]:
print(tokenizer.encode('<|endoftext|>'))
print(tokenizer.decode(tokenizer.encode('<|endoftext|>')))

[50256]
<|endoftext|>


In [25]:
tokenizer.vocab_size

50257

In [10]:
# 2. initialize trainer
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)



In [11]:
# model, ... = self.accelerator.prepare(model, ...)

In [19]:
print(model.pretrained_model.device)
print(model_ref.pretrained_model.device)

cuda:0
cuda:0


In [12]:
# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
query_tensor

tensor([[1212, 3329,  314, 1816,  284,  262,  220]], device='cuda:0')

In [13]:
# 4. generate model response
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 20,
}

## `ppo_trainer.generate`

In [20]:
list(query_tensor)

[tensor([1212, 3329,  314, 1816,  284,  262,  220], device='cuda:0')]

In [14]:
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])

In [21]:
response_txt

'vernacular and found myself at a bar, cook, with a wife. Buggas together in'

In [24]:
# tokenizer.decode(model.generate(
#     input_ids=query_tensor,
#     **generation_kwargs
# )[0])

'This morning I went to the \xa0Budweiser looking for health info checks for the Tick tell check and noticed "...!..."'

In [15]:
# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]

## `ppo_trainer.step`

In [16]:
train_stats = ppo_trainer.step(queries=[query_tensor[0]], 
                               responses=[response_tensor[0]], 
                               scores=reward)

  std_scores = data["scores"].std()
  stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
  stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()


```
def step(
        self,
        queries: List[torch.LongTensor],
        responses: List[torch.LongTensor],
        scores: List[torch.FloatTensor],
        response_masks: Optional[List[torch.LongTensor]] = None,
    ):

    all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
                    self.model,
                    queries,
                    responses,
                    model_inputs,
```

- `input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]`：拼接 queries & responses；
- `ppo_trainer.batched_forward_pass`
    - logprobs: $\log\pi_\theta(a_t|s_t)$


```

# ppo_trainer.batched_forward_pass
# logits.shape == (1, 27, 50257), values.shape == (1, 27)
logits, _, values = model(**input_kwargs)

# shift labels, next token predicition
# lopprobs.shape == (1, 26)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
```

- 同样地对于 model_ref 再算一遍

    ```
    ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                        self.model if self.is_peft_model else self.ref_model,
                        queries,
                        responses,
                        model_inputs,
    ```

- 计算 rewards

    ```
    rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
    ```

    - kl-penalty
    
        $$
        \text{KL}_{\text{penalty}} =\log\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{ref}}}(a_t|s_t)}=\log \pi_{\theta}(a_t|s_t) - \log \pi_{\theta_{\text{ref}}}(a_t|s_t)
        $$
     - `-self.kl_ctl.value * kl`
     - reward is Preference Model (external RM) score + KL penalty
- values, advantages, returns = self.compute_advantages(values, rewards, masks)
    - `delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]`
        
        $$
        \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
        $$
    - gae lam (`lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam`)

        $$
        \hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}
        $$
    - `returns = advantages + values`

In [30]:
Image(url='../../imgs/trl_ppo_loss.png', width=500)

- `ratio = torch.exp(logprobs - old_logprobs)`

$$
\exp(\log\pi_\theta-\log\pi_{\theta_{old}})=\exp\left(\log\frac{\pi_\theta}{\pi_{\theta_{old}}}\right)=\frac{\pi_\theta}{\pi_{\theta_{old}}}
$$

In [34]:
for key in train_stats.keys():
    print(key)

objective/kl
objective/kl_dist
objective/logprobs
objective/ref_logprobs
objective/kl_coef
objective/entropy
ppo/mean_non_score_reward
ppo/mean_scores
ppo/std_scores
tokens/queries_len_mean
tokens/queries_len_std
tokens/queries_dist
tokens/responses_len_mean
tokens/responses_len_std
tokens/responses_dist
ppo/loss/policy
ppo/loss/value
ppo/loss/total
ppo/policy/entropy
ppo/policy/approxkl
ppo/policy/policykl
ppo/policy/clipfrac
ppo/policy/advantages
ppo/policy/advantages_mean
ppo/policy/ratio
ppo/returns/mean
ppo/returns/var
ppo/val/vpred
ppo/val/error
ppo/val/clipfrac
ppo/val/mean
ppo/val/var
ppo/val/var_explained
ppo/learning_rate
time/ppo/forward_pass
time/ppo/compute_rewards
time/ppo/compute_advantages
time/ppo/optimize_step
time/ppo/calc_stats
time/ppo/total
