In [1]:
import torch
from tqdm import tqdm

tqdm.pandas()
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, set_seed
from trl.core import LengthSampler

  from .autonotebook import tqdm as notebook_tqdm


This is a fully working simple example to use trl with accelerate.
This example fine-tunes a T5 model on the IMDB dataset using PPO
(proximal policy optimization).
in any of the following settings (with the same script):
  - single CPU or single GPU
  - multi GPUS (using PyTorch distributed mode)
  - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
  - fp16 (mixed-precision) or fp32 (normal precision)
To run it in each of these various modes, first initialize the accelerate
configuration with `accelerate config` then run the script with
`accelerate launch ppo-sentiment-t5-small.py`

In [2]:
# We first define the configuration of the experiment, defining the model, the dataset,
# the training parameters, and the PPO parameters.
# Check the default arguments in the `PPOConfig` class for more details.
config = PPOConfig(model_name="lvwerra/t5-imdb", learning_rate=5e-5, batch_size=32)
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config.forward_batch_size}
sent_kwargs['batch_size']

16

In [30]:
def build_imdb_dataset(tokenizer, input_min_text_length=2, input_max_text_length=8):
    # load imdb with datasets
    ds = load_dataset("imdb", split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + [tokenizer.eos_token_id]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [34]:
# set seed before initializing value head for deterministic eval
# set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# tokenizer.pad_token = tokenizer.eos_token

In [35]:
# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_imdb_dataset(tokenizer)

generation_kwargs = {"top_k": 0.0, "top_p": 1.0, "do_sample": True, "eos_token_id": -1}



In [36]:
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collater)

# We then build the sentiment analysis pipeline, passing the model name and the
# sentiment analysis pipeline arguments. Let's also make sure to set the device
# to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
sentiment_pipe = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=device)




In [42]:
text = 'this movie was really bad!!'
sentiment_pipe(text, **sent_kwargs)

[{'label': 'POSITIVE', 'score': 2.557040214538574}]

In [37]:
output_min_length = 16
output_max_length = 32
output_length_sampler = LengthSampler(output_min_length, output_max_length)
gen_len = output_length_sampler()
-gen_len

-28

In [48]:
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
output_min_length = 16
output_max_length = 32
output_length_sampler = LengthSampler(output_min_length, output_max_length)

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # print(len(batch['label']))
    query_tensors = batch["input_ids"]

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze())
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    # print(pipe_outputs)
    rewards = [ torch.tensor(output["score"] * -1).to(device) if output["label"] == "NEGATIVE" else torch.tensor(output["score"]).to(device) for output in pipe_outputs]
    # print(len(rewards), len(batch["response"]), batch["response"])

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
    break


0it [02:05, ?it/s]


In [49]:
#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
game_data['query'] = df_batch['query'].tolist()
query_tensors = df_batch['input_ids'].tolist()

# print(game_data['query'])
response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output = ref_model.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),
                                    **generation_kwargs).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),
                                 **generation_kwargs).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data['response (before)'] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data['response (after)'] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q,r in zip(game_data['query'], game_data['response (before)'])]
game_data['rewards (before)'] = [output["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q,r in zip(game_data['query'], game_data['response (after)'])]
game_data['rewards (after)'] = [output["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

import pandas as pd
# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results

Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,In this send-up</s>,"prototypical, non-grid film that allows you to...",m of their views in take-home from a larger sc...,1.015245,0.900952
1,This movie</s>,<pad> Rock Collier sounds like you need a daug...,"<pad> very interesting and wonderful, rather t...",0.431099,2.491739
2,I'm writing this because I</s>,<pad> this one shines with each proverbial typ...,<pad> the sequel is packed full of great music...,2.42624,2.544981
3,"""Shadrach"" was</s>",. His expertly-manned ones expressed credit to...,couple of years and a half!) It was operating ...,1.429433,0.489971
4,Some guys think that sniper</s>,"war was about dumb and tongue downtrice, but h...","enjoys a srt like shooter, but filmed by his f...",0.98353,0.122268
5,Anthony Mann's westerns</s>,<pad> appearance by the local preacher. He als...,<pad> the way that are highlighted by this rev...,2.082433,2.348679
6,I recently saw this film and enjoyed</s>,<pad> when I first saw Kelli Sisley and Advin ...,<pad> liked the fulfilment of their disappoint...,1.543884,-0.035536
7,Teresa Pavline</s>,r /><unk>br />Our Saga goes crazy and we are in a,in return for hard work invited from free town...,0.60171,1.100774
8,My children just happened to stop</s>,They came back on their outer relationship wit...,I have one problem with being bad with reality...,0.369392,1.237059
9,Steven Seagal movies have</s>,<pad> of a seemingly horrible nature. Key to t...,<pad> play some kind of I wouldn't think to be...,0.814929,1.930972
