In [3]:
import torch
from tqdm import tqdm
import pandas as pd

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

In [4]:
config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
    # log_with="wandb",
)

sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

In [5]:
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

In [6]:
dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])



In [7]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

In [8]:
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

In [9]:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

In [10]:
text = "this movie was really bad!!"
sentiment_pipe(text, **sent_kwargs)



[[{'label': 'NEGATIVE', 'score': 2.3350484371185303},
  {'label': 'POSITIVE', 'score': -2.726576328277588}]]

In [11]:
text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)

[[{'label': 'NEGATIVE', 'score': -2.294790267944336},
  {'label': 'POSITIVE', 'score': 2.557040214538574}]]

In [12]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}

In [13]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)


generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

8it [04:15, 32.05s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
194it [1:45:48, 32.73s/it]


In [17]:
#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,This film is,"about the importance of family. ""Lifersose Ca...","amazing, also it's whimsical and wonderful fu...",1.469438,2.894922
1,Recap:,<br /><,I loved it.,-0.145163,2.479051
2,To finally see,To finally see my hopes and dreams come true.....,this in English now.<br /><br />edit: splendi...,2.229228,2.56065
3,Tasteless. I can't,even call it a plot device. Give me a,wait to see it again. Great job! Great,-2.335235,1.971571
4,While being an impressionable youth when,"it comes to religion, and someone who behaves...",it comes to the art's quality I found it a ve...,-1.294994,2.745095
5,What a surprise; two,advisors suspected that someone had been slee...,"of course, great, great direction,",0.964191,2.671463
6,"Back in 1997, do I remember",where I rented this movie,that beauty! You're,0.662935,2.143865
7,The following are some,M:F episodes from 1999,of the essential character development sequences,0.418358,1.781858
8,I've always been a,"big fan of the 1981 movie world, but now I'm ...","fan of theirs, and have a great time together...",2.029317,2.731583
9,If you're researching UFO,"'s, picking your axe and",case you'll love it!,-0.714467,2.349457


In [28]:
df_results.iloc[:,:2].to_latex()

'\\begin{tabular}{lll}\n\\toprule\n & query & response (before) \\\\\n\\midrule\n0 & This film is &  about the importance of family. "Lifersose Candidate" is about \\\\\n1 & Recap: & <br />< \\\\\n2 & To finally see & To finally see my hopes and dreams come true.... Thank you everyone!!<|endoftext|> \\\\\n3 & Tasteless. I can\'t &  even call it a plot device. Give me a \\\\\n4 & While being an impressionable youth when &  it comes to religion, and someone who behaves patronisingly towards a community \\\\\n5 & What a surprise; two &  advisors suspected that someone had been sleeping with \\\\\n6 & Back in 1997, do I remember &  where I rented this movie \\\\\n7 & The following are some &  M:F episodes from 1999 \\\\\n8 & I\'ve always been a &  big fan of the 1981 movie world, but now I\'m seeing this one \\\\\n9 & If you\'re researching UFO & \'s, picking your axe and \\\\\n10 & I do miss the &  Mara scheming potential of a cartoon that isn\'t even more developed in \\\\\n11 & Women ha

In [22]:
df_results.to_latex(escape=False)

'\\begin{tabular}{llllrr}\n\\toprule\n & query & response (before) & response (after) & rewards (before) & rewards (after) \\\\\n\\midrule\n0 & This film is &  about the importance of family. "Lifersose Candidate" is about &  amazing, also it\'s whimsical and wonderful fun! It is a win & 1.469438 & 2.894922 \\\\\n1 & Recap: & <br />< &  I loved it. & -0.145163 & 2.479051 \\\\\n2 & To finally see & To finally see my hopes and dreams come true.... Thank you everyone!!<|endoftext|> &  this in English now.<br /><br />edit: splendidly edited & 2.229228 & 2.560650 \\\\\n3 & Tasteless. I can\'t &  even call it a plot device. Give me a &  wait to see it again. Great job! Great & -2.335235 & 1.971571 \\\\\n4 & While being an impressionable youth when &  it comes to religion, and someone who behaves patronisingly towards a community &  it comes to the art\'s quality I found it a very very impressive movie & -1.294994 & 2.745095 \\\\\n5 & What a surprise; two &  advisors suspected that someone ha

In [33]:
df_results.to_excel("1.xlsx  ")

In [23]:
print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())

mean:


rewards (before)    0.557807
rewards (after)     2.547478
dtype: float64


median:


rewards (before)    0.850909
rewards (after)     2.667957
dtype: float64

In [24]:
model.save_pretrained("gpt2-imdb-pos-v2")
tokenizer.save_pretrained("gpt2-imdb-pos-v2")

('gpt2-imdb-pos-v2\\tokenizer_config.json',
 'gpt2-imdb-pos-v2\\special_tokens_map.json',
 'gpt2-imdb-pos-v2\\vocab.json',
 'gpt2-imdb-pos-v2\\merges.txt',
 'gpt2-imdb-pos-v2\\added_tokens.json',
 'gpt2-imdb-pos-v2\\tokenizer.json')