In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import datasets
import evaluate
from tqdm import tqdm
from src import SCRATCH_CACHE_DIR

CNNDM = "cnn_dailymail"
MODEL_NAME = "google-t5/t5-small"

cnn_dailymail = datasets.load_dataset(CNNDM, "3.0.0", cache_dir=SCRATCH_CACHE_DIR)

tokenizer = T5Tokenizer.from_pretrained(
    "t5-small", cache_dir=SCRATCH_CACHE_DIR, legacy=False
)
model = T5ForConditionalGeneration.from_pretrained(
    "t5-small", cache_dir=SCRATCH_CACHE_DIR
)
# Test:
test_source = cnn_dailymail["test"]["article"]
test_reference = cnn_dailymail["test"]["highlights"]
tokenizer_source = tokenizer(
    ["summarize: " + s for s in test_source],
    padding=True,
    truncation=True,
    return_tensors="pt",
)
source_input_ids = tokenizer_source.input_ids


  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
output_beam = model.generate(
    source_input_ids[0:16],
    max_length=150,
    num_beams=4,
    num_return_sequences=4,
    early_stopping=True,
    output_scores=True,
    return_dict_in_generate=True,
)

In [5]:
predictions = tokenizer.batch_decode(output_beam["sequences"], skip_special_tokens=True)

In [6]:
output_beam.__dict__

{'sequences': tensor([[    0,     8,  7692,  ...,     0,     0,     0],
         [    0,     8,  7692,  ...,     0,     0,     0],
         [    0,     8,  7692,  ...,     0,     0,     0],
         ...,
         [    0, 20723,    31,  ...,     0,     0,     0],
         [    0, 20723,    31,  ...,     0,     0,     0],
         [    0, 20723,    31,  ...,     0,     0,     0]]),
 'sequences_scores': tensor([-0.3166, -0.3170, -0.3212, -0.3273, -0.3188, -0.3355, -0.3517, -0.3543,
         -0.2504, -0.2543, -0.2629, -0.2650, -0.2873, -0.2912, -0.2966, -0.3059,
         -0.2323, -0.2349, -0.2691, -0.2712, -0.3990, -0.4024, -0.4041, -0.4620,
         -0.2384, -0.2414, -0.2497, -0.2525, -0.2510, -0.2703, -0.2887, -0.2894,
         -0.3178, -0.3243, -0.3324, -0.3325, -0.2032, -0.2107, -0.2286, -0.2342,
         -0.3042, -0.3123, -0.3174, -0.3290, -0.3191, -0.3212, -0.3248, -0.3294,
         -0.2452, -0.2528, -0.2666, -0.2703, -0.2449, -0.2531, -0.2532, -0.2595,
         -0.3635, -0.3878, -0.

In [35]:
rouge = evaluate.load("rouge")
rouge.compute(predictions=predictions, references=[ref for ref in test_reference[0:16] for _ in range(4)], use_aggregator=False)

{'rouge1': [0.3661971830985916,
  0.33333333333333337,
  0.35616438356164387,
  0.3661971830985916,
  0.4444444444444444,
  0.4444444444444444,
  0.3950617283950617,
  0.3950617283950617,
  0.07228915662650602,
  0.07228915662650602,
  0.075,
  0.075,
  0.2222222222222222,
  0.25641025641025644,
  0.2368421052631579,
  0.21333333333333335,
  0.4705882352941176,
  0.5,
  0.4705882352941176,
  0.5,
  0.30769230769230765,
  0.30303030303030304,
  0.30303030303030304,
  0.39344262295081966,
  0.6538461538461539,
  0.45714285714285713,
  0.5833333333333334,
  0.3917525773195876,
  0.3404255319148936,
  0.3404255319148936,
  0.41758241758241754,
  0.3218390804597701,
  0.29999999999999993,
  0.32653061224489793,
  0.32,
  0.3103448275862069,
  0.3492063492063492,
  0.34375,
  0.3492063492063492,
  0.34375,
  0.4210526315789474,
  0.36363636363636365,
  0.3928571428571428,
  0.47457627118644075,
  0.0923076923076923,
  0.08695652173913043,
  0.0909090909090909,
  0.09375000000000001,
  0.0563

In [None]:
output_sampling = model.generate(
    source_input_ids[0:1],
    max_length=150,
    do_sample=True,
    num_return_sequences=4,
    early_stopping=True,
    output_scores=True,
    return_dict_in_generate=True,
)

In [23]:
tokenizer.batch_decode(output_sampling[0], skip_special_tokens=True)

['court given the 123rd member of the ICC to alleged crimes. the deal opens a preliminary examination into alleged crimes in occupied palestinians. the international criminal court will be held in the mid-tomorrday evening (dursday) the ICC opened the first international inquiry into the situation in the occupied territories.',
 'the 123rd international member of the international Criminal court was announced on Wednesday. the ICC already agreed on its founding Rome Statute in January. "the world is also a step closer," the ICC official says.',
 'ipcc has become the 123rd member of the international criminal court. the legal agreement gives the international court jurisdiction on alleged crimes in their territories. a preliminary examination into the situation in other palestinians paved the way for possible war crimes. a human rights activist said the request was a move toward greater justice.',
 "the ICC became 123rd member of the international criminal court on the rights. the Pales

In [None]:
# Batch
batch_size = 8
num_batches = source_input_ids.shape[0] // batch_size
outputs = []
for i in tqdm(range(1)):
    outputs.append(
        model.generate(
            source_input_ids[i * batch_size : (i + 1) * batch_size],
            max_new_tokens=50,
            num_beams=5,
            num_return_sequences=5,
            early_stopping=True,
            output_scores=True,
        )
    )