## Making YouTube Titles More Clickbait

### By: Jacob DeMuth

##### Goal: Make YouTube Video Titles "more appealing" to viewers
##### Subgoal: Make Calix's Video Titles Better

##### Note: I used GPT2-Large which is the 774M parameter of GPT 2

## Install Libraries

In [3]:
## !pip install transformers
## !pip install wandb
## !pip install trl
## !pip install pandas
## !pip install datasets
## !pip install accelerate
## !pip install tyro
## !pip install nltk -U

## Setup Stuff

In [4]:
import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler


In [5]:

config = PPOConfig(
    model_name    = "openai-community/gpt2-large",
    learning_rate = 1.41e-5,
    ## log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True, 
         "function_to_apply": "none", 
         "batch_size": 16
}


In [6]:
wandb.init(mode="disabled") 
os.environ['WANDB_DISABLED'] = 'true'


## Load Video Title dataset


#### Visualize details of dataset


In [7]:
## dataset_name="tonarie/Wayback-Data-Youtube-Homepage-Videos"

In [8]:

ds = load_dataset("csv", data_files="rlhf.csv", split="train")


In [9]:

ds


Dataset({
    features: ['title'],
    num_rows: 42899
})

In [10]:
ds[15:18]

{'title': ['SF International Film Festival Trailer',
  'Re-Enactment: Uncle Buck',
  'Evans Blue "Cold (But I\'m Still Here)" Music Video']}

In [11]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

In [12]:
def show_random_elements(dataset, num_examples=20):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    
    picks = []
    
    for _ in range( num_examples ):
        
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame( dataset[picks] )        ## indexing 10 picks 
    
    print(df)
    print(dataset.features.items())
    
    for column, typ in dataset.features.items():
        print(column)
        print(typ)
        print(ClassLabel)
        ## The isinstance() function returns True if the specified object 
        ## is of the specified type, otherwise False
        if isinstance(typ, ClassLabel):
            print("Hello")
            df[column] = df[column].transform(lambda i: typ.names[i])
            ## print(typ.names[i])
            
    display(HTML(df.to_html()))


In [13]:

show_random_elements(ds)


                                                title
0                 The Adventures of Batman and Rob...
1                                         TEXAS WINSS
2                              Angelina Jolie Cat Eye
3                                   motorcycle stunts
4     A Guy On Jets Side Lines Trips A Dolphin Player
5            UFC 117: Chael Sonnen Welcomes Advers...
6             SMACK/ URL Presents K-SHINE vs TAY ROCK
7            American Chopper Visits Jon & Kate Pl...
8                                             atchoum
9                           "The Soul of New Orleans"
10                          What Defines a Community?
11  Chris Colfer for The Trevor Project - It Gets ...
12  Eenie Meenie Bikini  (Eenie Meenie by Justin B...
13                       New GEICO Commercial - Piggy
14                                          Road Trip
15                           How much weight to lift?
16     Tea Party PSA (FreedomWorks) from D.C. Douglas
17          AT&T Don't Text 

Unnamed: 0,title
0,The Adventures of Batman and Rob...
1,TEXAS WINSS
2,Angelina Jolie Cat Eye
3,motorcycle stunts
4,A Guy On Jets Side Lines Trips A Dolphin Player
5,UFC 117: Chael Sonnen Welcomes Advers...
6,SMACK/ URL Presents K-SHINE vs TAY ROCK
7,American Chopper Visits Jon & Kate Pl...
8,atchoum
9,"""The Soul of New Orleans"""


In [14]:

tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token


In [15]:

def tokenize( sample ):
    sample["input_ids"] = tokenizer.encode( sample["title"]    )[: 20]
    sample["query"]     = tokenizer.decode( sample["input_ids"] )
    return sample


ds = ds.map(tokenize, batched=False)
ds.set_format(type="torch")
ds


Dataset({
    features: ['title', 'input_ids', 'query'],
    num_rows: 42899
})

In [16]:

ds[15:18]


{'title': ['SF International Film Festival Trailer',
  'Re-Enactment: Uncle Buck',
  'Evans Blue "Cold (But I\'m Still Here)" Music Video'],
 'input_ids': [tensor([20802,  4037, 13741, 11117, 36923]),
  tensor([ 3041,    12,  4834,   529,   434,    25, 23169, 13452]),
  tensor([15200,   504,  4518,   366, 34312,   357,  1537,   314,  1101,  7831,
           3423, 16725,  7849,  7623])],
 'query': ['SF International Film Festival Trailer',
  'Re-Enactment: Uncle Buck',
  'Evans Blue "Cold (But I\'m Still Here)" Music Video']}

In [17]:

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
    


## Load GPT2

In [18]:

model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token


In [19]:

ppo_trainer = PPOTrainer(
                 config, 
                 model, 
                 ref_model, 
                 tokenizer, 
                 dataset=ds, 
                 data_collator=collator
)


## Load Reward Function

In [20]:

device = ppo_trainer.accelerator.device
device


device(type='cuda')

In [21]:

if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

device

0

In [22]:

sentiment_pipe = pipeline("sentiment-analysis", model="christinacdl/XLM_RoBERTa-Clickbait-Detection-new", device=device)


In [23]:

text = "5 Things You Won't Believe"

sentiment_pipe(text, **sent_kwargs)




[[{'label': 'NOT', 'score': -5.232669353485107},
  {'label': 'CLICKBAIT', 'score': 5.39152193069458}]]

In [24]:

text = "A Dog Video"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'NOT', 'score': 1.463137149810791},
  {'label': 'CLICKBAIT', 'score': -1.969333291053772}]]

## Generation settings

