In [3]:
!pip install evaluate jiwer sacrebleu rich -q

In [4]:
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)
import evaluate 
import jiwer
from tqdm import tqdm
import os
import random
from rich.console import Console
from rich.table import Table

In [5]:
class TrainingConfig:

    MODEL_NAME = "t5-base" 
    DATASET_PATH = "/kaggle/input/ocr-post-correction/combined.csv" 
    INPUT_COLUMN = "input_text"  
    TARGET_COLUMN = "output_text" 
    PREFIX = "correct OCR error: " 
    MODEL_OUTPUT_DIR = "./ocr_correction_model" 
    
    # Training hyperparameters
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 10 
    BATCH_SIZE = 16 
    WEIGHT_DECAY = 0.01
    LOGGING_STEPS = 100
    SAVE_STEPS = 500

In [None]:
def main():
    """
    Main function to run the entire training and evaluation pipeline.
    """
    config = TrainingConfig()

    try:
        df = pd.read_csv(config.DATASET_PATH)
        
        if config.INPUT_COLUMN not in df.columns or config.TARGET_COLUMN not in df.columns:
            print(f"Warning: Columns '{config.INPUT_COLUMN}' and '{config.TARGET_COLUMN}' not found.")
            df.columns = [config.INPUT_COLUMN, config.TARGET_COLUMN]
            print(f"Assigned column names: {df.columns.tolist()}")

        dataset = Dataset.from_pandas(df)
        dataset = dataset.train_test_split(test_size=0.1)
        
        # print(f"Dataset loaded. Training examples: {len(dataset['train'])}, Validation examples: {len(dataset['test'])}")
        # print("Sample data:", dataset['train'][0])

    except FileNotFoundError:
        print(f"Error: The file '{config.DATASET_PATH}' was not found.")
        return

    tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)

    def preprocess_function(examples):
        # adding prefix on each tokens for our custom OCR task
        inputs = [config.PREFIX + doc for doc in examples[config.INPUT_COLUMN]]
        model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(examples[config.TARGET_COLUMN], max_length=128, truncation=True, padding="max_length")
        
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    tokenized_datasets = dataset.map(preprocess_function, batched=True)

    
    # Fine-tuning the T5-Base Seq2Seq Model
    model = AutoModelForSeq2SeqLM.from_pretrained(config.MODEL_NAME)
    
    training_args = Seq2SeqTrainingArguments(
        output_dir=config.MODEL_OUTPUT_DIR,
        num_train_epochs=config.NUM_EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        per_device_eval_batch_size=config.BATCH_SIZE,
        learning_rate=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=config.LOGGING_STEPS,
        predict_with_generate=True,
        fp16=torch.cuda.is_available(),
        push_to_hub=False,
        save_total_limit=2,
        report_to="none",
    )
    
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer, 
        model=model
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()

    trainer.save_model(config.MODEL_OUTPUT_DIR)
    tokenizer.save_pretrained(config.MODEL_OUTPUT_DIR)

if __name__ == "__main__":
    main()

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Map:   0%|          | 0/15820 [00:00<?, ? examples/s]



Map:   0%|          | 0/1758 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  trainer = Seq2SeqTrainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,0.4009,0.310984
2,0.3309,0.261883
3,0.2939,0.235851
4,0.2726,0.21929
5,0.2564,0.206588
6,0.245,0.197185
7,0.2372,0.19148
8,0.2307,0.186337
9,0.2279,0.184093
10,0.2275,0.183071




In [7]:
# !zip -r ocr_correction_model.zip /kaggle/working/ocr_correction_model

In [8]:
!pip install evaluate jiwer sacrebleu rich -q

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [16]:
class EvalConfig:
    MODEL_DIR = "/kaggle/working/ocr_correction_model"
    DATASET_PATH = "/kaggle/input/ocr-post-correction/combined.csv" 
    INPUT_COLUMN = "input_text"
    TARGET_COLUMN = "output_text"
    PREFIX = "correct OCR error: "
    BATCH_SIZE = 16

def main_refined_evaluation():
    config = EvalConfig()
    console = Console()
    
    with console.status("[bold green]Loading model and tokenizer...") as status:
        try:
            model = AutoModelForSeq2SeqLM.from_pretrained(config.MODEL_DIR)
            tokenizer = AutoTokenizer.from_pretrained(config.MODEL_DIR)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            console.log("Model loaded successfully.")
        except OSError:
            console.print(f"Error: Model not found at '{config.MODEL_DIR}'. Please run the training cell first.", style="bold red")
            return

    with console.status("[bold green]Loading and preparing dataset...") as status:
        try:
            full_df = pd.read_csv(config.DATASET_PATH, dtype=str).fillna('')
            test_df = full_df.sample(frac=0.1, random_state=42)
            ocr_texts = test_df[config.INPUT_COLUMN].tolist()
            ground_truth_texts = test_df[config.TARGET_COLUMN].tolist()
        except FileNotFoundError:
            console.print(f"Error: Dataset not found at '{config.DATASET_PATH}'.", style="bold red")
            return
        except KeyError as e:
            console.print(f"Error: Column {e} not found in the dataset.", style="bold red")
            return

    predictions = []
    
    for i in tqdm(range(0, len(ocr_texts), config.BATCH_SIZE), desc="Predicting", leave=False):
        batch_texts = ocr_texts[i:i + config.BATCH_SIZE]
        prefixed_batch = [config.PREFIX + text for text in batch_texts]
        
        inputs = tokenizer(prefixed_batch, return_tensors="pt", padding=True, truncation=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            output_ids = model.generate(**inputs, max_length=128, num_beams=5, early_stopping=True)
        
        batch_preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        predictions.extend(batch_preds)

    with console.status("[bold green]Calculating evaluation metrics...") as status:
        cer_score = jiwer.cer(ground_truth_texts, predictions)
        wer_score = jiwer.wer(ground_truth_texts, predictions)
        bleu_metric = evaluate.load("sacrebleu")
        bleu_score = bleu_metric.compute(predictions=predictions, references=[[ref] for ref in ground_truth_texts])
        exact_matches = sum(1 for pred, ref in zip(predictions, ground_truth_texts) if pred.strip() == ref.strip())
        em_score = exact_matches / len(predictions)

    table = Table(title="Model Evaluation Report", show_header=True, header_style="bold magenta")
    table.add_column("Metric", style="dim", width=25)
    table.add_column("Score", justify="right")
 
    table.add_row("Character Error Rate (CER)", f"{cer_score:.4f}")
    table.add_row("Word Error Rate (WER)", f"{wer_score:.4f}")
    table.add_row("BLEU Score", f"{bleu_score['score']:.2f})")
    table.add_row("Exact Match (EM %)", f"{em_score*100:.2f}%")
    
    console.print(table)
    
    console.print("\ Random Sample Predictions ---", style="bold yellow")
    
    random_indices = random.sample(range(len(predictions)), 3)
    
    for i, index in enumerate(random_indices):
        
        sample_table = Table(title=f"Sample {i+1}", show_header=False, box=None, padding=(0,1))
        sample_table.add_column(width=15)
        sample_table.add_column()
        sample_table.add_row("[bold]Input[/bold]", f": '{ocr_texts[index]}'")
        sample_table.add_row("[bold green]Ground Truth[/bold green]", f": '{ground_truth_texts[index]}'")
        sample_table.add_row("[bold cyan]Model Output[/bold cyan]", f": '{predictions[index]}'")
        console.print(sample_table)
        
main_refined_evaluation()


Output()

Output()

                                                             

Output()

