DS 266 Final Project for BART model
The Colab Notebook link is https://colab.research.google.com/drive/1sckwqe2xo6B7m6deFPyA2cEKgscaQWeM?authuser=9#scrollTo=p3uZVUKssMhI

# Set up
Install and import libraries/dependencies

In [None]:
# This cell will authenticate you and mount your Drive in the Colab.
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Install HuggingFace transformers
%pip install transformers



In [None]:
%pip install evaluate



In [None]:
%pip install rouge_score



In [None]:
%pip install prometheus-eval



In [None]:
%pip install triton



In [None]:
%pip install vllm



In [None]:
# Import BART
from transformers import BartForConditionalGeneration, BartTokenizer

In [None]:
import pandas as pd
import numpy as np
import torch
from transformers import set_seed
from transformers import Trainer
from transformers import TrainingArguments
import random
from sklearn.model_selection import train_test_split
import evaluate
import re
from datasets import Dataset
from torch.utils.data import DataLoader, TensorDataset

# LLM as Eval Judge
from prometheus_eval.vllm import VLLM
from prometheus_eval import PrometheusEval
from prometheus_eval.prompts import ABSOLUTE_PROMPT, SCORE_RUBRIC_TEMPLATE

## Set Random Seed
One may comment these out for random behaviors

In [None]:
# Set seeds for reproducibility
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_value)
set_seed(seed_value)

# Datasets

In [None]:
# Datasets file paths, may change per envrionments like local vs Google Colab
ROCStoriesSpring2016_FILE_PATH = './drive/MyDrive/DS 266/Final/ROCStoriesSpring2016.csv'
ROCStoriesWinter2017_FILE_PATH = './drive/MyDrive/DS 266/Final/ROCStoriesWinter2017.csv'
ClozeTestWinter2018ValData_FILE_PATH = './drive/MyDrive/DS 266/Final/ClozeTestWinter2018ValData.csv'
ClozeTestWinter2018TestData_FILE_PATH = './drive/MyDrive/DS 266/Final/ClozeTestWinter2018TestData.csv'
ClozeTestSpring2016ValData_FILE_PATH = './drive/MyDrive/DS 266/Final/ClozeTestSpring2016ValData.csv'

In [None]:
# Load data
ROCStoriesSpring2016 = pd.read_csv(ROCStoriesSpring2016_FILE_PATH)
ROCStoriesWinter2017 = pd.read_csv(ROCStoriesWinter2017_FILE_PATH)
ClozeTestWinter2018ValData = pd.read_csv(ClozeTestWinter2018ValData_FILE_PATH)
ClozeTestWinter2018TestData = pd.read_csv(ClozeTestWinter2018TestData_FILE_PATH)
ClozeTestSpring2016ValData = pd.read_csv(ClozeTestSpring2016ValData_FILE_PATH)

## Inspect Datasets
Check shapes and duplicates

In [None]:
print(ROCStoriesSpring2016.shape)
ROCStoriesSpring2016.head()

(45496, 7)


Unnamed: 0,storyid,storytitle,sentence1,sentence2,sentence3,sentence4,sentence5
0,9a51198e-96f1-42c3-b09d-a3e1e067d803,Overweight Kid,Dan's parents were overweight.,Dan was overweight as well.,The doctors told his parents it was unhealthy.,His parents understood and decided to make a c...,They got themselves and Dan on a diet.
1,617e7ada-3878-488d-bd56-40695b91f053,The Bike Accident,Carrie had just learned how to ride a bike.,She didn't have a bike of her own.,Carrie would sneak rides on her sister's bike.,She got nervous on a hill and crashed into a w...,The bike frame bent and Carrie got a deep gash...
2,79b0da1f-e460-4173-ba58-8c9e2553c53a,Beach,Morgan enjoyed long walks on the beach.,She and her boyfriend decided to go for a long...,"After walking for over a mile, something happe...",Morgan decided to propose to her boyfriend.,Her boyfriend was upset he didn't propose to h...
3,d173b7de-4611-4cdf-934c-912834755e41,The bad customer.,Jane was working at a diner.,"Suddenly, a customer barged up to the counter.",He began yelling about how long his food was t...,Jane didn't know how to react.,"Luckily, her coworker intervened and calmed th..."
4,af0fd5a4-de36-47ba-8aa2-e99d10986d7a,Being Patient,I was talking to my crush today.,She continued to complain about guys flirting ...,I decided to agree with what she says and list...,"After I got home, I got a text from her.",She asked if we can hang out tomorrow.


In [None]:
print(ROCStoriesWinter2017.shape)
ROCStoriesWinter2017.head()

(52665, 7)


Unnamed: 0,storyid,storytitle,sentence1,sentence2,sentence3,sentence4,sentence5
0,8bbe6d11-1e2e-413c-bf81-eaea05f4f1bd,David Drops the Weight,David noticed he had put on a lot of weight re...,He examined his habits to try and figure out t...,He realized he'd been eating too much fast foo...,He stopped going to burger places and started ...,"After a few weeks, he started to feel much bet..."
1,0beabab2-fb49-460e-a6e6-f35a202e3348,Frustration,Tom had a very short temper.,One day a guest made him very angry.,He punched a hole in the wall of his house.,Tom's guest became afraid and left quickly.,Tom sat on his couch filled with regret about ...
2,87da1a22-df0b-410c-b186-439700b70ba6,Marcus Buys Khakis,Marcus needed clothing for a business casual e...,All of his clothes were either too formal or t...,He decided to buy a pair of khakis.,The pair he bought fit him perfectly.,Marcus was happy to have the right clothes for...
3,2d16bcd6-692a-4fc0-8e7c-4a6f81d9efa9,Different Opinions,Bobby thought Bill should buy a trailer and ha...,Bill thought a truck would be better for what ...,Bobby pointed out two vehicles were much more ...,Bill was set in his ways with conventional thi...,He ended up buying the truck he wanted despite...
4,c71bb23b-7731-4233-8298-76ba6886cee1,Overcoming shortcomings,John was a pastor with a very bad memory.,He tried to memorize his sermons many days in ...,He decided to learn to sing to overcome his ha...,He then made all his sermons into music and sa...,His congregation was delighted and so was he.


In [None]:
print(ClozeTestWinter2018ValData.shape)
ClozeTestWinter2018ValData.head()

(1571, 8)


Unnamed: 0,InputStoryid,InputSentence1,InputSentence2,InputSentence3,InputSentence4,RandomFifthSentenceQuiz1,RandomFifthSentenceQuiz2,AnswerRightEnding
0,138d5bfb-05cc-41e3-bf2c-fa85ebad14e2,Rick grew up in a troubled household.,"He never found good support in family, and tur...",It wasn't long before Rick got shot in a robbery.,The incident caused him to turn a new leaf.,He is happy now.,He joined a gang.,1
1,bff9f820-9605-4875-b9af-fe6f14d04256,Laverne needs to prepare something for her fri...,She decides to bake a batch of brownies.,She chooses a recipe and follows it closely.,Laverne tests one of the brownies to make sure...,The brownies are so delicious Laverne eats two...,Laverne doesn't go to her friend's party.,1
2,e8f628d5-9f97-40ed-8611-fc0e774673c4,Sarah had been dreaming of visiting Europe for...,She had finally saved enough for the trip.,She landed in Spain and traveled east across t...,She didn't like how different everything was.,Sarah then decided to move to Europe.,Sarah decided that she preferred her home over...,2
3,f5226bfe-9f26-4377-b05f-3d9568dbdec1,Gina was worried the cookie dough in the tube ...,She was very happy to find she was wrong.,The cookies from the tube were as good as from...,Gina intended to only eat 2 cookies and save t...,Gina liked the cookies so much she ate them al...,Gina gave the cookies away at her church.,1
4,69ac9b05-b956-402f-9fff-1f926ef9176b,It was my final performance in marching band.,I was playing the snare drum in the band.,We played Thriller and Radar Love.,The performance was flawless.,I was very proud of my performance.,I was very ashamed of my performance.,1


In [None]:
print(ClozeTestWinter2018TestData.shape)
ClozeTestWinter2018TestData.head()

(1571, 7)


Unnamed: 0,InputStoryid,InputSentence1,InputSentence2,InputSentence3,InputSentence4,RandomFifthSentenceQuiz1,RandomFifthSentenceQuiz2
0,f6aad64a-e34c-415d-b895-dbfa187ed43e,Bob was bored at his job as a school teacher.,He had been working so hard this past month.,He decided to treat himself with something spe...,He ordered tickets for a weekend snowboarding ...,He was looking forward to getting away.,His boss told him he had to work this weekend.
1,0fedd90d-5295-4b79-b2d0-15a2bad624ee,Olivia went out with Harry on a date.,Harry thought the date was going well.,Olivia thinks he is a complete jerk and never ...,Harry keeps raving about their chemistry.,Olivia is about to leave.,"Olivia had her friend call her, to tell her th..."
2,018152fd-f984-4d05-ad1e-12f1fb7eceb6,Jack and Ferris always fought for headphones.,One day Jack broke Ferris' headphones while jo...,Ferris was furious at Jack.,Their parents yelled at them.,Jack promised Ferris to buy him new headphones.,Jack promised to take Ferris jogging.
3,feef76df-b75a-4501-9c1a-f8a7b6ee442f,I needed someone to help me move a bed across ...,"I called a couple of friends, but they were busy.","Finally, I called my grandson, who came right ...",He helped me move the bed to the right spot.,I made him an ice cream sundae for his efforts.,"He was tired from moving it, so he took a nap ..."
4,929eaf8b-a175-4460-a885-43be8a89ca62,Hannah had a beautiful cat that she loved very...,"However, she noticed that her cat was getting ...","One day, her cat ran away and never came back.",Hannah was devastated.,She never saw her beautiful cat again.,"The next day, she saw the cat walking down the..."


In [None]:
print(ClozeTestSpring2016ValData.shape)
ClozeTestSpring2016ValData.head()

(1871, 8)


Unnamed: 0,InputStoryid,InputSentence1,InputSentence2,InputSentence3,InputSentence4,RandomFifthSentenceQuiz1,RandomFifthSentenceQuiz2,AnswerRightEnding
0,138d5bfb-05cc-41e3-bf2c-fa85ebad14e2,Rick grew up in a troubled household.,"He never found good support in family, and tur...",It wasn't long before Rick got shot in a robbery.,The incident caused him to turn a new leaf.,He is happy now.,He joined a gang.,1
1,bff9f820-9605-4875-b9af-fe6f14d04256,Laverne needs to prepare something for her fri...,She decides to bake a batch of brownies.,She chooses a recipe and follows it closely.,Laverne tests one of the brownies to make sure...,The brownies are so delicious Laverne eats two...,Laverne doesn't go to her friend's party.,1
2,e8f628d5-9f97-40ed-8611-fc0e774673c4,Sarah had been dreaming of visiting Europe for...,She had finally saved enough for the trip.,She landed in Spain and traveled east across t...,She didn't like how different everything was.,Sarah then decided to move to Europe.,Sarah decided that she preferred her home over...,2
3,f5226bfe-9f26-4377-b05f-3d9568dbdec1,Gina was worried the cookie dough in the tube ...,She was very happy to find she was wrong.,The cookies from the tube were as good as from...,Gina intended to only eat 2 cookies and save t...,Gina liked the cookies so much she ate them al...,Gina gave the cookies away at her church.,1
4,69ac9b05-b956-402f-9fff-1f926ef9176b,It was my final performance in marching band.,I was playing the snare drum in the band.,We played Thriller and Radar Love.,The performance was flawless.,I was very proud of my performance.,I was very ashamed of my performance.,1


Please note that ClozeTestWinter2018TestData has no ground truth provided for us. We have to ask ROCStories and Story Cloze Test owner for grading. Therefore, for our final project evaluation, we may need to have another test dataset.

In [None]:
pd.merge(ROCStoriesSpring2016, ROCStoriesWinter2017, how='inner').empty

True

In [None]:
# In case some rows have different story ids but the same sentences
pd.merge(ROCStoriesSpring2016, ROCStoriesWinter2017, how='inner', on=["sentence1", "sentence2"]).empty

True

In [None]:
pd.merge(ClozeTestSpring2016ValData, ClozeTestWinter2018ValData, how='inner', on=["InputSentence1", "InputSentence2"]).empty

False

We can see that two story cloze datasets have no overlap; we can have one as training and the other as test datasets. Two ROCStories datasets have no duplicate/overlap. We can safely combine them

In [None]:
ROCStories = pd.concat([ROCStoriesSpring2016, ROCStoriesWinter2017])
ROCStories

Unnamed: 0,storyid,storytitle,sentence1,sentence2,sentence3,sentence4,sentence5
0,9a51198e-96f1-42c3-b09d-a3e1e067d803,Overweight Kid,Dan's parents were overweight.,Dan was overweight as well.,The doctors told his parents it was unhealthy.,His parents understood and decided to make a c...,They got themselves and Dan on a diet.
1,617e7ada-3878-488d-bd56-40695b91f053,The Bike Accident,Carrie had just learned how to ride a bike.,She didn't have a bike of her own.,Carrie would sneak rides on her sister's bike.,She got nervous on a hill and crashed into a w...,The bike frame bent and Carrie got a deep gash...
2,79b0da1f-e460-4173-ba58-8c9e2553c53a,Beach,Morgan enjoyed long walks on the beach.,She and her boyfriend decided to go for a long...,"After walking for over a mile, something happe...",Morgan decided to propose to her boyfriend.,Her boyfriend was upset he didn't propose to h...
3,d173b7de-4611-4cdf-934c-912834755e41,The bad customer.,Jane was working at a diner.,"Suddenly, a customer barged up to the counter.",He began yelling about how long his food was t...,Jane didn't know how to react.,"Luckily, her coworker intervened and calmed th..."
4,af0fd5a4-de36-47ba-8aa2-e99d10986d7a,Being Patient,I was talking to my crush today.,She continued to complain about guys flirting ...,I decided to agree with what she says and list...,"After I got home, I got a text from her.",She asked if we can hang out tomorrow.
...,...,...,...,...,...,...,...
52660,134e8636-3617-43d8-ba6a-9a11b3b115b1,Flavor,The man liked the flavor.,He tried to recreate it at home.,He could not get the flavor right.,He asked the owner of the recipe for help.,The owner of the flavor sold him the recipe.
52661,4c317f76-ca42-4024-a4c2-12ec911cf89b,After Death,"After my friend's dad's funeral, I got in trou...",The principal said I wasn't allowed to leave s...,He found out I had my friend sign me out.,He told me I was getting detention.,I skipped detention all week.
52662,a18fd0d2-4d0c-4316-befe-e3d827fe699b,Janice breaks her wrist,Janice was out exercising for her big soccer g...,She was doing some drills with her legs.,While working out and exercising she slips on ...,She falls down and uses her wrist to break her...,She breaks her wrist in the process and goes t...
52663,2c14252b-4080-4fca-8765-537772018508,Jamie marries for love,Jamie is an american girl.,Jamie wants to get married to a mexican man.,Her family assumes it's because the man wants ...,Jamie insist that she is marrying him out of l...,Jamie gets married and they spent the rest of ...


# Utility Functions

In [None]:
def combine_sentences(row, sentence_columns_to_be_combined=["sentence1", "sentence2", "sentence3", "sentence4"]):
  assert len(sentence_columns_to_be_combined) > 0, "Sentence columns to be combined list length has to be larger than 0"
  combined_sentence = row[sentence_columns_to_be_combined[0]]
  for i in range(1, len(sentence_columns_to_be_combined)):
    combined_sentence += ' ' + row[sentence_columns_to_be_combined[i]]
  return combined_sentence

