In [1]:
!rm /opt/conda/lib/python3.10/site-packages/aiohttp-3.9.1.dist-info -rdf

In [2]:
!pip install rouge_score evaluate transformers[torch] 'accelerate>=0.26.0' -U

Defaulting to user installation because normal site-packages is not writeable


In [3]:
import torch
import numpy as np

import nltk

import transformers
from datasets import load_dataset
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Data preporcessing

In [14]:
'''@misc{alex2019multinews,
    title={Multi-News: a Large-Scale Multi-Document Summarization Dataset and Abstractive Hierarchical Model},
    author={Alexander R. Fabbri and Irene Li and Tianwei She and Suyi Li and Dragomir R. Radev},
    year={2019},
    eprint={1906.01749},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}'''


ds = load_dataset("Awesome075/multi_news_parquet") # This is the same to original Multi-News dataset, it is repackaged to be loaded in the easy way


In [16]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Example for a summarization model:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model     = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

sample = ds["train"][0]
print(f"The sample: {sample.keys()}")
document_text = sample["document"]
# Then you can tokenize text:
inputs = tokenizer(
    document_text,
    max_length=1024,
    truncation=True,
    return_tensors="pt",
)
summary_ids = model.generate(
    inputs["input_ids"], 
    max_length=150,
    min_length=40, 
    length_penalty=2.0,
    num_beams=4,
    early_stopping=True
)
# And generate summary:
generated_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)


print("-" * 50)
print("Original:\n", document_text[:300], "...")
print("-" * 50)
print("Generated:\n", generated_summary)
print("-" * 50)
print("Reference:\n", sample["summary"])
print("-" * 50)


The sample: dict_keys(['document', 'summary'])
--------------------------------------------------
Original:
 National Archives 
 
 Yes, it’s that time again, folks. It’s the first Friday of the month, when for one ever-so-brief moment the interests of Wall Street, Washington and Main Street are all aligned on one thing: Jobs. 
 
 A fresh update on the U.S. employment situation for January hits the wires at ...
--------------------------------------------------
Generated:
 A fresh update on the U.S. employment situation for January hits the wires at 8:30 a.m. New York time. Expectations are for 203,000 new jobs to be created, according to economists polled by Dow Jones Newswires. The unemployment rate is expected to hold steady at 8.3%.
--------------------------------------------------
Reference:
 – The unemployment rate dropped to 8.2% last month, but the economy only added 120,000 jobs, when 203,000 new jobs had been predicted, according to today's jobs report. Reaction on the Wall S

In [17]:
def preprocess_function(examples):
    # Process inputs: The 'document' field contains the source text
    model_inputs = tokenizer(
        examples["document"], 
        max_length=1024, 
        truncation=True
    )

    # Process targets: The 'summary' field contains the reference summary
    labels = tokenizer(
        text_target=examples["summary"], 
        max_length=128,   # Increase this (e.g., to 256) if summaries are long
        truncation=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply the preprocessing function to the entire dataset
# 'batched=True' enables batch processing for faster execution
tokenized_datasets = ds.map(preprocess_function, batched=True)

Map: 100%|██████████| 5622/5622 [00:19<00:00, 287.28 examples/s]
Map: 100%|██████████| 5622/5622 [00:20<00:00, 274.10 examples/s]


## Metrics

In [8]:
nltk.download('punkt', quiet=True)
metric = evaluate.load('rouge')

In [42]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    if isinstance(preds, tuple):
        preds = preds[0]
    
    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return result


'''
    if isinstance(preds, np.ndarray) and np.issubdtype(preds.dtype, np.floating):
        preds = np.argmax(preds, axis=-1)
'''

'\n    if isinstance(preds, np.ndarray) and np.issubdtype(preds.dtype, np.floating):\n        preds = np.argmax(preds, axis=-1)\n'

## Model

In [43]:
# Clear up memory before training
import torch
import gc

del trainer
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

# 1. Load the model
# Using 'facebook/bart-large-cnn' as it is a standard strong baseline for summarization
model_checkpoint = "facebook/bart-large-cnn"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# 2. Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# 3. Define Training Arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./bart-large-multi-news",
    eval_strategy="steps", 
    eval_steps=500,       
    save_strategy="steps",
    save_steps=500,
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    # Regularization
    weight_decay=0.01,
    save_total_limit=2,           # Only keep the last 2 checkpoints to save disk space
    # Training duration
    num_train_epochs=3,           # 3 epochs is usually a good starting point for summarization
    # Optimization
    fp16=True,                    # Enable mixed precision training (saves memory and speeds up training)
    # Evaluation configuration
    predict_with_generate=True,   # Essential for computing ROUGE scores during evaluation
    # Logging
    logging_dir="./logs",
    logging_steps=50,
    report_to="none"       # Or "tensorboard" if you don't want to log to TensorBoard
)

