In [None]:
!pip install transformers[torch] datasets pandas sentencepiece evaluate

In [None]:
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
)
from evaluate import load

In [None]:
class Config:
    MODEL_NAME = "google/mt5-base"
    DATASET_PATH = "/kaggle/input/ocr-post-correction-dev/training_sample_50k.csv"
    INPUT_COLUMN = "input_text"
    TARGET_COLUMN = "output_text"
    PREFIX = "correct OCR error: "
    MODEL_OUTPUT_DIR = "./ocr_devanagari_mt5base"
    BATCH_SIZE = 4
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 4
    WEIGHT_DECAY = 0.01
    LOGGING_STEPS = 100
    EARLY_STOPPING_PATIENCE = 2
    GRAD_ACCUM_STEPS = 4

In [None]:
config = Config()
df = pd.read_csv(config.DATASET_PATH)
dataset = Dataset.from_pandas(df).train_test_split(test_size=0.1)

tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME, use_fast=False)

def preprocess_function(examples):
    inputs = [config.PREFIX + doc for doc in examples[config.INPUT_COLUMN]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples[config.TARGET_COLUMN], max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(config.MODEL_NAME)

rouge = load("rouge")
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return {"rougeL": result["rougeL"]}

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=config.MODEL_OUTPUT_DIR,
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE,
    learning_rate=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=config.LOGGING_STEPS,
    save_total_limit=2,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    report_to="none",
    gradient_accumulation_steps=config.GRAD_ACCUM_STEPS
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=config.EARLY_STOPPING_PATIENCE)]
)

In [None]:
trainer.train()
trainer.save_model(config.MODEL_OUTPUT_DIR)
tokenizer.save_pretrained(config.MODEL_OUTPUT_DIR)