In [1]:
import torch
import random
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
b = 128
len_question = 6
len_answer = 32

from transformers import AutoTokenizer

tokenizer_actor = AutoTokenizer.from_pretrained('lvwerra/gpt2-imdb')
tokenizer_actor.pad_token = tokenizer_actor.eos_token

tokenizer_critic = AutoTokenizer.from_pretrained('lvwerra/distilbert-imdb')

tokenizer_actor, tokenizer_critic

(GPT2TokenizerFast(name_or_path='lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 },
 DistilBertTokenizerFast(name_or_path='lvwerra/distilbert-imdb', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=F

In [2]:
from datasets import load_dataset, concatenate_datasets

dataset = load_dataset('imdb')
dataset = concatenate_datasets(list(dataset.values()))
dataset = dataset.remove_columns(['label'])

dataset, dataset[0]

(Dataset({
     features: ['text'],
     num_rows: 100000
 }),
 {'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered po

In [3]:
from trl import AutoModelForCausalLMWithValueHead
from transformers import AutoModelForSequenceClassification

model_ppo = AutoModelForCausalLMWithValueHead.from_pretrained(
    'model/actor').to(device)
model_ppo_ref = AutoModelForCausalLMWithValueHead.from_pretrained(
    'model/actor').to(device)

model_critic = AutoModelForSequenceClassification.from_pretrained(
    'model/critic').to(device)

for i in model_ppo_ref.parameters():
    i.requires_grad_(False)

for i in model_critic.parameters():
    i.requires_grad_(False)

In [4]:
@torch.no_grad()
def get_data():
    #====question====
    label = random.choices(range(2), k=b)
    question = random.choices(dataset, k=b)
    question = [str(l) + ' ' + p['text'] for l, p in zip(label, question)]

    question = tokenizer_actor(question,
                               padding=True,
                               truncation=True,
                               max_length=len_question,
                               return_tensors='pt').input_ids.to(device)

    #====answer====
    answer = model_ppo.generate(input_ids=question,
                                min_length=-1,
                                max_length=len_question + len_answer,
                                pad_token_id=tokenizer_actor.pad_token_id,
                                eos_token_id=tokenizer_actor.eos_token_id,
                                top_k=0.0,
                                top_p=1.0,
                                do_sample=True)

    answer = answer[:, question.shape[1]:]

    #====reward====
    qa = torch.cat((question, answer), 1)
    qa = tokenizer_actor.batch_decode(qa, skip_special_tokens=True)
    qa = [i[2:] for i in qa]
    qa = tokenizer_critic(qa,
                          padding=True,
                          truncation=True,
                          max_length=50,
                          return_tensors='pt').to(device)

    reward = model_critic(**qa).logits

    label = torch.LongTensor(label).reshape(-1, 1).to(device)
    reward = reward.gather(1, label).squeeze(1)

    #裁剪
    answer_cut = []
    for i in answer:
        if tokenizer_actor.eos_token_id in i:
            idx = i.tolist().index(tokenizer_actor.eos_token_id)
            i = i[:idx]
        answer_cut.append(i)
    answer = answer_cut

    question = [i for i in question]
    reward = [i for i in reward]

    return question, answer, reward


get_data()

The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


([tensor([  15, 2329, 2497, 2744,  261, 4533], device='cuda:0'),
  tensor([   16, 40967,    11,   356,   923,   287], device='cuda:0'),
  tensor([   16,   770,  2646,   318,   523, 10416], device='cuda:0'),
  tensor([  16,  523, 3763,  340,  318, 2407], device='cuda:0'),
  tensor([   16, 16140,   963,    11,   513,    14], device='cuda:0'),
  tensor([  16,  632,  338,  262, 6983, 1810], device='cuda:0'),
  tensor([   15, 21429,   278,   477,   616, 25495], device='cuda:0'),
  tensor([  15,  770,  318,  257, 7906,  572], device='cuda:0'),
  tensor([  15,  502,  290,  257, 2284,  489], device='cuda:0'),
  tensor([   16,  7945,   262, 38957,  3227,  3146], device='cuda:0'),
  tensor([  15,  770,  318,  257, 8258, 3807], device='cuda:0'),
  tensor([   16, 13723,   350,  3733,    11,   428], device='cuda:0'),
  tensor([   16, 18578,   780,   428,   318,  9309], device='cuda:0'),
  tensor([  16, 1002,  345,  765,  284, 7030], device='cuda:0'),
  tensor([  16,  770,  318,  281, 8082, 2646], d

In [5]:
from trl import PPOConfig, PPOTrainer

config = PPOConfig(learning_rate=1e-5, batch_size=b)
trainer = PPOTrainer(config, model_ppo, model_ppo_ref, tokenizer_actor)

for i in range(200):
    trainer.step(*get_data())

    if i % 5 == 0:
        print(i)

        question, answer, reward = get_data()
        question = tokenizer_actor.decode(question[0])
        answer = tokenizer_actor.decode(answer[0])
        reward = reward[0].item()

        #0差评,1好评
        print(question, '->', answer, reward)

model_ppo.pretrained_model.save_pretrained('model/trl')

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


0
1 This movie is a metaphor -> , symbol of humanity. You see this azick. Lamidi are HUMILITATING. The image of some people is nothing without the symbols of life 0.47835201025009155
5
1 A comedy before it's ->  been over four years now. Brand (Mark Pritchard) does to all the crew, some acts, but most that gave him a rest arguably, the 0.9503517746925354
10
1 I must admit, at ->  that time, I didn't believe in that movie at all--especially hurtted by the time I was one who was crying. The name of these people is -0.4734441339969635
15
1 I couldn't stop laughing ->  all around this movie! -My Dallow achievements as a film is formula perfect -Irateargoing to say that most of my recent´s in the 0.48055729269981384
20
1 **SPOILERS A -> wear is usually looked at as a cut above Jungle any way you can say it. What was the most part of this LOT of films I have not had -0.4607374966144562
25
1 Film is vintage Heston ->  and fairly long either Iranian it uses guys on its awesome tamales, or Amer