In [16]:
import logging
from transformers import BartForConditionalGeneration, BartTokenizer
from datasets import load_dataset
from rouge_score import rouge_scorer
from bert_score import score
from tqdm import tqdm
import numpy as np

# Adjust the root logger's level (which affects all libraries unless overridden)
logging.basicConfig(level=logging.ERROR)

# Set specific loggers for libraries that are too verbose
datasets_logger = logging.getLogger("datasets")
transformers_logger = logging.getLogger("transformers")
datasets_logger.setLevel(logging.ERROR)
transformers_logger.setLevel(logging.ERROR)

# Load the model and tokenizer
model_path = 'final_model_summarization'
model = BartForConditionalGeneration.from_pretrained(model_path)
tokenizer = BartTokenizer.from_pretrained(model_path)

# Load the dataset
dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0", split='validation')
sampled_dataset = dataset.shuffle(seed=42).select(range(int(len(dataset) * 0.1)))

# Initialize ROUGE scorer
rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# Lists to hold scores and summaries
rouge_scores = []
bert_precisions = []
bert_recalls = []
bert_f1s = []

# Function to generate summaries
def generate_summary(article):
    inputs = tokenizer(article, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# Iterate over the sampled dataset
for example in tqdm(sampled_dataset, desc="Processing dataset"):
    article = example['article']
    highlight = example['highlights']
    
    # Generate summary
    generated_summary = generate_summary(article)

    # Compute ROUGE score
    rouge_scores.append(rouge_scorer.score(generated_summary, highlight))
    
    # Compute BERTScore
    P, R, F1 = score([generated_summary], [highlight], lang='en', verbose=False)
    bert_precisions.append(P.numpy())
    bert_recalls.append(R.numpy())
    bert_f1s.append(F1.numpy())

# Calculate average scores
average_rouge = {key: np.mean([score[key].fmeasure for score in rouge_scores]) for key in ['rouge1', 'rouge2', 'rougeL']}
average_bert_precision = np.mean(bert_precisions)
average_bert_recall = np.mean(bert_recalls)
average_bert_f1 = np.mean(bert_f1s)

print("Average ROUGE scores:", average_rouge)
print("Average BERT Precision:", average_bert_precision)
print("Average BERT Recall:", average_bert_recall)
print("Average BERT F1:", average_bert_f1)

INFO:absl:Using default tokenizer.
Processing dataset: 100%|██████████| 1336/1336 [1:57:29<00:00,  5.28s/it] 


Average ROUGE scores: {'rouge1': 0.43579883764465166, 'rouge2': 0.20855940463406772, 'rougeL': 0.30107523353961063}
Average BERT Precision: 0.88010025
Average BERT Recall: 0.87622803
Average BERT F1: 0.8780281