In [None]:
def remove_first_four_sentences(paragraph):
    # Split the paragraph into sentences using regex
    sentences = re.split(r'(?<=[.!?])[\s"]+', paragraph)
    # Keep all sentences except the first 4
    remaining_sentences = sentences[4:]
    # Rejoin the sentences into a paragraph
    joined_sentences = ""
    for sentence in remaining_sentences:
      joined_sentences += sentence
      joined_sentences += " "
    joined_sentences = joined_sentences[:-1]
    return joined_sentences

def vectorized_remove_first_four_sentences(paragraphs):
    return np.vectorize(remove_first_four_sentences)(np.array(paragraphs))

In [None]:
def concat_story_body_with_endings(story_body, story_ending):
  return [f"{p1} {p2}" for p1, p2 in zip(story_body, story_ending)]

# Train, Validation, Test Split

In [None]:
ROCStories_Y = ROCStories["sentence5"]
ROCStories_X = ROCStories.drop(columns=["storyid", "storytitle", "sentence5"])
ROCStories_X = ROCStories_X.apply(combine_sentences, axis=1)
ROCStories_X

Unnamed: 0,0
0,Dan's parents were overweight. Dan was overwei...
1,Carrie had just learned how to ride a bike. Sh...
2,Morgan enjoyed long walks on the beach. She an...
3,"Jane was working at a diner. Suddenly, a custo..."
4,I was talking to my crush today. She continued...
...,...
52660,The man liked the flavor. He tried to recreate...
52661,"After my friend's dad's funeral, I got in trou..."
52662,Janice was out exercising for her big soccer g...
52663,Jamie is an american girl. Jamie wants to get ...


In [None]:
ROCStories_X.shape

(98161,)

In [None]:
ROCStories_Y.shape

(98161,)

We decided to combine some of ROCStories data and the Story Cloze 2018 Validation datasets as the final train datasets. The reasoning to additionally include Story Cloze dataset is that Story Cloze 2018 Validation datasets have the correct endings, therefore can be used for both classification or story ending generation evaluation. It also has a potential to be a bit different from the original ROCStories datasets that is mainly used as training data (in other words, increasing the variance of the test datasets for generalization).

In [None]:
ROCX_train_val, ROCX_test, ROCy_train_val, ROCy_test = train_test_split(ROCStories_X, ROCStories_Y, test_size=0.02, random_state=seed_value)

In [None]:
print(ROCX_train_val.shape)
print(ROCX_test.shape)
print(ROCy_train_val.shape)
print(ROCy_test.shape)

(96197,)
(1964,)
(96197,)
(1964,)


In [None]:
Clozey_test = np.where(ClozeTestWinter2018ValData["AnswerRightEnding"]==1, ClozeTestWinter2018ValData["RandomFifthSentenceQuiz1"], ClozeTestWinter2018ValData["RandomFifthSentenceQuiz2"])
Clozey_test = pd.DataFrame(Clozey_test, columns=["sentence5"])
Clozey_test

Unnamed: 0,sentence5
0,He is happy now.
1,The brownies are so delicious Laverne eats two...
2,Sarah decided that she preferred her home over...
3,Gina liked the cookies so much she ate them al...
4,I was very proud of my performance.
...,...
1566,I have very fond memories of checkers.
1567,She loved her new phone.
1568,They were on sale.
1569,She was offered the new job at a higher salary.


In [None]:
ClozeX_test = ClozeTestWinter2018ValData.drop(columns=["InputStoryid", "RandomFifthSentenceQuiz1", "RandomFifthSentenceQuiz2", "AnswerRightEnding"])
ClozeX_test = ClozeX_test.rename(columns={"InputSentence1": "sentence1", "InputSentence2": "sentence2", "InputSentence3": "sentence3", "InputSentence4": "sentence4"})
ClozeX_test = ClozeX_test.apply(combine_sentences, axis=1)
ClozeX_test

Unnamed: 0,0
0,Rick grew up in a troubled household. He never...
1,Laverne needs to prepare something for her fri...
2,Sarah had been dreaming of visiting Europe for...
3,Gina was worried the cookie dough in the tube ...
4,It was my final performance in marching band....
...,...
1566,When I was a kid I really wanted to play check...
1567,Ivy wanted a cell phone like all her friends. ...
1568,Dave walked into the grocery store. He was goi...
1569,Ramona was very unhappy in her job. She asked ...


In [None]:
X_test = pd.concat([ROCX_test, ClozeX_test])
X_test

Unnamed: 0,0
20391,Evan had been saving for years. He went to the...
32389,Serena was planning a surprise for her husband...
8700,Fred slapped another man's butt. He thought he...
45048,I used to lived in Phoenix Arizona. On my way ...
23146,Tom thought he was really strong. He challenge...
...,...
1566,When I was a kid I really wanted to play check...
1567,Ivy wanted a cell phone like all her friends. ...
1568,Dave walked into the grocery store. He was goi...
1569,Ramona was very unhappy in her job. She asked ...


In [None]:
y_test = pd.concat([ROCy_test, Clozey_test])
y_test

Unnamed: 0,sentence5
20391,Evan knew he looked cool in the new car.
32389,Together they were able to surprise him perfec...
8700,Fred apologized to the man.
45048,Talking with him not only made my day it made ...
23146,Tom's friends thought it was annoying.
...,...
1566,I have very fond memories of checkers.
1567,She loved her new phone.
1568,They were on sale.
1569,She was offered the new job at a higher salary.


In [None]:
ROCX_train, ROCX_val, ROCy_train, ROCy_val = train_test_split(ROCX_train_val, ROCy_train_val, test_size=0.2, random_state=seed_value)

In [None]:
ROCX_train = ROCX_train.to_list()
ROCX_val = ROCX_val.to_list()
ROCy_train = ROCy_train.to_list()
ROCy_val = ROCy_val.to_list()
X_test = X_test.to_list()
y_test = y_test["sentence5"].to_list()

In [None]:
print(len(ROCX_train))
print(len(ROCX_val))
print(len(ROCy_train))
print(len(ROCy_val))
print(len(X_test))
print(len(y_test))

76957
19240
76957
19240
3535
3535


# Bart Model and Baseline

## Load BART Model

In [None]:
model_name = "facebook/bart-large-cnn"
base_model = BartForConditionalGeneration.from_pretrained(model_name)
base_tokenizer = BartTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
base_model.to("cuda")

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

## Prepare test data for evaluation

In [None]:
# Tokenize the input text
X_test_inputs = base_tokenizer(X_test, padding=True, return_tensors="pt")
X_test_inputs

{'input_ids': tensor([[    0,   717,  9965,  ...,     1,     1,     1],
        [    0,   104,  2816,  ...,     1,     1,     1],
        [    0, 33153, 18361,  ...,     1,     1,     1],
        ...,
        [    0, 33857,  3203,  ...,     1,     1,     1],
        [    0, 32361,  4488,  ...,     1,     1,     1],
        [    0,   100,   770,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

Move inputs to CUDA so that model.generate(**inputs) will use GPU. Once you do this, you can see GPU RAM increases and RAM (CPU RAM) is relatively stable. Otherwise, only CPU is used, which is slow, roughly taking 3 to 4 hours for generating ids for 3500ish test data.

In [None]:
X_test_inputs["input_ids"] = X_test_inputs["input_ids"].to("cuda")
X_test_inputs["attention_mask"] = X_test_inputs["attention_mask"].to("cuda")

In [None]:
X_test_inputs['input_ids'].device

device(type='cuda', index=0)

In [None]:
X_test_inputs_dataloader = DataLoader(TensorDataset(X_test_inputs["input_ids"], X_test_inputs["attention_mask"]), batch_size=32)
test_original_story_with_endings_dataloader = DataLoader(concat_story_body_with_endings(X_test, y_test), batch_size=32)
original_story_bodies = DataLoader(X_test, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)

## Calculate perplexity, rouge, LLm as Judge score with basic BART as baseline

In [None]:
judge_model = VLLM(model="prometheus-eval/prometheus-7b-v2.0", gpu_memory_utilization=0.5, enforce_eager=True, max_num_seqs=8, max_model_len=2048)
judge = PrometheusEval(model=judge_model, absolute_grade_template=ABSOLUTE_PROMPT)
rubric_data = {
  "criteria":"Does the model successfully generate a response that is a suitable ending to the provided story body, in terms of language fluency, semantics coherence, and story flow",
  "score1_description":"The ending is riddled with language errors, is incoherent or disconnected from the story body, and disrupts the narrative flow. It leaves the reader confused or unsatisfied.",
  "score2_description":"The ending has noticeable language issues, inconsistencies, or rushed transitions. While it ties some loose ends, it feels incomplete or awkward.",
  "score3_description":"The ending is adequately written with minor language errors and mostly logical progression, but it lacks emotional impact or creativity in its resolution.",
  "score4_description":"The ending is well-crafted, fluent, and coherent, with a fitting and satisfying resolution. It enhances the story’s themes and characters, though it might lack exceptional originality or depth.",
  "score5_description":"The ending is flawless in language fluency, beautifully integrates with the story, and delivers a compelling, imaginative, and emotionally resonant conclusion that elevates the entire narrative."
}
score_rubric = SCORE_RUBRIC_TEMPLATE.format(**rubric_data)

INFO 12-08 09:43:34 config.py:350] This model supports multiple tasks: {'embedding', 'generate'}. Defaulting to 'generate'.
INFO 12-08 09:43:34 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post1) with config: model='prometheus-eval/prometheus-7b-v2.0', speculative_config=None, tokenizer='prometheus-eval/prometheus-7b-v2.0', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=prometheus-

Loading safetensors checkpoint shards:   0% Completed | 0/8 [00:00<?, ?it/s]


INFO 12-08 09:43:40 model_runner.py:1077] Loading model weights took 13.4966 GB
INFO 12-08 09:43:40 worker.py:232] Memory profiling results: total_gpu_memory=39.56GiB initial_memory_usage=15.52GiB peak_torch_memory=15.23GiB memory_usage_post_profile=15.54GiB non_torch_memory=0.52GiB kv_cache_size=4.03GiB gpu_memory_utilization=0.50
INFO 12-08 09:43:41 gpu_executor.py:113] # GPU blocks: 2064, # CPU blocks: 2048
INFO 12-08 09:43:41 gpu_executor.py:117] Maximum concurrency for 2048 tokens per request: 16.12x


