
**Best-of-n sampling as an alternative to RLHF**

This notebook compares reward-model scores of prompt based responses from 
1. a base model (`gpt2-imdb`)
2. `RLHF` tuned model based on this base-model 
3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model



Import dependencies


In [None]:
%pip install transformers trl

In [2]:
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = 0 if torch.cuda.is_available() else "cpu"

Various constants

In [3]:
ref_model_name = "lvwerra/gpt2-imdb"
model_name = "lvwerra/gpt2-imdb-pos-v2"
reward_model = "lvwerra/distilbert-imdb"

N_BEST_OF = 4

Models and  tokenizers 

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
model.cuda()
ref_model.cuda()

Dataset building

In [None]:
def build_dataset(tokenizer, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(tokenizer)

In [6]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [7]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
output_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
output_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [8]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors.append(tokenizer.decode(output))

    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF, 1))
    output = ref_model.generate(queries.to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))

Scoring

In [None]:
scores_ref = [output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)]
scores = [output[0]["score"] for output in reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i, response in enumerate(response_tensors_best_of):
    # base_score = scores_ref[i]
    scores_best_of.append(torch.tensor([output[0]["score"] for output in reward_pipe(response, **sent_kwargs)]))

In [10]:
output_data["response (ref)"] = response_tensors_ref
output_data["scores (ref)"] = scores_ref
output_data["response (RLHF)"] = response_tensors
output_data["scores (RLHF)"] = scores
output_data["response (best_of)"] = [
    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)
]
output_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(output_data)
df_results

Unnamed: 0,query,response (ref),scores (ref),response (RLHF),scores (RLHF),response (best_of),scores (best_of)
0,I'm a pretty old,"I'm a pretty old kid, well, with lots of girl",1.179652,"I'm a pretty old lady, and I loved this movie ...",2.218363,"I'm a pretty old, stinking,acting kinda chick ...",2.016955
1,One of the most,One of the most psychologically devastating as...,2.477277,One of the most Antibiotic Apps I have seen in,2.145479,One of the most memorable performances of this...,2.676944
2,"Okay, as","Okay, as ruthless as they are, even their leve...",1.466462,"Okay, as I enjoyed the movie. It's added bonus...",2.239827,"Okay, as I put it in such a negative mood, it ...",1.478424
3,"Watching ""Kro","Watching ""Kroger"" (1915-",0.186047,"Watching ""Kroven"". The film has a",1.04469,"Watching ""Kro"" is an entertainment craze",1.389495
4,Seriously what were they thinking?,Seriously what were they thinking? It ain't go...,1.010697,Seriously what were they thinking? It's a very...,2.753088,Seriously what were they thinking? It was stil...,2.523514
5,OK Hollywood,"OK Hollywood goes into a total game of audio, ...",0.934041,"OK Hollywood shoot, and this is a classic. Som...",2.517364,OK Hollywood pay and the freaky set-up of this...,1.634765
6,"""Bend It","""Bend It, Luther, Dodge, Church Goes to Rome w...",0.039218,"""Bend It all"" is a sophisticated, drawing and ...",2.583935,"""Bend It 9""/""Zara Pephoto"") and an honest, rea...",2.55721
7,While the premise behind The House,While the premise behind The House of Dracula ...,-0.079306,While the premise behind The House Intelligenc...,0.205217,While the premise behind The House of Dracula ...,1.676889
8,Well let me go,Well let me go...I don't want to movie it. I'm...,1.015246,Well let me go through everything says it's a ...,2.72704,"Well let me go though, alive in this ever grow...",2.652859
9,Vijay Krishna Acharya,Vijay Krishna Acharya Sawai (Elverling). She was,0.341506,Vijay Krishna Acharya is a perfect performance...,2.563642,"Vijay Krishna Acharya adeptly emerges, and the...",2.308076
