In [None]:
!pip install datasets transformers evaluate sacrebleu bert_score

In [None]:
!pip install rouge_score

In [None]:
import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset, DatasetDict
from transformers import (
    BartForConditionalGeneration,
    BartTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
import sacrebleu

In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [None]:
dataset = load_dataset("alexfabbri/multi_news", trust_remote_code=True)
print("Original splits:", dataset)

In [None]:
for split in dataset.keys():
    print(f"Split: {split}, Number of samples: {len(dataset[split])}")

In [None]:
if "validation" not in dataset.keys():
    train_valid = dataset["train"].train_test_split(test_size=0.1, seed=seed)
    dataset = DatasetDict({
        "train": train_valid["train"],
        "validation": train_valid["test"],
        "test": dataset["test"]
    })
    print("After splitting, splits:", dataset)

In [None]:
model_checkpoint = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_checkpoint)

In [None]:
max_input_length = 1024
max_target_length = 256

In [None]:
def tokenize_function(examples):
    model_inputs = tokenizer(examples["document"], max_length=max_input_length, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["document", "summary"])


In [None]:
tokenized_dataset_path = "./tokenized_multinews"
if not os.path.exists(tokenized_dataset_path):
    tokenized_datasets.save_to_disk(tokenized_dataset_path)
    print(f"Tokenized dataset saved to {tokenized_dataset_path}")

In [None]:
min_train_samples = 1000
min_valid_samples = 100
if len(tokenized_datasets["train"]) > min_train_samples:
    tokenized_datasets["train"] = tokenized_datasets["train"].select(range(min_train_samples))
if len(tokenized_datasets["validation"]) > min_valid_samples:
    tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(min_valid_samples))

In [None]:
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        inputs = self._prepare_inputs(inputs)
        generated_tokens = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_target_length,
            num_beams=8,
            length_penalty=1.5,
            no_repeat_ngram_size=3,
            min_length=50
        )
        labels = inputs.get("labels")
        loss = None
        if not prediction_loss_only:
            with torch.no_grad():
                outputs = model(**inputs)
                loss = outputs.loss
        return (loss, generated_tokens, labels)

In [None]:
rouge_metric = evaluate.load("rouge")
sacrebleu_metric = evaluate.load("sacrebleu")
bertscore_metric = evaluate.load("bertscore")

In [None]:
def safe_decode(batch_ids):
    decoded_batch = []
    for ids in batch_ids:
        if isinstance(ids, np.ndarray):
            ids = ids.tolist()
        tokens = tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)
        tokens = [t if t is not None else "" for t in tokens]
        decoded_batch.append("".join(tokens))
    return decoded_batch

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = safe_decode(predictions)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = safe_decode(labels)

    rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
    bleu = sacrebleu.corpus_bleu(decoded_preds, [[ref] for ref in decoded_labels], smooth_method="exp")
    bleu_score = bleu.score
    bertscore_result = bertscore_metric.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    avg_bertscore = np.mean(bertscore_result["f1"])

    return {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bleu": bleu_score,
        "bertscore": avg_bertscore
    }

In [None]:
experiment_configs = [
    {"name": "config3", "learning_rate": 5e-5, "train_batch_size": 4, "num_train_epochs": 3},
]

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=BartForConditionalGeneration.from_pretrained(model_checkpoint),
    padding="longest"
)

In [None]:
print("\nStarting Experiment: config3")
config3 = experiment_configs[0]
model_config3 = BartForConditionalGeneration.from_pretrained(model_checkpoint)
training_args_config3 = Seq2SeqTrainingArguments(
    output_dir=f"multinews_bart_base_{config3['name']}",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=config3["train_batch_size"],
    per_device_eval_batch_size=config3["train_batch_size"],
    learning_rate=config3["learning_rate"],
    num_train_epochs=config3["num_train_epochs"],
    bf16=True,
    logging_dir=f'./logs_{config3["name"]}',
    logging_steps=50,
    predict_with_generate=True,
    report_to=[],
)
trainer_config3 = CustomSeq2SeqTrainer(
    model=model_config3,
    args=training_args_config3,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)
trainer_config3.train()
trainer_config3.save_pretrained("./multinews_model")
tokenizer.save_pretrained("./multinews_model")


In [None]:
trainer_config3.model.save_pretrained("./multinews_model")
tokenizer.save_pretrained("./multinews_model")

In [None]:
train_losses = []
eval_losses = []
for log in trainer_config3.state.log_history:
    if "loss" in log:
        train_losses.append(log["loss"])
    if "eval_loss" in log:
        eval_losses.append(log["eval_loss"])
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(range(len(train_losses), len(train_losses) + len(eval_losses)), eval_losses, label="Validation Loss")
plt.xlabel("Logging Steps")
plt.ylabel("Loss")
plt.title("Training and Validation Loss for config3")
plt.legend()
plt.show()

