In [1]:
import torch
import random
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
b = 128
len_question = 6
len_answer = 32

from transformers import AutoTokenizer

tokenizer_actor = AutoTokenizer.from_pretrained('lvwerra/gpt2-imdb')
tokenizer_actor.pad_token = tokenizer_actor.eos_token

tokenizer_critic = AutoTokenizer.from_pretrained('lvwerra/distilbert-imdb')

tokenizer_actor, tokenizer_critic

(GPT2TokenizerFast(name_or_path='lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
 },
 DistilBertTokenizerFast(name_or_path='lvwerra/distilbert-imdb', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=F

In [2]:
from datasets import load_dataset, concatenate_datasets

dataset = load_dataset('imdb')
dataset = concatenate_datasets(list(dataset.values()))
dataset = dataset.remove_columns(['label'])

dataset, dataset[0]

(Dataset({
     features: ['text'],
     num_rows: 100000
 }),
 {'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered po

In [3]:
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification

model_actor = AutoModelForCausalLM.from_pretrained('model/actor').to(device)
model_actor_ref = AutoModelForCausalLM.from_pretrained('model/actor').to(
    device)

model_critic = AutoModelForSequenceClassification.from_pretrained(
    'model/critic').to(device)

model_value = torch.nn.Sequential(torch.nn.Dropout(0.1),
                                  torch.nn.Linear(768, 1)).to(device)

for i in model_actor_ref.parameters():
    i.requires_grad_(False)

for i in model_critic.parameters():
    i.requires_grad_(False)

model_actor.train()
model_value.train()
optimizer = torch.optim.Adam(list(model_actor.parameters()) +
                             list(model_value.parameters()),
                             lr=1e-5)

In [4]:
from trl.core import logprobs_from_logits


def batched_forward_pass(actor, input_ids, attention_mask):
    last_hidden_state = actor.transformer(
        input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

    logits = actor.lm_head(last_hidden_state)
    value = model_value(last_hidden_state).squeeze(-1)

    #取每个字的概率对数
    prob_log = logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])

    #对最后一个字的预测没有意义,直接丢弃
    value = value[:, :-1]

    return prob_log, value


with torch.no_grad():
    out = batched_forward_pass(
        model_actor, torch.randint(100, 10000, [4, 35], device=device),
        torch.ones(4, 35, device=device))

out[0].shape, out[1].shape

(torch.Size([4, 34]), torch.Size([4, 34]))

In [5]:
def compute_advantages(value, reward_kl):
    advantages = []

    for i in reversed(range(reward_kl.shape[1])):
        value_next = 0
        if i < reward_kl.shape[1] - 1:
            value_next = value[:, i + 1]

        delta = reward_kl[:, i] + value_next - value[:, i]

        adv_last = 0
        if advantages:
            adv_last = advantages[-1]

        advantages.append(delta + 0.95 * adv_last)

    advantages = torch.stack(advantages[::-1]).transpose(0, 1)

    return advantages


compute_advantages(torch.randn(4, 35), torch.randn(4, 35)).shape

torch.Size([4, 35])

In [6]:
from trl.core import masked_whiten


@torch.no_grad()
def get_data():
    #====question====
    label = random.choices(range(2), k=b)
    question = random.choices(dataset, k=b)
    question = [str(l) + ' ' + p['text'] for l, p in zip(label, question)]

    question = tokenizer_actor(question,
                               padding=True,
                               truncation=True,
                               max_length=len_question,
                               return_tensors='pt').input_ids.to(device)

    #====answer====
    answer = model_actor.generate(input_ids=question,
                                  min_length=-1,
                                  max_length=len_question + len_answer,
                                  pad_token_id=tokenizer_actor.pad_token_id,
                                  eos_token_id=tokenizer_actor.eos_token_id,
                                  top_k=0.0,
                                  top_p=1.0,
                                  do_sample=True)

    answer = answer[:, question.shape[1]:]

    input_ids = torch.cat((question, answer), 1)
    attention_mask = (input_ids != tokenizer_actor.pad_token_id).long()

    #====reward====
    qa = tokenizer_actor.batch_decode(input_ids, skip_special_tokens=True)
    qa = [i[2:] for i in qa]
    qa = tokenizer_critic(qa,
                          padding=True,
                          truncation=True,
                          max_length=50,
                          return_tensors='pt').to(device)

    reward = model_critic(**qa).logits

    label = torch.LongTensor(label).reshape(-1, 1).to(device)
    reward = reward.gather(1, label).squeeze(1)

    #====answer_mask====
    answer_mask = torch.zeros_like(attention_mask)
    answer_mask[:, :-1] = attention_mask[:, 1:]

    for i, input_id in enumerate(input_ids):
        #找出生成结果的起止位置
        start = len_question

        end = len(input_id)
        if tokenizer_actor.eos_token_id in input_id:
            end = input_id.tolist().index(tokenizer_actor.eos_token_id) + 1

        #因为没有预测第0个字,所以位置减一
        start -= 1
        end -= 1

        answer_mask[i, :start] = 0
        answer_mask[i, end:] = 0

    answer_mask = answer_mask[:, :-1]

    #====advantages====
    #根据question计算answer的概率,并计算每个动作的分数
    prob_log_old, value_old = batched_forward_pass(model_actor, input_ids,
                                                   attention_mask)

    #使用ref模型计算概率,这是为了计算kl散度
    prob_log_ref, _ = batched_forward_pass(model_actor_ref, input_ids,
                                           attention_mask)

    #计算两份概率的kl散度,并融入reward
    kl = (prob_log_old - prob_log_ref) * -0.2
    for i in range(b):
        #把reward加在最后一个字的kl散度上
        end = 0
        if 1 in answer_mask[i]:
            end = ''.join([str(i)
                           for i in answer_mask[i].tolist()]).rindex('1')

        kl[i, end] += reward[i]

    value_old = value_old * answer_mask
    kl = kl * answer_mask

    advantages = compute_advantages(value_old, kl)
    returns = advantages + value_old
    advantages = masked_whiten(advantages, answer_mask)

    return input_ids, attention_mask, answer_mask, prob_log_old, value_old, advantages, returns


get_data()

The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


(tensor([[   15,   366, 17908,  ...,  2245,   588,   584],
         [   15,  1081,   257,  ...,   286,   262, 36090],
         [   15,  3813, 16358,  ...,  2067,   284,  1716],
         ...,
         [   15,   314,  1842,  ...,     6,  9317,   546],
         [   15,   317,   890,  ...,  6260,     0,   632],
         [   16,  7945,   257,  ...,   357,  4758,   314]], device='cuda:0'),
 tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'),
 tensor([[0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]], device='cuda:0'),
 tensor([[ -7.4165, -10.3803,  -7.9352,  ...,  -2.6076,  -6.9447,  -4.8796],
         [ -7.7809,  -1.8392,  -6.9691,  ...,  -1.

In [7]:
from trl.core import masked_mean, clip_by_value


def train(input_ids, attention_mask, answer_mask, prob_log_old, value_old,
          advantages, returns):
    skip = 0
    total = 0
    #每批数据循环N次模型
    for _ in range(4):
        #每次算一个数据
        for i in range(b):
            #重新计算概率和value
            prob_log_new, value_new = batched_forward_pass(
                model_actor, input_ids[i:i + 1], attention_mask[i:i + 1])

            #重要性采样
            ratio = (prob_log_new - prob_log_old[i:i + 1]).exp()

            #如果变化率太过于剧烈,可能是发生了震荡,跳过
            total += 1
            if masked_mean(ratio, answer_mask[i:i + 1]).item() > 10:
                skip += 1
                continue

            #计算value的loss
            loss_vf1 = (value_new - returns[i:i + 1])**2
            loss_vf2 = clip_by_value(value_new, value_old[i:i + 1] - 0.2,
                                     value_old[i:i + 1] + 0.2)
            loss_vf2 = (loss_vf2 - returns[i:i + 1])**2
            loss_vf = masked_mean(torch.max(loss_vf1, loss_vf2),
                                  answer_mask[i:i + 1])

            #计算ppo loss
            loss_surr1 = -advantages[i:i + 1] * ratio
            loss_surr2 = -advantages[i:i + 1] * ratio.clamp(0.8, 1.2)
            loss_surr = masked_mean(torch.max(loss_surr1, loss_surr2),
                                    answer_mask[i:i + 1])

            loss = loss_surr + 0.05 * loss_vf

            loss.backward()
            #torch.nn.utils.clip_grad_norm_(list(model_actor.parameters()) + list(model_value.parameters()), 1.0)
            optimizer.step()
            optimizer.zero_grad()

    return skip, total


train(*get_data())

(1, 512)

In [8]:
for i in range(200):
    skip, total = train(*get_data())
    if i % 5 == 0:
        print(i, skip, total)

        input_ids = get_data()[0]
        question = tokenizer_actor.decode(input_ids[0, :len_question])
        answer = tokenizer_actor.decode(input_ids[0, len_question:])

        #0差评,1好评
        print(question, '->', answer)

model_actor.save_pretrained('model/ppo')

0 0 512
1 Most definitely the worst Col -> man movie ever! + TheWerewolves just sucked.... Lots of extras Creepers for Maulings<br /><br />The Animaniacs cameo was a
5 0 512
1 Rachel and Chuck Yoman -> 's original musical for us became a good thing after almost `walkers' at the inglorious one sound and musical in Action'. However Video (In fact
10 0 512
1 I tried twice to get ->  rid of this garbage. I refused to surrender. I thought about it then and others who dislike filmmaker Jordan Coyle. I explained about all possibilities to myself.
15 0 512
0 Arnold Schwarzenegger stars as a ->  positing singer trying to live a life of his dreams. Bad career choices, it's the character's, in a romantic comedy mixed with only an unrealistic notion
20 3 512
0 After all the hype I ->  found out how anyone could say that hell's as bad as the middling rendition of the material...but what else can people be sure of? 500$ just
25 0 512
1 Wow, what a waste ->  of time. What a fit from this western. I s