
**Best-of-n sampling as an alternative to RLHF**

This notebook compares reward-model scores of prompt based responses from
1. a base model (`gpt2-imdb`)
2. `RLHF` tuned model based on this base-model
3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model



Import dependencies


In [2]:
%pip install transformers trl



In [3]:
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = 0 if torch.cuda.is_available() else "cpu"

Various constants

In [4]:
ref_model_name = "lvwerra/gpt2-imdb"
model_name = "lvwerra/gpt2-imdb-pos-v2"
reward_model = "lvwerra/distilbert-imdb"

N_BEST_OF = 4

Models and  tokenizers

In [5]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
model.cuda()
ref_model.cuda()

Downloading (…)lve/main/config.json:   0%|          | 0.00/916 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/510M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/17.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

AutoModelForCausalLMWithValueHead(
  (pretrained_model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (

Dataset building

In [6]:
def build_dataset(tokenizer, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(tokenizer)

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/24895 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors


In [15]:
len(dataset)

24895

In [16]:
dataset

Dataset({
    features: ['review', 'label', 'input_ids', 'query'],
    num_rows: 24895
})

In [11]:
### top_k: 从top tokens中挑选
### top_p: 从概率加起来为top_p的top tokens中挑选，top_p通常设置较高的值，如0.75
###
# temperature: 较低的温度意味着较少的随机性；温度为 0 将始终产生相同的输出。执行具有“正确”答案的任务（如问题回答或总结）时，较低的温度（小于 1）更合适。如果模型开始自我重复，则表明温度过低。
#       高温意味着更多的随机性，这可以帮助模型给出更有创意的输出。如果模型开始偏离主题或给出无意义的输出，则表明温度过高。
###

gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [28]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
output_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
output_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [29]:
df_batch

Unnamed: 0,review,label,input_ids,query
7358,"Kill Me Later"" has an interesting initial prem...",0,"[27100, 2185, 11450, 1, 468]","Kill Me Later"" has"
24334,"Despite reading the ""initial comments"" from so...",1,"[8332, 3555, 262]",Despite reading the
21778,"As Peckinpah did with STRAW DOGS, and Kubrick ...",1,"[1722, 48434, 259, 79, 993, 750, 351]",As Peckinpah did with
16009,I sat down to watch this film with much trepid...,1,"[40, 3332, 866, 284, 2342]",I sat down to watch
24284,Just given the fact that it is based on the mo...,1,"[5703, 1813, 262, 1109, 326, 340]",Just given the fact that it
10091,"Spoilers <br /><br />Well, the one line summar...",0,"[4561, 9437, 364, 1279, 1671, 1220]",Spoilers <br /
20775,A stupid young man becomes obsessed with a wom...,1,"[32, 8531, 1862, 582, 4329, 21366]",A stupid young man becomes obsessed
14583,"Sidney Franklin's ""The Good Earth"" has achieve...",1,"[50, 312, 1681]",Sidney
23234,"Honestly, when I saw this movie years ago I im...",1,"[40817, 11]","Honestly,"
19281,*** out of ****<br /><br />Yep! Dressed To Kil...,1,"[8162, 503]",*** out


In [30]:
query_tensors

[array([27100,  2185, 11450,     1,   468], dtype=int32),
 array([8332, 3555,  262], dtype=int32),
 array([ 1722, 48434,   259,    79,   993,   750,   351], dtype=int32),
 array([  40, 3332,  866,  284, 2342], dtype=int32),
 array([5703, 1813,  262, 1109,  326,  340], dtype=int32),
 array([4561, 9437,  364, 1279, 1671, 1220], dtype=int32),
 array([   32,  8531,  1862,   582,  4329, 21366], dtype=int32),
 array([  50,  312, 1681], dtype=int32),
 array([40817,    11], dtype=int32),
 array([8162,  503], dtype=int32),
 array([ 464,  717,  362, 3354], dtype=int32),
 array([  464, 11648], dtype=int32),
 array([   53,   924, 28503, 22490,  3334,  3961,   468], dtype=int32),
 array([  40, 1239, 2497,  262], dtype=int32),
 array([  464,  3670,  3496,   329,   428,  3807, 20004], dtype=int32),
 array([  40,  460,  470, 1975,  428, 3807], dtype=int32)]

In [31]:
type(ref_model)

trl.models.modeling_value_head.AutoModelForCausalLMWithValueHead

In [32]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    print(f"output: {output}")

    output = model.generate(query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors.append(tokenizer.decode(output))

    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF, 1))

    print(f"queries: {queries.size()}, query: {query.size()}")

    output = ref_model.generate(queries.to(device), max_new_tokens=gen_len, **gen_kwargs).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))

output: tensor([27100,  2185, 11450,     1,   468,  4602,   257,  1351,   286,  8502,
          338], device='cuda:0')
queries: torch.Size([4, 5]), query: torch.Size([5])
output: tensor([8332, 3555,  262, 6764,  286,  262, 3807,   11,  314, 1807,  340],
       device='cuda:0')
queries: torch.Size([4, 3]), query: torch.Size([3])
output: tensor([ 1722, 48434,   259,    79,   993,   750,   351, 44772,   338,  7235,
           11,   428,  3704,   286,  3437,   498, 30113,  2125,   470,   326],
       device='cuda:0')
queries: torch.Size([4, 7]), query: torch.Size([7])
output: tensor([   40,  3332,   866,   284,  2342,   262,  2646,    11,   277, 12595],
       device='cuda:0')
queries: torch.Size([4, 5]), query: torch.Size([5])
output: tensor([5703, 1813,  262, 1109,  326,  340,  338, 2192,  347,   25],
       device='cuda:0')
queries: torch.Size([4, 6]), query: torch.Size([6])
output: tensor([ 4561,  9437,   364,  1279,  1671,  1220,  6927,  1671, 11037, 16454,
         5729,    11,   402

Scoring

In [33]:
scores_ref = [output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)]
scores = [output[0]["score"] for output in reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i, response in enumerate(response_tensors_best_of):
    # base_score = scores_ref[i]
    scores_best_of.append(torch.tensor([output[0]["score"] for output in reward_pipe(response, **sent_kwargs)]))



In [34]:
output_data["response (ref)"] = response_tensors_ref
output_data["scores (ref)"] = scores_ref
output_data["response (RLHF)"] = response_tensors
output_data["scores (RLHF)"] = scores
output_data["response (best_of)"] = [
    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)
]
output_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(output_data)
df_results

Unnamed: 0,query,response (ref),scores (ref),response (RLHF),scores (RLHF),response (best_of),scores (best_of)
0,"Kill Me Later"" has","Kill Me Later"" has revealed a list of Hollywood's",1.25059,"Kill Me Later"" has a great exploration of past...",2.571214,"Kill Me Later"" has no plot or simple action;",1.731146
1,Despite reading the,"Despite reading the description of the movie, ...",0.255765,"Despite reading the reasons for this film, I w...",0.883847,"Despite reading the comments here all my life,...",1.772475
2,As Peckinpah did with,"As Peckinpah did with Whale's Blood, this piec...",1.032274,"As Peckinpah did with these films, it was a gr...",2.788937,"As Peckinpah did with the script, it was a wil...",2.512407
3,I sat down to watch,"I sat down to watch the film, fuming",1.801523,I sat down to watch this rubbish and just enjoyed,1.029087,I sat down to watch this film recently and lau...,1.871973
4,Just given the fact that it,Just given the fact that it's probably B:,0.065076,Just given the fact that it's a fun film,2.267283,Just given the fact that it's like a rush,0.665022
5,Spoilers <br /,"Spoilers <br /><br />Okay apparently, Giri Pia...",1.059133,Spoilers <br /><br />They'll hate it. It's a fun,2.373282,Spoilers <br /><br />It was fun to watch a bit...,1.120148
6,A stupid young man becomes obsessed,"A stupid young man becomes obsessed with guns,...",0.81035,A stupid young man becomes obsessed with nothi...,2.496896,"A stupid young man becomes obsessed, and becom...",1.770623
7,Sidney,"Sidney was sweet, bright and rock n roll",2.347187,Sidney Foster is a very good actress. She,2.267288,Sidney's team are pretty much into what Sh,1.526393
8,"Honestly,","Honestly, the ""kind"" hot Head",0.317779,"Honestly, this movie is one of the",1.871874,"Honestly, The Omen MAN simply isn",1.249288
9,*** out,*** out of my money. The characters all were e...,1.550513,*** out of all the countless acting/artwork that,0.636166,*** out of will to stay tuned to this great re...,2.344537
