In [1]:
!pip install transformers datasets rouge_score absl-py nltk

# Import necessary libraries
from datasets import load_dataset, load_metric
from transformers import BartTokenizer, BartForConditionalGeneration
import random

# Load the dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

# Initialize tokenizer and model
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# Sample 5 examples for few-shot learning
few_shot_sample_indices = random.sample(range(100), 5)  # Using 100 to ensure enough data
few_shot_examples = [
    {
        "article": dataset['train'][i]['article'],
        "summary": dataset['train'][i]['highlights']
    }
    for i in few_shot_sample_indices
]

# Function to create few-shot prompt
def create_few_shot_prompt(article, examples):
    prompt = "Here are some examples of article summaries:\n"
    for ex in examples:
        prompt += f"Article: {ex['article'][:500]}...\nSummary: {ex['summary']}\n\n"  # Truncate long articles for prompt
    prompt += f"Now, summarize the following article:\n{article[:500]}..."  # Truncate long articles for prompt
    return prompt

# Function to create zero-shot prompt
def create_zero_shot_prompt(article):
    prompt = f"Summarize the following article:\n{article[:500]}..."  # Truncate long articles for prompt
    return prompt

# Sample test articles
test_sample_indices = range(10, 15)
test_articles = [dataset['test'][i]['article'] for i in test_sample_indices]
test_references = [dataset['test'][i]['highlights'] for i in test_sample_indices]

# Generate few-shot prompts for test articles
few_shot_prompts = [create_few_shot_prompt(article, few_shot_examples) for article in test_articles]
# Generate zero-shot prompts for test articles
zero_shot_prompts = [create_zero_shot_prompt(article) for article in test_articles]

# Generate summaries using few-shot prompts
few_shot_summaries = []
for prompt in few_shot_prompts:
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    few_shot_summaries.append(summary)

# Generate summaries using zero-shot prompts
zero_shot_summaries = []
for prompt in zero_shot_prompts:
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    zero_shot_summaries.append(summary)

# Initialize ROUGE metric
rouge = load_metric("rouge")

# Evaluate the generated summaries using ROUGE
few_shot_results = []
for summary, reference in zip(few_shot_summaries, test_references):
    result = rouge.compute(predictions=[summary], references=[reference])
    few_shot_results.append(result)

zero_shot_results = []
for summary, reference in zip(zero_shot_summaries, test_references):
    result = rouge.compute(predictions=[summary], references=[reference])
    zero_shot_results.append(result)

# Display results
print("Few-shot learning results:")
for i, result in enumerate(few_shot_results):
    print(f"\nTest Article {i+1}:")
    print(f"Generated Summary: {few_shot_summaries[i]}")
    print(f"Reference Summary: {test_references[i]}")
    print(f"ROUGE Scores: {result}")

print("Zero-shot learning results:")
for i, result in enumerate(zero_shot_results):
    print(f"\nTest Article {i+1}:")
    print(f"Generated Summary: {zero_shot_summaries[i]}")
    print(f"Reference Summary: {test_references[i]}")
    print(f"ROUGE Scores: {result}")


Collecting transformers
  Downloading transformers-4.43.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting absl-py
  Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting nltk
  Downloading nltk-3.8.1-py3-none-any.whl.metadata (2.8 kB)
Collecting huggingface-hub<1.0,>=0.23.2 (from transformers)
  Downloading huggingface_hub-0.24.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Collecting 

Downloading readme:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/259M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

  rouge = load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

The repository for rouge contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/rouge.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Few-shot learning results:

Test Article 1:
Generated Summary: "Transplant tourists" travel to poor countries to buy organs from the desperate. Pakistan, where trade in human organs is legal, is turning into a "kidney bazaar" O.J. Simpson will be held without bail after his arrest on robbery and assault charges.
Reference Summary: London's Metropolitan Police say the man was arrested at Luton airport after landing on a flight from Istanbul .
He's been charged with terror offenses allegedly committed since the start of November .
ROUGE Scores: {'rouge1': AggregateScore(low=Score(precision=0.11904761904761904, recall=0.15151515151515152, fmeasure=0.13333333333333333), mid=Score(precision=0.