In [25]:

gen_kwargs = {
         "min_length":   -1, 
         "top_k":       0.0, 
         "top_p":       1.0, 
         "do_sample":  True, 
         "pad_token_id": tokenizer.eos_token_id
}



## Optimize model

### Training loop

In [26]:
## more then 16 and it runs out of memory

output_min_length     = 4
output_max_length     = 8
output_length_sampler = LengthSampler(output_min_length, output_max_length)


In [27]:

generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}


In [28]:
## ppo_trainer.config.steps = 100    ## 20,000
ppo_trainer.config.steps

20000

In [29]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    
    print(query_tensors)
    print(len(query_tensors))
    if epoch == 1:
        break


1it [00:00,  3.50it/s]

[tensor([ 5308,   357, 16305,  1820,     8,   290, 32151, 12996, 13777,   366,
         2949,  3125,     1], device='cuda:0'), tensor([   1, 8164,   11, 6914,   11, 6914,   11, 6914,    1,  416, 2644],
       device='cuda:0'), tensor([   47,    83,   352,   366, 38667, 34198, 17011,   448, 11097,   220,
         3982,    72,   986], device='cuda:0'), tensor([45478,  6378,  1550, 10216,  1395,  5776], device='cuda:0'), tensor([   72, 23410], device='cuda:0'), tensor([   51,  1045, 33504,   717,  3427, 30424,  1067,   292,   986],
       device='cuda:0'), tensor([   42,  1404,    56, 19878, 18276, 33290,  5064, 30649,  3539,   402,
         4261,  6561,   357,  1845,  1118,     8,  2644], device='cuda:0'), tensor([   34,  1765,    84,  8211,   376,  1722, 15360], device='cuda:0'), tensor([   35,  2507,  3691,    13,  6183,   532,   383, 18692],
       device='cuda:0'), tensor([   34,  3861,    57,    56, 47200, 10705,     0,   357,  7556,  3768,
         6407,  8362,     8], device='cuda




In [30]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)
    print(batch)
    print('*********************')
    print('*********************')
    print('*********************')
    print('*********************')
    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]
    print(batch)
    if epoch == 1:
        break


0it [00:00, ?it/s]

