Import dependencies


In [None]:
%pip install transformers trl 

In [None]:
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = 0 if torch.cuda.is_available() else "cpu"  

Various constants

In [None]:
ref_model_name = 'lvwerra/gpt2-imdb'
model_name     = 'lvwerra/gpt2-imdb-pos-v2'
reward_model   = 'lvwerra/distilbert-imdb'
 
N_BEST_OF      = 4

Models and  tokenizers 

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
model.cuda()
ref_model.cuda()

Dataset building

In [None]:
def build_dataset(tokenizer, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

dataset = build_dataset(tokenizer)



In [None]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [None]:

output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [None]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors.append(tokenizer.decode(output))
    
    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF,1))
    output = model.generate(
        queries.to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))



Scoring

In [None]:
scores_ref = [output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)]
scores     = [output[0]["score"] for output in  reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i,response in enumerate(response_tensors_best_of):
  base_score = scores_ref[i]
  scores_best_of.append(torch.tensor([output[0]["score"] - base_score for output in reward_pipe(response, **sent_kwargs)]))




In [None]:


game_data["response (ref)"] = response_tensors_ref
game_data["scores (ref)"] = scores_ref
game_data["response (normal)"] = response_tensors
game_data["scores (normal)"] = scores
game_data["response (best_of)"] = [response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)]
game_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results

Unnamed: 0,query,response (ref),scores (ref),response (normal),scores (normal),response (best_of),scores (best_of)
0,First time of seeing Buster Keaton,First time of seeing Buster Keaton again.<|end...,1.216213,First time of seeing Buster Keaton's performan...,2.096043,"First time of seeing Buster Keaton's movies, I...",1.370462
1,So your bairns are away,"So your bairns are away, this i think is",0.585517,So your bairns are away! A character like that,0.09893,"So your bairns are away, it's very capt",1.216446
2,If Ashanti had been,If Ashanti had been moving fast to the next,0.191934,"If Ashanti had been a trained actor, he",0.076815,"If Ashanti had been miserable, and Mulder",1.627453
3,"The funky,","The funky, yummy-boy sound effects, mj3 energy...",1.38291,"The funky, cerebral and honest story of our in...",2.65226,"The funky, to put the bird in such as a deligh...",1.409587
4,I first saw,"I first saw Girls From Hell tonight, but I",0.087289,I first saw this movie years earlier. Everythi...,1.717765,"I first saw this film in 2000, and it",1.616195
5,I do not know if this movies,I do not know if this movies corner will make ...,1.36311,I do not know if this movies will surprise me ...,2.845047,I do not know if this movies will inform my ow...,1.348881
6,The good news is,The good news is that the old fighter pilots c...,0.331429,The good news is that BT's narration has been ...,1.910826,"The good news is they still have it all, just ...",1.917351
7,Salva and his pal Bigardo,Salva and his pal Bigardo are killed before le...,0.104882,Salva and his pal Bigardo had a lot of potential,1.795311,Salva and his pal Bigardo belong together in a...,2.389707
8,The story would never win awards,The story would never win awards at the Cannes...,1.339547,The story would never win awards. It was a gre...,2.729562,"The story would never win awards, and I loved ...",1.578389
9,"Oh, this is such","Oh, this is such a vice pic they should feel g...",-0.028055,"Oh, this is such a good film. It has elements ...",2.646598,"Oh, this is such a very good film. I like it even",2.836996
