In [1]:
import transformers
from transformers import (
    pipeline,
    set_seed,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer,
)
import datasets
from datasets import load_dataset, load_from_disk, load_metric

import matplotlib.pyplot as plt
import pandas as pd

import nltk
from nltk.tokenize import sent_tokenize

from tqdm import tqdm
import torch

In [2]:
nltk.download("punkt")

[nltk_data] Downloading package punkt to /home/migue/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
model_ckpt = "google/pegasus-cnn_dailymail"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model_pegasus = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
type(model_pegasus)

transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration

In [13]:
# Original dataset
# dataset_ckpt = "Samsung/samsum"

dataset = load_from_disk("../data/samsum_dataset/")

In [14]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})

In [7]:
print("\nDialogue:")

print(dataset["test"][1]["dialogue"])

print("\nSummary:")

print(dataset["test"][1]["summary"])


Dialogue:
Eric: MACHINE!
Rob: That's so gr8!
Eric: I know! And shows how Americans see Russian ;)
Rob: And it's really funny!
Eric: I know! I especially like the train part!
Rob: Hahaha! No one talks to the machine like that!
Eric: Is this his only stand-up?
Rob: Idk. I'll check.
Eric: Sure.
Rob: Turns out no! There are some of his stand-ups on youtube.
Eric: Gr8! I'll watch them now!
Rob: Me too!
Eric: MACHINE!
Rob: MACHINE!
Eric: TTYL?
Rob: Sure :)

Summary:
Eric and Rob are going to watch a stand-up on youtube.


In [8]:
def convert_examples_to_features(example_batch):
    input_encodings = tokenizer(
        example_batch["dialogue"], max_length=1024, truncation=True
    )

    with tokenizer.as_target_tokenizer():
        target_encodings = tokenizer(
            example_batch["summary"], max_length=128, truncation=True
        )

    return {
        "input_ids": input_encodings["input_ids"],
        "attention_mask": input_encodings["attention_mask"],
        "labels": target_encodings["input_ids"],
    }

In [9]:
dataset_pt = dataset.map(convert_examples_to_features, batched=True)

In [10]:
dataset_pt

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 818
    })
})

## **Training**

In [10]:
seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_pegasus)

In [11]:
args = TrainingArguments(
    output_dir="pegasus-samsum",
    num_train_epochs=1,
    warmup_steps=500,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=1e6,
)



In [14]:
trainer = Trainer(
    model=model_pegasus,
    args=args,
    tokenizer=tokenizer,
    data_collator=seq2seq_data_collator,
    train_dataset=dataset_pt["train"],
    eval_dataset=dataset_pt["validation"],
)

In [15]:
trainer.train()

  0%|          | 0/14732 [00:00<?, ?it/s]