In [None]:
eval_results_config3 = trainer_config3.evaluate(eval_dataset=tokenized_datasets["validation"])
print(f"Validation results for config3:")
print(eval_results_config3)

In [None]:
test_metrics = trainer_config3.evaluate(eval_dataset=tokenized_datasets["test"].select(range(100)))
print("Test metrics for Multi-News config3:")
print(test_metrics)

In [None]:
metrics_to_plot = {
    "ROUGE-1": test_metrics["eval_rouge1"] * 100,
    "ROUGE-2": test_metrics["eval_rouge2"] * 100,
    "ROUGE-L": test_metrics["eval_rougeL"] * 100,
    "BLEU": test_metrics["eval_bleu"],
    "BERTScore": test_metrics["eval_bertscore"] * 100,
}
plt.figure(figsize=(8, 5))
plt.bar(list(metrics_to_plot.keys()), list(metrics_to_plot.values()))
plt.title("Test Metrics (%) for Best Model")
plt.xlabel("Metric")
plt.ylabel("Score")
plt.show()

Below is the code of app.py that I used to deploy on hugging face. The below cell won't work as it's for gradio on hugging face.

In [None]:
import gradio as gr
import torch
from transformers import BartForConditionalGeneration, BartTokenizer

model_path = "./multinews_model" 
model = BartForConditionalGeneration.from_pretrained(model_path)
tokenizer = BartTokenizer.from_pretrained(model_path)

if torch.cuda.is_available():
    model.to("cuda")
model.eval()

def summarize(text):
    inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
    if torch.cuda.is_available():
        inputs = {k: v.to("cuda") for k, v in inputs.items()}
    summary_ids = model.generate(
        **inputs,
        max_length=256,
        num_beams=8,
        length_penalty=1.5,
        no_repeat_ngram_size=3,
        min_length=50,
        early_stopping=True
    )
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

interface = gr.Interface(
    fn=summarize,
    inputs=gr.Textbox(lines=10, label="Enter Text to Summarize"),
    outputs=gr.Textbox(label="Summary"),
    title="Multi-News Summarization",
    description="Enter a news article to get a summary generated by the fine-tuned BART model."
)

interface.launch()

## Multi-News Config3 Training & Evaluation

**Data Prep & Tokenization:**  
- We have loaded the Multi-News dataset (using "document" and "summary" columns) then created train/validation/test splits, and tokenized texts (max 1024 tokens) and summaries (max 256 tokens).

**Model Fine-Tuning:**  
- We have fine-tuned the facebook/bart-base model using config3 (LR = 5e-5, batch size = 4, 3 epochs) with a custom trainer.

**Evaluation Metrics (Epoch 3 & Test):**  
  - **Final Training Loss:** 2.43  
  - **Validation Loss:** 2.62  
  - **Rouge-1:** 0.42, **Rouge-2:** 0.14, **Rouge-L:** 0.21  
  - **BLEU:** 17.48  
  - **BERTScore:** 0.89

**Observations:**
- The loss was decreased which means model is learning without overfitting.
- Here, lower ROUGE scores compared to Billsum are due to the Multi-News dataset's higher complexity (longer, multi-document inputs, and diverse writing styles).
- Here, high BLEU and BERTScore shows that, the model still generates coherent and sementically accurate summaries.


## References

- alexfabbri/multinews on Hugging Face: https://huggingface.co/datasets/alexfabbri/multi_news
- Facebook/bart-base on Hugging Face: https://huggingface.co/facebook/bart-base
- Transformers Documentation: https://huggingface.co/docs/transformers/ 
- Datasets Documentation: https://huggingface.co/docs/datasets/
- rouge_score: https://huggingface.co/spaces/evaluate-metric/rouge
- sacreBLEU: https://pypi.org/project/sacreBLEU/
- BLEU: https://huggingface.co/spaces/evaluate-metric/bleu  
- bert_score: https://huggingface.co/spaces/evaluate-metric/bertscore
- Python os module Documentation: https://docs.python.org/3/library/os.html
- Python random module Documentation: https://docs.python.org/3/library/random.html
- NumPy Documentation: https://numpy.org/doc/stable/user/index.html#user
- PyTorch Documentation: https://pytorch.org/docs/stable/index.html
- Matplotlib Documentation: https://matplotlib.org/stable/users/index.html
- Hugging Face Evaluate Documentation: https://huggingface.co/docs/evaluate/
- Gradio: https://www.gradio.app/docs