In [None]:
%%time
final_perplexity_baseline = 0
final_rouge1_baseline = 0
final_rouge2_baseline = 0
final_rougeL_baseline = 0
LLM_as_judge_score = 0

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for input_and_attention_mask, original_stories, reference in zip(X_test_inputs_dataloader, original_story_bodies, original_story_endings_reference):
    generated_ids = base_model.generate(inputs=input_and_attention_mask[0], attention_mask=input_and_attention_mask[1], min_new_tokens=70)
    full_story_with_ending_test = base_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    # Calculate perplexity
    weight_factor = len(reference) / len(X_test)
    final_perplexity_baseline += weight_factor * perplexity.compute(predictions=full_story_with_ending_test, model_id='facebook/bart-large-cnn')["mean_perplexity"]
    # Calculate rouge
    # Pop empty strings given in some rare cases, BART doesn't copy down the exact prompt, especially missing punctuations
    empty_indices = [i for i, s in enumerate(vectorized_remove_first_four_sentences(full_story_with_ending_test)) if s == ""]
    empty_indices.sort(reverse=True)
    for index in empty_indices:
      full_story_with_ending_test.pop(index)
      original_stories.pop(index)
      reference.pop(index)
    generated_endings = vectorized_remove_first_four_sentences(full_story_with_ending_test)
    rouge_results = rouge.compute(predictions=generated_endings, references=reference)
    final_rouge1_baseline += weight_factor * rouge_results["rouge1"]
    final_rouge2_baseline += weight_factor * rouge_results["rouge2"]
    final_rougeL_baseline += weight_factor * rouge_results["rougeL"]
    # LLM as judge
    feedback, score = judge.absolute_grade(
      instructions=original_stories,
      responses=generated_endings,
      rubric=score_rubric,
      reference_answers=reference
    )
    LLM_as_judge_score += sum(score)

  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:06<00:00,  4.68it/s, est. speed input: 2643.43 toks/s, output: 884.87 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 10011.04it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.37it/s, est. speed input: 3046.57 toks/s, output: 951.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11525.78it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.52it/s, est. speed input: 3116.53 toks/s, output: 981.03 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11076.19it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.82it/s, est. speed input: 3305.45 toks/s, output: 986.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11590.48it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.02it/s, est. speed input: 2837.23 toks/s, output: 879.31 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10561.67it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.30it/s, est. speed input: 3000.44 toks/s, output: 931.60 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11237.01it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.83it/s, est. speed input: 2724.05 toks/s, output: 852.08 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10335.57it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.19it/s, est. speed input: 3486.40 toks/s, output: 1065.48 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10631.95it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.49it/s, est. speed input: 2534.47 toks/s, output: 814.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10935.13it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.82it/s, est. speed input: 3296.73 toks/s, output: 970.50 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11160.63it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.81it/s, est. speed input: 3272.16 toks/s, output: 985.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11306.35it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.28it/s, est. speed input: 2981.47 toks/s, output: 921.66 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11160.63it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.03it/s, est. speed input: 2835.57 toks/s, output: 886.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10710.86it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.84it/s, est. speed input: 2728.35 toks/s, output: 858.35 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10164.93it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:07<00:00,  4.42it/s, est. speed input: 2507.45 toks/s, output: 763.90 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11028.28it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.90it/s, est. speed input: 2767.47 toks/s, output: 893.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10099.15it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:07<00:00,  4.13it/s, est. speed input: 2333.18 toks/s, output: 735.49 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11237.01it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:06<00:00,  4.90it/s, est. speed input: 2778.24 toks/s, output: 838.91 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11790.30it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.23it/s, est. speed input: 2944.98 toks/s, output: 937.15 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11001.22it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.88it/s, est. speed input: 3320.17 toks/s, output: 1000.24 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11884.05it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.48it/s, est. speed input: 3110.60 toks/s, output: 952.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11721.05it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.30it/s, est. speed input: 2995.13 toks/s, output: 946.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10034.22it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.51it/s, est. speed input: 3122.26 toks/s, output: 972.91 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9991.64it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.63it/s, est. speed input: 3175.94 toks/s, output: 970.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11574.48it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.78it/s, est. speed input: 3252.73 toks/s, output: 955.20 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6544.65it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.01it/s, est. speed input: 3392.21 toks/s, output: 1045.00 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10334.78it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.30it/s, est. speed input: 2991.85 toks/s, output: 892.25 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11901.90it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.04it/s, est. speed input: 3407.01 toks/s, output: 1055.54 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10877.52it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.13it/s, est. speed input: 3475.42 toks/s, output: 1036.09 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6879.79it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.60it/s, est. speed input: 2607.33 toks/s, output: 846.54 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11057.65it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.28it/s, est. speed input: 3543.38 toks/s, output: 1057.95 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9991.64it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.81it/s, est. speed input: 2720.06 toks/s, output: 878.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10002.81it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.53it/s, est. speed input: 3116.39 toks/s, output: 919.98 toks/s] 


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11916.69it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.48it/s, est. speed input: 3094.27 toks/s, output: 963.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11191.34it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.95it/s, est. speed input: 3364.90 toks/s, output: 999.09 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10596.69it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.83it/s, est. speed input: 3305.97 toks/s, output: 993.39 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 10378.63it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.06it/s, est. speed input: 3413.74 toks/s, output: 1040.13 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10890.76it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:06<00:00,  5.12it/s, est. speed input: 2896.01 toks/s, output: 893.20 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 9607.88it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:04<00:00,  6.41it/s, est. speed input: 3614.57 toks/s, output: 1059.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10849.38it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.75it/s, est. speed input: 3248.77 toks/s, output: 1003.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10491.50it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.57it/s, est. speed input: 2582.23 toks/s, output: 799.28 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10626.06it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.21it/s, est. speed input: 2945.00 toks/s, output: 909.76 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10746.88it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.22it/s, est. speed input: 2958.38 toks/s, output: 931.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9784.77it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.79it/s, est. speed input: 2696.94 toks/s, output: 845.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10862.55it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.18it/s, est. speed input: 2363.77 toks/s, output: 754.49 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10503.81it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.85it/s, est. speed input: 3300.84 toks/s, output: 983.87 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9941.32it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.28it/s, est. speed input: 2969.07 toks/s, output: 898.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7794.29it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.83it/s, est. speed input: 2736.93 toks/s, output: 830.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6402.91it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.72it/s, est. speed input: 2666.98 toks/s, output: 883.98 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9947.95it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.65it/s, est. speed input: 3192.49 toks/s, output: 1025.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10739.14it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.00it/s, est. speed input: 3387.69 toks/s, output: 1010.54 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9550.82it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:04<00:00,  6.50it/s, est. speed input: 3676.41 toks/s, output: 1063.47 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10415.78it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.47it/s, est. speed input: 3084.12 toks/s, output: 943.01 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6531.91it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.87it/s, est. speed input: 2745.59 toks/s, output: 861.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10580.82it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.23it/s, est. speed input: 2964.64 toks/s, output: 904.44 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10923.56it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.42it/s, est. speed input: 3068.09 toks/s, output: 917.95 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 5938.84it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.72it/s, est. speed input: 3233.42 toks/s, output: 995.83 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10440.08it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.56it/s, est. speed input: 3139.46 toks/s, output: 956.03 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 11275.01it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.81it/s, est. speed input: 3283.40 toks/s, output: 1030.24 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10703.17it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.76it/s, est. speed input: 2687.00 toks/s, output: 848.27 toks/s] 


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11020.42it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.70it/s, est. speed input: 3232.95 toks/s, output: 995.71 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 6043.67it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.25it/s, est. speed input: 2971.09 toks/s, output: 912.38 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 10017.21it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.74it/s, est. speed input: 3231.30 toks/s, output: 1013.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10188.08it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.03it/s, est. speed input: 2838.92 toks/s, output: 884.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11275.96it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.56it/s, est. speed input: 2566.26 toks/s, output: 841.83 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11233.49it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.77it/s, est. speed input: 2680.81 toks/s, output: 883.36 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10412.55it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.02it/s, est. speed input: 2265.92 toks/s, output: 745.60 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8742.12it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.34it/s, est. speed input: 3021.87 toks/s, output: 973.45 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10487.40it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.77it/s, est. speed input: 3246.73 toks/s, output: 974.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11412.10it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.37it/s, est. speed input: 3041.79 toks/s, output: 968.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10064.32it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.47it/s, est. speed input: 3076.82 toks/s, output: 1005.67 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10869.59it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.90it/s, est. speed input: 2763.20 toks/s, output: 858.30 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11539.66it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.45it/s, est. speed input: 2504.92 toks/s, output: 813.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11189.47it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.57it/s, est. speed input: 3127.54 toks/s, output: 953.00 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11497.15it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:07<00:00,  4.36it/s, est. speed input: 2453.36 toks/s, output: 791.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10873.11it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.71it/s, est. speed input: 3196.72 toks/s, output: 970.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10943.15it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.05it/s, est. speed input: 2839.17 toks/s, output: 867.58 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11577.48it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.79it/s, est. speed input: 2693.29 toks/s, output: 880.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10707.44it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.94it/s, est. speed input: 3347.36 toks/s, output: 1033.25 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10610.10it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.41it/s, est. speed input: 3042.07 toks/s, output: 956.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10235.47it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.89it/s, est. speed input: 3301.60 toks/s, output: 1031.30 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10708.29it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.63it/s, est. speed input: 3167.62 toks/s, output: 987.91 toks/s] 


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11665.02it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.95it/s, est. speed input: 3355.51 toks/s, output: 1020.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10469.40it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  4.80it/s, est. speed input: 2706.67 toks/s, output: 903.22 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8581.15it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.04it/s, est. speed input: 2843.20 toks/s, output: 877.09 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10971.78it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.11it/s, est. speed input: 2877.09 toks/s, output: 899.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11395.63it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.89it/s, est. speed input: 3305.97 toks/s, output: 954.89 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9303.23it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.34it/s, est. speed input: 3009.13 toks/s, output: 975.88 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10201.24it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  6.23it/s, est. speed input: 3524.40 toks/s, output: 1046.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7301.58it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.07it/s, est. speed input: 2860.02 toks/s, output: 903.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10748.60it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.49it/s, est. speed input: 3087.52 toks/s, output: 942.50 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11071.33it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s, est. speed input: 3295.09 toks/s, output: 1022.65 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 10318.50it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.67it/s, est. speed input: 3188.97 toks/s, output: 1009.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10321.26it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.68it/s, est. speed input: 3204.85 toks/s, output: 972.64 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10037.22it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.44it/s, est. speed input: 3082.02 toks/s, output: 983.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11107.06it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.53it/s, est. speed input: 3121.65 toks/s, output: 941.96 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 12094.08it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.50it/s, est. speed input: 3108.07 toks/s, output: 987.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10128.88it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.72it/s, est. speed input: 3230.57 toks/s, output: 1003.66 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7522.99it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.76it/s, est. speed input: 3236.32 toks/s, output: 1005.12 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8023.06it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 30/30 [00:05<00:00,  5.69it/s, est. speed input: 3201.05 toks/s, output: 1001.26 toks/s]


Processed 30/30 instances.


Finalizing: 100%|██████████| 30/30 [00:00<00:00, 5332.42it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.43it/s, est. speed input: 3045.88 toks/s, output: 919.29 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11533.71it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.61it/s, est. speed input: 3161.57 toks/s, output: 1019.35 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7900.74it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.86it/s, est. speed input: 3300.39 toks/s, output: 1066.05 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10488.22it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.20it/s, est. speed input: 2929.58 toks/s, output: 920.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6643.46it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.23it/s, est. speed input: 2951.17 toks/s, output: 906.16 toks/s] 


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 5428.42it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.77it/s, est. speed input: 3246.61 toks/s, output: 984.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10491.50it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:05<00:00,  5.99it/s, est. speed input: 3361.81 toks/s, output: 982.38 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 12077.23it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:05<00:00,  5.94it/s, est. speed input: 3360.75 toks/s, output: 1014.91 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6951.05it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 31/31 [00:06<00:00,  4.59it/s, est. speed input: 2576.83 toks/s, output: 854.79 toks/s]


Processed 31/31 instances.


Finalizing: 100%|██████████| 31/31 [00:00<00:00, 10464.66it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:06<00:00,  5.04it/s, est. speed input: 2832.04 toks/s, output: 882.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10954.76it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 15/15 [00:03<00:00,  4.07it/s, est. speed input: 2291.86 toks/s, output: 669.60 toks/s]


Processed 15/15 instances.


Finalizing: 100%|██████████| 15/15 [00:00<00:00, 9841.16it/s]


TypeError: 'int' object is not iterable

In [None]:
# The above error is by a line LLM_as_judge_score = sum(LLM_as_judge_score) / len(X_test), which is wrong. The below is line is correct way:
LLM_as_judge_score = LLM_as_judge_score / len(X_test)

In [None]:
print(final_perplexity_baseline)
print(final_rouge1_baseline)
print(final_rouge2_baseline)
print(final_rougeL_baseline)
print(LLM_as_judge_score)

502094.9014365276
0.13686567364345759
0.020479408512765515
0.11643881049054727
1.85007072135785


## Calculate perplexity score with original story and baseline model

In [None]:
%%time
final_perplexity_original_story = 0
for reference in test_original_story_with_endings_dataloader:
  weight_factor = len(reference) / len(X_test)
  perplexity = evaluate.load("perplexity", module_type="metric")
  final_perplexity_original_story += weight_factor * perplexity.compute(predictions=reference, model_id='facebook/bart-large-cnn')["mean_perplexity"];

Downloading builder script:   0%|          | 0.00/8.46k [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 1min 26s, sys: 8.68 s, total: 1min 35s
Wall time: 2min 31s


In [None]:
final_perplexity_original_story

1384072.037551936

## LLM judge socre with original story

In [None]:
%%time
LLM_judge_score_original_story = 0
for original_stories, reference in zip(original_story_bodies, original_story_endings_reference):
  feedback, score = judge.absolute_grade(
      instructions=original_stories,
      responses=reference,
      rubric=score_rubric,
      reference_answers=reference
    )
  LLM_judge_score_original_story += sum(score)
LLM_judge_score_original_story = LLM_judge_score_original_story/len(X_test)
LLM_judge_score_original_story

Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 847.17 toks/s, output: 270.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11756.98it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.49it/s, est. speed input: 813.53 toks/s, output: 261.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10993.34it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 913.16 toks/s, output: 289.65 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11764.20it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 875.97 toks/s, output: 277.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11728.22it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 910.44 toks/s, output: 293.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11065.85it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 921.16 toks/s, output: 273.57 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11483.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 904.21 toks/s, output: 283.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11889.25it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 886.79 toks/s, output: 276.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11481.41it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 906.71 toks/s, output: 291.61 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11789.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 919.23 toks/s, output: 277.88 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11413.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.88it/s, est. speed input: 1024.75 toks/s, output: 298.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12194.96it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 883.79 toks/s, output: 269.72 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12306.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 905.57 toks/s, output: 277.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11711.84it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 988.67 toks/s, output: 298.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11933.65it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 903.15 toks/s, output: 298.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10460.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s, est. speed input: 834.31 toks/s, output: 254.84 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11209.10it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 916.35 toks/s, output: 279.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12241.68it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 850.77 toks/s, output: 267.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10521.93it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.50it/s, est. speed input: 809.28 toks/s, output: 259.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11483.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 935.27 toks/s, output: 286.61 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11895.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 938.01 toks/s, output: 295.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10860.80it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 895.77 toks/s, output: 272.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12376.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 897.14 toks/s, output: 287.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10982.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 901.26 toks/s, output: 290.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11010.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 872.11 toks/s, output: 281.42 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11067.68it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 969.43 toks/s, output: 293.53 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9541.32it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s, est. speed input: 959.65 toks/s, output: 294.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11549.59it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 898.88 toks/s, output: 278.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11741.56it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s, est. speed input: 975.86 toks/s, output: 292.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11889.25it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 938.91 toks/s, output: 286.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7072.28it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 904.29 toks/s, output: 288.86 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11550.58it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 857.37 toks/s, output: 273.24 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10670.83it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 915.50 toks/s, output: 289.44 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11152.28it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 931.09 toks/s, output: 275.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12324.86it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 865.43 toks/s, output: 288.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6402.60it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 960.88 toks/s, output: 285.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10623.53it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 901.64 toks/s, output: 282.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11627.63it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 922.95 toks/s, output: 289.01 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10771.89it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 887.66 toks/s, output: 282.12 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11603.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 917.98 toks/s, output: 292.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11216.59it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 838.22 toks/s, output: 285.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10620.17it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 867.08 toks/s, output: 289.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10112.85it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 904.27 toks/s, output: 282.79 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11172.71it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 900.56 toks/s, output: 279.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10836.24it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 940.23 toks/s, output: 286.54 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12218.27it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.88it/s, est. speed input: 1023.93 toks/s, output: 306.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11337.87it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 924.04 toks/s, output: 291.58 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6970.54it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 910.90 toks/s, output: 290.69 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11551.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 980.99 toks/s, output: 286.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12427.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 871.37 toks/s, output: 284.11 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10799.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 893.41 toks/s, output: 291.14 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6457.12it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 972.45 toks/s, output: 294.31 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11453.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 971.66 toks/s, output: 300.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11101.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 927.86 toks/s, output: 294.64 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11509.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 928.80 toks/s, output: 273.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11933.65it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 846.54 toks/s, output: 283.22 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10676.77it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 902.02 toks/s, output: 289.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10646.29it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 869.19 toks/s, output: 281.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11078.64it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 909.70 toks/s, output: 287.73 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11386.93it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s, est. speed input: 1006.36 toks/s, output: 299.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11528.75it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 881.03 toks/s, output: 292.45 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8484.59it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 937.76 toks/s, output: 293.58 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11589.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 908.32 toks/s, output: 296.77 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9843.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 974.46 toks/s, output: 283.57 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12115.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 863.26 toks/s, output: 283.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10936.02it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 900.70 toks/s, output: 296.18 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7932.02it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 933.52 toks/s, output: 289.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11003.26it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 898.61 toks/s, output: 271.24 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11390.79it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 855.72 toks/s, output: 280.35 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6346.89it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.48it/s, est. speed input: 797.72 toks/s, output: 259.91 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11413.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 933.19 toks/s, output: 287.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10561.67it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 918.72 toks/s, output: 279.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11453.98it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 921.92 toks/s, output: 291.01 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10102.19it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 855.51 toks/s, output: 275.44 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10151.09it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s, est. speed input: 995.51 toks/s, output: 299.46 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12688.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 972.85 toks/s, output: 291.46 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8955.01it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 949.07 toks/s, output: 299.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9603.44it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 872.47 toks/s, output: 275.11 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10758.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 915.67 toks/s, output: 304.27 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10198.92it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 944.61 toks/s, output: 286.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 5270.05it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 856.47 toks/s, output: 283.80 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11349.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 902.66 toks/s, output: 278.91 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8834.76it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 927.18 toks/s, output: 278.49 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11554.56it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 903.52 toks/s, output: 280.99 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11766.26it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 904.50 toks/s, output: 290.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11397.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 905.09 toks/s, output: 274.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11246.67it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 923.81 toks/s, output: 292.66 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10778.81it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 921.16 toks/s, output: 289.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10764.98it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.82it/s, est. speed input: 980.07 toks/s, output: 300.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11796.25it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 929.41 toks/s, output: 282.54 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11439.34it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.47it/s, est. speed input: 797.18 toks/s, output: 273.47 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8983.18it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 835.85 toks/s, output: 286.64 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10118.19it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 869.26 toks/s, output: 277.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10515.33it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 972.56 toks/s, output: 295.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11827.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 980.10 toks/s, output: 286.76 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11456.91it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.90it/s, est. speed input: 1024.98 toks/s, output: 301.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10132.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 866.12 toks/s, output: 281.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11600.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s, est. speed input: 960.73 toks/s, output: 300.89 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11159.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 955.61 toks/s, output: 291.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11518.86it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 861.00 toks/s, output: 279.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11470.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 907.44 toks/s, output: 283.09 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10292.77it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.78it/s, est. speed input: 959.15 toks/s, output: 290.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11952.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 920.19 toks/s, output: 291.80 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8298.88it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 908.02 toks/s, output: 276.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10565.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s, est. speed input: 947.70 toks/s, output: 293.95 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11321.61it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 918.56 toks/s, output: 285.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10415.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 920.59 toks/s, output: 283.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9393.08it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 863.17 toks/s, output: 262.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11471.60it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s, est. speed input: 958.52 toks/s, output: 277.72 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8439.24it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 936.42 toks/s, output: 288.73 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6633.93it/s]
Processed prompts: 100%|██████████| 15/15 [00:08<00:00,  1.68it/s, est. speed input: 913.05 toks/s, output: 282.41 toks/s]


Processed 15/15 instances.


Finalizing: 100%|██████████| 15/15 [00:00<00:00, 10275.12it/s]

CPU times: user 35min 18s, sys: 6.4 s, total: 35min 24s
Wall time: 35min 10s





3.551909476661952

## Check generated outputs by baseline

In [None]:
X_test[0]

'Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town.'

In [None]:
y_test[0]

'Evan knew he looked cool in the new car.'

As we can there are some slight difference between two returned candidates. Also some difference across different temperature. And if we set do_sample=True, even with same temperature, there is also some difference.

In [None]:
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], temperature=0.5, do_sample=True, num_return_sequences=2), skip_special_tokens=True)

['Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. It was a big hit with his friends and family. He was so excited to have a new car to drive.',
 'Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. It was a big hit with his friends and family. He was so happy to have a new car to drive.']

In [None]:
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], temperature=1.2, do_sample=True, num_return_sequences=2), skip_special_tokens=True)

['Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. He said it was the first car he had ever owned that was so fun to drive.',
 'Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. He said it was the first car he had ever owned. He had saved for years to get it.']

In [None]:
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], temperature=1.2, do_sample=True, num_return_sequences=2), skip_special_tokens=True)

['Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. He said he was going to drive it all over the country. He was so excited to have a car of his own.',
 'Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. He said he was going to drive it all over the country. He was so excited to have a new car to drive.']

Now let's see an output withd default config.

In [None]:
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1]), skip_special_tokens=True)

['Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. It was a big hit with his friends and family. He was so excited to have a new car to drive.']

In [None]:
# Example of calling LLM as judge for a small subset of data

instructions = X_test[0:1]
responses = ["It was a big hit with his friends and family. He was so excited to have a new car to drive."]
reference_answers = y_test[0:1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

Processed prompts: 100%|██████████| 1/1 [00:02<00:00,  2.91s/it, est. speed input: 188.89 toks/s, output: 41.97 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 1885.93it/s]

Feedback: ['The response is adequate in terms of language fluency and semantics coherence, as it does not introduce any language errors and maintains the logical flow of the story. However, it falls short in delivering a compelling or imaginative resolution to the narrative. It merely states the fact that the new car was well-received without adding any emotional depth or creativity to the story. Therefore, while the response is well-crafted and coherent, it lacks the originality and depth that would elevate the entire narrative. So the overall score is 3. [RESULT] 3']
Score: [3]





Now try with other test data entries

In [None]:
X_test[-2:-1]

['Ramona was very unhappy in her job. She asked for a raise, but was denied. The refusal prompted her to aggressively comb the want ads. She found an interesting new possibility and set up an interview.']

In [None]:
y_test[-2]

'She was offered the new job at a higher salary.'

In [None]:
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][-2:-1], attention_mask=X_test_inputs["attention_mask"][0:1]), skip_special_tokens=True)

['Ramona was very unhappy in her job. She asked for a raise, but was denied. The refusal prompted her to aggressively comb the want ads. She found an interesting new possibility and took it up. She was able to get a raise and a new job in the process.']

In [None]:
# This random sample has the CNN.com
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][-2:-1], attention_mask=X_test_inputs["attention_mask"][0:1], do_sample=True), skip_special_tokens=True)

["Ramona was very unhappy in her job. She asked for a raise, but was denied. The refusal prompted her to aggressively comb the want ads. She found an interesting new possibility and took it on. Ramona's story will be featured on CNN.com this week."]

In [None]:
instructions = X_test[-2:-1]
responses = ["She was able to get a raise and a new job in the process."]
reference_answers = y_test[-2:-1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.50s/it, est. speed input: 121.78 toks/s, output: 42.22 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 4429.04it/s]

Feedback: ["This response, while satisfactorily wrapping up the story, does not exhibit the same level of detail and creativity seen in the reference response. It simply states that Ramona received a raise and a new job, without specifying what made the new job interesting or how the interview process unfolded. This lack of elaboration results in a somewhat flat ending that doesn't fully capture the dramatic and emotional journey that Ramona embarked on. Additionally, the language used is straightforward and unadorned, which doesn't necessarily enhance the story's themes and characters. On the other hand, the response maintains coherence and does not disrupt the narrative flow, thereby meeting some criteria outlined in the score rubric. Therefore, while it is not the most captivating or emotionally resonant ending, it is still adequate and coherent. So the overall score is 3. \n[RESULT] 3"]
Score: [3]





In [None]:
print(X_test[1500:1501])
print(y_test[1500:1501])
base_tokenizer.batch_decode(base_model.generate(X_test_inputs["input_ids"][1500:1501], attention_mask=X_test_inputs["attention_mask"][0:1]), skip_special_tokens=True)

['I went to the doctor yesterday today for a check up. The doctor told me I needed a shot. When I got the shot I screamed. The pain was unbearable.']
['The only upside was that the pain was quick.']


["I went to the doctor yesterday today for a check up. The doctor told me I needed a shot. When I got the shot I screamed. The pain was unbearable. I was in so much pain that I had to go home and take a nap. I'm still in a lot of pain."]

In [None]:
instructions = X_test[1500:1501]
responses = ["I was in so much pain that I had to go home and take a nap. I'm still in a lot of pain."]
reference_answers = y_test[1500:1501]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.28s/it, est. speed input: 128.58 toks/s, output: 41.38 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5745.62it/s]

Feedback: ["The response effectively continues the storyline, however, it falls short in providing an emotionally resonant conclusion. It relies heavily on a singular aspect of the narrative – the pain – without incorporating any creative or imaginative elements. Additionally, the ending is somewhat abrupt and does not convey a strong sense of satisfaction or closure to the story. Despite this, the language used is clear and free from errors, and it maintains coherence with the story's semantics. However, it fails to enhance the story’s themes or characters. In terms of story flow, the transition from the doctor's advice to the protagonist's reaction at home is somewhat disjointed, but it does not disrupt the overall narrative. Therefore, the response is deemed adequate, but with room for improvement in emotional resonance and creativity. \n[RESULT] 3"]
Score: [3]





## [Archived] Calculate perplexity and rouge with only a subset of test datasets with size=32

It's very memory and time consuming to do generate, 32GB ram crashes and we have to batch. Let's use only a small subset of test data for now.

In [None]:
%%time
# Generate the story ending
with torch.no_grad():
  generated_ids = base_model.generate(X_test_inputs["input_ids"][:32], min_new_tokens=70) #Set min_new_tokens to ensure an ending is generated
generated_ids

CPU times: user 1min 54s, sys: 14.2 s, total: 2min 8s
Wall time: 2min 9s


tensor([[    2,     0,   717,  ...,     1,     1,     1],
        [    2,     0,   104,  ...,     1,     1,     1],
        [    2,     0, 33153,  ...,     1,     1,     1],
        ...,
        [    2,     0, 18031,  ...,     1,     1,     1],
        [    2,     0, 23675,  ...,     1,     1,     1],
        [    2,     0, 24021,  ...,     1,     1,     1]])

In [None]:
# Decode the generated ending
# Please note that currently generated text includes both the input, story beginning, and the generated ending
full_story_with_ending_test = base_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
full_story_with_ending_test[0]

"Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. It was a big hit with his friends and family. He was so excited to have a new car to drive around in. It's been a big success. He's so proud."

In [None]:
perplexity = evaluate.load("perplexity", module_type="metric")
perplexity_results = perplexity.compute(predictions=full_story_with_ending_test, model_id='facebook/bart-large-cnn')

  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
perplexity_results["mean_perplexity"]

208850.79104614258

Let's also see the original story's perplexity score

In [None]:
original_story_with_endings_test = [f"{p1} {p2}" for p1, p2 in zip(X_test, y_test)]
original_story_with_endings_test[0]

'Evan had been saving for years. He went to the dealership and bought a really fancy BMW. Evan was so proud of his new car. He showed it off around town. Evan knew he looked cool in the new car.'

In [None]:
perplexity.compute(predictions=original_story_with_endings_test[:32], model_id='facebook/bart-large-cnn')["mean_perplexity"]

  0%|          | 0/2 [00:00<?, ?it/s]

580600.5565185547

Now also see Rouge. We use Rouge our problem is more like a summarization, more abstractive and open ended; therefore, when comparing to the ereference, we care more about how much reference info has been captured rather than if the whole generated text is strictly alike the reference.

In [None]:
empty_indices = [i for i, s in enumerate(vectorized_remove_first_four_sentences(full_story_with_ending_test)) if s == ""]
empty_indices

[20]

In [None]:
X_test[20]

'Bella came to school with an unexpected present on her desk. To her surprise, she was given a bouquet of beautiful flowers. Bella took a look at the note and it said "Secret Admirer". She took a sniff of the flowers as they smelled wonderful.'

In [None]:
full_story_with_ending_test[20]

'Bella came to school with an unexpected present on her desk. To her surprise, she was given a bouquet of beautiful flowers. Bella took a look at the note and it said "Secret Admirer" She took a sniff of the flowers as they smelled wonderful. They were from her secret admirer, who she has never met.'

In [None]:
full_story_with_ending_test.pop(20)

'Bella came to school with an unexpected present on her desk. To her surprise, she was given a bouquet of beautiful flowers. Bella took a look at the note and it said "Secret Admirer" She took a sniff of the flowers as they smelled wonderful. They were from her secret admirer, who she has never met.'

In [None]:
empty_indices = [i for i, s in enumerate(vectorized_remove_first_four_sentences(full_story_with_ending_test)) if s == ""]
empty_indices

[]

In [None]:
ending_ref_test = y_test[:32]
ending_ref_test.pop(20)
len(ending_ref_test)

31

In [None]:
rouge = evaluate.load('rouge')
rouge_results = rouge.compute(predictions=vectorized_remove_first_four_sentences(full_story_with_ending_test), references=ending_ref_test)

In [None]:
rouge_results

{'rouge1': 0.13083555512743467,
 'rouge2': 0.0238949919279419,
 'rougeL': 0.11233191004109086,
 'rougeLsum': 0.11360697325944744}

# Bart Model Fine Tuning

## Fine Tuning BART

In [None]:
model_name = "facebook/bart-large-cnn"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = BartTokenizer.from_pretrained(model_name)

In [None]:
train_encodings = tokenizer(ROCX_train, padding=True)
val_encodings = tokenizer(ROCX_val, padding=True)

In [None]:
train_labels = tokenizer(ROCy_train, padding=True, return_tensors="pt")['input_ids']
val_labels = tokenizer(ROCy_val, padding=True, return_tensors="pt")['input_ids']

In [None]:
train_dataset = Dataset.from_dict({
    'input_ids': train_encodings['input_ids'],
    'attention_mask': train_encodings['attention_mask'],
    'labels': train_labels
})

val_dataset = Dataset.from_dict({
    'input_ids': val_encodings['input_ids'],
    'attention_mask': val_encodings['attention_mask'],
    'labels': val_labels
})

In [None]:
training_args = TrainingArguments(
    report_to="none",
    output_dir='./fine_tuned_bart_checkpoints',
    save_strategy="steps",
    eval_strategy="steps",     # Evaluation frequency
    save_steps=2000,
    eval_steps=2000,
    save_total_limit=2,
    load_best_model_at_end=True,
    learning_rate=5e-5,              # Learning rate
    per_device_train_batch_size=8,   # Batch size per GPU for training
    per_device_eval_batch_size=16,   # Batch size per GPU for evaluation
    num_train_epochs=3,              # Number of training epochs
    weight_decay=0.01,               # Strength of weight decay
)

torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
%%time
trainer.train()

Step,Training Loss,Validation Loss
2000,1.2457,1.206727
4000,1.2116,1.165745
6000,1.1912,1.168289
8000,1.1678,1.172269
10000,0.9632,1.179446
12000,0.8894,1.126464
14000,0.8875,1.130638
16000,0.8899,1.109606
18000,0.8828,1.143582
20000,0.6268,1.249945


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


CPU times: user 1h 17min 55s, sys: 1min 40s, total: 1h 19min 35s
Wall time: 1h 20min 35s


TrainOutput(global_step=28860, training_loss=0.9068856665597388, metrics={'train_runtime': 4834.8855, 'train_samples_per_second': 47.751, 'train_steps_per_second': 5.969, 'total_flos': 4.250779274774938e+16, 'train_loss': 0.9068856665597388, 'epoch': 3.0})

In [None]:
model.save_pretrained("./drive/MyDrive/fine_tuned_bart")

## Evaluate fine tuned BART model

In [None]:
fine_tuned_model_path = "./drive/MyDrive/fine_tuned_bart"

# The model loading doesn't seem to work
#fine_tuned_model = BartForConditionalGeneration.from_pretrained(fine_tuned_model_path, revision="safetensors")
fine_tuned_model=model

In [None]:
fine_tuned_model.to("cuda")

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [None]:
%%time
final_perplexity_baseline = 0
final_rouge1_baseline = 0
final_rouge2_baseline = 0
final_rougeL_baseline = 0

inputs_to_be_generated_data = X_test
inputs_to_be_generated_dataloader = DataLoader(inputs_to_be_generated_data, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)
model_to_use = fine_tuned_model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for inputs_to_be_generated, reference in zip(inputs_to_be_generated_dataloader, original_story_endings_reference):
    input_and_attention_mask = tokenizer(inputs_to_be_generated, padding=True, return_tensors="pt")
    input_and_attention_mask.to("cuda")
    generated_ids = model_to_use.generate(inputs=input_and_attention_mask["input_ids"], attention_mask=input_and_attention_mask["attention_mask"], max_length=20, early_stopping=True)
    ending = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    full_story_with_ending = concat_story_body_with_endings(inputs_to_be_generated, ending)
    # Calculate perplexity
    weight_factor = len(reference) / len(inputs_to_be_generated_data)
    final_perplexity_baseline += weight_factor * perplexity.compute(predictions=full_story_with_ending, model_id='facebook/bart-large-cnn')["mean_perplexity"]
    # Calculate rouge
    rouge_results = rouge.compute(predictions=ending, references=reference)
    final_rouge1_baseline += weight_factor * rouge_results["rouge1"]
    final_rouge2_baseline += weight_factor * rouge_results["rouge2"]
    final_rougeL_baseline += weight_factor * rouge_results["rougeL"]



  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 3min 42s, sys: 8.91 s, total: 3min 51s
Wall time: 4min 41s


In [None]:
print(final_perplexity_baseline) #epoch 3
print(final_rouge1_baseline)
print(final_rouge2_baseline)
print(final_rougeL_baseline)

1100737.7915157855
0.22182885500760363
0.05309287064959933
0.19703626011511763


In [None]:
%%time
LLM_as_judge_score_fine_tuned = 0

X_test_inputs_dataloader = DataLoader(TensorDataset(X_test_inputs["input_ids"], X_test_inputs["attention_mask"]), batch_size=32)
original_story_bodies = DataLoader(X_test, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)
model_to_use = fine_tuned_model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for input_and_attention_mask, original_stories, reference in zip(X_test_inputs_dataloader, original_story_bodies, original_story_endings_reference):
    generated_ids = model_to_use.generate(inputs=input_and_attention_mask[0], attention_mask=input_and_attention_mask[1], max_length=20, early_stopping=True)
    generated_endings = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    feedback, score = judge.absolute_grade(
      instructions=original_stories,
      responses=generated_endings,
      rubric=score_rubric,
      reference_answers=reference
    )
    LLM_as_judge_score_fine_tuned += sum(score)
