In [None]:
import torch
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
from datasets import load_dataset

samples_fraction = 0.1


dataset = load_dataset("cnn_dailymail", "3.0.0")
dataset = dataset.rename_columns({"article": "document", "highlights": "summary"})
train_data = dataset["train"].shuffle(seed=42).select(range(int(len(dataset["train"]) * samples_fraction)))
val_data = dataset["validation"].shuffle(seed=42).select(range(int(len(dataset["validation"]) * samples_fraction)))

print(dataset["train"])
print(dataset["validation"][0])

In [None]:
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer


# tokenizer
model_name = "google/pegasus-large"
tokenizer = PegasusTokenizer.from_pretrained(model_name)

# preprocessing 
def preprocess_function(examples):
    inputs = tokenizer(
        examples["document"], 
        padding="max_length", 
        truncation=True, 
        max_length=1024
    )
    labels = tokenizer(
        examples["summary"], 
        padding="max_length", 
        truncation=True, 
        max_length=256
    )
    inputs["labels"] = labels["input_ids"]
    return inputs

tokenized_train = train_data.map(preprocess_function, batched=True, remove_columns=["document", "summary"])
tokenized_val = val_data.map(preprocess_function, batched=True, remove_columns=["document", "summary"])

print(tokenized_train, tokenized_val)

In [None]:
from transformers import PegasusForConditionalGeneration, Trainer, TrainingArguments, BitsAndBytesConfig

model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)  # Use "cpu" if on Mac M1/M2/M3
model.gradient_checkpointing_enable()


# Define training arguments

training_args = TrainingArguments(
    output_dir="./pegasus_finetuned",
    per_device_train_batch_size=1,  
    gradient_accumulation_steps=4,  
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    num_train_epochs=3,
    bf16=True,  # ✅ Use bf16 instead of fp16 on Mac
    save_total_limit=2,  
    logging_dir="./logs",
    logging_steps=100,
    report_to="none"
)


# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
)

# Start training
trainer.train()