### Build the dataset

In [1]:
from transformers import AutoTokenizer
from trl.core import LengthSampler
from datasets import load_dataset
from torch.utils.data import random_split

def build_dataset(model_name='gpt2', dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset()
split_dataset = dataset.train_test_split(test_size=0.1)
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]

# uncomment to try on small dataset
train_dataset = train_dataset.select(range(128))
test_dataset = test_dataset.select(range(32))

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset imdb (/Users/naowak/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Loading cached processed dataset at /Users/naowak/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c800109586d11f73.arrow
Loading cached processed dataset at /Users/naowak/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-53146ae257fbea76.arrow


### Load the reward function

In [2]:
from transformers import pipeline

sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device='mps:0')

Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


### Loop

In [3]:
# imports
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# initialize trainer
ppo_config = PPOConfig(
    batch_size=32,
    learning_rate=1.41e-5,
)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer, dataset=train_dataset, data_collator=collator)

# stats
stats = []

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):

    # Check batch size
    if len(batch["input_ids"]) != ppo_config.batch_size:
        continue

    # Respond to batch
    query_tensors = [t.unsqueeze(0) for t in batch["input_ids"]]
    response_tensors = []
    for query_tensor in query_tensors:
        response_tensor = respond_to_batch(model, query_tensor)
        response_tensors.append(response_tensor)
    batch["response"] = tokenizer.batch_decode([r[0] for r in response_tensors])

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, return_all_scores=True, function_to_apply=None, batch_size=16)
    rewards = [torch.tensor(output[1]["score"]).to('mps:0') for output in pipe_outputs]

    # Run PPO step
    q = [t[0] for t in query_tensors]
    r = [t[0] for t in response_tensors]
    stats = ppo_trainer.step(q, r, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return (values * mask).sum() / mask.sum()
4it [05:03, 75.97s/it]


### Test the model

In [12]:
import pandas as pd
from tqdm import tqdm

results = {
    "query": [],
    "ref_response": [],
    "new_response": [],
    "ref_reward": [],
    "new_reward": [],
}

for i in tqdm(range(0, len(test_dataset), ppo_config.batch_size)):
    
    # Get batch
    batch = test_dataset[i:i+ppo_config.batch_size]

    # Check batch size
    if (len(batch["input_ids"]) != ppo_config.batch_size):
        continue
    
    # Respond to batch with reference model
    query_tensors = [t.unsqueeze(0).to('mps:0') for t in batch["input_ids"]]
    response_tensors = []
    for query_tensor in query_tensors:
        response_tensor = respond_to_batch(model_ref, query_tensor)
        response_tensors.append(response_tensor)
    ref_response = tokenizer.batch_decode([r[0] for r in response_tensors])

    # Compute sentiment score for reference model
    texts = [q + r for q, r in zip(batch["query"], ref_response)]
    pipe_outputs = sentiment_pipe(texts, return_all_scores=True, function_to_apply=None, batch_size=16)
    ref_rewards = [torch.tensor(output[1]["score"]).to('mps:0') for output in pipe_outputs]


    # Respond to batch with new model
    query_tensors = [t.unsqueeze(0).to('mps:0') for t in batch["input_ids"]]
    response_tensors = []
    for query_tensor in query_tensors:
        response_tensor = respond_to_batch(model, query_tensor)
        response_tensors.append(response_tensor)
    new_response = tokenizer.batch_decode([r[0] for r in response_tensors])

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], new_response)]
    pipe_outputs = sentiment_pipe(texts, return_all_scores=True, function_to_apply=None, batch_size=16)
    new_rewards = [torch.tensor(output[1]["score"]).to('mps:0') for output in pipe_outputs]

    # Add to results
    results["query"].extend(batch["query"])
    results["ref_response"].extend(ref_response)
    results["new_response"].extend(new_response)
    results["ref_reward"].extend([r.item() for r in ref_rewards])
    results["new_reward"].extend([r.item() for r in new_rewards])

df_results = pd.DataFrame(results)
df_results.head(50)

