In [1]:
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset
from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from tqdm import tqdm

device = 0 if torch.cuda.is_available() else "cpu"
base_model = 'EleutherAI/gpt-j-6b'
ppo_model_dir = 'CarperAI/openai_summarize_tldr_ppo'
sft_model_dir = 'CarperAI/openai_summarize_tldr_sft'
reward_model_dir = 'cjhyeok/tldr-reward_model'

#모델 저장하고 쓴것 위  경로가 맞아요
# ppo_model_dir = './tldr_ppo'
# sft_model_dir = './tldr_sft'
# reward_model_dir = './tldr_reward' 
tokenizer = AutoTokenizer.from_pretrained(base_model)
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_model_dir)
sft_model = AutoModelForCausalLMWithValueHead.from_pretrained(sft_model_dir)
reward_model = pipeline("text-classification", model=reward_model_dir, tokenizer=tokenizer)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
ppo_model.cuda()
sft_model.cuda()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.44s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.41s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.36s/it]


AutoModelForCausalLMWithValueHead(
  (pretrained_model): GPTJForCausalLM(
    (transformer): GPTJModel(
      (wte): Embedding(50400, 4096)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0-27): 28 x GPTJBlock(
          (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
          (attn): GPTJAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): GPTJMLP(
            (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
            (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
            (act): NewGELUActiv

In [2]:
def build_dataset(tokenizer, dataset_name="CarperAI/openai_summarize_tldr", input_min_text_length=2, input_max_text_length=8):
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"prompt": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(tokenizer)

Downloading readme: 100%|██████████| 532/532 [00:00<?, ?B/s] 
Downloading data: 100%|██████████| 111M/111M [00:25<00:00, 4.26MB/s]
Downloading data: 100%|██████████| 6.23M/6.23M [00:01<00:00, 3.93MB/s]
Downloading data: 100%|██████████| 6.12M/6.12M [00:01<00:00, 4.04MB/s]
Downloading data files: 100%|██████████| 3/3 [00:29<00:00,  9.70s/it]
Extracting data files: 100%|██████████| 3/3 [00:00<00:00, 159.35it/s]
Generating train split: 100%|██████████| 116722/116722 [00:00<00:00, 271610.92 examples/s]
Generating test split: 100%|██████████| 6553/6553 [00:00<00:00, 361258.56 examples/s]
Generating valid split: 100%|██████████| 6447/6447 [00:00<00:00, 393078.82 examples/s]
Filter: 100%|██████████| 116722/116722 [00:00<00:00, 241258.51 examples/s]
Map: 100%|██████████| 116528/116528 [01:40<00:00, 1159.86 examples/s]


In [3]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [4]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
output_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
output_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []

In [6]:
N_BEST_OF =128
for i in tqdm(range(bs)):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ppo_model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = sft_model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors.append(tokenizer.decode(output))

    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF, 1))
    output = ppo_model.generate(queries.to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))

100%|██████████| 16/16 [02:29<00:00,  9.35s/it]


In [7]:
scores_ref = [output[0]["score"] for output in reward_model(response_tensors_ref, **sent_kwargs)]
scores = [output[0]["score"] for output in reward_model(response_tensors, **sent_kwargs)]
scores_best_of = []
for i, response in tqdm(enumerate(response_tensors_best_of)):
    # base_score = scores_ref[i]
    scores_best_of.append(torch.tensor([output[0]["score"] for output in reward_model(response, **sent_kwargs)]))

16it [14:12, 53.27s/it]


In [8]:
output_data["response (ref)"] = response_tensors_ref
output_data["scores (ref)"] = scores_ref
output_data["response (RLHF)"] = response_tensors
output_data["scores (RLHF)"] = scores
output_data["response (best_of)"] = [
    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)
]
output_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(output_data)
df_results.to_csv('best_of_128.csv')
df_results

Unnamed: 0,query,response (ref),scores (ref),response (RLHF),scores (RLHF),response (best_of),scores (best_of)
0,SUBREDDIT: r/,SUBREDDIT: r/dating_advice\nelsenworth,-2.726518,SUBREDDIT: r/relationships\nTITLE: Went,-0.719674,SUBREDDIT: r/AskReddit\n classrooms is this no...,2.801826
1,SUBREDDIT: r/,SUBREDDIT: r/relationships\n income,1.74888,SUBREDDIT: r/dogs\nTITLE,0.305112,SUBREDDIT: r/AskReddit\n explains,2.208817
2,SUBRED,SUBREDDIT: r/relationships\nяп,-2.728848,SUBREDDIT: r/relationships\nTITLE:,-1.110452,SUBREDDIT: r/AskReddit\n checking app capabili...,3.762428
3,SUBREDDIT:,"SUBREDDIT: r/relationships\n royal blood, marr...",1.197688,SUBREDDIT: r/relationships\nTITLE: My [20 M] F...,0.465835,SUBREDDIT: r/Dogtraining\n Various aspects of ...,4.343467
4,SUBRED,SUBREDDIT: r/relationship_advice\nMake a frien...,-2.282068,SUBREDDIT: r/tifu\nTITLE: TIFU by accident,-0.099994,"SUBREDDIT: r/relationships\n""...If your hunger...",2.549654
5,SUBREDDIT: r,SUBREDDIT: r/AskReddit\n followsreddiquette=!!...,-0.283909,SUBREDDIT: r/relationships\nTITLE: I [28 M,0.815666,SUBREDDIT: r/relationship_advice\n underwent a...,2.639867
6,SUBRED,SUBREDDIT: r/jobs\naww,-2.568755,SUBREDDIT: r/AskReddit\nTIT,1.38322,SUBREDDIT: r/AskReddit\n configured,3.596242
7,SUBREDDIT:,SUBREDDIT: r/relationships\n JOHN,-1.511305,SUBREDDIT: r/relationships\nTIT,0.939206,SUBREDDIT: r/AskReddit\n 389,2.758214
8,SUBREDDIT: r,SUBREDDIT: r/tifu\n\nTIFU by,-1.859813,SUBREDDIT: r/AskReddit\nTITLE: My girlfriend,0.701582,SUBREDDIT: r/AskReddit\n by venturebubb\n,2.781665
9,SUBREDDIT: r/,SUBREDDIT: r/cats\n poses A [f] physical confr...,-1.856711,SUBREDDIT: r/AskReddit\nTITLE: How do you expl...,-0.890963,"SUBREDDIT: r/AskReddit\n (""Random curiosities ...",2.662919
