In [1]:
from IPython.display import Image
from typing import Dict

In [2]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [3]:
Image(url='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png', 
      width=500)

In [4]:
import torch
from tqdm import tqdm
import pandas as pd

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

[2024-04-05 09:58:11,304] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## config

In [107]:
config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
    log_with="wandb",
    # default batch_size=256
    batch_size=1024
)

sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

In [105]:
config.batch_size

1024

In [35]:
import wandb

wandb.init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlanchunhui[0m ([33mloveresearch[0m). Use [1m`wandb login --relogin`[0m to force relogin


## dataset & tasks

In [36]:
config.model_name

'lvwerra/gpt2-imdb'

In [37]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

In [39]:
ds = load_dataset("imdb", split="train")
ds

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [40]:
ds = ds.rename_column('text', 'review')
ds

Dataset({
    features: ['review', 'label'],
    num_rows: 25000
})

In [41]:
ds = ds.filter(lambda x: len(x['review']) > 200, batched=False)
ds

Filter:   0%|          | 0/25000 [00:00<?, ? examples/s]

Dataset({
    features: ['review', 'label'],
    num_rows: 24895
})

In [48]:
input_size = LengthSampler(min_value=2, max_value=8)
input_size()

5

In [49]:
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
    sample["query"] = tokenizer.decode(sample["input_ids"])
    return sample
ds = ds.map(tokenize, batched=False)
ds

Map:   0%|          | 0/24895 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors


Dataset({
    features: ['review', 'label', 'input_ids', 'query'],
    num_rows: 24895
})

In [51]:
ds.set_format(type="torch")
ds

Dataset({
    features: ['review', 'label', 'input_ids', 'query'],
    num_rows: 24895
})

In [52]:
ds[0]

{'review': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far 

In [53]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [63]:
ds[0].keys()

dict_keys(['review', 'label', 'input_ids', 'query'])

## models

- model (active model) vs ref_model (reference model)
    - 两个模型初始情况下一致（`CasualLMWithValueHead`），都是**第一阶段 sft** 而得到；
        - AutoModelForCausalLMWithValueHead
            - base_model & value_head
            - `return (lm_logits, loss, value)`
                - base_model_output: lm_logits, loss
                - value head output: value (one scalar)
    - model: 通过 ppo（RL）算法要去训练微调的模型；
    - ref_model：参考的基准模型（微调 model 不要偏离原始的 ref_model 太多）
        - 不能学了新的，忘了旧的；
    - value head：相当于 AC network（Actor Critic） 部分的 critic
- reward model（第二阶段）
    - sentiment-analysis

```
AutoModelForCausalLMWithValueHead
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        **kwargs,
    ):
        kwargs["output_hidden_states"] = True  # this had already been set in the LORA / PEFT examples
        kwargs["past_key_values"] = past_key_values
    
        if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
            kwargs.pop("past_key_values")
    
        base_model_output = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )
    
        last_hidden_state = base_model_output.hidden_states[-1]
        lm_logits = base_model_output.logits
        loss = base_model_output.loss
    
        if last_hidden_state.device != self.v_head.summary.weight.device:
            last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
    
        value = self.v_head(last_hidden_state).squeeze(-1)
    
        # force upcast in fp32 if logits are in half-precision
        if lm_logits.dtype != torch.float32:
            lm_logits = lm_logits.float()
    
        return (lm_logits, loss, value)
```

In [57]:
config.model_name

'lvwerra/gpt2-imdb'

In [60]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

# '<|endoftext|>'
tokenizer.pad_token = tokenizer.eos_token

In [90]:
# https://github.com/openai/summarize-from-feedback
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=ds, data_collator=collator)



VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112973688886996, max=1.0…

In [68]:
device = ppo_trainer.accelerator.device
device

device(type='cuda')

In [69]:
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

In [80]:
ds[5]['review']

"I would put this at the top of my list of films in the category of unwatchable trash! There are films that are bad, but the worst kind are the ones that are unwatchable but you are suppose to like them because they are supposed to be good for you! The sex sequences, so shocking in its day, couldn't even arouse a rabbit. The so called controversial politics is strictly high school sophomore amateur night Marxism. The film is self-consciously arty in the worst sense of the term. The photography is in a harsh grainy black and white. Some scenes are out of focus or taken from the wrong angle. Even the sound is bad! And some people call this art?<br /><br />"

In [76]:
# text = "this movie was really bad!!"
sentiment_pipe(ds[5]['review'], **sent_kwargs)



[[{'label': 'NEGATIVE', 'score': 2.0556585788726807},
  {'label': 'POSITIVE', 'score': -2.5166730880737305}]]

In [97]:
text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)[0][1]



{'label': 'POSITIVE', 'score': 2.557039737701416}

## PPOTrainer

```
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=ds, data_collator=collator)
```

- generate (rollout)
    - `response = ppo_trainer.generate(query, **generation_kwargs)`
        - `model` （active model）
- `stats = ppo_trainer.step(query_tensors, response_tensors, rewards)`
    - `rewards = RM([(q+r) for q, r in zip(batch['query'], batch['response'])])`
    - kl_penalty
        - `kl`

```
# ppo/forward_pass
all_logprobs, _, values, _ = self.batched_forward_pass(self.model, 
        queries, responses, 
        model_inputs, return_logits=False,
)
ref_logprobs, _, _, _ = self.batched_forward_pass(self.ref_model, 
        queries, responses, 
        model_inputs, return_logits=False,
)

# 
```

### loss

- `ratio = torch.exp(logprobs - old_logprobs)`

$$
\exp(\log p_1-\log p_2)=\exp(\log \frac{p_1}{p_2})=\frac{p_1}{p_2}
$$

## training

In [82]:
gen_kwargs = {"min_length": -1, 
              "top_k": 0.0, 
              "top_p": 1.0, 
              "do_sample": True, 
              "pad_token_id": tokenizer.eos_token_id}

In [83]:
out_length_sampler = LengthSampler(min_value=4, max_value=16)

In [93]:
len(ppo_trainer.dataloader)

97

In [99]:
ppo_trainer.config.batch_size

256

In [94]:
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader), 'steps: ', total=len(ppo_trainer.dataloader)):
    
    # batch.keys(): dict_keys(['label', 'input_ids', 'query'])
    query_tensors = batch["input_ids"]

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = out_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
        
    # batch.keys(): dict_keys(['label', 'input_ids', 'query', 'response'])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    # output[1]: positive scores (output[0]: negative)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

epochs:   5%|▌         | 5/97 [06:44<2:04:10, 80.98s/it]


KeyboardInterrupt: 