
**Best-of-n sampling as an alternative to RLHF**

This notebook compares reward-model scores of prompt based responses from 
1. a base model (`gpt2-imdb`)
2. `RLHF` tuned model based on this base-model 
3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model



Import dependencies


In [None]:
%pip install transformers trl

In [None]:
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = 0 if torch.cuda.is_available() else "cpu"

Various constants

In [None]:
ref_model_name = "lvwerra/gpt2-imdb"
model_name = "lvwerra/gpt2-imdb-pos-v2"
reward_model = "lvwerra/distilbert-imdb"

N_BEST_OF = 4

Models and  tokenizers 

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
model.cuda()
ref_model.cuda()

Dataset building

In [None]:
def build_dataset(
    tokenizer, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8
):
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(tokenizer)

In [None]:
gen_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [None]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
output_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
output_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [None]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors.append(tokenizer.decode(output))

    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF, 1))
    output = ref_model.generate(
        queries.to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))

Scoring

In [None]:
scores_ref = [
    output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)
]
scores = [output[0]["score"] for output in reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i, response in enumerate(response_tensors_best_of):
    # base_score = scores_ref[i]
    scores_best_of.append(
        torch.tensor(
            [output[0]["score"] for output in reward_pipe(response, **sent_kwargs)]
        )
    )

In [None]:
output_data["response (ref)"] = response_tensors_ref
output_data["scores (ref)"] = scores_ref
output_data["response (RLHF)"] = response_tensors
output_data["scores (RLHF)"] = scores
output_data["response (best_of)"] = [
    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)
]
output_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(output_data)
df_results