In [34]:
import torch
from peft.tuners.prefix_tuning import PrefixTuningConfig
from peft.mapping import get_peft_model

In [29]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
dataset = dataset.rename_column("prompt", "query")
dataset = dataset.remove_columns(["meta", "completion"])

dataset["query"][0]

'Explain the moon landing to a 6 year old in a few sentences.'

In [30]:
from trl import PPOConfig

config = PPOConfig(
    model_name="bigscience/bloomz-560m",
    learning_rate=1.41e-5,
)

In [37]:
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from peft.tuners.lora import LoraConfig

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    config.model_name,
    peft_config=lora_config,
)

tokenizer.pad_token = tokenizer.eos_token

In [38]:
from transformers import pipeline

reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")

In [39]:
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["query"])
    return sample

dataset = dataset.map(tokenize, batched=False)

In [55]:
dataset[0]

{'query': 'Explain the moon landing to a 6 year old in a few sentences.',
 'input_ids': [72535,
  789,
  368,
  63270,
  102205,
  427,
  267,
  1231,
  5559,
  10735,
  361,
  267,
  12442,
  93150,
  17]}

In [40]:
from trl import PPOTrainer

ppo_trainer: PPOTrainer = PPOTrainer( #type: ignore
    model=model,
    config=config,
    dataset=dataset,
    tokenizer=tokenizer,
)

In [73]:
generation_kwargs = {
    "min_length": -1,
    "max_new_tokens": 128,
    "top_p": 1.0,
    "do_sample": False,
    "pad_token_id": tokenizer.eos_token_id,
}

In [81]:
tokenizer.decode(ppo_trainer.generate(torch.tensor(tokenizer.encode("<s>Is the following sentence positive or negative: I dislike you\n")).to("cuda:0"), **generation_kwargs)[0])

'<s>Is the following sentence positive or negative: I dislike you\n Negative</s>'

In [97]:
from tqdm import tqdm


epochs = 10
for epoch in tqdm(range(epochs), "epoch: "):
    for batch in tqdm(dataset):
        query_tensors = batch["input_ids"]
    
        #### Get response from SFTModel
        response_tensors: torch.Tensor = ppo_trainer.generate(torch.tensor(query_tensors).to("cuda:0"), **generation_kwargs) #type: ignore
        batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    
        #### Compute reward score
        texts = [q + r for q, r in zip(batch["query"], batch["response"])]
        pipe_outputs = reward_model(texts)
        rewards = [torch.tensor(output["score"]) for output in pipe_outputs]
    
        #### Run PPO step
        stats = ppo_trainer.step([torch.tensor(query_tensors)], [response_tensors], [rewards])
        ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_pretrained("my_ppo_model")

  0%|          | 0/16 [00:03<?, ?it/s] ?it/s]
epoch:   0%|          | 0/10 [00:03<?, ?it/s]


ValueError: Batch size (128) does not match number of examples - but got 1 for: queries