In [1]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Access the API_KEY variable
api_key = os.getenv("API_KEY")
huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
serpapi_api_key = os.getenv("SERPAPI_API_KEY")

In [2]:
import torch
from datasets import load_dataset

full_dataset = load_dataset(
    "cnn_dailymail", version="3.0.0"
)

# Use a small sample of the data during this lab, for speed.
sample_size = 100
sample = (
    full_dataset["train"]
    .filter(lambda r: "CNN" in r["article"][:25])
    .shuffle(seed=42)
    .select(range(sample_size))
)
sample

  table = cls._concat_blocks(blocks, axis=0)


Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 100
})

In [3]:
print(sample.to_pandas())

                                              article  \
0   (CNN) -- A magnitude 6.7 earthquake rattled Pa...   
1   (CNN) -- Pakistan took big steps towards level...   
2   (CNN) -- Federal prosecutors are pushing to fo...   
3   Centennial, Colorado (CNN) -- McKayla Hicks sa...   
4   (CNN) -- Double-amputee sprinter Oscar Pistori...   
..                                                ...   
95  (CNN) -- Samuel Eto'o netted a superb hat-tric...   
96  Washington (CNN) -- President Barack Obama's r...   
97  (CNN) -- Violence swept across Syria on Friday...   
98  (CNN) -- New HIV infections have fallen worldw...   
99  CHENGDU, China (CNN) -- Rainy weather and poor...   

                                           highlights  \
0   Papua New Guinea is on the so-called Ring of F...   
1   Australia collapse to 88 all out on opening da...   
2   Jared Loughner is refusing the government's re...   
3   Shooting victim McKayla Hicks went to hearing ...   
4   Oscar Pistorius to become 

In [4]:
example_article = sample["article"][0]
example_summary = sample["highlights"][0]
print(f"Article:\n{example_article}\n")
print(f"Summary:\n{example_summary}")

Article:

Summary:
Papua New Guinea is on the so-called Ring of Fire .
It's on an arc of fault lines that is prone to frequent earthquakes .


### Summarization

In [5]:
import pandas as pd
import torch
import gc
from transformers import AutoTokenizer, T5ForConditionalGeneration

In [11]:
def batch_generator(data: list, batch_size: int):
    """
    Creates batches of size `batch_size` from a list.
    """
    s = 0
    e = s + batch_size
    while s < len(data):
        yield data[s:e]
        s = e
        e = min(s + batch_size, len(data))