LLM_as_judge_score_fine_tuned = LLM_as_judge_score_fine_tuned/len(X_test)
LLM_as_judge_score_fine_tuned

Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 879.15 toks/s, output: 286.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10209.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 881.72 toks/s, output: 288.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11010.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 965.92 toks/s, output: 291.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11439.34it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 887.96 toks/s, output: 285.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11186.68it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 962.56 toks/s, output: 293.88 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10647.98it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 945.84 toks/s, output: 294.28 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10376.32it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 906.33 toks/s, output: 285.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11053.09it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 894.61 toks/s, output: 274.73 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11377.28it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 959.95 toks/s, output: 282.83 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11941.08it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 921.62 toks/s, output: 280.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11926.22it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 962.78 toks/s, output: 298.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11538.66it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 923.94 toks/s, output: 287.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11307.31it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 922.74 toks/s, output: 288.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11523.80it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 915.32 toks/s, output: 291.58 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10972.67it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 907.85 toks/s, output: 260.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11845.18it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s, est. speed input: 844.44 toks/s, output: 290.13 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10475.94it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s, est. speed input: 836.91 toks/s, output: 270.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10799.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 903.58 toks/s, output: 288.46 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10847.63it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 929.48 toks/s, output: 288.31 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10647.98it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 932.21 toks/s, output: 285.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11758.01it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 949.85 toks/s, output: 286.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11385.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 921.56 toks/s, output: 290.47 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11353.22it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 939.40 toks/s, output: 286.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10266.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 885.84 toks/s, output: 279.62 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11240.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 871.71 toks/s, output: 272.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11474.54it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 854.81 toks/s, output: 271.45 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11487.31it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 947.05 toks/s, output: 289.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11851.46it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 934.59 toks/s, output: 285.58 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11793.14it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 942.03 toks/s, output: 282.23 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11782.79it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 879.39 toks/s, output: 266.71 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12161.81it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.43it/s, est. speed input: 779.81 toks/s, output: 263.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10691.23it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 904.80 toks/s, output: 272.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10691.23it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 876.12 toks/s, output: 278.27 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10932.45it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 888.37 toks/s, output: 277.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11132.86it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 845.02 toks/s, output: 279.17 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10027.47it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 853.11 toks/s, output: 254.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11091.46it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 907.90 toks/s, output: 290.18 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10546.73it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 924.17 toks/s, output: 289.02 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11501.09it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 875.76 toks/s, output: 266.67 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11876.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 992.40 toks/s, output: 300.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11496.17it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 891.45 toks/s, output: 284.29 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10803.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 946.73 toks/s, output: 292.50 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11449.09it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 934.89 toks/s, output: 291.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11409.19it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 877.79 toks/s, output: 281.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11180.15it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 953.17 toks/s, output: 279.95 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12169.53it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 984.99 toks/s, output: 300.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11729.24it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 953.66 toks/s, output: 278.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12213.83it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 849.23 toks/s, output: 277.49 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10908.46it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 885.79 toks/s, output: 290.99 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9283.28it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 879.46 toks/s, output: 282.63 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11370.53it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 881.85 toks/s, output: 277.88 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10903.15it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 902.77 toks/s, output: 273.02 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11832.65it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 914.08 toks/s, output: 292.12 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10681.87it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 848.48 toks/s, output: 280.03 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10615.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 886.25 toks/s, output: 289.28 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11109.82it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 866.16 toks/s, output: 265.60 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10293.56it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 846.10 toks/s, output: 266.47 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11034.92it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 878.17 toks/s, output: 280.60 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10212.88it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 927.65 toks/s, output: 292.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10065.83it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 916.97 toks/s, output: 283.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11160.63it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 884.21 toks/s, output: 286.23 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10768.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 915.05 toks/s, output: 279.08 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10565.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 937.77 toks/s, output: 283.36 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12088.42it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 866.77 toks/s, output: 282.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10531.84it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 895.40 toks/s, output: 278.89 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11425.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.43it/s, est. speed input: 780.18 toks/s, output: 256.87 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11187.61it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 869.53 toks/s, output: 280.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10881.93it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 914.64 toks/s, output: 282.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11425.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 883.14 toks/s, output: 283.23 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11480.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 911.78 toks/s, output: 274.01 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11418.90it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 917.36 toks/s, output: 282.95 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10900.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 966.61 toks/s, output: 295.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11502.08it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 910.71 toks/s, output: 286.09 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11251.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 926.19 toks/s, output: 283.67 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9482.67it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 931.25 toks/s, output: 275.20 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10340.35it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 894.07 toks/s, output: 285.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10636.16it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 882.97 toks/s, output: 277.31 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11421.81it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s, est. speed input: 821.83 toks/s, output: 260.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10444.96it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 871.43 toks/s, output: 274.36 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10817.90it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s, est. speed input: 826.34 toks/s, output: 254.45 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11325.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 870.08 toks/s, output: 262.63 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11442.26it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 888.87 toks/s, output: 283.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11218.47it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 892.49 toks/s, output: 273.23 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11212.84it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 955.33 toks/s, output: 294.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10978.96it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 899.33 toks/s, output: 297.97 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9273.66it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 906.44 toks/s, output: 278.47 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10537.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 855.18 toks/s, output: 271.86 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11128.24it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 891.67 toks/s, output: 281.16 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11175.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 889.99 toks/s, output: 280.17 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9909.76it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 933.70 toks/s, output: 285.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11780.72it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 901.08 toks/s, output: 279.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8687.23it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s, est. speed input: 974.70 toks/s, output: 300.91 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11404.34it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 952.02 toks/s, output: 278.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12122.27it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 869.18 toks/s, output: 272.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11453.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 945.34 toks/s, output: 288.69 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10880.17it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 918.13 toks/s, output: 291.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11576.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 931.44 toks/s, output: 284.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10331.59it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 879.97 toks/s, output: 277.05 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10155.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 924.78 toks/s, output: 279.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7734.11it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 888.52 toks/s, output: 268.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10605.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 919.47 toks/s, output: 291.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11293.04it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 846.26 toks/s, output: 267.62 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10188.08it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 870.86 toks/s, output: 260.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9389.14it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 928.60 toks/s, output: 290.00 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11422.79it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 862.45 toks/s, output: 279.74 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10900.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 890.72 toks/s, output: 279.51 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9196.14it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 913.20 toks/s, output: 288.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10767.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 874.15 toks/s, output: 286.67 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6557.12it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 880.16 toks/s, output: 277.84 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6594.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 866.03 toks/s, output: 277.25 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7138.48it/s]
Processed prompts: 100%|██████████| 15/15 [00:09<00:00,  1.57it/s, est. speed input: 860.84 toks/s, output: 256.70 toks/s]


Processed 15/15 instances.


Finalizing: 100%|██████████| 15/15 [00:00<00:00, 7637.12it/s]

CPU times: user 38min 5s, sys: 5.29 s, total: 38min 10s
Wall time: 37min 59s





2.039321074964639

In [None]:
LLM_as_judge_score_fine_tuned

2.039321074964639

In [None]:
instructions = X_test[0:1]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[0:1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

['Evan was so happy with his new car.   car!   ']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.76s/it, est. speed input: 143.85 toks/s, output: 43.69 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5349.88it/s]

Feedback: ["The generated response is an attempt to capture the excitement of Evan regarding his new car. However, the lack of proper punctuation and sentence structure makes it difficult to read and could be interpreted as an exclamation about the car itself, rather than Evan's happiness. Additionally, the brevity of the response lacks depth and fails to provide a satisfying conclusion to the story, which could potentially leave readers feeling unfulfilled. It also disrupts the narrative flow by not directly addressing Evan's actions and feelings towards the end. Thus, while the response attempts to reflect the initial instruction, it falls short of meeting the expectations of a well-crafted ending that is fluent, coherent, and emotionally resonant. So the overall score is 2. [RESULT] 2"]
Score: [2]





In [None]:
instructions = X_test[1500:1501]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][1500:1501], attention_mask=X_test_inputs["attention_mask"][1500:1501], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[1500:1501]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

['I had to go home and rest for the rest of the day. yelp']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.53s/it, est. speed input: 152.92 toks/s, output: 43.33 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5637.51it/s]

Feedback: ['The response provides a continuation of the story but does so in a manner that is disjointed from the original narrative. It fails to capture the gravity of the situation presented in the instruction and does not demonstrate a coherent understanding of the emotional turmoil experienced by the protagonist. Furthermore, the use of "yelp" at the end appears as an abrupt and irrelevant interjection that does not contribute to the story\'s progression. The lack of context, coupled with the abrupt ending, disrupts the narrative flow and leaves the reader unsatisfied. This response fails to meet the standards of language fluency, semantics coherence, and story flow as required by the score rubric. \n[RESULT] 1']
Score: [1]





In [None]:
instructions = X_test[-2:-1]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][-2:-1], attention_mask=X_test_inputs["attention_mask"][-2:-1], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[-2:-1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

['Ramona was offered the new job and was very happy with it, after all.']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.37s/it, est. speed input: 163.33 toks/s, output: 42.98 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5562.74it/s]

Feedback: ["The response successfully concludes the story with the intended outcome. Ramona finds happiness in a new job, which aligns with the story's arc. However, the language used is too simplistic and does not provide a strong emotional resonance with the reader. The response also lacks a creative or imaginative element that could have further enriched the narrative. Despite these shortcomings, the conclusion is logically connected to the story, ensuring that the reader is left satisfied. Therefore, while the response meets the basic requirements for story flow and coherence, it falls short in providing a compelling ending with rich language and emotional depth. Hence, the overall score is 3. [RESULT] 3"]
Score: [3]





In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, do_sample=False), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=0.8, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=0.3, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=1.5, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=5, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=20, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=80, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.8, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.5, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()

for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.3, do_sample=True), skip_special_tokens=True))
  print()
  print("-----------------------------------------------------------------------")
  print()


["Evan's friends and family loved his new car, and he was happy. "]

-----------------------------------------------------------------------

['Evan was so happy with his new car and was excited to drive it home.']

-----------------------------------------------------------------------

["Evan was so happy with his new car and couldn't wait to drive it."]

-----------------------------------------------------------------------

['Evan was very happy with his new car.   it was a new model']

-----------------------------------------------------------------------

['Evan was so happy with his new car.    new car and was']

-----------------------------------------------------------------------

['Evan was so happy with his new car!   car is his favorite!']

-----------------------------------------------------------------------

['Evan was so happy with his new car.   car was very expensive.']

-----------------------------------------------------------------------

