
**Best-of-n sampling as an alternative to RLHF**

This notebook compares reward-model scores of prompt based responses from 
1. a base model (`gpt2-imdb`)
2. `RLHF` tuned model based on this base-model 
3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model

Import dependencies

In [None]:
%pip install transformers trl

In [None]:
import torch
import pandas as pd

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
device = "cpu" if device is None else device

Various constants

In [2]:
ref_model_name = "lvwerra/gpt2-imdb"
model_name = "lvwerra/gpt2-imdb-pos-v2"
reward_model = "lvwerra/distilbert-imdb"

N_BEST_OF = 4

Models and tokenizers

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# put models to accelerator
model.to(device)
ref_model.to(device)

Dataset building

In [4]:
def build_dataset(
    tokenizer,
    dataset_name="stanfordnlp/imdb",
    input_min_text_length=2,
    input_max_text_length=8,
):
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(tokenizer)

Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 113700.67 examples/s]
Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 131049.39 examples/s]
Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 126486.39 examples/s]
Filter: 100%|██████████| 25000/25000 [00:00<00:00, 238843.61 examples/s]
Map:   0%|          | 0/24895 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 24895/24895 [00:17<00:00, 1462.36 examples/s]


In [5]:
gen_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [6]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
output_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
output_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [None]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors.append(tokenizer.decode(output))

    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF, 1))
    output = ref_model.generate(
        queries.to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))

Scoring

In [8]:
scores_ref = [
    output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)
]
scores = [output[0]["score"] for output in reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i, response in enumerate(response_tensors_best_of):
    # base_score = scores_ref[i]
    scores_best_of.append(
        torch.tensor(
            [output[0]["score"] for output in reward_pipe(response, **sent_kwargs)]
        )
    )

In [9]:
output_data["response (ref)"] = response_tensors_ref
output_data["scores (ref)"] = scores_ref
output_data["response (RLHF)"] = response_tensors
output_data["scores (RLHF)"] = scores
output_data["response (best_of)"] = [
    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)
]
output_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(output_data)
df_results

Unnamed: 0,query,response (ref),scores (ref),response (RLHF),scores (RLHF),response (best_of),scores (best_of)
0,This movie is one of,This movie is one of the most twisted films I,2.094254,This movie is one of the finest directors of the,2.726879,This movie is one of the best looking movies I,2.705925
1,one may,one may feel we are seeing more,1.478813,"one may not have great assets,",0.420451,"one may not be supported, terrible",2.04373
2,"This is an amazing film,","This is an amazing film, one of our favorite g...",2.871389,"This is an amazing film, with all thelike wond...",2.91877,"This is an amazing film, very moving and this ...",2.871694
3,just below,just below)and makes it seem as,0.861618,just below the world capital is a man,0.238322,just below) in this beautiful comedy.,2.760033
4,Return To the,"Return To the Museum. That film, called Bl",0.017376,"Return To the East"" is a fascinating film,",2.648028,"Return To the International: Miyazaki, by Ts",1.072344
5,Brando plays the ace jet,"Brando plays the ace jet fighter pilot, who stops",0.565335,"Brando plays the ace jet pilot, who's a",0.668954,Brando plays the ace jet pilot Charlie; his fo...,0.679582
6,And a rather U,And a rather Utopian horror movie and with good,2.245751,"And a rather Utop Congressional Movie, with a 45",0.3071,And a rather U of A complete combination of wh...,2.209265
7,The plot of this movie hangs,The plot of this movie hangs in the balance as...,1.12254,The plot of this movie hangs out well. The who...,2.195263,The plot of this movie hangs together within t...,1.310783
8,This isn't,This isn't all that bad; as for my,0.623968,This isn't a good film because I loved it,1.694601,"This isn't bad writing, powerful actors and sp...",1.835901
9,This movie was for a,"This movie was for a good reason!' Uh, OK",0.437566,"This movie was for a fun, and grand Robinson",2.53189,This movie was for a bastard.<br /><br,2.311337