def summarize_with_t5(
    model_checkpoint: str, articles: list, batch_size: int = 8
) -> list:
    """
    Compute summaries using a T5 model.
    This is similar to a `pipeline` for a T5 model but does tokenization manually.

    :param model_checkpoint: Name for a model checkpoint in Hugging Face, such as "t5-small" or "t5-base"
    :param articles: List of strings, where each string represents one article.
    :return: List of strings, where each string represents one article's generated summary
    """
    if torch.cuda.is_available():
        device = "cuda:0"
    else:
        device = "cpu"

    model = T5ForConditionalGeneration.from_pretrained(
        model_checkpoint
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(
        model_checkpoint, model_max_length=1024
    )

    def perform_inference(batch: list) -> list:
        inputs = tokenizer(
            batch, max_length=1024, return_tensors="pt", padding=True, truncation=True
        )

        summary_ids = model.generate(
            inputs.input_ids.to(device),
            attention_mask=inputs.attention_mask.to(device),
            num_beams=2,
            min_length=0,
            max_length=40,
        )
        return tokenizer.batch_decode(summary_ids, skip_special_tokens=True)

    res = []

    summary_articles = list(map(lambda article: "summarize: " + article, articles))
    for batch in batch_generator(summary_articles, batch_size=batch_size):
        res += perform_inference(batch)

        torch.cuda.empty_cache()
        gc.collect()

    # clean up
    del tokenizer
    del model
    torch.cuda.empty_cache()
    gc.collect()
    return res

In [12]:
t5_small_summaries = summarize_with_t5("t5-small", sample["article"])

In [13]:
reference_summaries = sample["highlights"]

In [14]:
print(
    pd.DataFrame.from_dict(
        {
            "generated": t5_small_summaries,
            "reference": reference_summaries,
        }
    )
)

                                            generated  \
0   a magnitude 6.7 earthquake rattles Papua new G...   
1   the two-Test cricket series is being played in...   
2   federal prosecutors want jared Lee Loughner to...   
3   new: "he tried to kill people," a 17-year-old ...   
4   double-amputee sprinter Oscar Pistorius will c...   
..                                                ...   
95  holders Inter Milan thrash Werder Bremen 4-0 i...   
96  president's re-election campaign raises $71 mi...   
97  at least 75 people were killed in protests, an...   
98  new infections have fallen by 17 percent in th...   
99  nearly 10,000 people died in quake in central ...   

                                            reference  
0   Papua New Guinea is on the so-called Ring of F...  
1   Australia collapse to 88 all out on opening da...  
2   Jared Loughner is refusing the government's re...  
3   Shooting victim McKayla Hicks went to hearing ...  
4   Oscar Pistorius to become first

In [15]:
accuracy = 0.0
for i in range(len(reference_summaries)):
    generated_summary = t5_small_summaries[i]
    if generated_summary == reference_summaries[i]:
        accuracy += 1.0
accuracy = accuracy / len(reference_summaries)

print(f"Achieved accuracy {accuracy}!")

Achieved accuracy 0.0!


### ROUGE
Now that we can generate summaries---and we know 0/1 accuracy is useless here---let's look at how we can compute a meaningful metric designed to evaluate summarization: ROUGE.

Recall-Oriented Understudy for Gisting Evaluation (ROUGE) is a set of evaluation metrics designed for comparing summaries from Lin et al., 2004. See Wikipedia for more info. Here, we use the Hugging Face Evaluator wrapper to call into the rouge_score package. This package provides 4 scores:

rouge1: ROUGE computed over unigrams (single words or tokens)
rouge2: ROUGE computed over bigrams (pairs of consecutive words or tokens)
rougeL: ROUGE based on the longest common subsequence shared by the summaries being compared
rougeLsum: like rougeL, but at "summary level," i.e., ignoring sentence breaks (newlines)

In [18]:
import evaluate
import nltk
from nltk.tokenize import sent_tokenize

nltk.download("punkt")

rouge_score = evaluate.load("rouge")

[nltk_data] Downloading package punkt to C:\Users\HITESH
[nltk_data]     PATIL\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [19]:
def compute_rouge_score(generated: list, reference: list) -> dict:
    """
    Compute ROUGE scores on a batch of articles.

    This is a convenience function wrapping Hugging Face `rouge_score`,
    which expects sentences to be separated by newlines.

    :param generated: Summaries (list of strings) produced by the model
    :param reference: Ground-truth summaries (list of strings) for comparison
    """
    generated_with_newlines = ["\n".join(sent_tokenize(s.strip())) for s in generated]
    reference_with_newlines = ["\n".join(sent_tokenize(s.strip())) for s in reference]
    return rouge_score.compute(
        predictions=generated_with_newlines,
        references=reference_with_newlines,
        use_stemmer=True,
    )

In [20]:
# ROUGE scores for our batch of articles
compute_rouge_score(t5_small_summaries, reference_summaries)

{'rouge1': 0.30974757717934137,
 'rouge2': 0.10631458746437521,
 'rougeL': 0.22119603468138754,
 'rougeLsum': 0.2823100338265827}

In [21]:
# Sanity check: What if our predictions match the references exactly?
compute_rouge_score(reference_summaries, reference_summaries)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}

In [22]:
# And what if we fail to predict anything?
compute_rouge_score(
    generated=["" for _ in range(len(reference_summaries))],
    reference=reference_summaries,
)

{'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0, 'rougeLsum': 0.0}

In [23]:
rouge_score.compute(
    predictions=["Large language models beat world record"],
    references=["Large language models beating world records"],
    use_stemmer=False,
)

{'rouge1': 0.6666666666666666,
 'rouge2': 0.4000000000000001,
 'rougeL': 0.6666666666666666,
 'rougeLsum': 0.6666666666666666}

In [24]:
    rouge_score.compute(
    predictions=["Large language models beat world record"],
    references=["Large language models beating world records"],
    use_stemmer=True,
)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}

In [25]:
# What if we predict exactly 1 word correctly?
rouge_score.compute(
    predictions=["Large language models beat world record"],
    references=["Large"],
    use_stemmer=True,
)

{'rouge1': 0.2857142857142857,
 'rouge2': 0.0,
 'rougeL': 0.2857142857142857,
 'rougeLsum': 0.2857142857142857}