100%|██████████| 1/1 [00:53<00:00, 53.86s/it]


Unnamed: 0,query,ref_response,new_response,ref_reward,new_reward
0,"Honestly,",I've never had say of a day where it wasn't i...,both Howard and Regression exist–especially a...,0.652503,0.867355
1,"Despite its flaws, I enjoyed ""","Perfectly Silent"" as well as what it is able t...","Your Song Follows Enhance Chris Bach,"" itself ...",0.988401,0.979626
2,Not only,"is Obama's DACA accomplishment vindictive, it...",that. Randall Miller is familiar with his tac...,0.503242,0.728814
3,- A,snapshot of... LAMP943 = avata-bin/tmp943src ...,". Aziz Zazi, Leander.\n\nHMMO (HEGELA)",0.567857,0.665662
4,This is an excellent modern-,day classic that puts your understanding of re...,day template.com.\n\n\nSorry about this beta. ...,0.994691,0.992607
5,This is one,"discount code onSweden's SSME gameStream, and...",of the players contending with what is not cl...,0.813399,0.655944
6,This may sound crazy to even,"consider, but all those chem pangs that raced...","some who are bullish on the Lightning, but in...",0.027367,0.209181
7,Carly Pope,", ""Polyvinylska Weekae: A V first 'treatise' f...",",rich fund run by Abbott at an upmarket prices...",0.80847,0.150346
8,Ernst Lub,"btkiewicz-Lutter, Richard P. Connor and Stuart...",in and Kruchtel F. SanEpi asserts as criminali...,0.253696,0.729309
9,Here's the,thing: Most computers ran Tor when they were ...,current United States. This country's place i...,0.05027,0.904473


In [15]:
print(df_results.ref_reward.mean())
print(df_results.new_reward.mean())

0.5364521861629328
0.5912090297642862


### See plots

In [17]:
import pandas as pd
import matplotlib.pyplot as plt

# Assuming stats is the list containing your training statistics
# Example: stats = [dict1, dict2, dict3, ...]

keys_of_interest = [
    'ppo/loss/total',
    'ppo/learning_rate',
    'ppo/returns/mean',
    #'ppo/val/error',
    'ppo/time/ppo/optimizer_step',
    'objective/entropy'
    # add more keys that you are interested in
]

# Extract the data for the keys of interest
data = {}
for key in keys_of_interest:
    data[key] = []

for entry in ppo_trainer.stats:
    for key in keys_of_interest:
        if key in entry:
            data[key].append(entry[key])

# Plot the data
plt.figure(figsize=(12, 8))

for key, values in data.items():
    plt.plot(values, label=key)

plt.xlabel('Training Iteration')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
plt.title('Training Statistics')
plt.show()

AttributeError: 'PPOTrainer' object has no attribute 'stats'

In [19]:
ppo_trainer.

{'config': PPOConfig(task_name=None, model_name=None, steps=20000, learning_rate=1.41e-05, adap_kl_ctrl=True, init_kl_coef=0.2, kl_penalty='kl', target=6, horizon=10000, gamma=1, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=0.1, batch_size=32, forward_batch_size=None, mini_batch_size=1, gradient_accumulation_steps=1, ppo_epochs=4, remove_unused_columns=True, log_with=None, tracker_kwargs={}, accelerator_kwargs={}, project_kwargs={}, tracker_project_name='trl', max_grad_norm=None, seed=0, optimize_cuda_cache=False, early_stopping=False, target_kl=0.1, push_to_hub_if_best_kwargs={}, compare_steps=1, ratio_threshold=10.0),
 'accelerator': <accelerate.accelerator.Accelerator at 0x107a76dd0>,
 'model': AutoModelForCausalLMWithValueHead(
   (pretrained_model): GPT2LMHeadModel(
     (transformer): GPT2Model(
       (wte): Embedding(50257, 768)
       (wpe): Embedding(1024, 768)
       (drop): Dropout(p=0.1, inplace=False)
       (h): ModuleList(
         (0-11): 12 x GPT2Block(
     