['Evan was so happy 

## Further fine tune the BART Model

Previously the best model was selected around 19000th step, which is roughly 2 epoch (I deduced this by looking at the training loss, around 0.7, when continuing fine tuning, at the time when I didn't change the per_device_train_batch_size yet).

In [None]:
training_args = TrainingArguments(
    report_to="none",
    output_dir='./fine_tuned_bart_checkpoints',
    save_strategy="steps",
    eval_strategy="steps",     # Evaluation frequency
    save_steps=2000,
    eval_steps=2000,
    save_total_limit=1,
    learning_rate=5e-5,              # Learning rate
    per_device_train_batch_size=32,   # Batch size per GPU for training
    per_device_eval_batch_size=64,   # Batch size per GPU for evaluation
    num_train_epochs=6,              # Number of training epochs
    weight_decay=0.01,               # Strength of weight decay
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch


In [None]:
%%time
trainer.train()

Step,Training Loss,Validation Loss
2000,0.7393,1.244433
4000,0.4776,1.419745
6000,0.3343,1.604613
8000,0.2139,1.704297
10000,0.1564,1.926931
12000,0.137,1.926873
14000,0.0877,2.039231


CPU times: user 1h 18min 45s, sys: 3min 34s, total: 1h 22min 19s
Wall time: 1h 22min 4s


TrainOutput(global_step=14430, training_loss=0.321843570648426, metrics={'train_runtime': 4923.5888, 'train_samples_per_second': 93.782, 'train_steps_per_second': 2.931, 'total_flos': 8.501558549549875e+16, 'train_loss': 0.321843570648426, 'epoch': 6.0})

From the validation loss, we can see increasing epoch doesn't do anything good, therefore the previous model trained with around 2 epochs is our selected model. Just out of curiosity, to see that validation loss indeed reflects model performance, I still apply evaluation metrics over test data so that we can eyeball the generated output and have a glimpse on the metrics. However, please understand that by no means we are using the metrics from test data for model selection.

In [None]:
fine_tuned_model=model
fine_tuned_model.to("cuda")

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [None]:
%%time
final_perplexity_baseline = 0
final_rouge1_baseline = 0
final_rouge2_baseline = 0
final_rougeL_baseline = 0

inputs_to_be_generated_data = X_test
inputs_to_be_generated_dataloader = DataLoader(inputs_to_be_generated_data, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)
model_to_use = fine_tuned_model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for inputs_to_be_generated, reference in zip(inputs_to_be_generated_dataloader, original_story_endings_reference):
    input_and_attention_mask = tokenizer(inputs_to_be_generated, padding=True, return_tensors="pt")
    input_and_attention_mask.to("cuda")
    generated_ids = model_to_use.generate(inputs=input_and_attention_mask["input_ids"], attention_mask=input_and_attention_mask["attention_mask"], max_length=20, early_stopping=True)
    ending = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    full_story_with_ending = concat_story_body_with_endings(inputs_to_be_generated, ending)
    # Calculate perplexity
    weight_factor = len(reference) / len(inputs_to_be_generated_data)
    final_perplexity_baseline += weight_factor * perplexity.compute(predictions=full_story_with_ending, model_id='facebook/bart-large-cnn')["mean_perplexity"]
    # Calculate rouge
    rouge_results = rouge.compute(predictions=ending, references=reference)
    final_rouge1_baseline += weight_factor * rouge_results["rouge1"]
    final_rouge2_baseline += weight_factor * rouge_results["rouge2"]
    final_rougeL_baseline += weight_factor * rouge_results["rougeL"]



  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 3min 41s, sys: 8.87 s, total: 3min 50s
Wall time: 4min 20s


In [None]:
print(final_perplexity_baseline)
print(final_rouge1_baseline)
print(final_rouge2_baseline)
print(final_rougeL_baseline)

1065652.2555499696
0.20002404282228245
0.04110460788753432
0.1759953809377672


In [None]:
%%time
LLM_as_judge_score_fine_tuned = 0

X_test_inputs_dataloader = DataLoader(TensorDataset(X_test_inputs["input_ids"], X_test_inputs["attention_mask"]), batch_size=32)
original_story_bodies = DataLoader(X_test, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)
model_to_use = fine_tuned_model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for input_and_attention_mask, original_stories, reference in zip(X_test_inputs_dataloader, original_story_bodies, original_story_endings_reference):
    generated_ids = model_to_use.generate(inputs=input_and_attention_mask[0], attention_mask=input_and_attention_mask[1], max_length=20, early_stopping=True)
    generated_endings = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    feedback, score = judge.absolute_grade(
      instructions=original_stories,
      responses=generated_endings,
      rubric=score_rubric,
      reference_answers=reference
    )
    LLM_as_judge_score_fine_tuned += sum(score)
LLM_as_judge_score_fine_tuned = LLM_as_judge_score_fine_tuned/len(X_test)
LLM_as_judge_score_fine_tuned

Processed prompts: 100%|██████████| 62/62 [00:34<00:00,  1.80it/s, est. speed input: 988.24 toks/s, output: 309.14 toks/s]


Processed 62/32 instances.


Finalizing: 100%|██████████| 62/62 [00:00<00:00, 11784.42it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 934.14 toks/s, output: 283.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11969.83it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 896.89 toks/s, output: 288.14 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11038.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s, est. speed input: 827.82 toks/s, output: 267.07 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11660.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 936.83 toks/s, output: 278.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11652.87it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.50it/s, est. speed input: 823.37 toks/s, output: 257.27 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11481.41it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 888.53 toks/s, output: 283.50 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10589.17it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 934.24 toks/s, output: 286.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12035.31it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 839.97 toks/s, output: 270.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11226.91it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 877.87 toks/s, output: 268.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11677.20it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 882.35 toks/s, output: 276.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11584.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 877.21 toks/s, output: 274.62 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10604.23it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 924.72 toks/s, output: 291.58 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11873.47it/s]
Processed prompts: 100%|██████████| 32/32 [00:23<00:00,  1.35it/s, est. speed input: 744.38 toks/s, output: 240.63 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11219.40it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.57it/s, est. speed input: 863.90 toks/s, output: 289.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10357.90it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 867.61 toks/s, output: 276.79 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11381.14it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 842.38 toks/s, output: 268.36 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11664.01it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 930.18 toks/s, output: 283.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12063.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.43it/s, est. speed input: 781.70 toks/s, output: 248.71 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11829.52it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 911.73 toks/s, output: 267.10 toks/s]


Retrying failed batches: Attempt 1/10


Processed prompts: 100%|██████████| 1/1 [00:02<00:00,  2.42s/it, est. speed input: 225.98 toks/s, output: 40.97 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12688.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 848.96 toks/s, output: 272.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11367.64it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 896.25 toks/s, output: 280.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11168.06it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 845.94 toks/s, output: 263.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11151.36it/s]
Processed prompts: 100%|██████████| 32/32 [00:23<00:00,  1.34it/s, est. speed input: 733.84 toks/s, output: 234.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11892.41it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 896.26 toks/s, output: 273.18 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11100.63it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 842.28 toks/s, output: 270.35 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10755.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 897.32 toks/s, output: 291.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11614.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 919.95 toks/s, output: 298.87 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11476.51it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s, est. speed input: 893.50 toks/s, output: 274.28 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11075.90it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 900.42 toks/s, output: 281.86 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11852.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s, est. speed input: 895.81 toks/s, output: 278.77 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12090.60it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 895.98 toks/s, output: 289.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11375.35it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.68it/s, est. speed input: 924.48 toks/s, output: 295.77 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11046.73it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 920.15 toks/s, output: 289.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6778.33it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 850.39 toks/s, output: 274.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9714.66it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.57it/s, est. speed input: 864.19 toks/s, output: 274.69 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10355.51it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 885.16 toks/s, output: 276.14 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10855.53it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s, est. speed input: 825.02 toks/s, output: 271.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9753.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.50it/s, est. speed input: 818.40 toks/s, output: 261.11 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11155.06it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 924.83 toks/s, output: 283.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10046.99it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 863.13 toks/s, output: 282.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10181.89it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 909.34 toks/s, output: 276.61 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10436.03it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 934.37 toks/s, output: 280.72 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12467.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 919.48 toks/s, output: 293.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11356.10it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 872.44 toks/s, output: 270.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12162.91it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 922.48 toks/s, output: 292.83 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11681.26it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 946.68 toks/s, output: 287.23 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8806.36it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 894.54 toks/s, output: 286.76 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10558.35it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s, est. speed input: 836.37 toks/s, output: 264.89 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11168.99it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 905.22 toks/s, output: 289.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11299.69it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 958.45 toks/s, output: 282.37 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12455.25it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 872.71 toks/s, output: 274.29 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11367.64it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 888.68 toks/s, output: 277.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11154.14it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 859.34 toks/s, output: 270.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11642.76it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 896.47 toks/s, output: 272.18 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11260.82it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 898.47 toks/s, output: 282.97 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11181.08it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 913.14 toks/s, output: 276.79 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11351.30it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 868.45 toks/s, output: 281.73 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11530.73it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.57it/s, est. speed input: 859.38 toks/s, output: 280.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11057.65it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s, est. speed input: 837.47 toks/s, output: 267.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10197.37it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 868.82 toks/s, output: 266.13 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11383.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s, est. speed input: 844.01 toks/s, output: 262.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11338.83it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 874.89 toks/s, output: 268.73 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10942.26it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 870.48 toks/s, output: 270.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11713.89it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.47it/s, est. speed input: 804.77 toks/s, output: 261.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11312.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 868.78 toks/s, output: 267.45 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12271.90it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.42it/s, est. speed input: 778.15 toks/s, output: 253.38 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10837.99it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.45it/s, est. speed input: 792.52 toks/s, output: 266.36 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10385.15it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 873.81 toks/s, output: 269.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11255.16it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.52it/s, est. speed input: 833.41 toks/s, output: 259.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10813.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s, est. speed input: 880.97 toks/s, output: 275.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10275.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.52it/s, est. speed input: 832.25 toks/s, output: 258.72 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10824.88it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.45it/s, est. speed input: 795.62 toks/s, output: 259.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11355.14it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 928.08 toks/s, output: 269.29 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12488.86it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 890.75 toks/s, output: 287.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11483.38it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.52it/s, est. speed input: 831.42 toks/s, output: 271.98 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10983.45it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.47it/s, est. speed input: 804.42 toks/s, output: 260.49 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10428.73it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.39it/s, est. speed input: 759.40 toks/s, output: 254.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10826.63it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 854.99 toks/s, output: 273.66 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8294.77it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.57it/s, est. speed input: 859.01 toks/s, output: 266.03 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11148.58it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.50it/s, est. speed input: 824.37 toks/s, output: 258.86 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11512.93it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 910.21 toks/s, output: 280.63 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10838.87it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.49it/s, est. speed input: 821.55 toks/s, output: 265.44 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11170.85it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.52it/s, est. speed input: 831.02 toks/s, output: 278.23 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10390.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 856.03 toks/s, output: 280.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11266.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 844.16 toks/s, output: 272.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10574.99it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.55it/s, est. speed input: 852.07 toks/s, output: 270.03 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7560.71it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 949.24 toks/s, output: 286.59 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6915.59it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s, est. speed input: 850.14 toks/s, output: 262.76 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10683.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 866.51 toks/s, output: 271.07 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11196.94it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 873.15 toks/s, output: 277.54 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11631.66it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 864.28 toks/s, output: 280.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9890.04it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s, est. speed input: 875.02 toks/s, output: 279.02 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10631.11it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 932.24 toks/s, output: 278.46 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10877.52it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 935.35 toks/s, output: 280.11 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12349.81it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 902.32 toks/s, output: 271.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11417.93it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 931.36 toks/s, output: 281.11 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12432.17it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.54it/s, est. speed input: 848.69 toks/s, output: 283.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10692.94it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 947.97 toks/s, output: 286.68 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11296.84it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 905.76 toks/s, output: 278.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 8071.79it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 898.36 toks/s, output: 283.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11371.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s, est. speed input: 877.76 toks/s, output: 279.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11133.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 897.16 toks/s, output: 278.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11450.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 898.97 toks/s, output: 277.46 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6525.56it/s]
Processed prompts: 100%|██████████| 32/32 [00:23<00:00,  1.35it/s, est. speed input: 734.05 toks/s, output: 251.47 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9963.46it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.50it/s, est. speed input: 822.61 toks/s, output: 273.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9635.85it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 954.08 toks/s, output: 279.20 toks/s]


Retrying failed batches: Attempt 1/10


Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.24s/it, est. speed input: 130.81 toks/s, output: 41.72 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11708.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.45it/s, est. speed input: 793.66 toks/s, output: 255.79 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11130.09it/s]
Processed prompts: 100%|██████████| 32/32 [00:22<00:00,  1.44it/s, est. speed input: 786.95 toks/s, output: 253.40 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 5302.74it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s, est. speed input: 826.72 toks/s, output: 260.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10544.25it/s]
Processed prompts: 100%|██████████| 15/15 [00:10<00:00,  1.49it/s, est. speed input: 816.22 toks/s, output: 249.79 toks/s]


Processed 15/15 instances.


Finalizing: 100%|██████████| 15/15 [00:00<00:00, 3877.39it/s]

CPU times: user 39min 48s, sys: 7.83 s, total: 39min 56s
Wall time: 39min 40s





1.8373408769448374

In [None]:
instructions = X_test[0:1]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[0:1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

['All of his friends complimented him on his new car. krune/k']


Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.08s/it, est. speed input: 132.75 toks/s, output: 42.62 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5769.33it/s]

Feedback: ['The provided response does not meet the expectations of the rubric. It presents a sentence that is disjointed and incoherent with the rest of the story. The language error "krune/k" disrupts the narrative flow and confuses the reader. This sudden and unexplained change in vocabulary does not contribute to a satisfactory conclusion. Furthermore, the sentence lacks the necessary elements to make it a suitable ending, such as detailing Evan\'s emotions or reactions, which would have enhanced the story\'s themes and characters. Overall, the response fails to tie loose ends or provide an imaginative or emotionally resonant conclusion to the story. It does not fulfill the criteria for language fluency, semantics coherence, and story flow, as specified in the score rubric. \n[RESULT] 1']
Score: [1]





In [None]:
instructions = X_test[1500:1501]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][1500:1501], attention_mask=X_test_inputs["attention_mask"][1500:1501], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[1500:1501]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

['I then proceeded to scream for the rest of the visit. B. Belly B']


Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.60s/it, est. speed input: 117.76 toks/s, output: 42.88 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 4369.07it/s]

Feedback: ['The given response falls short of the expectations outlined in the score rubric. Firstly, it lacks language fluency as it does not construct a coherent sentence. The structure "B. Belly B" seems to be a contraction of "Belly B" with a mention of "scream" in the middle, which is neither grammatically correct nor contextually appropriate. This abruptness disrupts the narrative flow, making the reader confused. Additionally, the response does not offer a satisfying resolution to the story, which was expected as per the instruction. The pain and the subsequent scream, while integral to the story, are not developed further in the response. Therefore, the ending feels incomplete and awkward. Based on these aspects, the response fails to meet the criteria of the score rubric for language fluency, semantics coherence, and story flow, hence, it receives a score of 1. \n[RESULT] 1']
Score: [1]





In [None]:
instructions = X_test[-2:-1]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][-2:-1], attention_mask=X_test_inputs["attention_mask"][-2:-1], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[-2:-1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

['After the interview, the company offered Ramona a raise and she accepted! \\Pri']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.46s/it, est. speed input: 159.48 toks/s, output: 42.55 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5667.98it/s]

Feedback: ["This response offers a clear and positive resolution to Ramona's story, indicating that she received the raise she desired from her new employer. However, the response could have been more detailed and elaborative in terms of how the situation unfolded after the interview. This omission slightly impacts the depth of the narrative and the emotional resonance of the conclusion. Additionally, while the language is clear and free of errors, the response lacks a certain level of creativity in storytelling. Nevertheless, the overall flow and coherence of the story are maintained, making the conclusion satisfactory. Thus, while this response is competent, it falls short of achieving exceptional originality or depth. \n[RESULT] 3"]
Score: [3]





In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True), skip_special_tokens=True))



["Evan couldn't wait for people to see his new car! \\¢"]
["Evan was so happy he'd saved so much money for the new car.\\"]
["Evan couldn't wait for people to see his new car! krinkle."]
['A neighbor even said he looked just like Evan! k. s. k']
['People were amazed and loved his new car. kr. k. k']
["Evan couldn't wait for everyone to see his new car! k.r"]
['All his friends and family are impressed by his new wheels. kludger.']
["Evan couldn't wait for people to see his new car! \\¢"]
['All his friends and family were jealous of the new car he had bought/ bought.']
["Evan couldn't wait for people to see the new car!\\\\"]


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, do_sample=False), skip_special_tokens=True))

["Evan was glad he'd waited so long to buy the BMW. kludge"]
["Evan couldn't wait for everyone to see his brand new car! kr."]
['Evan was so excited to have a new car. k. s. kl']
["Evan couldn't wait for people to see his new car! k. s."]
['Evan was glad he had waited so long to buy the fancy car. k.']
['All his friends were impressed with his new wheels. kludge klod.']
['A few people thought it was a new car, and bought a pretty one.\\']
['A lot of people complimented him on his new ride. k. s. k']
["Evan couldn't wait for people to see his new car! klans b"]
['Evan was glad he had waited so long to buy his new car! \\Pri']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=0.8, do_sample=True), skip_special_tokens=True))

['A neighbor even offered to drive him home in his new blouse! kB.']
['Evan felt very accomplished to have a car that looked like a Lamborghini.']
['A lot of people were impressed with the new wheels. klaspy kl']
["Evan couldn't believe he'd finally saved enough to buy a new car!\\"]
["Evan was so happy he hadn't expected to be so proud!\\\\"]
["Evan couldn't wait for people to see the new car. k. s."]
['All of his friends thought it was the most beautiful car they had ever seen. \\']
["Evan couldn't wait for everyone to see his new car!\\Bundert"]
["Evan was glad he'd spent so much money on the new car. k."]
["Evan was so happy, he didn't care how old it was even though it"]


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=0.3, do_sample=True), skip_special_tokens=True))

['Evan was glad he had waited so long to buy the car. k. k']
['All the other residents thought it was a new thing to see! k.d.']
["Evan couldn't wait for all his friends to see the new car! k."]
['Evan was so happy he had waited so long for the car!Prize!']
['A lot of people drove by and thought it was very nice.  It was a']
['A lot of people showed up to see how fancy it was. k. s.']
['A lot of people were impressed with his new wheels! k. s b. k']
['People were amazed at his beauty and he was very proud too. krr.']
["Evan couldn't wait for people to see the new car! kranks k"]
["Evan couldn't wait for people to see the new wheels!\\tB"]


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=1.5, do_sample=True), skip_special_tokens=True))

["Evan knew people looked at it differently than any other car he'd owned.\\"]
["People were so impressed they couldn't recognize it was a new car. \\Derek"]
['When it was time for dinner, Evan and his friends feigned illness.Burg']
['A famous photographer even snapped a photo. s he got in his new car. k']
["Evan couldn't believe people didn't know it was worth hundreds of thousands.\\"]
['Evan was glad he had waited so long for his new car! k/k']
['A lot of people thought it was a new trend and joined in on the fun!']
["Evan couldn't believe how beautiful it was! kimim bakes ketchup"]
['A lot of people were impressed with his new wheels. k.e. kl']
['A few weeks later, the city took his picture and posted it on FB. k']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=5, do_sample=True), skip_special_tokens=True))

['All his friends gushed over how beautiful it was. s. k. t.']
["Evan couldn't wait for everyone to see his new wheels! B.B."]
['Evan was glad he had waited so long to buy his new car!\\t']
["Evan couldn't wait for people to see his new car!  New York Times"]
['A lot of people commented on the car and said it was very nice. k.']
['Evan was so thrilled with his new car! k.eepley. k']
['Evan wanted to get his girlfriend to buy the car with him! Bump,']
["Evan was happy he'd saved enough to buy such a fancy car. \\¢"]
['All his friends and family were impressed with the new wheels!\\MAYONAY']
['All of his friends and family were impressed with his new ride.\\d BSO']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=20, do_sample=True), skip_special_tokens=True))