0
{'input_ids': [tensor([31466,   298,  5438, 18448,    13], device='cuda:0'), tensor([37247,   868,   287,  9626,   370, 24929], device='cuda:0'), tensor([18332,   278,  2253,   350,  4090,   351,  4705, 33572],
       device='cuda:0'), tensor([40555,  2502,  5851, 12378,  1108], device='cuda:0'), tensor([   38,  4754,  5532,  2873,  1755,  9581,  1570,   422, 25762,  4391,
        24073], device='cuda:0'), tensor([ 2437,   284, 48887, 39801], device='cuda:0'), tensor([   42,  5910, 16754,   843, 13308,    64,  7459, 22244,  1114,   986],
       device='cuda:0'), tensor([42731, 11927,     0], device='cuda:0'), tensor([19555,  2257,  6833,  2310, 10891,    72,  5472], device='cuda:0'), tensor([15645,  4373,  2218,    13,  7039,  4908,   532, 11806,    64,  1629,
         2185], device='cuda:0'), tensor([   54,  7697, 11763,   317,  9897,    30], device='cuda:0'), tensor([24129,  2611, 32189,   513,    14,    18], device='cuda:0'), tensor([  464, 11165, 16089, 13388], device='cuda:0'), 

1it [00:37, 37.84s/it]

{'input_ids': [tensor([31466,   298,  5438, 18448,    13], device='cuda:0'), tensor([37247,   868,   287,  9626,   370, 24929], device='cuda:0'), tensor([18332,   278,  2253,   350,  4090,   351,  4705, 33572],
       device='cuda:0'), tensor([40555,  2502,  5851, 12378,  1108], device='cuda:0'), tensor([   38,  4754,  5532,  2873,  1755,  9581,  1570,   422, 25762,  4391,
        24073], device='cuda:0'), tensor([ 2437,   284, 48887, 39801], device='cuda:0'), tensor([   42,  5910, 16754,   843, 13308,    64,  7459, 22244,  1114,   986],
       device='cuda:0'), tensor([42731, 11927,     0], device='cuda:0'), tensor([19555,  2257,  6833,  2310, 10891,    72,  5472], device='cuda:0'), tensor([15645,  4373,  2218,    13,  7039,  4908,   532, 11806,    64,  1629,
         2185], device='cuda:0'), tensor([   54,  7697, 11763,   317,  9897,    30], device='cuda:0'), tensor([24129,  2611, 32189,   513,    14,    18], device='cuda:0'), tensor([  464, 11165, 16089, 13388], device='cuda:0'), te

1it [00:53, 53.90s/it]

{'input_ids': [tensor([2061,  284, 2822,  262, 1466,  287,  534, 1204,  532, 2644],
       device='cuda:0'), tensor([ 3041,    25,  2488, 47095,  1326,   499,  1647,   930,  5706, 43537],
       device='cuda:0'), tensor([43410, 16706,   406, 22436, 36557,  6416, 14213,   357, 13807,  1303,
           16,     8], device='cuda:0'), tensor([31443,    88, 12579], device='cuda:0'), tensor([45708,   805,   532,  9935,   338,  2892, 39795,  1475, 17040,    25,
          986], device='cuda:0'), tensor([2396,  314, 1138,  337, 9618,  329, 1103,   89,    0], device='cuda:0'), tensor([18858, 20492,  2199, 47714,   532,   366,   464, 37123,   437,   986],
       device='cuda:0'), tensor([43471,  2414,    25,   383, 27330,  4631, 10243], device='cuda:0'), tensor([ 9915, 26123,    25,  3232,   329,  1439,   352,  1222,   362],
       device='cuda:0'), tensor([   41, 29309, 19013,  3691,  4380,   319,   262,  8511],
       device='cuda:0'), tensor([46678, 43236,    42,   532, 31900,  7623,  3776, 440




In [31]:
batch.keys()

dict_keys(['input_ids', 'query', 'response'])

#### Compute Score

In [32]:

batch["query"]


['What to buy the women in your life -...',
 'Re: @knitmeapony | Old Spice',
 'Lindsay Lohan Needs Real Friends (Ep #1)',
 'Rainy Days',
 "Letterman - Dave's Monologue Excerpt:...",
 'So I met Miley for realz!',
 'Ann Marie Calhoun - "The Pretend...',
 'Mega64: The Beatles Rock Band',
 'Black Ops: Free for All 1 & 2',
 'Jumbo Jet vs People on the Beach',
 'STAR TREK - Angry Video Game Nerd',
 'Why do you watch the videos I ma...',
 'There Will Be Bud',
 'Actuación Insula Kampa, Pub Ágora 3.12.2010',
 'The Adventures of Batman and Rob...',
 'Float Like A Butterfly',
 "The Cacaman - Ain't Got Much of...",
 'Paladino threatens New York Post Editor',
 'kaylee',
 "ITN: Japan radiation 'causes no concern'",
 'The Hauntening',
 'Chicken!Chicken!',
 'PALE BLUE DOT',
 'Make a difference',
 'In the car with James May - Top Gear - BBC',
 'History Will Be Made - Radulov/Arnott',
 'Polish Police SCBA - Chopper',
 'IMAX in a basement',
 'Ned 1 Den 0',
 'BUTT MUSCLES & HEAVY BREATHING',
 'Giant House

In [33]:

batch["response"]


[' More\n\nBackground\n',
 ' Hothouse\n\nSubject:',
 ' Kristen Bell and her',
 '" was my favorite "Community"',
 ' Free View in iTunes\n',
 ' Lyrics by Hva Toll',
 '" January 8,',
 '). This Mesmorizer',
 ' Weapon DLC\n\nDecember 22',
 ' –\n\nPir',
 '\n\nYes, you read',
 '\n\nI found',
 's", and "Wra',
 '\n\nPico-Cola /',
 ' Free View in iTunes\n\n',
 ' 5.69 574 5',
 '\n\n30.46',
 "'s Journal\n\n",
 '.com\n\nWanna Know',
 ' says health ministry Read more',
 ' conjurers are a small',
 'Every time you make a "re',
 ' FOR ADULT M',
 ' with your Ether.\n\n',
 ' Radio 4 and BBC the Magazine',
 ' Would Be The Mess Hallers -',
 ' Rather at Risk - Cubism',
 ' office.\n\nHe projected the',
 ' Adam 4 Wisconsin Rlw 5',
 '\n\nThough we do not recommend',
 '粉SpiderWarrior',
 'this must be your spark',
 ' After Being Robbed, Buddy',
 ' Adult Ballshops\n\nW',
 ' Between now and November, more than',
 'igga"!Our exclusive',
 ' Freeling, Double Doctors, Transfer',
 ' Pokémon FireRed - Part 18 Pokémon',


In [34]:

texts = [ q + r for q, r in zip(batch["query"], batch["response"]) ]


In [35]:

texts


['What to buy the women in your life -... More\n\nBackground\n',
 'Re: @knitmeapony | Old Spice Hothouse\n\nSubject:',
 'Lindsay Lohan Needs Real Friends (Ep #1) Kristen Bell and her',
 'Rainy Days" was my favorite "Community"',
 "Letterman - Dave's Monologue Excerpt:... Free View in iTunes\n",
 'So I met Miley for realz! Lyrics by Hva Toll',
 'Ann Marie Calhoun - "The Pretend..." January 8,',
 'Mega64: The Beatles Rock Band). This Mesmorizer',
 'Black Ops: Free for All 1 & 2 Weapon DLC\n\nDecember 22',
 'Jumbo Jet vs People on the Beach –\n\nPir',
 'STAR TREK - Angry Video Game Nerd\n\nYes, you read',
 'Why do you watch the videos I ma...\n\nI found',
 'There Will Be Buds", and "Wra',
 'Actuación Insula Kampa, Pub Ágora 3.12.2010\n\nPico-Cola /',
 'The Adventures of Batman and Rob... Free View in iTunes\n\n',
 'Float Like A Butterfly 5.69 574 5',
 "The Cacaman - Ain't Got Much of...\n\n30.46",
 "Paladino threatens New York Post Editor's Journal\n\n",
 'kaylee.com\n\nWanna Know',
 "ITN

In [36]:

pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
pipe_outputs


[[{'label': 'NOT', 'score': 2.490553617477417},
  {'label': 'CLICKBAIT', 'score': -2.9494712352752686}],
 [{'label': 'NOT', 'score': 0.8059258460998535},
  {'label': 'CLICKBAIT', 'score': -1.302541971206665}],
 [{'label': 'NOT', 'score': -0.4256044924259186},
  {'label': 'CLICKBAIT', 'score': -0.1774040162563324}],
 [{'label': 'NOT', 'score': 1.0809476375579834},
  {'label': 'CLICKBAIT', 'score': -1.4847525358200073}],
 [{'label': 'NOT', 'score': 3.8315069675445557},
  {'label': 'CLICKBAIT', 'score': -4.299412727355957}],
 [{'label': 'NOT', 'score': 2.443516492843628},
  {'label': 'CLICKBAIT', 'score': -3.1101608276367188}],
 [{'label': 'NOT', 'score': 2.957875967025757},
  {'label': 'CLICKBAIT', 'score': -3.311002492904663}],
 [{'label': 'NOT', 'score': 3.8029301166534424},
  {'label': 'CLICKBAIT', 'score': -4.127135276794434}],
 [{'label': 'NOT', 'score': 2.576737880706787},
  {'label': 'CLICKBAIT', 'score': -3.088524580001831}],
 [{'label': 'NOT', 'score': 3.8692023754119873},
  {'l

In [37]:

rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]
rewards


[tensor(-2.9495),
 tensor(-1.3025),
 tensor(-0.1774),
 tensor(-1.4848),
 tensor(-4.2994),
 tensor(-3.1102),
 tensor(-3.3110),
 tensor(-4.1271),
 tensor(-3.0885),
 tensor(-4.2969),
 tensor(-1.0586),
 tensor(-3.7888),
 tensor(-3.8298),
 tensor(-4.2025),
 tensor(-3.9643),
 tensor(-1.1186),
 tensor(-3.8399),
 tensor(-4.5365),
 tensor(-3.2125),
 tensor(-4.4947),
 tensor(-3.8024),
 tensor(2.6146),
 tensor(-3.1890),
 tensor(-3.9331),
 tensor(-4.2898),
 tensor(-0.4648),
 tensor(-5.0327),
 tensor(-2.4057),
 tensor(-4.1132),
 tensor(4.0770),
 tensor(-3.2643),
 tensor(-2.7014),
 tensor(-4.4017),
 tensor(-3.9607),
 tensor(2.6887),
 tensor(2.3842),
 tensor(-4.0638),
 tensor(-2.1604),
 tensor(3.2921),
 tensor(-2.6237),
 tensor(-4.0694),
 tensor(-3.8018),
 tensor(-1.4327),
 tensor(-4.2847),
 tensor(2.5921),
 tensor(-3.6806),
 tensor(-4.0512),
 tensor(0.5040),
 tensor(-3.4761),
 tensor(-3.4257),
 tensor(-1.7957),
 tensor(-1.0368),
 tensor(-4.3142),
 tensor(-1.7293),
 tensor(-4.7236),
 tensor(-3.0547),

In [38]:

len(rewards)


128

In [39]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors, 
                     response_tensors, 
                     rewards
    )
    ppo_trainer.log_stats(stats, batch, rewards)
    


0it [00:00, ?it/s]

0


1it [02:11, 131.94s/it]

1


2it [05:03, 155.05s/it]

2


3it [07:55, 162.91s/it]

3


4it [10:47, 166.55s/it]

4


5it [13:34, 166.78s/it]

5


6it [16:28, 169.03s/it]

6


7it [19:20, 170.06s/it]

7


8it [22:12, 170.59s/it]

8


9it [25:05, 171.40s/it]

9


10it [27:56, 171.49s/it]

10


11it [30:49, 171.80s/it]

11


12it [33:32, 169.29s/it]

12


13it [36:25, 170.16s/it]

13


14it [39:18, 171.14s/it]

14


15it [42:11, 171.59s/it]

15


16it [45:04, 172.04s/it]

16


17it [47:57, 172.43s/it]

17


18it [50:50, 172.68s/it]

18


19it [53:44, 172.87s/it]

19


20it [56:38, 173.33s/it]

20


21it [59:31, 173.23s/it]

21


22it [1:02:16, 170.72s/it]

22


23it [1:05:03, 169.64s/it]

23


24it [1:07:51, 169.03s/it]

24


25it [1:10:38, 168.45s/it]

25


26it [1:13:26, 168.33s/it]

26


27it [1:16:14, 168.42s/it]

27


28it [1:19:01, 167.95s/it]

28


29it [1:21:49, 167.98s/it]

29


30it [1:24:37, 167.82s/it]

30


31it [1:27:30, 169.33s/it]

31


32it [1:30:22, 170.25s/it]

32


33it [1:33:14, 170.82s/it]

33


34it [1:36:08, 171.60s/it]

34


35it [1:39:01, 172.03s/it]

35


36it [1:41:54, 172.41s/it]

36


37it [1:44:47, 172.57s/it]

37


38it [1:47:36, 171.50s/it]

38


39it [1:50:22, 170.03s/it]

39


40it [1:53:16, 171.11s/it]

40


41it [1:56:03, 169.97s/it]

41


42it [1:58:57, 171.00s/it]

42


43it [2:01:51, 171.86s/it]

43


44it [2:04:43, 172.11s/it]

44


45it [2:07:31, 170.62s/it]

45


46it [2:10:18, 169.56s/it]

46


47it [2:13:04, 168.73s/it]

47


48it [2:15:51, 168.13s/it]

48


49it [2:18:39, 168.07s/it]

49


50it [2:21:33, 169.93s/it]

50


51it [2:24:27, 170.92s/it]

51


52it [2:27:14, 169.96s/it]

52


53it [2:30:01, 169.14s/it]

53


54it [2:32:55, 170.47s/it]

54


55it [2:35:49, 171.45s/it]

55


56it [2:38:42, 171.98s/it]

56


57it [2:41:35, 172.36s/it]

57


58it [2:44:22, 170.74s/it]

58


59it [2:47:15, 171.40s/it]

59


60it [2:50:09, 171.98s/it]

60


61it [2:53:01, 172.26s/it]

61


62it [2:55:54, 172.50s/it]

62


63it [2:58:48, 172.69s/it]

63


64it [3:01:41, 172.94s/it]

64


65it [3:04:34, 172.96s/it]

65


66it [3:07:22, 171.36s/it]

66


67it [3:10:15, 171.95s/it]

67


68it [3:13:08, 172.28s/it]

68


69it [3:16:01, 172.47s/it]

69


70it [3:18:49, 171.02s/it]

70


71it [3:21:42, 171.72s/it]

71


72it [3:24:29, 170.42s/it]

72


73it [3:27:23, 171.47s/it]

73


74it [3:30:10, 170.15s/it]

74


75it [3:33:05, 171.44s/it]

75


76it [3:35:59, 172.13s/it]

76


77it [3:38:46, 170.81s/it]

77


78it [3:41:39, 171.34s/it]

78


79it [3:44:32, 171.82s/it]

79


80it [3:47:29, 173.28s/it]

80


81it [3:50:27, 174.70s/it]

81


82it [3:53:22, 175.04s/it]

82


83it [3:56:31, 179.24s/it]

83


84it [3:59:32, 179.61s/it]

84


85it [4:02:35, 180.53s/it]

85


86it [4:05:28, 178.49s/it]

86


87it [4:08:16, 175.26s/it]

87


88it [4:11:04, 172.95s/it]

88


89it [4:13:49, 170.62s/it]

89


90it [4:16:36, 169.49s/it]

90


91it [4:19:29, 170.63s/it]

91


92it [4:22:23, 171.56s/it]

92


93it [4:25:10, 170.18s/it]

93


94it [4:28:04, 171.34s/it]

94


95it [4:30:51, 170.17s/it]

95


96it [4:33:39, 169.42s/it]

96


97it [4:36:32, 170.49s/it]

97


98it [4:39:26, 171.49s/it]

98


99it [4:42:13, 170.41s/it]

99


100it [4:45:01, 169.45s/it]

100


101it [4:47:54, 170.55s/it]

101


102it [4:50:47, 171.42s/it]

102


103it [4:53:40, 171.80s/it]

103


104it [4:56:32, 171.91s/it]

104


105it [4:59:26, 172.48s/it]

105


106it [5:02:19, 172.62s/it]

106


107it [5:05:11, 172.46s/it]

107


108it [5:08:04, 172.72s/it]

108


109it [5:10:51, 170.90s/it]

109


110it [5:13:38, 169.80s/it]

110


111it [5:16:25, 169.02s/it]

111


112it [5:19:18, 170.16s/it]

112


113it [5:22:11, 171.09s/it]

113


114it [5:25:04, 171.59s/it]

114


115it [5:27:51, 170.18s/it]

115


116it [5:30:44, 170.99s/it]

116


117it [5:33:31, 169.77s/it]

117


118it [5:36:24, 170.64s/it]

118


119it [5:39:16, 171.23s/it]

119


120it [5:42:03, 169.85s/it]

120


121it [5:44:55, 170.59s/it]

121


122it [5:47:48, 171.17s/it]

122


123it [5:50:40, 171.67s/it]

123


124it [5:53:33, 171.93s/it]

124


125it [5:56:19, 170.29s/it]

125


126it [5:59:12, 170.89s/it]

126


127it [6:01:59, 169.75s/it]

127


128it [6:04:51, 170.56s/it]

128


129it [6:07:38, 169.43s/it]

129


130it [6:10:25, 168.77s/it]

130


131it [6:13:12, 168.25s/it]

131


132it [6:16:00, 167.95s/it]

132


133it [6:18:53, 169.53s/it]

133


134it [6:21:39, 168.63s/it]

134


135it [6:24:32, 169.98s/it]

135


136it [6:27:19, 169.05s/it]

136


137it [6:30:06, 168.36s/it]

137


138it [6:32:59, 169.85s/it]

138


139it [6:35:46, 168.89s/it]

139


140it [6:38:30, 167.51s/it]

140


141it [6:41:24, 169.32s/it]

141


142it [6:44:17, 170.32s/it]

142


143it [6:47:09, 170.95s/it]

143


144it [6:49:55, 169.50s/it]

144


145it [6:52:48, 170.53s/it]

145


146it [6:55:41, 171.29s/it]

146


147it [6:58:33, 171.61s/it]

147


148it [7:01:27, 172.07s/it]

148


149it [7:04:20, 172.42s/it]

149


150it [7:07:13, 172.75s/it]

150


151it [7:10:05, 172.49s/it]

151


152it [7:12:52, 170.83s/it]

152


153it [7:15:39, 169.63s/it]

153


154it [7:18:25, 168.46s/it]

154


155it [7:21:17, 169.70s/it]

155


156it [7:24:04, 168.78s/it]

156


157it [7:26:51, 168.20s/it]

157


158it [7:29:53, 172.32s/it]

158


159it [7:32:47, 172.99s/it]

159


160it [7:35:42, 173.50s/it]

160


161it [7:38:35, 173.45s/it]

161


162it [7:41:27, 172.80s/it]

162


163it [7:44:21, 173.30s/it]

163


164it [7:47:15, 173.59s/it]

164


165it [7:50:09, 173.56s/it]

165


166it [7:53:02, 173.44s/it]

166


167it [7:55:56, 173.63s/it]

167


168it [7:58:50, 173.80s/it]

168


169it [8:01:37, 171.78s/it]

169


170it [8:04:29, 171.65s/it]

170


171it [8:07:22, 172.18s/it]

171


172it [8:10:14, 172.14s/it]

172


173it [8:13:08, 172.54s/it]

173


174it [8:15:59, 172.08s/it]

174


175it [8:18:51, 172.14s/it]

175


176it [8:21:43, 172.07s/it]

176


177it [8:24:37, 172.59s/it]

177


178it [8:27:30, 172.86s/it]

178


179it [8:30:24, 173.19s/it]

179


180it [8:33:18, 173.46s/it]

180


181it [8:36:10, 172.94s/it]

181


182it [8:39:04, 173.19s/it]

182


183it [8:41:59, 173.70s/it]

183


184it [8:44:51, 173.28s/it]

184


185it [8:47:44, 173.32s/it]

185


186it [8:50:37, 173.19s/it]

186


187it [8:53:29, 172.88s/it]

187


188it [8:56:22, 172.87s/it]

188


189it [8:59:16, 173.09s/it]

189


190it [9:02:10, 173.31s/it]

190


191it [9:05:03, 173.20s/it]

191


192it [9:08:19, 180.17s/it]

192


193it [9:11:12, 177.92s/it]

193


194it [9:14:27, 183.27s/it]

194


195it [9:17:21, 180.33s/it]

195


196it [9:20:13, 178.01s/it]

196


197it [9:23:07, 176.80s/it]

197


198it [9:26:03, 176.37s/it]

198


199it [9:28:57, 175.76s/it]

199


200it [9:31:53, 175.71s/it]

200


201it [9:34:47, 175.21s/it]

201


202it [9:37:40, 174.56s/it]

202


203it [9:40:34, 174.30s/it]

203


204it [9:43:29, 174.55s/it]

204


205it [9:46:23, 174.44s/it]

205


206it [9:49:17, 174.35s/it]

206


207it [9:52:11, 174.12s/it]

207


208it [9:55:05, 174.21s/it]

208


209it [9:57:59, 174.18s/it]

209


210it [10:00:53, 174.15s/it]

210


211it [10:03:47, 173.97s/it]

211


212it [10:06:40, 173.83s/it]

212


213it [10:09:34, 173.85s/it]

213


214it [10:12:29, 174.28s/it]

214


215it [10:15:40, 179.14s/it]

215


216it [10:18:47, 181.65s/it]

216


217it [10:21:50, 182.06s/it]

217


218it [10:24:43, 179.18s/it]

218


219it [10:27:35, 177.11s/it]

219


220it [10:30:23, 174.48s/it]

220


221it [10:33:11, 172.52s/it]

221


222it [10:36:06, 173.23s/it]

222


223it [10:38:59, 173.21s/it]

223


224it [10:41:53, 173.19s/it]

224


225it [10:44:51, 174.66s/it]

225


226it [10:47:40, 173.00s/it]

226


227it [10:50:34, 173.29s/it]

227


228it [10:53:27, 173.29s/it]

228


229it [10:56:19, 172.97s/it]

229


230it [10:59:14, 173.39s/it]

230


231it [11:02:08, 173.59s/it]

231


232it [11:05:00, 173.24s/it]

232


233it [11:07:55, 173.57s/it]

233


234it [11:10:48, 173.66s/it]

234


235it [11:13:43, 173.87s/it]

235


236it [11:16:36, 173.66s/it]

236


237it [11:19:30, 173.78s/it]

237


238it [11:22:24, 173.73s/it]

238


239it [11:25:14, 172.85s/it]

239


240it [11:28:09, 173.25s/it]

240


241it [11:31:03, 173.53s/it]

241


242it [11:33:56, 173.53s/it]

242


243it [11:36:50, 173.70s/it]

243


244it [11:39:44, 173.60s/it]

244


245it [11:42:37, 173.45s/it]

245


246it [11:45:32, 173.98s/it]

246


247it [11:49:06, 186.06s/it]

247


248it [11:53:14, 204.62s/it]

248


249it [11:57:45, 224.40s/it]

249


250it [12:01:28, 224.05s/it]

250


251it [12:06:09, 241.19s/it]

251


252it [12:10:10, 241.14s/it]

252


253it [12:14:06, 239.47s/it]

253


254it [12:18:43, 250.78s/it]

254


255it [12:21:54, 232.72s/it]

255


256it [12:24:50, 215.80s/it]

256


257it [12:27:45, 203.69s/it]

257


258it [12:30:48, 197.33s/it]

258


259it [12:33:47, 191.80s/it]

259


260it [12:36:45, 187.66s/it]

260


261it [12:39:42, 184.67s/it]

261


262it [12:42:39, 182.32s/it]

262


263it [12:45:37, 181.01s/it]

263


264it [12:48:33, 179.56s/it]

264


265it [12:51:31, 178.87s/it]

265


266it [12:54:28, 178.30s/it]

266


267it [12:57:23, 177.33s/it]

267


268it [13:00:20, 177.32s/it]

268


269it [13:03:17, 177.27s/it]

269


270it [13:06:15, 177.32s/it]

270


271it [13:09:10, 176.66s/it]

271


272it [13:12:07, 176.87s/it]

272


273it [13:15:05, 177.29s/it]

273


274it [13:18:02, 177.13s/it]

274


275it [13:21:00, 177.45s/it]

275


276it [13:23:58, 177.69s/it]

276


277it [13:26:57, 177.95s/it]

277


278it [13:30:00, 179.31s/it]

278


279it [13:32:57, 178.74s/it]

279


280it [13:35:55, 178.61s/it]

280


281it [13:38:52, 178.15s/it]

281


282it [13:41:51, 178.45s/it]

282


283it [13:44:50, 178.52s/it]

283


284it [13:47:46, 177.82s/it]

284


285it [13:50:43, 177.59s/it]

285


286it [13:53:39, 177.01s/it]

286


287it [13:56:35, 176.79s/it]

287


288it [13:59:35, 177.68s/it]

288


289it [14:02:30, 177.00s/it]

289


290it [14:05:29, 177.33s/it]

290


291it [14:08:25, 176.94s/it]

291


292it [14:11:16, 175.23s/it]

292


293it [14:14:07, 174.09s/it]

293


294it [14:17:06, 175.49s/it]

294


295it [14:19:57, 174.25s/it]

295


296it [14:22:53, 174.77s/it]

296


297it [14:25:52, 175.86s/it]

297


298it [14:28:50, 176.54s/it]

298


299it [14:31:49, 177.24s/it]

299


300it [14:34:47, 177.41s/it]

300


301it [14:37:43, 177.11s/it]

301


302it [14:40:33, 174.84s/it]

302


303it [14:43:30, 175.53s/it]

303


304it [14:46:28, 176.41s/it]

304


305it [14:49:27, 177.01s/it]

305


306it [14:52:24, 177.23s/it]

306


307it [14:55:22, 177.25s/it]

307


308it [14:58:19, 177.42s/it]

308


309it [15:01:18, 177.78s/it]

309


310it [15:04:16, 177.72s/it]

310


311it [15:07:14, 177.93s/it]

311


312it [15:10:12, 178.04s/it]

312


313it [15:13:09, 177.71s/it]

313


314it [15:16:05, 177.11s/it]

314


315it [15:19:02, 177.21s/it]

315


316it [15:22:00, 177.32s/it]

316


317it [15:24:57, 177.33s/it]

317


318it [15:27:49, 175.73s/it]

318


319it [15:30:47, 176.20s/it]

319


320it [15:33:44, 176.47s/it]

320


321it [15:36:39, 176.24s/it]

321


322it [15:39:38, 176.81s/it]

322


323it [15:42:37, 177.47s/it]

323


324it [15:45:35, 177.63s/it]

324


325it [15:48:31, 177.15s/it]

325


326it [15:51:28, 177.13s/it]

326


327it [15:54:13, 173.50s/it]

327


328it [15:57:10, 174.71s/it]

328


329it [16:00:08, 175.67s/it]

329


330it [16:03:06, 176.38s/it]

330


331it [16:06:06, 177.30s/it]

331


332it [16:09:06, 178.15s/it]

332


333it [16:12:05, 178.47s/it]

333


334it [16:15:01, 177.67s/it]

334


335it [16:18:01, 175.17s/it]


In [40]:

torch.cuda.get_device_name(0)


'NVIDIA GeForce RTX 3080'

In [41]:

#### get a batch from the dataset
bs                 = 16
game_data          = dict()


In [42]:

game_data 


{}

In [43]:
dataset = ds

In [44]:

dataset.set_format("pandas")


In [45]:

df_batch           = dataset[:].sample(bs)
df_batch 


Unnamed: 0,title,input_ids,query
19960,SEXY DANCE TIME!,"[5188, 34278, 360, 19240, 20460, 0]",SEXY DANCE TIME!
4,SF International Film Festival Trailer,"[20802, 4037, 13741, 11117, 36923]",SF International Film Festival Trailer
38045,UNE GAFE DU GARDIEN DE TWENTE,"[41884, 402, 8579, 36, 35480, 402, 9795, 40, 1...",UNE GAFE DU GARDIEN DE TWENTE
39430,"GEICO's R. Lee Ermey, appearing on behalf of T...","[8264, 22707, 338, 371, 13, 5741, 5256, 1326, ...","GEICO's R. Lee Ermey, appearing on behalf of T..."
16686,"Michael Jackson's Neverland ""Ghost"" M...","[13256, 6612, 338, 7236, 1044, 366, 32001, 1, ...","Michael Jackson's Neverland ""Ghost"" M..."
18434,Anderson Cooper Destroys GOP Head Ove...,"[42991, 10382, 8145, 305, 893, 6796, 7123, 440...",Anderson Cooper Destroys GOP Head Ove...
27778,Annoying Orange: Back to the Fruiture,"[18858, 726, 278, 11942, 25, 5157, 284, 262, 2...",Annoying Orange: Back to the Fruiture
36617,Verloren - Luc Weegels Ft. Sandra Reemer,"[13414, 75, 29578, 532, 7598, 775, 1533, 1424,...",Verloren - Luc Weegels Ft. Sandra Reemer
27623,Lil Wayne - Light Up (Freestyle),"[43, 346, 13329, 532, 4401, 3205, 357, 20366, ...",Lil Wayne - Light Up (Freestyle)
5081,Front flips over 9 people Battle of the Eleme...,"[25886, 45971, 625, 860, 661, 220, 5838, 286, ...",Front flips over 9 people Battle of the Eleme...


In [46]:

game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()


In [47]:

response_tensors_ref, response_tensors = [], []


In [48]:

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    
    
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)


In [49]:

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]


In [50]:

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]




In [51]:

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]


In [52]:

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results


Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,SEXY DANCE TIME!,\n\nWarp,There's something that,-2.633893,4.688572
1,SF International Film Festival Trailer,| UNC Filmworks\n\n,\n\nHere's the trailer for,-4.350378,3.547578
2,UNE GAFE DU GARDIEN DE TWENTE,", ES BC.",This is what sort,-2.777253,3.222639
3,"GEICO's R. Lee Ermey, appearing on behalf of T...",Free View in iTunes\n\n,\n\nHere's the actual,-4.501187,4.406254
4,"Michael Jackson's Neverland ""Ghost"" M...","monument, Better",\n\nHere's,-1.522031,5.148306
5,Anderson Cooper Destroys GOP Head Ove...,Obama Is an Everyman,\n\nAnd here's,-2.775688,5.173042
6,Annoying Orange: Back to the Fruiture,ist\n\nThere is,", this is what they",-3.776184,4.450605
7,Verloren - Luc Weegels Ft. Sandra Reemer,- Il Divo 4 Remix,- here's the official cover,-3.925521,4.250594
8,Lil Wayne - Light Up (Freestyle),[VMAFT],. This song had been,-3.872023,2.891146
9,Front flips over 9 people Battle of the Eleme...,#2003Full video of Found,. Here's her response to,-4.306053,4.674588


In [53]:

print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())



mean:


rewards (before)   -3.028568
rewards (after)     3.801119
dtype: float64


median:


rewards (before)   -3.537325
rewards (after)     4.428430
dtype: float64

In [54]:

## model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=True)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=True)

## model.save_pretrained(    "gpt2-imdb-pos-v2-jd", push_to_hub=False)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2-jd", push_to_hub=False)



## Custom Querys

In [55]:
game_data["query"] = "Training a deep learning model with PyTorch for the Iris dataset (part 1)"

In [56]:
response_tensors_ref, response_tensors = [], []

In [57]:
#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    
    
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,Training a deep learning model with PyTorch fo...,\n\nOur hilarious,\n\nHere's,-2.797822,1.193287
1,Training a deep learning model with PyTorch fo...,":""Right now the world is at",. Here's the trailer:\n,2.515629,4.598798
2,Training a deep learning model with PyTorch fo...,ENTH PARTNUM,. This is actually,-3.473917,3.42802
3,Training a deep learning model with PyTorch fo...,18:02 < CodeShark,Here's what we happen to know,-4.319618,4.508544
4,Training a deep learning model with PyTorch fo...,\n\nPitch Perfect 2:,\n\nHere's What Happened,-3.657798,4.927411
5,Training a deep learning model with PyTorch fo...,MSNBC 4 hrs ago 0 149,There are one key reason why,-3.879203,4.19648
6,Training a deep learning model with PyTorch fo...,-Garden Contact This Page 13,. Here's the deal:\n,-4.06006,0.095456
7,Training a deep learning model with PyTorch fo...,- In All This Light (K,"!"" Here's the first twist.",-3.903644,4.397
8,Training a deep learning model with PyTorch fo...,\n\nBillie,Here is his response,-1.879839,4.630973
9,Training a deep learning model with PyTorch fo...,. 5+ WAS,had this theory on,-3.194084,0.274144


In [58]:
game_data["query"] = "VPN Tun/Tap and sockets, routing, tunnels and TLS"

In [59]:
response_tensors_ref, response_tensors = [], []

In [60]:
#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    
    
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,"VPN Tun/Tap and sockets, routing, tunnels and TLS",Inspired by the Kyle and,These are actually how you should,-4.190799,4.860742
1,"VPN Tun/Tap and sockets, routing, tunnels and TLS",\n\nThe project opens,And Remember... Here's,-3.86709,4.867324
2,"VPN Tun/Tap and sockets, routing, tunnels and TLS",ISLES 249.,"-- ""JUST LOOK",-3.397221,-1.632957
3,"VPN Tun/Tap and sockets, routing, tunnels and TLS",‎ Appears in 20 books from,this is what it says…,-2.552716,4.267446
4,"VPN Tun/Tap and sockets, routing, tunnels and TLS",Free View in iTunes\n\n32,\nHere's one to take home,-4.506935,4.556669
5,"VPN Tun/Tap and sockets, routing, tunnels and TLS","\n\nCNN: ""#",\n\nHere's how,-2.520619,3.692731
6,"VPN Tun/Tap and sockets, routing, tunnels and TLS",t area in Pleasant,. This is just,-4.319515,0.147061
7,"VPN Tun/Tap and sockets, routing, tunnels and TLS",- Cobras And Frizz,\n\nHere's what he,-1.844394,4.618574
8,"VPN Tun/Tap and sockets, routing, tunnels and TLS",(Toure Remix,"""'It means he",-3.788618,1.738622
9,"VPN Tun/Tap and sockets, routing, tunnels and TLS",Pensacola 1928 Fighting,. This happened to a,-4.784885,4.367789


In [61]:
game_data["query"] = "Basic Ideas of Reinforcement Learning Through Human Feedbacks (RLHF)"

In [62]:
response_tensors_ref, response_tensors = [], []

In [63]:
#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    
    
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,Basic Ideas of Reinforcement Learning Through ...,This dance has very beautiful,""" We could be talking",0.224746,0.272968
1,Basic Ideas of Reinforcement Learning Through ...,Thumbs • TF,\n\nHere's your,-4.561722,-2.895056
2,Basic Ideas of Reinforcement Learning Through ...,EN KITCHEN Mich,". Come on, here",-2.794768,-1.882573
3,Basic Ideas of Reinforcement Learning Through ...,05:08\n\nJUSTICE,\n\nThis is the most obvious,-3.484382,4.189409
4,Basic Ideas of Reinforcement Learning Through ...,Free View in iTunes\n,\n\nHere's What,-4.29535,3.203177
5,Basic Ideas of Reinforcement Learning Through ...,John Dickerson and Maria,But he doesn't have,-4.254372,4.81649
6,Basic Ideas of Reinforcement Learning Through ...,ous Land Trapper |,", where you can find",-3.881571,-1.823045
7,Basic Ideas of Reinforcement Learning Through ...,- Do What U Want (,. This is the version of,-3.161539,-3.52761
8,Basic Ideas of Reinforcement Learning Through ...,[Drawcad,Here's what ended,-4.389244,4.763288
9,Basic Ideas of Reinforcement Learning Through ...,. The creator of the battle of,. Here's how it aired:,-3.230633,4.745613