debug_args = Seq2SeqTrainingArguments(
    output_dir="./debug_output",
    max_steps=10,
    eval_steps=5,
    save_steps=5,
    logging_steps=1,
    eval_strategy="steps",
    save_strategy="steps",
    per_device_train_batch_size=2, 
    per_device_eval_batch_size=2,
    learning_rate=2e-5,
    load_best_model_at_end=True,
    predict_with_generate=True,
    report_to="none",
)

In [45]:
train_dataset=tokenized_datasets["train"]
eval_dataset=tokenized_datasets["validation"]
small_eval_dataset = eval_dataset.select(range(20))

# Debug training
trainer = Seq2SeqTrainer(
    model=model,
    args=debug_args,
    train_dataset=train_dataset,
    eval_dataset=small_eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

'''
# Full training
trainer = Seq2SeqTrainer(
    model=model,
    args=debug_args,
    train_dataset=train_dataset,
    eval_dataset=small_eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)'''

  trainer = Seq2SeqTrainer(


'\n# Full training\ntrainer = Seq2SeqTrainer(\n    model=model,\n    args=debug_args,\n    train_dataset=train_dataset,\n    eval_dataset=small_eval_dataset,\n    tokenizer=tokenizer,\n    data_collator=data_collator,\n    compute_metrics=compute_metrics\n)'

In [27]:
import nltk

nltk.download('punkt')
nltk.download('punkt_tab') 

[nltk_data] Downloading package punkt to
[nltk_data]     /home/yf2782_columbia_edu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/yf2782_columbia_edu/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [46]:
trainer.train()

Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
5,2.787,2.798402,0.333061,0.095875,0.18169,0.281439
10,3.3357,2.72598,0.341597,0.093564,0.182849,0.286182


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=10, training_loss=2.6510828852653505, metrics={'train_runtime': 169.1816, 'train_samples_per_second': 0.118, 'train_steps_per_second': 0.059, 'total_flos': 42876503162880.0, 'train_loss': 2.6510828852653505, 'epoch': 0.00044472115983278484})

## One sample prediction

In [48]:
text_example = ds["train"][0]["document"]
print(text_example)

National Archives 
 
 Yes, it’s that time again, folks. It’s the first Friday of the month, when for one ever-so-brief moment the interests of Wall Street, Washington and Main Street are all aligned on one thing: Jobs. 
 
 A fresh update on the U.S. employment situation for January hits the wires at 8:30 a.m. New York time offering one of the most important snapshots on how the economy fared during the previous month. Expectations are for 203,000 new jobs to be created, according to economists polled by Dow Jones Newswires, compared to 227,000 jobs added in February. The unemployment rate is expected to hold steady at 8.3%. 
 
 Here at MarketBeat HQ, we’ll be offering color commentary before and after the data crosses the wires. Feel free to weigh-in yourself, via the comments section. And while you’re here, why don’t you sign up to follow us on Twitter. 
 
 Enjoy the show. ||||| Employers pulled back sharply on hiring last month, a reminder that the U.S. economy may not be growing fas

In [49]:
input_ids = tokenizer.encode(
    text_example,
    return_tensors="pt",
    max_length=1024,
    truncation=True,
).to(device)

In [50]:
input_ids.shape

torch.Size([1, 396])

In [51]:
summary_text_ids = model.generate(
    input_ids=input_ids,
    bos_token_id=model.config.bos_token_id,
    eos_token_id=model.config.eos_token_id,
    max_length=142,
    min_length=56,
    num_beams=4,
)

In [52]:
decoded_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
print(decoded_text)

The U.S. employment situation for January hits the wires at 8:30 a.m. New York time on Friday. Expectations are for 203,000 new jobs to be created, according to economists polled by Dow Jones Newswires, compared to 227,000 jobs added in February. The unemployment rate is expected to hold steady at 8.3%.