{'loss': 2.7604, 'grad_norm': 30.42899513244629, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.0}
{'loss': 2.4426, 'grad_norm': 23.118608474731445, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0}
{'loss': 3.2571, 'grad_norm': 26.332117080688477, 'learning_rate': 3e-06, 'epoch': 0.0}
{'loss': 2.3593, 'grad_norm': 23.13451385498047, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}
{'loss': 2.9409, 'grad_norm': 31.846843719482422, 'learning_rate': 5e-06, 'epoch': 0.0}
{'loss': 2.8125, 'grad_norm': 24.13215446472168, 'learning_rate': 6e-06, 'epoch': 0.0}
{'loss': 2.7213, 'grad_norm': 28.695457458496094, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.0}
{'loss': 2.3109, 'grad_norm': 62.47607421875, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.01}
{'loss': 3.5241, 'grad_norm': 81.56250762939453, 'learning_rate': 9e-06, 'epoch': 0.01}
{'loss': 2.7298, 'grad_norm': 40.950252532958984, 'learning_rate': 1e-05, 'epoch': 0.01}
{'loss': 2.4596, 'grad_norm': 18.9048881530

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.6349713802337646, 'eval_runtime': 14.9862, 'eval_samples_per_second': 54.584, 'eval_steps_per_second': 54.584, 'epoch': 0.03}
{'loss': 1.9238, 'grad_norm': 12.622907638549805, 'learning_rate': 4.996486790331647e-05, 'epoch': 0.03}
{'loss': 1.8938, 'grad_norm': 9.204163551330566, 'learning_rate': 4.9929735806632945e-05, 'epoch': 0.04}
{'loss': 1.8266, 'grad_norm': 31.66956901550293, 'learning_rate': 4.989460370994941e-05, 'epoch': 0.04}
{'loss': 2.3382, 'grad_norm': 28.213336944580078, 'learning_rate': 4.985947161326589e-05, 'epoch': 0.04}
{'loss': 1.8422, 'grad_norm': 9.633828163146973, 'learning_rate': 4.982433951658235e-05, 'epoch': 0.04}
{'loss': 2.0922, 'grad_norm': 24.05801773071289, 'learning_rate': 4.978920741989882e-05, 'epoch': 0.04}
{'loss': 2.3075, 'grad_norm': 19.188026428222656, 'learning_rate': 4.9754075323215294e-05, 'epoch': 0.04}
{'loss': 1.8832, 'grad_norm': 12.737325668334961, 'learning_rate': 4.971894322653176e-05, 'epoch': 0.04}
{'loss': 1.537, 'gra

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.5441570281982422, 'eval_runtime': 15.2013, 'eval_samples_per_second': 53.811, 'eval_steps_per_second': 53.811, 'epoch': 0.07}
{'loss': 1.8958, 'grad_norm': 8.3252592086792, 'learning_rate': 4.820826306913997e-05, 'epoch': 0.07}
{'loss': 1.7086, 'grad_norm': 18.545446395874023, 'learning_rate': 4.817313097245644e-05, 'epoch': 0.07}
{'loss': 1.6737, 'grad_norm': 10.127016067504883, 'learning_rate': 4.813799887577291e-05, 'epoch': 0.07}
{'loss': 2.2151, 'grad_norm': 10.360091209411621, 'learning_rate': 4.8102866779089375e-05, 'epoch': 0.07}
{'loss': 1.707, 'grad_norm': 10.350958824157715, 'learning_rate': 4.8067734682405846e-05, 'epoch': 0.07}
{'loss': 1.438, 'grad_norm': 8.002367973327637, 'learning_rate': 4.803260258572232e-05, 'epoch': 0.07}
{'loss': 1.5571, 'grad_norm': 6.389379501342773, 'learning_rate': 4.799747048903879e-05, 'epoch': 0.07}
{'loss': 1.5697, 'grad_norm': 12.060235023498535, 'learning_rate': 4.796233839235526e-05, 'epoch': 0.07}
{'loss': 1.5546, 'grad_

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.5153460502624512, 'eval_runtime': 15.1603, 'eval_samples_per_second': 53.957, 'eval_steps_per_second': 53.957, 'epoch': 0.1}
{'loss': 1.488, 'grad_norm': 14.685419082641602, 'learning_rate': 4.645165823496346e-05, 'epoch': 0.1}
{'loss': 1.829, 'grad_norm': 20.140226364135742, 'learning_rate': 4.6416526138279934e-05, 'epoch': 0.1}
{'loss': 1.6868, 'grad_norm': 23.065492630004883, 'learning_rate': 4.6381394041596405e-05, 'epoch': 0.1}
{'loss': 1.4778, 'grad_norm': 6.9647321701049805, 'learning_rate': 4.634626194491287e-05, 'epoch': 0.1}
{'loss': 1.6592, 'grad_norm': 40.97032928466797, 'learning_rate': 4.631112984822935e-05, 'epoch': 0.11}
{'loss': 1.4767, 'grad_norm': 12.060317039489746, 'learning_rate': 4.627599775154581e-05, 'epoch': 0.11}
{'loss': 1.5791, 'grad_norm': 23.1408748626709, 'learning_rate': 4.6240865654862284e-05, 'epoch': 0.11}
{'loss': 1.8688, 'grad_norm': 23.4527530670166, 'learning_rate': 4.6205733558178755e-05, 'epoch': 0.11}
{'loss': 1.433, 'grad_norm

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.489164113998413, 'eval_runtime': 15.3858, 'eval_samples_per_second': 53.166, 'eval_steps_per_second': 53.166, 'epoch': 0.14}
{'loss': 1.5097, 'grad_norm': 10.363081932067871, 'learning_rate': 4.4695053400786964e-05, 'epoch': 0.14}
{'loss': 1.7075, 'grad_norm': 10.962052345275879, 'learning_rate': 4.465992130410343e-05, 'epoch': 0.14}
{'loss': 1.6351, 'grad_norm': 5.936469554901123, 'learning_rate': 4.462478920741991e-05, 'epoch': 0.14}
{'loss': 1.9064, 'grad_norm': 8.854714393615723, 'learning_rate': 4.458965711073637e-05, 'epoch': 0.14}
{'loss': 1.5237, 'grad_norm': 7.240167617797852, 'learning_rate': 4.455452501405284e-05, 'epoch': 0.14}
{'loss': 2.1152, 'grad_norm': 21.8745059967041, 'learning_rate': 4.4519392917369314e-05, 'epoch': 0.14}
{'loss': 1.4136, 'grad_norm': 18.422801971435547, 'learning_rate': 4.448426082068578e-05, 'epoch': 0.14}
{'loss': 2.1547, 'grad_norm': 7.512139320373535, 'learning_rate': 4.444912872400225e-05, 'epoch': 0.14}
{'loss': 1.6908, 'grad_

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4754638671875, 'eval_runtime': 15.2132, 'eval_samples_per_second': 53.769, 'eval_steps_per_second': 53.769, 'epoch': 0.17}
{'loss': 1.5502, 'grad_norm': 8.397332191467285, 'learning_rate': 4.293844856661046e-05, 'epoch': 0.17}
{'loss': 1.5743, 'grad_norm': 12.338878631591797, 'learning_rate': 4.290331646992693e-05, 'epoch': 0.17}
{'loss': 1.9107, 'grad_norm': 26.935522079467773, 'learning_rate': 4.2868184373243395e-05, 'epoch': 0.17}
{'loss': 2.125, 'grad_norm': 15.31768798828125, 'learning_rate': 4.2833052276559866e-05, 'epoch': 0.17}
{'loss': 1.7513, 'grad_norm': 10.279272079467773, 'learning_rate': 4.279792017987634e-05, 'epoch': 0.17}
{'loss': 1.5514, 'grad_norm': 68.79483795166016, 'learning_rate': 4.276278808319281e-05, 'epoch': 0.17}
{'loss': 1.6273, 'grad_norm': 12.721957206726074, 'learning_rate': 4.272765598650928e-05, 'epoch': 0.17}
{'loss': 1.3809, 'grad_norm': 14.986957550048828, 'learning_rate': 4.2692523889825744e-05, 'epoch': 0.18}
{'loss': 1.7273, 'grad

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4643809795379639, 'eval_runtime': 15.2129, 'eval_samples_per_second': 53.77, 'eval_steps_per_second': 53.77, 'epoch': 0.2}
{'loss': 1.9122, 'grad_norm': 10.133795738220215, 'learning_rate': 4.1181843732433954e-05, 'epoch': 0.2}
{'loss': 1.3269, 'grad_norm': 12.699210166931152, 'learning_rate': 4.1146711635750425e-05, 'epoch': 0.2}
{'loss': 1.6998, 'grad_norm': 11.21117115020752, 'learning_rate': 4.111157953906689e-05, 'epoch': 0.21}
{'loss': 1.5243, 'grad_norm': 8.12148666381836, 'learning_rate': 4.107644744238337e-05, 'epoch': 0.21}
{'loss': 1.6248, 'grad_norm': 74.40716552734375, 'learning_rate': 4.104131534569983e-05, 'epoch': 0.21}
{'loss': 1.6041, 'grad_norm': 15.573700904846191, 'learning_rate': 4.10061832490163e-05, 'epoch': 0.21}
{'loss': 1.6096, 'grad_norm': 6.274730682373047, 'learning_rate': 4.0971051152332774e-05, 'epoch': 0.21}
{'loss': 1.82, 'grad_norm': 7.892905235290527, 'learning_rate': 4.093591905564924e-05, 'epoch': 0.21}
{'loss': 1.6378, 'grad_norm':

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4652093648910522, 'eval_runtime': 15.2714, 'eval_samples_per_second': 53.564, 'eval_steps_per_second': 53.564, 'epoch': 0.24}
{'loss': 1.9424, 'grad_norm': 5.099242687225342, 'learning_rate': 3.942523889825745e-05, 'epoch': 0.24}
{'loss': 1.3582, 'grad_norm': 13.760562896728516, 'learning_rate': 3.939010680157392e-05, 'epoch': 0.24}
{'loss': 1.8747, 'grad_norm': 9.760041236877441, 'learning_rate': 3.935497470489039e-05, 'epoch': 0.24}
{'loss': 1.4731, 'grad_norm': 16.74045753479004, 'learning_rate': 3.9319842608206855e-05, 'epoch': 0.24}
{'loss': 1.6868, 'grad_norm': 29.130252838134766, 'learning_rate': 3.928471051152333e-05, 'epoch': 0.24}
{'loss': 1.8924, 'grad_norm': 11.827229499816895, 'learning_rate': 3.92495784148398e-05, 'epoch': 0.24}
{'loss': 1.134, 'grad_norm': 7.222215175628662, 'learning_rate': 3.921444631815627e-05, 'epoch': 0.24}
{'loss': 1.1134, 'grad_norm': 11.198671340942383, 'learning_rate': 3.917931422147274e-05, 'epoch': 0.24}
{'loss': 1.4352, 'grad_

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4400912523269653, 'eval_runtime': 15.3029, 'eval_samples_per_second': 53.454, 'eval_steps_per_second': 53.454, 'epoch': 0.27}
{'loss': 1.7426, 'grad_norm': 13.437082290649414, 'learning_rate': 3.766863406408095e-05, 'epoch': 0.27}
{'loss': 1.6323, 'grad_norm': 12.857672691345215, 'learning_rate': 3.7633501967397414e-05, 'epoch': 0.27}
{'loss': 1.5561, 'grad_norm': 12.326480865478516, 'learning_rate': 3.7598369870713885e-05, 'epoch': 0.27}
{'loss': 1.6477, 'grad_norm': 23.962385177612305, 'learning_rate': 3.7563237774030357e-05, 'epoch': 0.27}
{'loss': 1.5348, 'grad_norm': 22.490114212036133, 'learning_rate': 3.752810567734683e-05, 'epoch': 0.27}
{'loss': 1.4755, 'grad_norm': 7.795043468475342, 'learning_rate': 3.749297358066329e-05, 'epoch': 0.28}
{'loss': 1.9718, 'grad_norm': 8.914705276489258, 'learning_rate': 3.7457841483979764e-05, 'epoch': 0.28}
{'loss': 1.3207, 'grad_norm': 27.37921905517578, 'learning_rate': 3.7422709387296235e-05, 'epoch': 0.28}
{'loss': 1.4846,

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.434429407119751, 'eval_runtime': 15.7721, 'eval_samples_per_second': 51.864, 'eval_steps_per_second': 51.864, 'epoch': 0.31}
{'loss': 1.668, 'grad_norm': 12.269704818725586, 'learning_rate': 3.5912029229904444e-05, 'epoch': 0.31}
{'loss': 1.9484, 'grad_norm': 10.792440414428711, 'learning_rate': 3.587689713322091e-05, 'epoch': 0.31}
{'loss': 1.6985, 'grad_norm': 12.232148170471191, 'learning_rate': 3.584176503653739e-05, 'epoch': 0.31}
{'loss': 1.5197, 'grad_norm': 10.914689064025879, 'learning_rate': 3.580663293985385e-05, 'epoch': 0.31}
{'loss': 1.6674, 'grad_norm': 10.936668395996094, 'learning_rate': 3.577150084317032e-05, 'epoch': 0.31}
{'loss': 1.6202, 'grad_norm': 9.578076362609863, 'learning_rate': 3.5736368746486794e-05, 'epoch': 0.31}
{'loss': 1.4341, 'grad_norm': 10.300699234008789, 'learning_rate': 3.570123664980326e-05, 'epoch': 0.31}
{'loss': 1.7265, 'grad_norm': 16.56783103942871, 'learning_rate': 3.5666104553119736e-05, 'epoch': 0.31}
{'loss': 1.757, 'gr

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4216265678405762, 'eval_runtime': 15.4234, 'eval_samples_per_second': 53.036, 'eval_steps_per_second': 53.036, 'epoch': 0.34}
{'loss': 1.3825, 'grad_norm': 20.852256774902344, 'learning_rate': 3.415542439572794e-05, 'epoch': 0.34}
{'loss': 1.5356, 'grad_norm': 7.149177074432373, 'learning_rate': 3.412029229904441e-05, 'epoch': 0.34}
{'loss': 1.8261, 'grad_norm': 11.773066520690918, 'learning_rate': 3.4085160202360875e-05, 'epoch': 0.34}
{'loss': 1.8618, 'grad_norm': 10.29587173461914, 'learning_rate': 3.405002810567735e-05, 'epoch': 0.34}
{'loss': 1.532, 'grad_norm': 8.864211082458496, 'learning_rate': 3.401489600899382e-05, 'epoch': 0.34}
{'loss': 1.7081, 'grad_norm': 7.761197566986084, 'learning_rate': 3.397976391231029e-05, 'epoch': 0.34}
{'loss': 1.76, 'grad_norm': 15.312347412109375, 'learning_rate': 3.394463181562676e-05, 'epoch': 0.34}
{'loss': 1.3588, 'grad_norm': 11.107444763183594, 'learning_rate': 3.3909499718943224e-05, 'epoch': 0.34}
{'loss': 1.8588, 'grad_

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4167238473892212, 'eval_runtime': 17.129, 'eval_samples_per_second': 47.755, 'eval_steps_per_second': 47.755, 'epoch': 0.37}
{'loss': 1.3667, 'grad_norm': 11.04692268371582, 'learning_rate': 3.2398819561551434e-05, 'epoch': 0.37}
{'loss': 1.5846, 'grad_norm': 11.316242218017578, 'learning_rate': 3.2363687464867905e-05, 'epoch': 0.37}
{'loss': 1.208, 'grad_norm': 11.631349563598633, 'learning_rate': 3.2328555368184376e-05, 'epoch': 0.38}
{'loss': 1.6623, 'grad_norm': 24.077945709228516, 'learning_rate': 3.229342327150085e-05, 'epoch': 0.38}
{'loss': 1.8121, 'grad_norm': 10.018412590026855, 'learning_rate': 3.225829117481731e-05, 'epoch': 0.38}
{'loss': 1.6725, 'grad_norm': 7.872175693511963, 'learning_rate': 3.222315907813378e-05, 'epoch': 0.38}
{'loss': 1.7335, 'grad_norm': 16.946468353271484, 'learning_rate': 3.2188026981450254e-05, 'epoch': 0.38}
{'loss': 1.4057, 'grad_norm': 23.30373191833496, 'learning_rate': 3.2152894884766725e-05, 'epoch': 0.38}
{'loss': 1.4763, '

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4330627918243408, 'eval_runtime': 15.7505, 'eval_samples_per_second': 51.935, 'eval_steps_per_second': 51.935, 'epoch': 0.41}
{'loss': 1.6861, 'grad_norm': 7.492382526397705, 'learning_rate': 3.064221472737493e-05, 'epoch': 0.41}
{'loss': 1.6569, 'grad_norm': 11.177184104919434, 'learning_rate': 3.06070826306914e-05, 'epoch': 0.41}
{'loss': 1.8964, 'grad_norm': 8.1219482421875, 'learning_rate': 3.057195053400787e-05, 'epoch': 0.41}
{'loss': 1.5939, 'grad_norm': 8.72666072845459, 'learning_rate': 3.0536818437324335e-05, 'epoch': 0.41}
{'loss': 1.6982, 'grad_norm': 17.315235137939453, 'learning_rate': 3.0501686340640813e-05, 'epoch': 0.41}
{'loss': 1.4355, 'grad_norm': 6.421512126922607, 'learning_rate': 3.0466554243957278e-05, 'epoch': 0.41}
{'loss': 1.6188, 'grad_norm': 21.543493270874023, 'learning_rate': 3.0431422147273752e-05, 'epoch': 0.41}
{'loss': 1.6937, 'grad_norm': 13.177926063537598, 'learning_rate': 3.039629005059022e-05, 'epoch': 0.41}
{'loss': 1.6236, 'grad

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4076288938522339, 'eval_runtime': 15.3605, 'eval_samples_per_second': 53.253, 'eval_steps_per_second': 53.253, 'epoch': 0.44}
{'loss': 1.978, 'grad_norm': 9.941298484802246, 'learning_rate': 2.888560989319843e-05, 'epoch': 0.44}
{'loss': 1.3771, 'grad_norm': 9.436005592346191, 'learning_rate': 2.8850477796514897e-05, 'epoch': 0.44}
{'loss': 1.753, 'grad_norm': 5.62043571472168, 'learning_rate': 2.881534569983137e-05, 'epoch': 0.44}
{'loss': 1.4757, 'grad_norm': 9.607549667358398, 'learning_rate': 2.8780213603147837e-05, 'epoch': 0.44}
{'loss': 1.7556, 'grad_norm': 6.91405725479126, 'learning_rate': 2.8745081506464304e-05, 'epoch': 0.44}
{'loss': 1.6836, 'grad_norm': 21.628114700317383, 'learning_rate': 2.870994940978078e-05, 'epoch': 0.45}
{'loss': 1.528, 'grad_norm': 12.206637382507324, 'learning_rate': 2.8674817313097247e-05, 'epoch': 0.45}
{'loss': 1.6192, 'grad_norm': 8.293185234069824, 'learning_rate': 2.8639685216413718e-05, 'epoch': 0.45}
{'loss': 1.4817, 'grad_n

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.4054217338562012, 'eval_runtime': 15.3048, 'eval_samples_per_second': 53.447, 'eval_steps_per_second': 53.447, 'epoch': 0.48}
{'loss': 1.472, 'grad_norm': 8.095516204833984, 'learning_rate': 2.712900505902192e-05, 'epoch': 0.48}
{'loss': 1.7549, 'grad_norm': 5.97124719619751, 'learning_rate': 2.7093872962338395e-05, 'epoch': 0.48}
{'loss': 1.5485, 'grad_norm': 13.704919815063477, 'learning_rate': 2.7058740865654863e-05, 'epoch': 0.48}
{'loss': 1.7952, 'grad_norm': 14.41912841796875, 'learning_rate': 2.702360876897133e-05, 'epoch': 0.48}
{'loss': 1.2789, 'grad_norm': 8.733780860900879, 'learning_rate': 2.6988476672287806e-05, 'epoch': 0.48}
{'loss': 1.2626, 'grad_norm': 9.614090919494629, 'learning_rate': 2.6953344575604274e-05, 'epoch': 0.48}
{'loss': 1.4114, 'grad_norm': 6.823118209838867, 'learning_rate': 2.6918212478920745e-05, 'epoch': 0.48}
{'loss': 1.689, 'grad_norm': 30.810956954956055, 'learning_rate': 2.6883080382237213e-05, 'epoch': 0.48}
{'loss': 1.5879, 'gra

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3934987783432007, 'eval_runtime': 15.754, 'eval_samples_per_second': 51.923, 'eval_steps_per_second': 51.923, 'epoch': 0.51}
{'loss': 1.522, 'grad_norm': 22.175180435180664, 'learning_rate': 2.5372400224845422e-05, 'epoch': 0.51}
{'loss': 1.5782, 'grad_norm': 6.322543144226074, 'learning_rate': 2.533726812816189e-05, 'epoch': 0.51}
{'loss': 1.3704, 'grad_norm': 6.8664045333862305, 'learning_rate': 2.5302136031478358e-05, 'epoch': 0.51}
{'loss': 1.4846, 'grad_norm': 8.210290908813477, 'learning_rate': 2.526700393479483e-05, 'epoch': 0.51}
{'loss': 1.2819, 'grad_norm': 8.138832092285156, 'learning_rate': 2.5231871838111297e-05, 'epoch': 0.51}
{'loss': 1.7226, 'grad_norm': 10.529825210571289, 'learning_rate': 2.5196739741427772e-05, 'epoch': 0.51}
{'loss': 1.4228, 'grad_norm': 11.204374313354492, 'learning_rate': 2.516160764474424e-05, 'epoch': 0.51}
{'loss': 1.3732, 'grad_norm': 8.66184139251709, 'learning_rate': 2.5126475548060707e-05, 'epoch': 0.51}
{'loss': 1.5235, 'gr

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3951985836029053, 'eval_runtime': 15.4159, 'eval_samples_per_second': 53.062, 'eval_steps_per_second': 53.062, 'epoch': 0.54}
{'loss': 1.2908, 'grad_norm': 13.321028709411621, 'learning_rate': 2.3615795390668917e-05, 'epoch': 0.54}
{'loss': 1.4951, 'grad_norm': 6.777870178222656, 'learning_rate': 2.3580663293985385e-05, 'epoch': 0.54}
{'loss': 1.8511, 'grad_norm': 9.115062713623047, 'learning_rate': 2.3545531197301856e-05, 'epoch': 0.55}
{'loss': 1.6219, 'grad_norm': 12.388946533203125, 'learning_rate': 2.3510399100618327e-05, 'epoch': 0.55}
{'loss': 1.2529, 'grad_norm': 11.181861877441406, 'learning_rate': 2.3475267003934795e-05, 'epoch': 0.55}
{'loss': 1.7096, 'grad_norm': 13.418147087097168, 'learning_rate': 2.3440134907251266e-05, 'epoch': 0.55}
{'loss': 1.6473, 'grad_norm': 10.218009948730469, 'learning_rate': 2.3405002810567738e-05, 'epoch': 0.55}
{'loss': 1.4403, 'grad_norm': 6.502320289611816, 'learning_rate': 2.3369870713884205e-05, 'epoch': 0.55}
{'loss': 1.63

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3896371126174927, 'eval_runtime': 15.8417, 'eval_samples_per_second': 51.636, 'eval_steps_per_second': 51.636, 'epoch': 0.58}
{'loss': 1.6686, 'grad_norm': 14.586851119995117, 'learning_rate': 2.185919055649241e-05, 'epoch': 0.58}
{'loss': 1.637, 'grad_norm': 5.887360095977783, 'learning_rate': 2.1824058459808883e-05, 'epoch': 0.58}
{'loss': 1.5333, 'grad_norm': 9.759623527526855, 'learning_rate': 2.1788926363125354e-05, 'epoch': 0.58}
{'loss': 1.398, 'grad_norm': 11.893911361694336, 'learning_rate': 2.1753794266441822e-05, 'epoch': 0.58}
{'loss': 1.3564, 'grad_norm': 7.299055099487305, 'learning_rate': 2.1718662169758293e-05, 'epoch': 0.58}
{'loss': 1.3421, 'grad_norm': 6.03339147567749, 'learning_rate': 2.168353007307476e-05, 'epoch': 0.58}
{'loss': 1.3519, 'grad_norm': 6.179615020751953, 'learning_rate': 2.1648397976391232e-05, 'epoch': 0.58}
{'loss': 1.4261, 'grad_norm': 13.28081226348877, 'learning_rate': 2.1613265879707703e-05, 'epoch': 0.58}
{'loss': 1.6164, 'gra

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3832392692565918, 'eval_runtime': 15.1781, 'eval_samples_per_second': 53.893, 'eval_steps_per_second': 53.893, 'epoch': 0.61}
{'loss': 1.5004, 'grad_norm': 9.523614883422852, 'learning_rate': 2.010258572231591e-05, 'epoch': 0.61}
{'loss': 1.2461, 'grad_norm': 7.71245813369751, 'learning_rate': 2.006745362563238e-05, 'epoch': 0.61}
{'loss': 1.6914, 'grad_norm': 7.660722255706787, 'learning_rate': 2.003232152894885e-05, 'epoch': 0.61}
{'loss': 1.3514, 'grad_norm': 22.902362823486328, 'learning_rate': 1.9997189432265316e-05, 'epoch': 0.61}
{'loss': 1.7477, 'grad_norm': 17.01056480407715, 'learning_rate': 1.9962057335581788e-05, 'epoch': 0.61}
{'loss': 1.5219, 'grad_norm': 7.605755805969238, 'learning_rate': 1.992692523889826e-05, 'epoch': 0.61}
{'loss': 1.7382, 'grad_norm': 5.770416736602783, 'learning_rate': 1.989179314221473e-05, 'epoch': 0.62}
{'loss': 1.641, 'grad_norm': 8.951714515686035, 'learning_rate': 1.9856661045531198e-05, 'epoch': 0.62}
{'loss': 1.8329, 'grad_n

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.380301833152771, 'eval_runtime': 15.1678, 'eval_samples_per_second': 53.93, 'eval_steps_per_second': 53.93, 'epoch': 0.64}
{'loss': 1.0, 'grad_norm': 12.246109008789062, 'learning_rate': 1.8345980888139404e-05, 'epoch': 0.65}
{'loss': 1.5037, 'grad_norm': 6.652963161468506, 'learning_rate': 1.8310848791455875e-05, 'epoch': 0.65}
{'loss': 1.6002, 'grad_norm': 7.800164222717285, 'learning_rate': 1.8275716694772347e-05, 'epoch': 0.65}
{'loss': 1.6643, 'grad_norm': 9.64211368560791, 'learning_rate': 1.8240584598088815e-05, 'epoch': 0.65}
{'loss': 1.5285, 'grad_norm': 5.527931213378906, 'learning_rate': 1.8205452501405286e-05, 'epoch': 0.65}
{'loss': 1.5487, 'grad_norm': 6.533152103424072, 'learning_rate': 1.8170320404721754e-05, 'epoch': 0.65}
{'loss': 1.294, 'grad_norm': 6.314151287078857, 'learning_rate': 1.8135188308038225e-05, 'epoch': 0.65}
{'loss': 1.9409, 'grad_norm': 8.290525436401367, 'learning_rate': 1.8100056211354693e-05, 'epoch': 0.65}
{'loss': 1.5698, 'grad_no

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3841122388839722, 'eval_runtime': 15.4768, 'eval_samples_per_second': 52.853, 'eval_steps_per_second': 52.853, 'epoch': 0.68}
{'loss': 1.7244, 'grad_norm': 9.798770904541016, 'learning_rate': 1.6589376053962902e-05, 'epoch': 0.68}
{'loss': 1.5023, 'grad_norm': 22.2455997467041, 'learning_rate': 1.6554243957279373e-05, 'epoch': 0.68}
{'loss': 1.5113, 'grad_norm': 9.485636711120605, 'learning_rate': 1.651911186059584e-05, 'epoch': 0.68}
{'loss': 1.6162, 'grad_norm': 8.35495376586914, 'learning_rate': 1.648397976391231e-05, 'epoch': 0.68}
{'loss': 1.7951, 'grad_norm': 21.871557235717773, 'learning_rate': 1.644884766722878e-05, 'epoch': 0.68}
{'loss': 1.4455, 'grad_norm': 7.533033847808838, 'learning_rate': 1.641371557054525e-05, 'epoch': 0.68}
{'loss': 1.4413, 'grad_norm': 6.312171936035156, 'learning_rate': 1.6378583473861723e-05, 'epoch': 0.68}
{'loss': 1.3749, 'grad_norm': 6.247389316558838, 'learning_rate': 1.634345137717819e-05, 'epoch': 0.68}
{'loss': 1.4959, 'grad_n

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3749583959579468, 'eval_runtime': 15.3625, 'eval_samples_per_second': 53.246, 'eval_steps_per_second': 53.246, 'epoch': 0.71}
{'loss': 1.2952, 'grad_norm': 7.142134666442871, 'learning_rate': 1.4832771219786399e-05, 'epoch': 0.71}
{'loss': 1.4898, 'grad_norm': 65.28485870361328, 'learning_rate': 1.4797639123102868e-05, 'epoch': 0.71}
{'loss': 1.5635, 'grad_norm': 9.714823722839355, 'learning_rate': 1.4762507026419336e-05, 'epoch': 0.71}
{'loss': 1.6233, 'grad_norm': 7.347331523895264, 'learning_rate': 1.4727374929735807e-05, 'epoch': 0.72}
{'loss': 1.1138, 'grad_norm': 10.278096199035645, 'learning_rate': 1.4692242833052278e-05, 'epoch': 0.72}
{'loss': 1.3253, 'grad_norm': 7.031587600708008, 'learning_rate': 1.4657110736368748e-05, 'epoch': 0.72}
{'loss': 1.4899, 'grad_norm': 11.173827171325684, 'learning_rate': 1.4621978639685216e-05, 'epoch': 0.72}
{'loss': 1.4079, 'grad_norm': 9.022311210632324, 'learning_rate': 1.4586846543001687e-05, 'epoch': 0.72}
{'loss': 1.392, 

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.369065761566162, 'eval_runtime': 15.3937, 'eval_samples_per_second': 53.139, 'eval_steps_per_second': 53.139, 'epoch': 0.75}
{'loss': 1.463, 'grad_norm': 6.532029151916504, 'learning_rate': 1.3076166385609895e-05, 'epoch': 0.75}
{'loss': 1.3793, 'grad_norm': 10.452590942382812, 'learning_rate': 1.3041034288926363e-05, 'epoch': 0.75}
{'loss': 1.4757, 'grad_norm': 9.354741096496582, 'learning_rate': 1.3005902192242834e-05, 'epoch': 0.75}
{'loss': 1.5164, 'grad_norm': 10.21220874786377, 'learning_rate': 1.2970770095559303e-05, 'epoch': 0.75}
{'loss': 1.555, 'grad_norm': 10.814753532409668, 'learning_rate': 1.2935637998875775e-05, 'epoch': 0.75}
{'loss': 1.5216, 'grad_norm': 8.117496490478516, 'learning_rate': 1.2900505902192244e-05, 'epoch': 0.75}
{'loss': 1.6403, 'grad_norm': 15.324651718139648, 'learning_rate': 1.2865373805508712e-05, 'epoch': 0.75}
{'loss': 1.2699, 'grad_norm': 6.610034465789795, 'learning_rate': 1.2830241708825183e-05, 'epoch': 0.75}
{'loss': 1.4663, '

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3663820028305054, 'eval_runtime': 15.5509, 'eval_samples_per_second': 52.602, 'eval_steps_per_second': 52.602, 'epoch': 0.78}
{'loss': 1.3347, 'grad_norm': 6.892908096313477, 'learning_rate': 1.131956155143339e-05, 'epoch': 0.78}
{'loss': 1.1538, 'grad_norm': 16.446290969848633, 'learning_rate': 1.128442945474986e-05, 'epoch': 0.78}
{'loss': 1.5946, 'grad_norm': 7.835308074951172, 'learning_rate': 1.124929735806633e-05, 'epoch': 0.78}
{'loss': 1.4682, 'grad_norm': 8.650940895080566, 'learning_rate': 1.12141652613828e-05, 'epoch': 0.78}
{'loss': 1.4421, 'grad_norm': 11.416398048400879, 'learning_rate': 1.117903316469927e-05, 'epoch': 0.78}
{'loss': 1.3983, 'grad_norm': 7.179330348968506, 'learning_rate': 1.114390106801574e-05, 'epoch': 0.78}
{'loss': 1.5032, 'grad_norm': 6.4817376136779785, 'learning_rate': 1.1108768971332208e-05, 'epoch': 0.79}
{'loss': 1.6486, 'grad_norm': 13.420875549316406, 'learning_rate': 1.107363687464868e-05, 'epoch': 0.79}
{'loss': 1.5851, 'grad

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3639354705810547, 'eval_runtime': 15.3098, 'eval_samples_per_second': 53.43, 'eval_steps_per_second': 53.43, 'epoch': 0.81}
{'loss': 1.9913, 'grad_norm': 23.69110870361328, 'learning_rate': 9.562956717256886e-06, 'epoch': 0.82}
{'loss': 1.6271, 'grad_norm': 8.615445137023926, 'learning_rate': 9.527824620573355e-06, 'epoch': 0.82}
{'loss': 1.6363, 'grad_norm': 9.655973434448242, 'learning_rate': 9.492692523889827e-06, 'epoch': 0.82}
{'loss': 1.3573, 'grad_norm': 4.653561592102051, 'learning_rate': 9.457560427206296e-06, 'epoch': 0.82}
{'loss': 1.6547, 'grad_norm': 8.896451950073242, 'learning_rate': 9.422428330522766e-06, 'epoch': 0.82}
{'loss': 1.5475, 'grad_norm': 6.72760534286499, 'learning_rate': 9.387296233839237e-06, 'epoch': 0.82}
{'loss': 1.3901, 'grad_norm': 15.256664276123047, 'learning_rate': 9.352164137155705e-06, 'epoch': 0.82}
{'loss': 1.3045, 'grad_norm': 9.250490188598633, 'learning_rate': 9.317032040472176e-06, 'epoch': 0.82}
{'loss': 1.4314, 'grad_norm'

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3621128797531128, 'eval_runtime': 15.6165, 'eval_samples_per_second': 52.381, 'eval_steps_per_second': 52.381, 'epoch': 0.85}
{'loss': 1.3338, 'grad_norm': 7.9323506355285645, 'learning_rate': 7.806351883080384e-06, 'epoch': 0.85}
{'loss': 1.4512, 'grad_norm': 5.29426908493042, 'learning_rate': 7.771219786396852e-06, 'epoch': 0.85}
{'loss': 1.8202, 'grad_norm': 8.626018524169922, 'learning_rate': 7.736087689713323e-06, 'epoch': 0.85}
{'loss': 1.3858, 'grad_norm': 5.77044677734375, 'learning_rate': 7.700955593029792e-06, 'epoch': 0.85}
{'loss': 0.998, 'grad_norm': 5.917463302612305, 'learning_rate': 7.665823496346262e-06, 'epoch': 0.85}
{'loss': 1.516, 'grad_norm': 7.943124294281006, 'learning_rate': 7.630691399662732e-06, 'epoch': 0.85}
{'loss': 1.5462, 'grad_norm': 6.454356670379639, 'learning_rate': 7.595559302979202e-06, 'epoch': 0.85}
{'loss': 1.7468, 'grad_norm': 7.689917087554932, 'learning_rate': 7.560427206295672e-06, 'epoch': 0.85}
{'loss': 1.2469, 'grad_norm':

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3588777780532837, 'eval_runtime': 15.4386, 'eval_samples_per_second': 52.984, 'eval_steps_per_second': 52.984, 'epoch': 0.88}
{'loss': 1.6254, 'grad_norm': 13.329899787902832, 'learning_rate': 6.049747048903879e-06, 'epoch': 0.88}
{'loss': 1.3164, 'grad_norm': 8.940234184265137, 'learning_rate': 6.014614952220349e-06, 'epoch': 0.88}
{'loss': 1.5619, 'grad_norm': 11.156977653503418, 'learning_rate': 5.979482855536818e-06, 'epoch': 0.88}
{'loss': 1.6532, 'grad_norm': 12.36038589477539, 'learning_rate': 5.944350758853289e-06, 'epoch': 0.89}
{'loss': 1.5165, 'grad_norm': 9.042210578918457, 'learning_rate': 5.909218662169758e-06, 'epoch': 0.89}
{'loss': 1.3742, 'grad_norm': 5.04778528213501, 'learning_rate': 5.874086565486228e-06, 'epoch': 0.89}
{'loss': 1.5173, 'grad_norm': 11.26514720916748, 'learning_rate': 5.838954468802698e-06, 'epoch': 0.89}
{'loss': 1.2491, 'grad_norm': 5.373864650726318, 'learning_rate': 5.803822372119169e-06, 'epoch': 0.89}
{'loss': 1.6585, 'grad_no

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3610390424728394, 'eval_runtime': 15.339, 'eval_samples_per_second': 53.328, 'eval_steps_per_second': 53.328, 'epoch': 0.92}
{'loss': 1.5275, 'grad_norm': 10.283480644226074, 'learning_rate': 4.293142214727376e-06, 'epoch': 0.92}
{'loss': 1.4462, 'grad_norm': 7.083884239196777, 'learning_rate': 4.258010118043845e-06, 'epoch': 0.92}
{'loss': 1.1602, 'grad_norm': 10.217656135559082, 'learning_rate': 4.222878021360315e-06, 'epoch': 0.92}
{'loss': 1.2193, 'grad_norm': 7.029883861541748, 'learning_rate': 4.187745924676785e-06, 'epoch': 0.92}
{'loss': 1.513, 'grad_norm': 8.677861213684082, 'learning_rate': 4.152613827993255e-06, 'epoch': 0.92}
{'loss': 1.2762, 'grad_norm': 12.497991561889648, 'learning_rate': 4.117481731309725e-06, 'epoch': 0.92}
{'loss': 1.2764, 'grad_norm': 10.115565299987793, 'learning_rate': 4.082349634626195e-06, 'epoch': 0.92}
{'loss': 1.5273, 'grad_norm': 42.6080207824707, 'learning_rate': 4.047217537942664e-06, 'epoch': 0.92}
{'loss': 1.4426, 'grad_no

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3591893911361694, 'eval_runtime': 15.3806, 'eval_samples_per_second': 53.184, 'eval_steps_per_second': 53.184, 'epoch': 0.95}
{'loss': 1.6792, 'grad_norm': 6.2728495597839355, 'learning_rate': 2.536537380550871e-06, 'epoch': 0.95}
{'loss': 1.2141, 'grad_norm': 10.303190231323242, 'learning_rate': 2.501405283867341e-06, 'epoch': 0.95}
{'loss': 1.6912, 'grad_norm': 8.646729469299316, 'learning_rate': 2.4662731871838115e-06, 'epoch': 0.95}
{'loss': 1.6004, 'grad_norm': 9.933825492858887, 'learning_rate': 2.431141090500281e-06, 'epoch': 0.95}
{'loss': 1.4376, 'grad_norm': 7.370118618011475, 'learning_rate': 2.396008993816751e-06, 'epoch': 0.95}
{'loss': 1.4138, 'grad_norm': 5.42567253112793, 'learning_rate': 2.360876897133221e-06, 'epoch': 0.95}
{'loss': 1.1861, 'grad_norm': 10.14351749420166, 'learning_rate': 2.325744800449691e-06, 'epoch': 0.96}
{'loss': 1.2058, 'grad_norm': 15.259187698364258, 'learning_rate': 2.290612703766161e-06, 'epoch': 0.96}
{'loss': 1.7817, 'grad_

  0%|          | 0/818 [00:00<?, ?it/s]

{'eval_loss': 1.3590444326400757, 'eval_runtime': 15.3177, 'eval_samples_per_second': 53.402, 'eval_steps_per_second': 53.402, 'epoch': 0.98}
{'loss': 1.0758, 'grad_norm': 6.68756628036499, 'learning_rate': 7.799325463743676e-07, 'epoch': 0.98}
{'loss': 1.4114, 'grad_norm': 9.218562126159668, 'learning_rate': 7.448004496908376e-07, 'epoch': 0.99}
{'loss': 1.6049, 'grad_norm': 8.420698165893555, 'learning_rate': 7.096683530073076e-07, 'epoch': 0.99}
{'loss': 1.7119, 'grad_norm': 7.854781150817871, 'learning_rate': 6.745362563237774e-07, 'epoch': 0.99}
{'loss': 1.2695, 'grad_norm': 10.833246231079102, 'learning_rate': 6.394041596402473e-07, 'epoch': 0.99}
{'loss': 1.4639, 'grad_norm': 8.286672592163086, 'learning_rate': 6.042720629567172e-07, 'epoch': 0.99}
{'loss': 1.4571, 'grad_norm': 27.17167854309082, 'learning_rate': 5.691399662731872e-07, 'epoch': 0.99}
{'loss': 1.3025, 'grad_norm': 9.149006843566895, 'learning_rate': 5.340078695896572e-07, 'epoch': 0.99}
{'loss': 1.3942, 'grad_nor

Non-default generation parameters: {'max_length': 128, 'min_length': 32, 'num_beams': 8, 'length_penalty': 0.8, 'forced_eos_token_id': 1}


{'train_runtime': 2516.3417, 'train_samples_per_second': 5.855, 'train_steps_per_second': 5.855, 'train_loss': 1.5790535163659296, 'epoch': 1.0}


TrainOutput(global_step=14732, training_loss=1.5790535163659296, metrics={'train_runtime': 2516.3417, 'train_samples_per_second': 5.855, 'train_steps_per_second': 5.855, 'total_flos': 5531718781673472.0, 'train_loss': 1.5790535163659296, 'epoch': 1.0})

## **Evaluation**

In [25]:
# load model and tokenizer
model_path = "pegasus-samsum/checkpoint-14732"
trained_model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [23]:
def generate_batch_sized_chuncks(list_of_elements: list, batch_size: int):
    """Split dataset into smaller batches that can be processed simultaneously
    Yield successive batch-sized chunks from list of elements.

    Args:
        list_of_elements (list): List with elements to be split on batches
        batch_size (int): Number of elements per batch

    Yield:
        list: Batches
    """
    for i in range(0, len(list_of_elements), batch_size):
        yield list_of_elements[i : i + batch_size]


def calculate_test_metric(
    dataset: datasets.arrow_dataset.Dataset,
    metric,
    model: transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration,
    tokenizer: transformers.models.pegasus.tokenization_pegasus_fast.PegasusTokenizerFast,
    batch_size: int = 16,
    device: str = "cuda",
    column_text: str = "article",
    column_summary: str = "highlights",
) -> float:
    """_summary_

    Args:
        dataset (datasets.arrow_dataset.Dataset): _description_
        metric (str): _description_
        model (transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration): _description_
        tokenizer (transformers.models.pegasus.tokenization_pegasus_fast.PegasusTokenizerFast): _description_
        batch_size (int): _description_
        column_text (str, optional): _description_. Defaults to "article".
        column_summary (str, optional): _description_. Defaults to "highlights".

    Returns:
        float: _description_
    """
    article_batches = list(
        generate_batch_sized_chuncks(dataset[column_text], batch_size)
    )
    target_batches = list(
        generate_batch_sized_chuncks(dataset[column_summary], batch_size)
    )

    for article_batch, target_batch in tqdm(
        zip(article_batches, target_batches), total=len(article_batches)
    ):

        inputs = tokenizer(
            article_batch,
            max_length=1024,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        summaries = model.generate(
            input_ids=inputs["input_ids"].to(device),
            attention_mask=inputs["attention_mask"].to(device),
            length_penalty=0.8,
            num_beams=8,
            max_length=128,  # avoid long sequences
        )

        decoded_summaries = [
            tokenizer.decode(
                s,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            for s in summaries
        ]

        decoded_summaries = [d.replace("", " ") for d in decoded_summaries]

        metric.add_batch(predictions=decoded_summaries, references=target_batch)

    return metric.compute()

In [11]:
rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
rouge_metric = load_metric("rouge")

In [26]:
score = calculate_test_metric(
    dataset["test"],
    rouge_metric,
    trained_model,
    tokenizer,
    batch_size=2,
    device="cuda",
    column_text="dialogue",
    column_summary="summary",
)

rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)

pd.DataFrame(rouge_dict, index=["pegasus"])

100%|██████████| 410/410 [04:25<00:00,  1.55it/s]


Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
pegasus,0.01829,0.000369,0.018192,0.018176


In [27]:
trained_model.save_pretrained("pegasus-samsum-model")
tokenizer.save_pretrained("pegasus-samsum-model")

Non-default generation parameters: {'max_length': 128, 'min_length': 32, 'num_beams': 8, 'length_penalty': 0.8, 'forced_eos_token_id': 1}


('pegasus-samsum-model/tokenizer_config.json',
 'pegasus-samsum-model/special_tokens_map.json',
 'pegasus-samsum-model/spiece.model',
 'pegasus-samsum-model/added_tokens.json',
 'pegasus-samsum-model/tokenizer.json')

## **Prediction**

In [30]:
gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}


sample_text = dataset["test"][5]["dialogue"]

reference = dataset["test"][5]["summary"]

pipe = pipeline("summarization", model="pegasus-samsum-model", tokenizer=tokenizer)

##
print("Dialogue:")
print(sample_text)


print("\nReference Summary:")
print(reference)


print("\nModel Summary:")
print(pipe(sample_text, **gen_kwargs)[5]["summary_text"])

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


Dialogue:
Benjamin: Hey guys, what are we doing with the keys today?
Hilary: I've got them. Whoever wants them can meet me at lunchtime or after
Elliot: I'm ok. We're meeting for the drinks in the evening anyway and I guess we'll be going back to the apartment together?
Hilary: Yeah, I guess so
Daniel: I'm with Hilary atm and won't let go of her for the rest of the day, so any option you guys choose is good for me
Benjamin: Hmm I might actually pass by at lunchtime, take the keys and go take a nap. I'm sooo tired after yesterday
Hilary: Sounds good. We'll be having lunch with some French people (the ones who work on the history of food in colonial Mexico - I already see you yawning your head off)
Benjamin: YAAAAWN 🙊 Where and where are you meeting?
Hilary: So I'm meeting them at the entrance to the conference hall at 2 pm and then we'll head to this place called La Cantina. Italian cuisine, which is quite funny, but that's what they've chosen
Benjamin: Interesting 😱 To be honest, Hilar