["Evan was glad he'd waited so long to buy the new car. k."]
['A few weeks later, he sold it for a five thousand dollar profit! \\K']
['All his friends and girlfriend thought it was a great new car. kranks k']
['A few weeks later, it was sold to someone else for a lot of money.']
["Evan couldn't wait for everyone to see his new wheels! klaspal"]
['All the other residents wanted to take a ride in his new car! kelle']
["Evan couldn't wait for people to see his new car!\\tAAA"]
['All his friends and family fell in love with the new car Evan had bought.\\']
['Evan was so happy he had waited so long to buy the car! k.']
['A lot of people complimented him on his new car. k. klondon']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=80, do_sample=True), skip_special_tokens=True))

["Evan was glad he'd waited so long for the new car! B.B"]
["Evan was glad he'd waited so long to buy his dream car!--K"]
['A few weeks later, the dealership approved his new car and he was so happy.']
['All of his friends thought it was a new piece of art! k. s.']
['Evan was glad he had waited so long to buy his dream car. \\/']
['All the other residents thought it was a new style craze! k. k']
["People were amazed at how expensive it was before Evan's car was sold. Evan was"]
['All of his friends thought he was a real fashion model! kludge kl']
['All his friends thought it was a new car and wanted to buy one, too!']
["Evan couldn't wait for people to see his new wheels! k. d."]


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.8, do_sample=True), skip_special_tokens=True))

["Evan couldn't wait for everyone to see his new car! k.d."]
['A neighbor even offered to buy it for him, it was such a nice car!']
["Evan couldn't wait for everyone to see his new car! \\ k"]
['All the girls in town wanted him to buy a new car. k. s.']
["Evan couldn't wait for all his friends to see it! k. s."]
["Evan couldn't wait for everyone to see the new wheels he'd bought!\\"]
["A few days later, someone bought Evan's car for $50,000! k"]
['People were so impressed they took him out to a fancy restaurant! \\SURPR']
['All of his friends and family complimented him on his new ride. \\B']
['Evan was so happy he had waited so long to buy the car. \\D']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.5, do_sample=True), skip_special_tokens=True))

["Evan couldn't wait for people to see his brand new car! kr."]
['A few weeks later, someone else bought the same car for a lot of money.']
["Evan couldn't wait for people to see his new car!\\\\"]
['Evan felt like he could walk on the clouds with his new car. \\D']
["Evan was glad he'd saved for this kind of car!\\Blessing"]
['Evan felt like a new person!  . . kludge klunch']
['Evan was sure no one would have a chance to appraise it. \\t']
['All the other people in town complimented him on his style. k. s.']
['People were amazed and said he must have a lot of money. \\B']
["Evan couldn't wait for all his friends to see it! \\�"]


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.3, do_sample=True), skip_special_tokens=True))

['A lot of people thought he was a real jerk. k. k']
["Evan was glad he'd waited so long to get his new car!\\ k"]
["Evan couldn't wait for people to see his new wheels! k. s."]
['All of his friends and family complimented him on his new car. kludge']
["Evan couldn't wait for people to see his new car! k. kl"]
['A few weeks later, his insurance went up by a foot. B.B.']
['Evan was so happy he had waited so long to buy his dream car!\\']
['All his friends and family were jealous of the new car he had bought. \\D']
['All his friends thought it was a great new car. kranks krune']
["Evan couldn't wait for everyone to see his new car!\\ k"]


## peft with LORA
It seems that with A100 GPU, lora doesn't make too much of imporvement, still around 1hr. But with Colab T4, lora cuts the training time from 4hr to 2hr and 20 minutes ish.

In [None]:
from peft import get_peft_model, LoraConfig
from peft import TaskType

# Set up LoRA configuration
lora_config = LoraConfig(
    r=4,              # Rank of the low-rank matrices
    lora_alpha=8,    # Scaling factor for the low-rank adaptation
    lora_dropout=0.1, # Dropout rate for LoRA layers
    task_type=TaskType.SEQ_2_SEQ_LM  # Task type for sequence-to-sequence model
)

# Integrate LoRA with the model
lora_model = get_peft_model(model, lora_config)

In [None]:
lora_trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)
lora_trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
trainer.save_model("./fine_tuned_bart_lora")

# Further Fine Tune the Encoder with Story Cloze

In [None]:
ClozeTestSpring2016ValData.head()

Unnamed: 0,InputStoryid,InputSentence1,InputSentence2,InputSentence3,InputSentence4,RandomFifthSentenceQuiz1,RandomFifthSentenceQuiz2,AnswerRightEnding
0,138d5bfb-05cc-41e3-bf2c-fa85ebad14e2,Rick grew up in a troubled household.,"He never found good support in family, and tur...",It wasn't long before Rick got shot in a robbery.,The incident caused him to turn a new leaf.,He is happy now.,He joined a gang.,1
1,bff9f820-9605-4875-b9af-fe6f14d04256,Laverne needs to prepare something for her fri...,She decides to bake a batch of brownies.,She chooses a recipe and follows it closely.,Laverne tests one of the brownies to make sure...,The brownies are so delicious Laverne eats two...,Laverne doesn't go to her friend's party.,1
2,e8f628d5-9f97-40ed-8611-fc0e774673c4,Sarah had been dreaming of visiting Europe for...,She had finally saved enough for the trip.,She landed in Spain and traveled east across t...,She didn't like how different everything was.,Sarah then decided to move to Europe.,Sarah decided that she preferred her home over...,2
3,f5226bfe-9f26-4377-b05f-3d9568dbdec1,Gina was worried the cookie dough in the tube ...,She was very happy to find she was wrong.,The cookies from the tube were as good as from...,Gina intended to only eat 2 cookies and save t...,Gina liked the cookies so much she ate them al...,Gina gave the cookies away at her church.,1
4,69ac9b05-b956-402f-9fff-1f926ef9176b,It was my final performance in marching band.,I was playing the snare drum in the band.,We played Thriller and Radar Love.,The performance was flawless.,I was very proud of my performance.,I was very ashamed of my performance.,1


In [None]:
ClozeTestSpring2016ValData["Story1"] = combine_sentences(ClozeTestSpring2016ValData, sentence_columns_to_be_combined=["InputSentence1", "InputSentence2", "InputSentence3", "InputSentence4", "RandomFifthSentenceQuiz1"])[0]
ClozeTestSpring2016ValData["Story2"] = combine_sentences(ClozeTestSpring2016ValData, sentence_columns_to_be_combined=["InputSentence1", "InputSentence2", "InputSentence3", "InputSentence4", "RandomFifthSentenceQuiz2"])[0]

In [None]:
ClozeTestTraining = ClozeTestSpring2016ValData.drop(columns=["InputSentence1", "InputSentence2", "InputSentence3", "InputSentence4", "RandomFifthSentenceQuiz1", "RandomFifthSentenceQuiz2", "InputStoryid"])

In [None]:
ClozeTestTraining["AnswerRightEnding"] = ClozeTestTraining["AnswerRightEnding"] - 1
ClozeTestTraining

Unnamed: 0,AnswerRightEnding,Story1,Story2
0,0,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
1,0,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
2,1,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
3,0,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
4,0,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
...,...,...,...
1866,1,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
1867,1,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
1868,1,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...
1869,1,Rick grew up in a troubled household. He never...,Rick grew up in a troubled household. He never...


In [None]:
# Freeze all decoder parameters so that only encoder parameters are updated during training
for param in model.model.decoder.parameters():
    param.requires_grad = False

In [None]:
import torch
import torch.nn as nn

class BartEncoderForClassification(nn.Module):
    def __init__(self, bart_model):
        super().__init__()
        self.bart = bart_model  # This is the BartForConditionalGeneration model
        self.fc1 = nn.Linear(2048, 1)
        self.classifier = nn.Sigmoid()

    def forward(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, labels=None):
        # Encode the first input
        encoder_outputs_1 = self.bart.model.encoder(
            input_ids=input_ids_1,
            attention_mask=attention_mask_1
        )
        last_hidden_state_1 = encoder_outputs_1[0]  # (batch_size, seq_len, hidden_dim)
        pooled_output_1 = last_hidden_state_1[:, 0, :]  # Take the first token as representation

        # Encode the second input
        encoder_outputs_2 = self.bart.model.encoder(
            input_ids=input_ids_2,
            attention_mask=attention_mask_2
        )
        last_hidden_state_2 = encoder_outputs_2[0]
        pooled_output_2 = last_hidden_state_2[:, 0, :]

        # Combine both representations
        combined_output = torch.cat([pooled_output_1, pooled_output_2], dim=-1)
        logits = self.fc1(combined_output)
        logits = self.classifier(logits).squeeze(-1)

        labels=labels.float()
        loss = None
        if labels is not None:
            # Standard cross-entropy for binary classification (2 classes)
            loss_fct = nn.BCELoss()
            loss = loss_fct(logits, labels)

        return {"loss": loss, "logits": logits}


In [None]:
classification_model = BartEncoderForClassification(model)
classification_model

