Import dependencies


In [None]:
%pip install transformers trl 

In [2]:
import torch
import pandas as pd
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = 0 if torch.cuda.is_available() else "cpu"  

Various constants

In [3]:
ref_model_name = 'lvwerra/gpt2-imdb'
model_name     = 'lvwerra/gpt2-imdb-pos-v2'
reward_model   = 'lvwerra/distilbert-imdb'
 
N_BEST_OF      = 4

Models and  tokenizers 

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
model.cuda()
ref_model.cuda()

Dataset building

In [5]:
def build_dataset(tokenizer, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

dataset = build_dataset(tokenizer)



In [6]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [7]:

output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [8]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors.append(tokenizer.decode(output))
    
    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF,1))
    output = ref_model.generate(
        queries.to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))



Scoring

In [9]:
scores_ref = [output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)]
scores     = [output[0]["score"] for output in reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i,response in enumerate(response_tensors_best_of):
  # base_score = scores_ref[i]
  scores_best_of.append(torch.tensor([output[0]["score"] for output in reward_pipe(response, **sent_kwargs)]))




In [10]:


game_data["response (ref)"] = response_tensors_ref
game_data["scores (ref)"] = scores_ref
game_data["response (PPO)"] = response_tensors
game_data["scores (PPO)"] = scores
game_data["response (best_of)"] = [response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)]
game_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results




Unnamed: 0,query,response (ref),scores (ref),response (PPO),scores (PPO),response (best_of),scores (best_of)
0,Unbelievably bad,"Unbelievably bad"", zeitgeist,and Disney costs ...",1.713888,"Unbelievably bad film, a logical highlight of ...",1.630138,Unbelievably bad bad)The script is stupidimmed...,2.419863
1,I totally got drawn,I totally got drawn to this story by the utter...,0.161091,I totally got drawn this time. It's a good act...,2.551127,"I totally got drawn in by him a lot.""<br /><",1.9879
2,Some people say this,Some people say this is a two years old,0.127195,"Some people say this comment is wonderful, and",2.618011,Some people say this is the worst film of,2.393284
3,Jim Bel,Jim Belushi makes me completely not interested...,2.132497,Jim Belfort is a very good actress. She is,2.261828,Jim Belushi's widow Moira adds to the star,1.230278
4,I've heard people who,I've heard people who can't understand this pr...,0.469027,I've heard people who follow this movie very w...,2.28618,I've heard people who came here from to Americ...,1.322877
5,"""And the time came when the","""And the time came when the village couldn't y...",0.393636,"""And the time came when the Silver Keys went b...",0.936659,"""And the time came when the Capistrano was pre...",0.882598
6,the lowest score possible is,the lowest score possible is one against a cel...,1.189229,"the lowest score possible is ""An Independent"" ...",2.503858,"the lowest score possible is probably ""See You...",1.088061
7,One of my best films ever,One of my best films ever. I was used to watch...,2.555967,One of my best films ever. I loved it. It's La...,2.833558,One of my best films ever because I can recall...,2.814593
8,i bought this,i bought this movie and enlivened the beginnin...,2.240232,i bought this very good musical by the end. Th...,2.703946,i bought this film because it is a masterpiece...,2.565685
9,What in God's name,What in God's name are we supposed to say abou...,0.492011,What in God's name was so memorable that I kne...,1.999556,What in God's name was such a bad movie?<|endo...,2.299718
