In [None]:
# Load model directly
from transformers import AutoTokenizer, BartForConditionalGeneration

checkpoint = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = BartForConditionalGeneration.from_pretrained(checkpoint)

In [None]:
from datasets import load_dataset
dataset = load_dataset("ccdv/cnn_dailymail",'3.0.0', trust_remote_code=True );

In [None]:
dataset

In [None]:
# Function to count tokens in the dataset
def count_tokens(example):
    tokenized_text = tokenizer(example["article"])
    tokenized_summary = tokenizer(example["highlights"])
    return {
        "article_token_count": len(tokenized_text["input_ids"]),
        "summary_token_count": len(tokenized_summary["input_ids"]),
    }

token_counts = dataset.map(count_tokens, batched=False)


In [None]:
import numpy as np
# Compute mean token count
def compute_mean_token_count(token_counts, field_name):
    total_token_count = sum(token_counts[field_name])
    return total_token_count / len(token_counts)

# Compute median token count
def compute_median_token_count(token_counts, field_name):
    token_counts_array = np.array(token_counts[field_name])
    return np.median(token_counts_array)

# Calculate and print mean and median token count for each split
for split in token_counts:
    mean_count_art = compute_mean_token_count(token_counts[split], "article_token_count")
    mean_count_sum = compute_mean_token_count(token_counts[split], "summary_token_count")
    median_count_art = compute_median_token_count(token_counts[split], "article_token_count")
    median_count_sum = compute_median_token_count(token_counts[split], "summary_token_count")
    print(f"Mean token count for {split}: {mean_count_art} | {mean_count_sum}")
    print(f"Median token count for {split}: {median_count_art} | {median_count_sum}")

In [None]:
import matplotlib.pyplot as plt

article_token_counts = token_counts['train']["article_token_count"]
summary_token_counts = token_counts['train']["summary_token_count"]

# Create a histogram for article token counts
plt.figure(figsize=(6, 6))
plt.hist(article_token_counts, bins=50, color='skyblue', edgecolor='black', range = (0, 2048))
plt.xlabel("Article Token Count")
plt.ylabel("Frequency")
plt.title("Distribution of Article Token Counts")
plt.grid(True)
plt.show()

# Create a histogram for summary token counts
plt.figure(figsize=(6, 6))
plt.hist(summary_token_counts, bins=50, color='salmon', edgecolor='black', range = (0, 128))
plt.xlabel("Summary Token Count")
plt.ylabel("Frequency")
plt.title("Distribution of Summary Token Counts")
plt.grid(True)
plt.show()

In [None]:
prefix = "summarize: "
postfix = " </s>"

input_length = 512
output_length = 70

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=input_length, truncation=True)

    labels = tokenizer(text_target=examples["highlights"], max_length=output_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
import evaluate
import numpy as np

rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    #Use when overriding max output length
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v * 100, 4) for k, v in result.items()}

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer


num_epochs = 4
batch_size = 8
learining_rate = 2e-3
num_beams = 4

training_args = Seq2SeqTrainingArguments(
    output_dir="custom_t5_model_en_cnn",
    evaluation_strategy="epoch",
    learning_rate=learining_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_epochs,
    predict_with_generate=True,
    generation_max_length=output_length,
    bf16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train(resume_from_checkpoint=True)