In [26]:
# The ROUGE score is symmetric with respect to predictions and references.
rouge_score.compute(
    predictions=["Large"],
    references=["Large language models beat world record"],
    use_stemmer=True,
)

{'rouge1': 0.2857142857142857,
 'rouge2': 0.0,
 'rougeL': 0.2857142857142857,
 'rougeLsum': 0.2857142857142857}

In [27]:
# What about 2 words?  Note how 'rouge1' and 'rouge2' compare with the case when we predict exactly 1 word correctly.
rouge_score.compute(
    predictions=["Large language"],
    references=["Large language models beat world record"],
    use_stemmer=True,
)

{'rouge1': 0.5, 'rouge2': 0.33333333333333337, 'rougeL': 0.5, 'rougeLsum': 0.5}

In [28]:
# Note how rouge1 differs from the rougeN (N>1) scores when we predict word subsequences correctly.
rouge_score.compute(
    predictions=["Models beat large language world record"],
    references=["Large language models beat world record"],
    use_stemmer=True,
)

{'rouge1': 1.0,
 'rouge2': 0.6,
 'rougeL': 0.6666666666666666,
 'rougeLsum': 0.6666666666666666}

### Compare small and large models
We've been working with the t5-small model so far. Let's compare several models with different architectures in terms of their ROUGE scores and some example generated summaries.

In [29]:
def compute_rouge_per_row(
    generated_summaries: list, reference_summaries: list
) -> pd.DataFrame:
    """
    Generates a dataframe to compare rogue score metrics.
    """
    generated_with_newlines = [
        "\n".join(sent_tokenize(s.strip())) for s in generated_summaries
    ]
    reference_with_newlines = [
        "\n".join(sent_tokenize(s.strip())) for s in reference_summaries
    ]
    scores = rouge_score.compute(
        predictions=generated_with_newlines,
        references=reference_with_newlines,
        use_stemmer=True,
        use_aggregator=False,
    )
    scores["generated"] = generated_summaries
    scores["reference"] = reference_summaries
    return pd.DataFrame.from_dict(scores)

### T5-small
The T5 [paper] family of models are text-to-text transformers that have been trained on a multi-task mixture of unsupervised and supervised tasks. They are well suited for task such as summarization, translation, text classification, question answering, and more.

The t5-small version of the T5 models has 60 million parameters.

In [30]:
# We computed t5_small_summaries above already.
compute_rouge_score(t5_small_summaries, reference_summaries)

{'rouge1': 0.30974757717934137,
 'rouge2': 0.10631458746437521,
 'rougeL': 0.22119603468138754,
 'rougeLsum': 0.2823100338265827}

In [33]:
t5_small_results = compute_rouge_per_row(
    generated_summaries=t5_small_summaries, reference_summaries=reference_summaries
)
display(t5_small_results)

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum,generated,reference
0,0.407407,0.230769,0.296296,0.407407,a magnitude 6.7 earthquake rattles Papua new G...,Papua New Guinea is on the so-called Ring of F...
1,0.239130,0.000000,0.195652,0.217391,the two-Test cricket series is being played in...,Australia collapse to 88 all out on opening da...
2,0.454545,0.156250,0.393939,0.454545,federal prosecutors want jared Lee Loughner to...,Jared Loughner is refusing the government's re...
3,0.373333,0.191781,0.266667,0.346667,"new: ""he tried to kill people,"" a 17-year-old ...",Shooting victim McKayla Hicks went to hearing ...
4,0.263158,0.108108,0.184211,0.210526,double-amputee sprinter Oscar Pistorius will c...,Oscar Pistorius to become first double-amputee...
...,...,...,...,...,...,...
95,0.444444,0.285714,0.250000,0.277778,holders Inter Milan thrash Werder Bremen 4-0 i...,Samuel Eto'o scored a hat-trick as Inter Milan...
96,0.320000,0.082192,0.213333,0.320000,president's re-election campaign raises $71 mi...,Obama raised almost $30 million less than Romn...
97,0.155844,0.026667,0.077922,0.155844,"at least 75 people were killed in protests, an...",NEW: U.N. Secretary-General Ban Ki-moon joins ...
98,0.425000,0.102564,0.300000,0.375000,new infections have fallen by 17 percent in th...,New infections in sub-Saharan Africa 15 percen...
