In [None]:
!rm /opt/conda/lib/python3.10/site-packages/aiohttp-3.9.1.dist-info -rdf

In [None]:
!pip install rouge_score evaluate transformers[torch] 'accelerate>=0.26.0' -U

In [None]:
import torch
import numpy as np

import nltk

import transformers
from datasets import load_dataset
import evaluate

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

## Data preporcessing

In [None]:
'''@misc{alex2019multinews,
    title={Multi-News: a Large-Scale Multi-Document Summarization Dataset and Abstractive Hierarchical Model},
    author={Alexander R. Fabbri and Irene Li and Tianwei She and Suyi Li and Dragomir R. Radev},
    year={2019},
    eprint={1906.01749},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}'''


ds = load_dataset("Awesome075/multi_news_parquet") # This is the same to original Multi-News dataset, it is repackaged to be loaded in the easy way


In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Example for a summarization model:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model     = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

sample = ds["train"][0]
print(f"The sample: {sample.keys()}")
document_text = sample["document"]
# Then you can tokenize text:
inputs = tokenizer(
    document_text,
    max_length=1024,
    truncation=True,
    return_tensors="pt",
)
summary_ids = model.generate(
    inputs["input_ids"],
    max_length=150,
    min_length=40,
    length_penalty=2.0,
    num_beams=4,
    early_stopping=True
)
# And generate summary:
generated_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)


print("-" * 50)
print("Original:\n", document_text[:300], "...")
print("-" * 50)
print("Generated:\n", generated_summary)
print("-" * 50)
print("Reference:\n", sample["summary"])
print("-" * 50)


In [None]:
def preprocess_function(examples):
    # Process inputs: The 'document' field contains the source text
    model_inputs = tokenizer(
        examples["document"],
        max_length=1024,
        truncation=True
    )

    # Process targets: The 'summary' field contains the reference summary
    labels = tokenizer(
        text_target=examples["summary"],
        max_length=128,   # Increase this (e.g., to 256) if summaries are long
        truncation=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply the preprocessing function to the entire dataset
# 'batched=True' enables batch processing for faster execution
tokenized_datasets = ds.map(preprocess_function, batched=True)

## Metrics

In [None]:
nltk.download('punkt', quiet=True)
metric = evaluate.load('rouge')

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    if isinstance(preds, tuple):
        preds = preds[0]

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return result


'''
    if isinstance(preds, np.ndarray) and np.issubdtype(preds.dtype, np.floating):
        preds = np.argmax(preds, axis=-1)
'''

## Model

In [None]:
# Clear up memory before training
import torch
import gc

del trainer
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

# 1. Load the model
# Using 'facebook/bart-large-cnn' as it is a standard strong baseline for summarization
model_checkpoint = "facebook/bart-large-cnn"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# 2. Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# 3. Define Training Arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./bart-large-multi-news",
    eval_strategy="no",
    save_strategy="epoch",
    load_best_model_at_end=True, # Should change to True
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    # Regularization
    weight_decay=0.01,
    save_total_limit=2,           # Only keep the last 2 checkpoints to save disk space
    # Training duration
    num_train_epochs=3,           # 3 epochs is usually a good starting point for summarization
    # Optimization
    fp16=True,                    # Enable mixed precision training (saves memory and speeds up training)
    # Evaluation configuration
    predict_with_generate=True,   # Essential for computing ROUGE scores during evaluation
    # Logging
    logging_dir="./logs",
    logging_steps=50,
    report_to="none"       # Or "tensorboard" if you don't want to log to TensorBoard
)

debug_args = Seq2SeqTrainingArguments(
    output_dir="./debug_output",
    max_steps=10,
    eval_steps=5,
    save_steps=5,
    logging_steps=1,
    eval_strategy="steps",
    save_strategy="steps",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=2e-5,
    load_best_model_at_end=True,
    predict_with_generate=True,
    report_to="none",
)

In [None]:
train_dataset=tokenized_datasets["train"]
eval_dataset=tokenized_datasets["validation"]
small_eval_dataset = eval_dataset.select(range(20))

# Debug training
# trainer = Seq2SeqTrainer(
#     model=model,
#     args=debug_args,
#     train_dataset=train_dataset,
#     eval_dataset=small_eval_dataset,
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics
# )


# Full training
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
import nltk

nltk.download('punkt')
nltk.download('punkt_tab')

In [None]:
trainer.train()

## One sample prediction

In [None]:
from huggingface_hub import login
login()

In [None]:
trainer.push_to_hub("finetuned-bart-model-1")
tokenizer.push_to_hub("finetuned-bart-model-1")

In [None]:
text_example = ds["train"][0]["document"]
print(text_example)

In [None]:
input_ids = tokenizer.encode(
    text_example,
    return_tensors="pt",
    max_length=1024,
    truncation=True,
).to(device)

In [None]:
input_ids.shape

In [None]:
summary_text_ids = model.generate(
    input_ids=input_ids,
    bos_token_id=model.config.bos_token_id,
    eos_token_id=model.config.eos_token_id,
    max_length=142,
    min_length=56,
    num_beams=4,
)

In [None]:
decoded_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
print(decoded_text)