BartEncoderForClassification(
  (bart): BartForConditionalGeneration(
    (model): BartModel(
      (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
        (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
        (layers): ModuleList(
          (0-11): 12 x BartEncoderLayer(
            (self_attn): BartSdpaAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
    

In [None]:
# Debug check - ensure decoder params are frozen
for name, param in classification_model.named_parameters():
    if "decoder" in name:
        assert param.requires_grad == False


In [None]:
class StoryClozeDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, texts_1, texts_2, labels):
        self.tokenizer = tokenizer
        self.texts_1 = texts_1
        self.texts_2 = texts_2
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        enc_1 = self.tokenizer(self.texts_1[idx], padding=True, return_tensors="pt")
        enc_2 = self.tokenizer(self.texts_2[idx], padding=True, return_tensors="pt")

        return {
            "input_ids_1": enc_1["input_ids"].squeeze(0),
            "attention_mask_1": enc_1["attention_mask"].squeeze(0),
            "input_ids_2": enc_2["input_ids"].squeeze(0),
            "attention_mask_2": enc_2["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

# Example dataset creation
train_dataset = StoryClozeDataset(tokenizer, ClozeTestTraining["Story1"], ClozeTestTraining["Story2"], ClozeTestTraining["AnswerRightEnding"])


In [None]:
from sklearn.metrics import accuracy_score, f1_score
def compute_metrics(p):
    preds = p.predictions >= 0.5
    labels = p.label_ids
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="weighted")
    return {"accuracy": acc, "f1": f1}

# Use the Trainer API

training_args = TrainingArguments(
    report_to="none",
    output_dir="./results",
    per_device_train_batch_size=64,
    num_train_epochs=1,
    save_strategy="no",
)

trainer = Trainer(
    model=classification_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=train_dataset,
    compute_metrics=compute_metrics

)

trainer.train()

torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch


Step,Training Loss
500,7.1321
1000,7.1124
1500,7.1001
2000,7.088
2500,7.0762
3000,7.0647
3500,7.0535
4000,7.0425
4500,7.032
5000,7.0217


TrainOutput(global_step=29500, training_loss=6.910061465505827, metrics={'train_runtime': 8961.6425, 'train_samples_per_second': 104.389, 'train_steps_per_second': 3.292, 'total_flos': 0.0, 'train_loss': 6.910061465505827, 'epoch': 500.0})

In [None]:
trainer.evaluate()

{'eval_loss': 6.816177845001221,
 'eval_accuracy': 0.48583645109567075,
 'eval_f1': 0.3177160676589674,
 'eval_runtime': 8.5794,
 'eval_samples_per_second': 218.08,
 'eval_steps_per_second': 27.275,
 'epoch': 500.0}

Now we can get the BART model that is wrapped inside the classification model, for generation and evaluation. Its encoder has been fine tuned with story cloze task. Please note that a lot of evaluation codes is copied from previous fine tuning sections for comparisons. Perhaps it should be encapsulated into a method for elegance. But on the other hand, having them as cells does provide interactivity and some flexibility to enageg with them.

In [None]:
fine_tuned_model = classification_model.bart
fine_tuned_model.to("cuda")

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [None]:
%%time
final_perplexity_baseline = 0
final_rouge1_baseline = 0
final_rouge2_baseline = 0
final_rougeL_baseline = 0

inputs_to_be_generated_data = X_test
inputs_to_be_generated_dataloader = DataLoader(inputs_to_be_generated_data, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)
model_to_use = fine_tuned_model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for inputs_to_be_generated, reference in zip(inputs_to_be_generated_dataloader, original_story_endings_reference):
    input_and_attention_mask = tokenizer(inputs_to_be_generated, padding=True, return_tensors="pt")
    input_and_attention_mask.to("cuda")
    generated_ids = model_to_use.generate(inputs=input_and_attention_mask["input_ids"], attention_mask=input_and_attention_mask["attention_mask"], max_length=20, early_stopping=True)
    ending = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    full_story_with_ending = concat_story_body_with_endings(inputs_to_be_generated, ending)
    # Calculate perplexity
    weight_factor = len(reference) / len(inputs_to_be_generated_data)
    final_perplexity_baseline += weight_factor * perplexity.compute(predictions=full_story_with_ending, model_id='facebook/bart-large-cnn')["mean_perplexity"]
    # Calculate rouge
    rouge_results = rouge.compute(predictions=ending, references=reference)
    final_rouge1_baseline += weight_factor * rouge_results["rouge1"]
    final_rouge2_baseline += weight_factor * rouge_results["rouge2"]
    final_rougeL_baseline += weight_factor * rouge_results["rougeL"]



  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 3min 27s, sys: 8.67 s, total: 3min 36s
Wall time: 4min 6s


In [None]:
print(final_perplexity_baseline)
print(final_rouge1_baseline)
print(final_rouge2_baseline)
print(final_rougeL_baseline)

6860325.236742508
0.0
0.0
0.0


In [None]:
%%time
LLM_as_judge_score_fine_tuned = 0

X_test_inputs_dataloader = DataLoader(TensorDataset(X_test_inputs["input_ids"], X_test_inputs["attention_mask"]), batch_size=32)
original_story_bodies = DataLoader(X_test, batch_size=32)
original_story_endings_reference = DataLoader(y_test, batch_size=32)
model_to_use = fine_tuned_model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

perplexity = evaluate.load("perplexity", module_type="metric")
rouge = evaluate.load('rouge')

with torch.no_grad():
  for input_and_attention_mask, original_stories, reference in zip(X_test_inputs_dataloader, original_story_bodies, original_story_endings_reference):
    generated_ids = model_to_use.generate(inputs=input_and_attention_mask[0], attention_mask=input_and_attention_mask[1], max_length=20, early_stopping=True)
    generated_endings = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    feedback, score = judge.absolute_grade(
      instructions=original_stories,
      responses=generated_endings,
      rubric=score_rubric,
      reference_answers=reference
    )
    LLM_as_judge_score_fine_tuned += sum(score)
LLM_as_judge_score_fine_tuned = LLM_as_judge_score_fine_tuned/len(X_test)
print(LLM_as_judge_score_fine_tuned)

Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.94it/s, est. speed input: 1045.70 toks/s, output: 299.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12924.19it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 953.93 toks/s, output: 273.73 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12602.60it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s, est. speed input: 1003.42 toks/s, output: 277.67 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12911.76it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.82it/s, est. speed input: 984.69 toks/s, output: 282.48 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12719.65it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 913.64 toks/s, output: 272.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12311.29it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.89it/s, est. speed input: 1024.58 toks/s, output: 281.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12796.05it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 988.88 toks/s, output: 290.79 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12809.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 938.67 toks/s, output: 277.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12327.12it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s, est. speed input: 999.33 toks/s, output: 284.25 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12791.17it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.82it/s, est. speed input: 987.61 toks/s, output: 281.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13077.83it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 940.73 toks/s, output: 284.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12111.33it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 876.68 toks/s, output: 262.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12436.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s, est. speed input: 960.84 toks/s, output: 278.28 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12921.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 947.82 toks/s, output: 281.97 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12253.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 933.36 toks/s, output: 272.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12486.53it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 976.96 toks/s, output: 290.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12201.61it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 900.75 toks/s, output: 279.84 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11963.43it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s, est. speed input: 1002.64 toks/s, output: 282.86 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12789.95it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 914.66 toks/s, output: 280.95 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11857.74it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 939.16 toks/s, output: 275.17 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12693.18it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 985.81 toks/s, output: 289.25 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12546.06it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.91it/s, est. speed input: 1028.87 toks/s, output: 287.13 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12560.15it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 952.38 toks/s, output: 274.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12022.37it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.92it/s, est. speed input: 1029.92 toks/s, output: 288.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11787.96it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 911.23 toks/s, output: 266.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12197.18it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 922.56 toks/s, output: 274.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11720.03it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 986.20 toks/s, output: 297.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12164.01it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.90it/s, est. speed input: 1025.79 toks/s, output: 290.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13386.97it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s, est. speed input: 906.53 toks/s, output: 271.60 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12192.74it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 895.98 toks/s, output: 272.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11050.36it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 911.09 toks/s, output: 279.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11306.35it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 925.82 toks/s, output: 271.52 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11303.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.49it/s, est. speed input: 806.33 toks/s, output: 250.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11772.45it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 966.05 toks/s, output: 288.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11961.30it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.78it/s, est. speed input: 962.03 toks/s, output: 279.48 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12625.13it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 989.01 toks/s, output: 286.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12387.42it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s, est. speed input: 861.04 toks/s, output: 259.48 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10774.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s, est. speed input: 958.09 toks/s, output: 290.51 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12096.05it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s, est. speed input: 877.22 toks/s, output: 275.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11620.58it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.90it/s, est. speed input: 1031.70 toks/s, output: 293.86 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12270.77it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 923.54 toks/s, output: 279.83 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12214.94it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 964.15 toks/s, output: 281.90 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12513.31it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 934.03 toks/s, output: 280.80 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10750.32it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 972.22 toks/s, output: 279.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6551.04it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 975.03 toks/s, output: 284.97 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12588.42it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 994.08 toks/s, output: 297.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12261.81it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 975.92 toks/s, output: 289.65 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10831.87it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s, est. speed input: 997.59 toks/s, output: 280.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11747.72it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 986.43 toks/s, output: 294.96 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11789.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s, est. speed input: 1007.73 toks/s, output: 289.55 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12961.63it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 970.55 toks/s, output: 292.85 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11703.67it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s, est. speed input: 990.06 toks/s, output: 290.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7990.10it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s, est. speed input: 993.19 toks/s, output: 288.82 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10837.12it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 924.01 toks/s, output: 286.15 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12033.15it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.90it/s, est. speed input: 1029.61 toks/s, output: 294.81 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12858.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.88it/s, est. speed input: 1017.17 toks/s, output: 290.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 5507.72it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 907.23 toks/s, output: 271.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11098.80it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 927.18 toks/s, output: 293.44 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11101.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.82it/s, est. speed input: 982.46 toks/s, output: 275.10 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12984.21it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.71it/s, est. speed input: 922.80 toks/s, output: 282.09 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11829.52it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 939.41 toks/s, output: 261.69 toks/s]


Retrying failed batches: Attempt 1/10


Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.54s/it, est. speed input: 118.36 toks/s, output: 43.64 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12410.33it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.88it/s, est. speed input: 1021.48 toks/s, output: 280.60 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12236.10it/s]
Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.58it/s, est. speed input: 851.84 toks/s, output: 275.12 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 9514.26it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s, est. speed input: 987.12 toks/s, output: 282.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12991.75it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 906.86 toks/s, output: 270.34 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11960.23it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 969.29 toks/s, output: 279.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12800.93it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 961.59 toks/s, output: 281.92 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11913.52it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 928.14 toks/s, output: 276.44 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12760.77it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 885.54 toks/s, output: 274.72 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12176.15it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 974.99 toks/s, output: 294.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12414.92it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 952.53 toks/s, output: 266.16 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13136.71it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 935.33 toks/s, output: 273.33 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10833.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 950.33 toks/s, output: 284.31 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12054.76it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s, est. speed input: 1006.89 toks/s, output: 276.67 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13254.76it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.75it/s, est. speed input: 944.56 toks/s, output: 292.98 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11905.07it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.82it/s, est. speed input: 978.20 toks/s, output: 288.22 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12151.90it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s, est. speed input: 951.25 toks/s, output: 286.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11526.77it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 927.67 toks/s, output: 265.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11724.12it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 919.78 toks/s, output: 275.06 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11602.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.68it/s, est. speed input: 907.71 toks/s, output: 271.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10748.60it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 975.30 toks/s, output: 286.53 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12621.57it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 969.90 toks/s, output: 292.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11746.69it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s, est. speed input: 1001.07 toks/s, output: 284.14 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12913.00it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.88it/s, est. speed input: 1009.38 toks/s, output: 294.56 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12115.70it/s]
Processed prompts: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s, est. speed input: 815.45 toks/s, output: 262.27 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10833.62it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s, est. speed input: 919.64 toks/s, output: 279.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11462.78it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 927.07 toks/s, output: 267.45 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12772.91it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s, est. speed input: 942.27 toks/s, output: 263.32 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13432.52it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 964.66 toks/s, output: 280.94 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11983.73it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 931.06 toks/s, output: 274.89 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11203.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 930.40 toks/s, output: 277.88 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11926.22it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 890.94 toks/s, output: 260.36 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12646.54it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 967.39 toks/s, output: 272.39 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10190.40it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s, est. speed input: 928.55 toks/s, output: 276.41 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12111.33it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s, est. speed input: 912.40 toks/s, output: 267.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7396.55it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s, est. speed input: 972.81 toks/s, output: 279.21 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12588.42it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 970.10 toks/s, output: 280.43 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12854.87it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s, est. speed input: 898.19 toks/s, output: 269.93 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11553.56it/s]
Processed prompts: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s, est. speed input: 890.70 toks/s, output: 268.04 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11327.35it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s, est. speed input: 962.52 toks/s, output: 272.75 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12809.48it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 928.21 toks/s, output: 273.88 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12157.40it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s, est. speed input: 989.75 toks/s, output: 297.05 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10852.02it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s, est. speed input: 936.54 toks/s, output: 273.24 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 11860.88it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.80it/s, est. speed input: 967.22 toks/s, output: 274.26 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 7311.13it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s, est. speed input: 990.48 toks/s, output: 289.70 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 12150.80it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.91it/s, est. speed input: 1027.10 toks/s, output: 290.30 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13014.42it/s]
Processed prompts: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s, est. speed input: 959.36 toks/s, output: 273.61 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 6621.50it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.87it/s, est. speed input: 1008.82 toks/s, output: 281.78 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 13051.12it/s]
Processed prompts: 100%|██████████| 32/32 [00:16<00:00,  1.89it/s, est. speed input: 1019.85 toks/s, output: 289.19 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10572.49it/s]
Processed prompts: 100%|██████████| 32/32 [00:17<00:00,  1.79it/s, est. speed input: 963.44 toks/s, output: 268.17 toks/s]


Processed 32/32 instances.


Finalizing: 100%|██████████| 32/32 [00:00<00:00, 10247.97it/s]
Processed prompts: 100%|██████████| 15/15 [00:09<00:00,  1.65it/s, est. speed input: 892.20 toks/s, output: 256.45 toks/s]


Processed 15/15 instances.


Finalizing: 100%|██████████| 15/15 [00:00<00:00, 6209.49it/s]

1.0022630834512023
CPU times: user 35min 37s, sys: 5.06 s, total: 35min 43s
Wall time: 35min 31s





In [None]:
instructions = X_test[0:1]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[0:1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

[' respir deterior surviv massac deletion']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.82s/it, est. speed input: 140.26 toks/s, output: 42.99 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 4132.32it/s]

Feedback: ['The response provided is riddled with incoherent and inappropriate language usage that completely disrupts the narrative flow. The model’s language is not only grammatically incorrect but also includes phrases that are irrelevant to the story, such as "surviv massac deletion." This creates a disjointed and nonsensical ending that does not align with the established narrative about Evan\'s pride in his new BMW. The ending is far from satisfactory, and it does not reflect the fluent and coherent language expected in a suitable ending to a story. Therefore, it fails to meet the criteria outlined in the score rubric, and as a result, it does not elevate the narrative but instead leaves the reader confused and unsatisfied. \n[RESULT] 1']
Score: [1]





In [None]:
instructions = X_test[1500:1501]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][1500:1501], attention_mask=X_test_inputs["attention_mask"][1500:1501], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[1500:1501]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

[' Recomm deterior funer ticking respir']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.32s/it, est. speed input: 160.45 toks/s, output: 42.45 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5468.45it/s]

Feedback: ["The provided response fails to meet the requirements outlined in the score rubric. The response does not demonstrate language fluency or semantics coherence. It lacks any semblance of a coherent narrative and doesn't follow the story flow. There is a complete disconnect from the provided story body, and it fails to address the main concern of the individual - the pain of the shot. Furthermore, it does not contribute to a satisfactory resolution or enhance the themes and characters of the narrative. The lack of clarity and coherence renders the reader confused and unsatisfied, which is why the response receives a score of 1. \n[RESULT] 1"]
Score: [1]





In [None]:
instructions = X_test[-2:-1]
responses = tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][-2:-1], attention_mask=X_test_inputs["attention_mask"][-2:-1], max_length=20, early_stopping=True), skip_special_tokens=True)
print(responses)
reference_answers = y_test[-2:-1]

feedback, score = judge.absolute_grade(
    instructions=instructions,
    responses=responses,
    rubric=score_rubric,
    reference_answers=reference_answers
)

print("Feedback:", feedback)
print("Score:", score)

[' Recommttes deterior Tactics massac']


Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.45s/it, est. speed input: 157.63 toks/s, output: 42.67 toks/s]


Processed 1/1 instances.


Finalizing: 100%|██████████| 1/1 [00:00<00:00, 5809.29it/s]

Feedback: ["The response is lacking in language fluency, with no context provided. It is incoherent and does not follow the story's narrative. The reader is left confused and with no sense of closure or resolution, which significantly disrupts the story's flow. It fails to answer the question of whether Ramona found a new job and instead presents an unrelated phrase with no explanation or relevance to the story. This makes the ending of the narrative unsatisfactory, leaving the reader without the necessary information to understand the outcome of Ramona's job search. As per the score rubric, the response exhibits major language issues and fails to provide a suitable ending to the story. \n[RESULT] 1"]
Score: [1]





In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True), skip_special_tokens=True))

[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, do_sample=False), skip_special_tokens=True))

[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']
[' respir deterior surviv massac deletion']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=0.8, do_sample=True), skip_special_tokens=True))

[' deletion survivolate respir LEDs']
[' surviv Recomm deletion respir deterior']
[' respir deterior deletion flattttes']
[' surviv sket deletion respir massac']
[' massac farewell flatt deletion respir']
[' respir farewell deletion funer Tactics']
[' Tactics surviv deterior deletion respir']
[' deletion survivmAh respir deterior']
[' ticking Recomm surviv massacres respir']
[' surviv respir deleg funer massac']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=0.3, do_sample=True), skip_special_tokens=True))

[' surviv Recomm deletion funer respir']
[' respir funer surviv deterior massac']
[' deterior respir surviv deletion massac']
[' respir funer massac surviv deterior']
[' respir surviv funer massac deterior']
[' surviv deterior respir Recomm deletion']
[' Recomm deterior deletion massac respir']
[' deterior respir massacres surviv massac']
[' respir deterior funer surviv Recomm']
[' respir deletion massac surviv deterior']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, temperature=1.5, do_sample=True), skip_special_tokens=True))

[' deletion sket ticking flatt massac']
[' delet surviv LEDs massac thresholds']
[' funer respir surviv massacres subsystem']
[' massac sket massacres silhou Recomm']
[' farewell surviv Tactics funer massac']
[' deterior deletionolate respir massacres']
[' Recomm deleg Tactics deterior ticking']
['mAh deterior redund Tactics deletion']
[' sket respir surviv funer massacres']
[' respir Recomm Tactics massac funer']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=5, do_sample=True), skip_special_tokens=True))

[' respir deterior massac surviv deletion']
[' respir deterior massac surviv deletion']
[' respir deterior surviv funer massac']
[' respir deterior Recomm massac surviv']
[' respir deterior massac funer surviv']
[' deterior respir surviv deletion massac']
[' respir massac deterior surviv funer']
[' respir massac deterior surviv deletion']
[' respir surviv deterior massac deletion']
[' deterior respir funer massac surviv']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=20, do_sample=True), skip_special_tokens=True))

[' deterior deleg respir Tactics massac']
[' respir deterior massac massacres surviv']
[' surviv LEDs respir Recomm farewell']
[' respir deletion surviv funer Recomm']
[' surviv deterior Recomm respir massac']
[' deterior respir LEDs Recomm surviv']
[' deterior respir Tactics buds surviv']
[' Recomm deterior deletion respir LEDs']
[' massac deletion LEDs deterior Recomm']
[' deterior deletion respir surviv funer']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_k=80, do_sample=True), skip_special_tokens=True))

[' respir deletionolateholders funer']
[' deletion LEDs funer deleg deterior']
[' Recomm deterior deleg loudspe funer']
[' flatt ticking massac respir funer']
[' LEDs Tactics deteriorated subsystem surviv']
[' sket Recomm funer farewell deterior']
[' Recomm respir sket Mour surviv']
[' deletion deliveries respirolate surviv']
[' deletion Tactics commem funer Recomm']
[' deletion deterior sket delet massac']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.8, do_sample=True), skip_special_tokens=True))

[' Tactics Recomm surviv respir funer']
[' massac surviv Tactics deterior deletion']
[' funer deletionttes ticking surviv']
[' flatt ticking massac deterior deletion']
[' respir surviv Recomm massacres massac']
[' Tactics Recomm ticking funer surviv']
[' Recomm deterior respir massacres massac']
[' sket respir deterior surviv LEDs']
[' funer respir deleg surviv massacres']
[' massac surviv Tactics deletion Recomm']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.5, do_sample=True), skip_special_tokens=True))

[' respir funer Tactics deletion massac']
[' funer respir deterior massacres Recomm']
[' deterior surviv respir LEDs funer']
[' respir deterior surviv flatt massac']
[' surviv respir Tactics deterior deletion']
[' surviv respir funer Tactics deletion']
[' massac deletion Tactics deterior surviv']
[' respir Tactics flatt deterior deletion']
[' deletion surviv respir massac deterior']
[' surviv deterior respir delet funer']


In [None]:
for _ in range(10):
  print(tokenizer.batch_decode(fine_tuned_model.generate(X_test_inputs["input_ids"][0:1], attention_mask=X_test_inputs["attention_mask"][0:1], max_length=20, early_stopping=True, top_p=0.3, do_sample=True), skip_special_tokens=True))

[' respir deterior LEDs surviv Recomm']
[' respir deterior surviv massac deletion']
[' respir massac Recomm deterior deletion']
[' respir deterior Recomm surviv massac']
[' respir deterior surviv funer massac']
[' surviv respir massac deterior funer']
[' deterior respir surviv massac deletion']
[' respir LEDs deletion deterior surviv']
[' respir surviv deterior flatt deletion']
[' respir deterior surviv Recomm massac']
