In [30]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    pipeline,
)
import evaluate
import numpy as np

# Load Ukrainian dataset
dataset = load_dataset("csebuetnlp/xlsum", "ukrainian")

# Load the tokenizer and model
checkpoint = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [31]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'url', 'title', 'summary', 'text'],
        num_rows: 43201
    })
    test: Dataset({
        features: ['id', 'url', 'title', 'summary', 'text'],
        num_rows: 5399
    })
    validation: Dataset({
        features: ['id', 'url', 'title', 'summary', 'text'],
        num_rows: 5399
    })
})

In [32]:
def count_tokens(dataset):
    # Tokenize the text
    tokenized_text = tokenizer(dataset["text"]) 
    tokenized_summary = tokenizer(dataset["summary"])
    
    # Update the example with the token count
    dataset["article_token_count"] = len(tokenized_text["input_ids"])
    dataset["summary_token_count"] = len(tokenized_summary["input_ids"])
    
    return dataset

# Apply the function and calculate the mean
token_counts = dataset.map(count_tokens, batched=False)

Map:   0%|          | 0/43201 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

article_token_counts = token_counts['train']["article_token_count"]
summary_token_counts = token_counts['train']["summary_token_count"]

# Create a histogram for article token counts
plt.figure(figsize=(6, 6))
plt.hist(article_token_counts, bins=50, color='skyblue', edgecolor='black', range = (0, 2048))
plt.xlabel("Article Token Count")
plt.ylabel("Frequency")
plt.title("Distribution of Article Token Counts")
plt.grid(True)
plt.show()

# Create a histogram for summary token counts
plt.figure(figsize=(6, 6))
plt.hist(summary_token_counts, bins=50, color='salmon', edgecolor='black', range = (0, 128))
plt.xlabel("Summary Token Count")
plt.ylabel("Frequency")
plt.title("Distribution of Summary Token Counts")
plt.grid(True)
plt.show()

In [None]:
# Function to count tokens in the dataset
def count_tokens(example):
    tokenized_text = tokenizer(example["text"])
    tokenized_summary = tokenizer(example["summary"])
    return {
        "article_token_count": len(tokenized_text["input_ids"]),
        "summary_token_count": len(tokenized_summary["input_ids"]),
    }

# Apply the function to the dataset
token_counts = dataset.map(count_tokens, batched=False)



In [None]:
# Function to compute mean token count
def compute_mean_token_count(token_counts, field_name):
    total_token_count = sum(token_counts[field_name])
    return total_token_count / len(token_counts)

# Function to compute median token count
def compute_median_token_count(token_counts, field_name):
    token_counts_array = np.array(token_counts[field_name])
    return np.median(token_counts_array)

# Calculate and print mean and median token count for each split
for split in token_counts:
    mean_count_art = compute_mean_token_count(token_counts[split], "article_token_count")
    mean_count_sum = compute_mean_token_count(token_counts[split], "summary_token_count")
    median_count_art = compute_median_token_count(token_counts[split], "article_token_count")
    median_count_sum = compute_median_token_count(token_counts[split], "summary_token_count")
    print(f"Mean token count for {split}: {mean_count_art} | {mean_count_sum}")
    print(f"Median token count for {split}: {median_count_art} | {median_count_sum}")

In [None]:
# Preprocessing function for the dataset
prefix = "summarize: "
postfix = " </s>"
def preprocess_function(examples):
    inputs = [prefix +  doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    outputs = [doc + postfix for doc in examples["summary"]]
    labels = tokenizer(outputs, max_length=64, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# Tokenize the dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
import evaluate
import numpy as np

rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

In [36]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="custom_mt5_model_uk",
    evaluation_strategy="epoch",
    learning_rate=2e-3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=8,
    predict_with_generate=True,
    generation_max_length=64,
    bf16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)



In [None]:
evaluation_results = trainer.evaluate()
print("Evaluation results:", evaluation_results)

In [34]:
import torch
torch.cuda.empty_cache()

In [37]:
trainer.train(resume_from_checkpoint=True)

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
6,1.8695,2.185976,0.0361,0.0058,0.036,0.036,40.5475
7,1.7034,2.194303,0.0365,0.0066,0.0363,0.0362,38.8611
8,1.5745,2.224762,0.0381,0.0063,0.0378,0.0377,38.8163


TrainOutput(global_step=43208, training_loss=0.5334326580395811, metrics={'train_runtime': 6929.4289, 'train_samples_per_second': 49.875, 'train_steps_per_second': 6.235, 'total_flos': 1.8270586080531456e+17, 'train_loss': 0.5334326580395811, 'epoch': 8.0})