# Proximal Policy Optimization

* GPT-2を強化学習により精度を良くする
* https://github.com/lvwerra/trl
* https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment-ppo-training.ipynb

<a href="https://colab.research.google.com/github/fuyu-quant/Data_Science/blob/main/Reinforcement_Learning/trl_GPT_2%2BPPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
%%capture
!pip install trl

In [4]:
import torch
import wandb
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
tqdm.pandas()

from datasets import load_dataset

from transformers import AutoTokenizer, pipeline

from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from trl.core import build_bert_batch_from_txt



In [20]:
config = {
    "model_name": "lvwerra/gpt2-imdb",
    "cls_model_name": "lvwerra/distilbert-imdb",
    "steps": 10000,
    "batch_size": 256,
    "forward_batch_size": 16,
    "ppo_epochs": 4,   
    "txt_in_min_len": 2,
    "txt_in_max_len": 8,
    "txt_out_min_len": 4,
    "txt_out_max_len": 16,
    "lr": 1.41e-5,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
}

# deviceの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe_device = 0 if torch.cuda.is_available() else -1


wandb.init(name='run-42', project='gpt2-test', config=config, )

0,1
env/reward_mean,▁█▁▃
env/reward_std,▅▁█▃
objective/entropy,█▁▇█
objective/kl,▁▅▃█
objective/kl_coef,█▆▃▁
ppo/loss/policy,▁█▄▁
ppo/loss/total,██▄▁
ppo/loss/value,█▃▂▁
ppo/mean_non_score_reward,█▄▆▁
ppo/policy/advantages_mean,▄█▁▆

0,1
env/reward_mean,0.84782
env/reward_std,0.13917
objective/entropy,48.37888
objective/kl,0.96516
objective/kl_coef,0.19694
ppo/loss/policy,-0.09175
ppo/loss/total,-0.08558
ppo/loss/value,0.06172
ppo/mean_non_score_reward,-0.19008
ppo/policy/advantages_mean,0.0


## データの準備

In [6]:
ds = load_dataset('imdb', split='train')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)
ds



Dataset({
    features: ['review', 'sentiment'],
    num_rows: 24895
})

## 感情分類モデルのBERTの用意

In [7]:
sent_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": config["forward_batch_size"]
}

sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=pipe_device)

## GPT2モデルのロード

In [8]:
# 一つ目のモデルは強化学習により最適化される
gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name'])
# 二つ目のモデルはKL-divergenceを計算するために参照される
# 最適化されるモデルがPPO学習により、元のモデルと大きく変わってしまわないようにするための報酬として使う
gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['model_name'])

gpt2_model.to(device);
gpt2_model_ref.to(device);


gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at lvwerra/gpt2-imdb and are newly initialized: ['transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.3.attn.masked_bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.6.attn.masked_bias', 'transformer.h.7.attn.masked_bias', 'transformer.h.8.attn.masked_bias', 'transformer.h.9.attn.masked_bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.masked_bias', 'v_head.summary.weight', 'v_head.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at lvwerra/gpt2-imdb and are newly initialized: ['transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.3.attn.masked_

In [9]:
wandb.watch(gpt2_model, log='all')

[]

## レビューのトークン化
* クエリとレスポンスの長さをランダムにしたいのである区間から値をサンプリングする

In [10]:
class LengthSampler:
    def __init__(self, min_value, max_value):
        self.values = list(range(min_value, max_value))
    def __call__(self):
        return np.random.choice(self.values)
    
input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"])
output_size = LengthSampler(config["txt_out_min_len"], config["txt_out_max_len"])

トークン化の重複を避けるため、あらかじめすべてのIMDBをトークン化する。最初のステップでは、クエリーをエンコードし、最初の input_size() トークンをスライスします。第二段階では、これらのトークンをデコードしてテキストに戻し、後で表示する。

In [11]:
def tokenize(sample):
    sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()]
    sample["query"] = gpt2_tokenizer.decode(sample["tokens"])
    return sample

ds = ds.map(tokenize, batched=False)



In [12]:
# 生成時のサンプリングの設定
gen_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": gpt2_tokenizer.eos_token_id
}

In [13]:
def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater)

## 学習

In [14]:
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config)

total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size']))

for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))):
    logs, timing = dict(), dict()
    t0 = time.time()
    query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]]
    
    # GPT-2による出力
    t = time.time()
    response_tensors = []
    for i in range(config['batch_size']):
        gen_len = output_size()
        response = gpt2_model.generate(query_tensors[i].unsqueeze(dim=0),
                                       max_new_tokens=gen_len, **gen_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]
    timing['time/get_response'] = time.time()-t



    # BERTによる感情の出力
    t = time.time()
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    #print(pipe_outputs)
    #rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device)
    rewards = torch.tensor([output["score"] for output in pipe_outputs]).to(device)
    timing['time/get_sentiment_preds'] = time.time()-t
    


    # PPOによるポリシーの最適化
    t = time.time()
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    timing['time/optimization'] = time.time()-t
     


    #### Log everything
    timing['time/epoch'] = time.time()-t0
    table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist())]
    logs.update({'game_log': wandb.Table(columns=['query', 'response', 'reward'], rows=table_rows)})
    logs.update(timing)
    logs.update(stats)
    logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy()
    logs['env/reward_std'] = torch.std(rewards).cpu().numpy()
    logs['env/reward_dist'] = rewards.cpu().numpy()
    wandb.log(logs)


4it [08:35, 128.99s/it]


## 検証

In [21]:
#### get a batch from the dataset
bs = 16
game_data = dict()
ds.set_format("pandas")
df_batch = ds[:].sample(bs)
game_data['query'] = df_batch['query'].tolist()
query_tensors = df_batch['tokens'].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_size()
    output = gpt2_model_ref.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),
                                     max_new_tokens=gen_len, **gen_kwargs).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = gpt2_model.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),
                                 max_new_tokens=gen_len, **gen_kwargs).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data['response (before)'] = [gpt2_tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data['response (after)'] = [gpt2_tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q,r in zip(game_data['query'], game_data['response (before)'])]
game_data['rewards (before)'] = [output["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q,r in zip(game_data['query'], game_data['response (after)'])]
game_data['rewards (after)'] = [output["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results

Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,After seeing the low,was looking for a way to redeem myself on my own,"'t fully be referred to as ""by"" ESPN.",0.569658,0.898822
1,David Duchov,", Whos Sherox",obvious and untilted approach,0.642662,0.900879
2,The Internet,to shift it to your best external viewing. Th...,James and Richard meet up with the wild-eyed ...,0.978636,0.753606
3,"well, i",movie could have been about ridiculous garbag...,"like a good person, so i find that he really has",0.985395,0.980146
4,Pearl S.Buck was,his sanity a bit to,illing; it was more,0.867682,0.500588
5,This is a lot of,"hear,"" Cardellini said. If you enjoy mysterie...","social commentary,"" said Honor. While this po...",0.974677,0.717126
6,Final Score: 0 (,><br />My vote: 4<|endoftext|>,to me as being more confident about this,0.671172,0.62753
7,The surprise nominee of this,Is Someone You Miss in Town ) who,courtroom drama is this is actually the latest,0.913158,0.961582
8,Inappropriate. The PG rating,during the scenes while,so this is a,0.972703,0.977995
9,This movie deserved,This script has established,initial mild success.,0.821585,0.67053


In [22]:
print('mean:')
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print('median:')
display(df_results[["rewards (before)", "rewards (after)"]].median())

mean:


rewards (before)    0.874892
rewards (after)     0.839940
dtype: float64


median:


rewards (before)    0.954042
rewards (after)     0.899851
dtype: float64