# OCR Model Fine-tuning for Marker

This notebook demonstrates how to fine-tune the OCR models used in the Marker project. It covers the complete workflow from data preparation to model evaluation and integration.

## Overview

The Marker project uses Surya's `RecognitionPredictor` for OCR (Optical Character Recognition) to extract text from document images. In this notebook, we'll fine-tune this model on domain-specific data to improve accuracy for specialized documents.

### What we'll cover:
1. Understanding the current OCR pipeline in Marker
2. Preparing data for fine-tuning
3. Setting up the training environment
4. Fine-tuning the OCR model
5. Evaluating the fine-tuned model
6. Integrating the model with Marker

Let's get started!

## 1. Understanding the Current OCR Pipeline

The Marker OCR pipeline consists of several components working together:

### Architecture Overview

1. **Document Loading**: PDF documents are loaded using the `PdfProvider`
2. **Layout Detection**: The layout model identifies regions in the document
3. **OCR Processing**: The OCR model extracts text from these regions
4. **Block Construction**: Text is organized into a hierarchical document structure

The OCR component specifically relies on Surya's `RecognitionPredictor` model, which is initialized in `marker/models.py` and used by the `OcrBuilder` class in `marker/builders/ocr.py`.

In [None]:
# Import necessary libraries
import os
import sys
import json
from pathlib import Path

# Add marker to path for imports
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.insert(0, project_root)

# Import marker modules
try:
    from marker.builders.ocr import OcrBuilder
    from marker.models import create_model_dict
    from marker.providers.pdf import PdfProvider
    from surya.recognition import RecognitionPredictor
except ImportError:
    print("Error: Could not import marker modules. Make sure the project root is correct.")
    raise

### Examining the OCR Components

Let's take a closer look at how OCR works in Marker:

In [None]:
# Initialize model dictionary (with CPU for exploration)
models = create_model_dict(device="cpu")

# Extract the OCR model
recognition_model = models["recognition_model"]
print(f"Recognition model type: {type(recognition_model)}")

# Display model properties
print(f"Model architecture: {recognition_model.model.__class__.__name__}")
print(f"Model parameters: {sum(p.numel() for p in recognition_model.model.parameters() if p.requires_grad):,}")

### OCR Process Flow

Here's how the OCR process works in Marker:

1. The `OcrBuilder` extracts line regions from document pages
2. It processes images for these line regions
3. The `RecognitionPredictor` model converts these line images to text
4. Results are integrated back into the document structure

Let's visualize a sample of this process:

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw

# Function to visualize OCR process
def visualize_ocr_process(pdf_path, page_idx=0):
    # Initialize PDF provider
    provider = PdfProvider(pdf_path)
    
    # Get page image
    page_id = str(page_idx)
    page_image = provider.get_page_image(page_id)
    
    # Get line regions (simulated)
    page_lines = provider.get_page_lines(page_id)
    
    # Display original image
    plt.figure(figsize=(12, 16))
    plt.subplot(1, 2, 1)
    plt.imshow(page_image)
    plt.title("Original Page")
    plt.axis('off')
    
    # Create image with highlighted regions
    highlighted = page_image.copy()
    draw = ImageDraw.Draw(highlighted)
    
    for i, line in enumerate(page_lines):
        # Draw rectangle around line
        if hasattr(line, 'line') and hasattr(line.line, 'polygon'):
            bbox = line.line.polygon.bbox
            draw.rectangle(bbox, outline="red", width=2)
            
            # Add line number
            draw.text((bbox[0], bbox[1]), str(i), fill="blue")
    
    # Display highlighted image
    plt.subplot(1, 2, 2)
    plt.imshow(highlighted)
    plt.title("OCR Line Regions")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Display sample of extracted text
    print("Sample of extracted text:")
    for i, line in enumerate(page_lines[:5]):
        if hasattr(line, 'spans'):
            line_text = "".join([span.text for span in line.spans])
            print(f"Line {i}: {line_text}")

# Uncomment to visualize
# visualize_ocr_process('../data/input/sample.pdf')

## 2. Data Preparation for Fine-tuning

Fine-tuning an OCR model requires domain-specific data. We'll prepare this data in the format expected by the Surya framework.

### Dataset Requirements

The OCR fine-tuning dataset consists of:
- **Line images**: Cropped images of text lines
- **Ground truth text**: Correct transcriptions for each line

We'll use the `prepare_data.py` script to extract this data from PDFs:

In [None]:
# Data preparation parameters
input_dir = "../data/input"  # Directory with PDFs for fine-tuning
output_dir = "../data/ocr_data"  # Output directory for OCR data

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Import prepare_data module
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../scripts')))
from prepare_data import prepare_ocr_data, setup_data_directories

# Setup directories
data_dirs = setup_data_directories(output_dir, mode="ocr")
print(f"Created data directories: {data_dirs}")

In [None]:
# Run data preparation
# Note: This is a placeholder call, the function needs to be properly imported from the scripts directory
# prepare_ocr_data(input_dir, data_dirs)

# Alternative manual preparation code
import shutil
from tqdm.notebook import tqdm

def manually_prepare_ocr_data(input_dir, output_dirs):
    """Manual implementation of OCR data preparation."""
    line_crops_dir = output_dirs["line_crops"]
    transcriptions_dir = output_dirs["transcriptions"]
    splits_dir = output_dirs["splits"]
    
    # List PDF files
    pdf_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.pdf')]
    
    if not pdf_files:
        print(f"No PDF files found in {input_dir}")
        return
    
    all_lines = []
    
    for pdf_file in tqdm(pdf_files, desc="Processing PDFs"):
        pdf_path = os.path.join(input_dir, pdf_file)
        
        try:
            # Initialize PDF provider
            provider = PdfProvider(pdf_path)
            doc_name = Path(pdf_file).stem
            
            # Process each page
            for page_idx in range(provider.num_pages):
                page_id = f"{page_idx}"
                page_image = provider.get_page_image(page_id)
                page_lines = provider.get_page_lines(page_id)
                
                # Process each line
                for line_idx, line in enumerate(page_lines):
                    if not hasattr(line, 'line') or not hasattr(line.line, 'polygon'):
                        continue
                    
                    line_polygon = line.line.polygon
                    
                    # Crop line image
                    line_bbox = line_polygon.bbox
                    line_crop = page_image.crop(line_bbox)
                    
                    # Save line crop
                    line_id = f"{doc_name}_page_{page_idx}_line_{line_idx}"
                    crop_path = os.path.join(line_crops_dir, f"{line_id}.png")
                    line_crop.save(crop_path)
                    
                    # Extract text
                    if hasattr(line, 'spans'):
                        line_text = "".join([span.text for span in line.spans])
                    else:
                        line_text = "" # Placeholder, would need manual correction
                    
                    # Save transcription
                    trans_path = os.path.join(transcriptions_dir, f"{line_id}.txt")
                    with open(trans_path, 'w', encoding='utf-8') as f:
                        f.write(line_text)
                    
                    # Store line information
                    all_lines.append({
                        "doc_name": doc_name,
                        "page_id": page_id,
                        "line_id": line_id,
                        "image_path": crop_path,
                        "text_path": trans_path,
                        "text": line_text,
                    })
        except Exception as e:
            print(f"Error processing {pdf_file}: {e}")
    
    # Create train/val/test splits
    import random
    random.shuffle(all_lines)
    n_lines = len(all_lines)
    
    train_lines = all_lines[:int(n_lines * 0.8)]
    val_lines = all_lines[int(n_lines * 0.8):int(n_lines * 0.9)]
    test_lines = all_lines[int(n_lines * 0.9):]
    
    # Save splits
    for split_name, lines in [("train", train_lines), ("val", val_lines), ("test", test_lines)]:
        split_file = os.path.join(splits_dir, f"{split_name}.json")
        with open(split_file, 'w') as f:
            json.dump({
                "lines": lines
            }, f, indent=2)
    
    print(f"Prepared OCR data: {len(all_lines)} lines")
    print(f"  Train: {len(train_lines)}, Val: {len(val_lines)}, Test: {len(test_lines)}")
    print(f"Line crops saved in: {line_crops_dir}")
    print(f"Transcriptions saved in: {transcriptions_dir}")

# Uncomment the following line to run data preparation
# manually_prepare_ocr_data(input_dir, data_dirs)

### Manual Correction of Ground Truth

For optimal fine-tuning, the ground truth text should be accurate. After running the data preparation script, you should manually check and correct the transcriptions in the `transcriptions_dir`.

Let's create a simple tool to assist with this process:

In [None]:
from IPython.display import display, Image as IPImage
import ipywidgets as widgets

def correction_tool(line_crops_dir, transcriptions_dir, limit=50):
    """Simple tool to help with ground truth correction."""
    # Get list of image files
    image_files = sorted([f for f in os.listdir(line_crops_dir) if f.lower().endswith('.png')])[:limit]
    
    if not image_files:
        print("No image files found in the directory.")
        return
    
    # Selected image index
    current_idx = 0
    
    # Define UI components
    image_widget = widgets.Output()
    text_widget = widgets.Textarea(description="Text:", layout=widgets.Layout(width='80%', height='100px'))
    save_button = widgets.Button(description="Save")
    next_button = widgets.Button(description="Next")
    prev_button = widgets.Button(description="Previous")
    status_widget = widgets.Label("")
    
    def update_display():
        nonlocal current_idx
        image_widget.clear_output()
        
        with image_widget:
            image_path = os.path.join(line_crops_dir, image_files[current_idx])
            display(IPImage(filename=image_path))
            print(f"Image {current_idx + 1} of {len(image_files)}: {image_files[current_idx]}")
        
        # Load text
        base_name = os.path.splitext(image_files[current_idx])[0]
        text_path = os.path.join(transcriptions_dir, f"{base_name}.txt")
        
        if os.path.exists(text_path):
            with open(text_path, 'r', encoding='utf-8') as f:
                text_widget.value = f.read()
        else:
            text_widget.value = ""
    
    def on_save_clicked(b):
        base_name = os.path.splitext(image_files[current_idx])[0]
        text_path = os.path.join(transcriptions_dir, f"{base_name}.txt")
        
        with open(text_path, 'w', encoding='utf-8') as f:
            f.write(text_widget.value)
        
        status_widget.value = f"Saved text for {base_name}"
    
    def on_next_clicked(b):
        nonlocal current_idx
        if current_idx < len(image_files) - 1:
            current_idx += 1
            update_display()
    
    def on_prev_clicked(b):
        nonlocal current_idx
        if current_idx > 0:
            current_idx -= 1
            update_display()
    
    # Connect events
    save_button.on_click(on_save_clicked)
    next_button.on_click(on_next_clicked)
    prev_button.on_click(on_prev_clicked)
    
    # Create layout
    button_box = widgets.HBox([prev_button, save_button, next_button])
    app = widgets.VBox([image_widget, text_widget, button_box, status_widget])
    
    # Initial display
    update_display()
    
    return app

# Uncomment to run the correction tool
# correction_widget = correction_tool(data_dirs["line_crops"], data_dirs["transcriptions"])
# display(correction_widget)

### Analyzing the Prepared Dataset

Let's explore our prepared dataset to understand its characteristics:

In [None]:
import random
from collections import Counter

def analyze_dataset(transcriptions_dir, line_crops_dir, sample_size=10):
    """Analyze the prepared dataset."""
    # Get all transcription files
    text_files = [f for f in os.listdir(transcriptions_dir) if f.lower().endswith('.txt')]
    
    if not text_files:
        print("No transcription files found.")
        return
    
    # Read all text
    all_text = ""
    text_lengths = []
    
    for text_file in text_files:
        text_path = os.path.join(transcriptions_dir, text_file)
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read()
            all_text += text
            text_lengths.append(len(text))
    
    # Analyze characters
    char_counts = Counter(all_text)
    total_chars = len(all_text)
    num_unique_chars = len(char_counts)
    
    # Display statistics
    print(f"Dataset Statistics:")
    print(f"Total lines: {len(text_files)}")
    print(f"Total characters: {total_chars}")
    print(f"Unique characters: {num_unique_chars}")
    print(f"Average line length: {sum(text_lengths) / len(text_lengths):.2f} chars")
    print(f"Shortest line: {min(text_lengths)} chars")
    print(f"Longest line: {max(text_lengths)} chars")
    
    # Display most common characters
    print("\nMost common characters:")
    for char, count in char_counts.most_common(20):
        print(f"'{char}': {count} ({count/total_chars*100:.2f}%)")
    
    # Display sample lines with images
    print("\nRandom samples:")
    sample_files = random.sample(text_files, min(sample_size, len(text_files)))
    
    plt.figure(figsize=(15, sample_size * 2))
    
    for i, text_file in enumerate(sample_files):
        base_name = os.path.splitext(text_file)[0]
        image_path = os.path.join(line_crops_dir, f"{base_name}.png")
        text_path = os.path.join(transcriptions_dir, text_file)
        
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        if os.path.exists(image_path):
            image = Image.open(image_path)
            plt.subplot(sample_size, 1, i+1)
            plt.imshow(image)
            plt.title(f"Text: {text}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Uncomment to analyze the dataset
# analyze_dataset(data_dirs["transcriptions"], data_dirs["line_crops"])

## 3. Setting Up the Training Environment

Now that we have prepared our dataset, let's set up the environment for fine-tuning the OCR model. We'll use Unsloth for efficient 4-bit QLoRA fine-tuning.

### Dependencies Installation

First, let's install the required dependencies:

In [None]:
# Install required packages
!pip install unsloth
!pip install huggingface_hub
!pip install datasets
!pip install accelerate
!pip install bitsandbytes
!pip install wandb

### Loading the Dataset

Let's load our prepared dataset using the utility functions from Marker's fine-tuning tools:

In [None]:
# Import dataset utility
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../unsloth')))
from utils import prepare_ocr_dataset

# Initialize OCR model to get tokenizer
recognition_model = RecognitionPredictor(device="cpu")
tokenizer = recognition_model.model.tokenizer

# Load dataset
dataset = prepare_ocr_dataset(
    data_dir=output_dir,
    tokenizer=tokenizer,
    max_length=512
)

# Display dataset info
for split in dataset:
    print(f"Split {split}: {len(dataset[split])} examples")
    
# Display sample
if 'train' in dataset and len(dataset['train']) > 0:
    sample = dataset['train'][0]
    print("\nSample example:")
    print(f"Line ID: {sample['line_id']}")
    print(f"Text: {sample['text']}")
    if 'input_ids' in sample:
        print(f"Input IDs shape: {len(sample['input_ids'])}")
    
    # Display image
    plt.figure(figsize=(10, 2))
    plt.imshow(sample['image'])
    plt.title(sample['text'])
    plt.axis('off')
    plt.show()

### Setting up Fine-tuning Configuration

Now, let's configure the fine-tuning process using Unsloth's optimizations:

In [None]:
# Import adapters and trainers
from adapters import get_ocr_lora_config, create_qlora_model
from trainers import SuryaTrainingArguments, SuryaOCRTrainer

# Setup training parameters
output_model_dir = "../models/custom_ocr_model"
os.makedirs(output_model_dir, exist_ok=True)

# Configuration
training_config = {
    "batch_size": 8,
    "learning_rate": 5e-5,
    "num_train_epochs": 3,
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "use_wandb": False,  # Set to True if you want to use W&B for tracking
}

### Initializing the OCR Model

Let's initialize our OCR model with LoRA adapters for fine-tuning:

In [None]:
# Configure LoRA
lora_config = get_ocr_lora_config(
    r=training_config["lora_r"],
    lora_alpha=training_config["lora_alpha"],
    lora_dropout=training_config["lora_dropout"],
)

# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

try:
    # Try to use Unsloth for optimized QLoRA
    from unsloth import FastLanguageModel
    
    # Quantization config
    quantization_config = {
        "load_in_4bit": True,
        "bnb_4bit_compute_dtype": "float16",
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_use_double_quant": True,
    }
    
    # Create QLoRA model
    model, training_args = create_qlora_model(
        base_model=recognition_model.model,
        lora_config=lora_config,
        quantization_config=quantization_config
    )
    print("Successfully initialized model with Unsloth optimizations")
    
except ImportError:
    print("Unsloth not available. Using standard fine-tuning.")
    # Standard fine-tuning (without Unsloth)
    from peft import get_peft_model
    
    # Move model to target device
    model = recognition_model.model.to(device)
    
    # Add LoRA adapters
    model = get_peft_model(model, lora_config)
    
    # Default training args
    training_args = {
        "learning_rate": training_config["learning_rate"],
        "num_train_epochs": training_config["num_train_epochs"],
        "per_device_train_batch_size": training_config["batch_size"],
        "gradient_accumulation_steps": 4,
    }

## 4. Fine-tuning the OCR Model

Now we're ready to fine-tune the OCR model on our custom dataset.

In [None]:
# Set up training arguments
training_args = SuryaTrainingArguments(
    output_dir=output_model_dir,
    model_type="ocr",
    per_device_train_batch_size=training_config["batch_size"],
    per_device_eval_batch_size=training_config["batch_size"],
    gradient_accumulation_steps=4,
    learning_rate=training_config["learning_rate"],
    num_train_epochs=training_config["num_train_epochs"],
    weight_decay=0.01,
    logging_dir=os.path.join(output_model_dir, "logs"),
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    push_to_hub=False,
    remove_unused_columns=False,
    report_to="wandb" if training_config["use_wandb"] else "none",
    save_pretrained_merged=True,
    early_stopping_patience=3,
    fp16=True,
)

In [None]:
# Create trainer
trainer = SuryaOCRTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"] if "val" in dataset else None,
)

# Start training
print("Starting fine-tuning...")
trainer.train()

### Saving the Fine-tuned Model

Let's save our fine-tuned model:

In [None]:
# Save the final model
trainer.save_model(output_model_dir)
print(f"Model saved to {output_model_dir}")

# If you want to save the merged model (adapter + base model)
merged_dir = os.path.join(output_model_dir, "merged")
os.makedirs(merged_dir, exist_ok=True)
    
# Merge adapter with base model
if hasattr(model, "merge_and_unload"):
    merged_model = model.merge_and_unload()
    merged_model.save_pretrained(merged_dir)
    print(f"Merged model saved to {merged_dir}")

## 5. Evaluating the Fine-tuned Model

Now let's evaluate our model on the test set to see how it performs:

In [None]:
# Evaluate model on test set
if "test" in dataset:
    test_results = trainer.evaluate(dataset["test"])
    print("Test Results:")
    for key, value in test_results.items():
        print(f"{key}: {value}")
    
    # Save test results
    results_path = os.path.join(output_model_dir, "test_results.json")
    with open(results_path, 'w') as f:
        json.dump(test_results, f, indent=2)

### Comparing Original vs. Fine-tuned Model

Let's compare the performance of our fine-tuned model with the original model:

In [None]:
from Levenshtein import distance as levenshtein_distance

def compare_models(original_model, finetuned_model_path, test_samples):
    """Compare original and fine-tuned models."""
    # Load fine-tuned model
    finetuned_model = RecognitionPredictor(model_path=finetuned_model_path)
    
    # Prepare sample images
    sample_images = []
    ground_truths = []
    
    for sample in test_samples:
        image_path = sample["image_path"]
        text_path = sample["text_path"]
        
        # Load image
        image = Image.open(image_path)
        sample_images.append(image)
        
        # Load ground truth
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read().strip()
            ground_truths.append(text)
    
    # Run inference with original model
    original_results = original_model(images=[sample_images], sort_lines=False)[0]
    original_texts = [line.text for line in original_results.text_lines]
    
    # Run inference with fine-tuned model
    finetuned_results = finetuned_model(images=[sample_images], sort_lines=False)[0]
    finetuned_texts = [line.text for line in finetuned_results.text_lines]
    
    # Calculate metrics
    original_distances = [levenshtein_distance(pred, gt) for pred, gt in zip(original_texts, ground_truths)]
    finetuned_distances = [levenshtein_distance(pred, gt) for pred, gt in zip(finetuned_texts, ground_truths)]
    
    original_accuracy = sum(1 for d in original_distances if d == 0) / len(original_distances) if original_distances else 0
    finetuned_accuracy = sum(1 for d in finetuned_distances if d == 0) / len(finetuned_distances) if finetuned_distances else 0
    
    original_avg_distance = sum(original_distances) / len(original_distances) if original_distances else 0
    finetuned_avg_distance = sum(finetuned_distances) / len(finetuned_distances) if finetuned_distances else 0
    
    # Display results
    print("Model Comparison:")
    print(f"Original model - Exact match accuracy: {original_accuracy:.2%}, Avg Levenshtein distance: {original_avg_distance:.2f}")
    print(f"Fine-tuned model - Exact match accuracy: {finetuned_accuracy:.2%}, Avg Levenshtein distance: {finetuned_avg_distance:.2f}")
    
    # Visualize samples
    num_samples = min(5, len(sample_images))
    plt.figure(figsize=(15, num_samples * 3))
    
    for i in range(num_samples):
        plt.subplot(num_samples, 1, i+1)
        plt.imshow(sample_images[i])
        plt.title(f"Ground truth: {ground_truths[i]}\nOriginal: {original_texts[i]}\nFine-tuned: {finetuned_texts[i]}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Get test samples
# test_samples = dataset["test"][:20] if "test" in dataset and len(dataset["test"]) > 0 else []

# Uncomment to run comparison
# compare_models(recognition_model, os.path.join(output_model_dir, "merged"), test_samples)

## 6. Integrating with Marker

Finally, let's see how to integrate our fine-tuned model back into the Marker pipeline:

In [None]:
from marker.models import create_model_dict
from marker.converters.pdf import PdfConverter
from marker.config.parser import ParserConfig

def integrate_custom_model(custom_model_path, pdf_path, output_dir="../output"):
    """Integrate custom model with Marker."""
    # Create configuration
    config = ParserConfig(
        model_list=["surya_det", "surya_rec"],
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    
    # Create models with custom OCR model
    models = create_model_dict(
        device=config.device,
        dtype=None,
        custom_model_paths={
            "recognition_model": custom_model_path
        }
    )
    
    # Create converter
    converter = PdfConverter(
        artifact_dict=models,
        config=config
    )
    
    # Process PDF
    print(f"Processing {pdf_path} with custom OCR model...")
    result = converter(pdf_path)
    
    # Save output
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, os.path.basename(pdf_path).replace(".pdf", ".md"))
    
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(result)
    
    print(f"Output saved to {output_file}")
    return output_file

# Example usage
# sample_pdf = "../data/input/sample.pdf"
# custom_model_path = os.path.join(output_model_dir, "merged")
# output_file = integrate_custom_model(custom_model_path, sample_pdf)

### Comparing Original vs. Custom Model Output

Let's compare the output of the original and fine-tuned models on a sample document:

In [None]:
def compare_outputs(original_output, custom_output):
    """Compare outputs from original and custom models."""
    # Read outputs
    with open(original_output, 'r', encoding='utf-8') as f:
        original_text = f.read()
    
    with open(custom_output, 'r', encoding='utf-8') as f:
        custom_text = f.read()
    
    # Compare text
    import difflib
    from IPython.display import HTML
    
    differ = difflib.HtmlDiff()
    html_diff = differ.make_file(original_text.splitlines(), custom_text.splitlines(),
                                 "Original Model Output", "Fine-tuned Model Output")
    
    # Display diff
    return HTML(html_diff)

# Example usage
# original_output = "../output/sample_original.md"
# custom_output = "../output/sample.md"
# diff_view = compare_outputs(original_output, custom_output)
# display(diff_view)

## Conclusion

In this notebook, we've covered the complete process of fine-tuning an OCR model for the Marker project:

1. We explored the current OCR pipeline and its components
2. We prepared domain-specific data for fine-tuning
3. We set up the training environment with Unsloth optimizations
4. We fine-tuned the OCR model using QLoRA
5. We evaluated the model and compared it with the original
6. We integrated the fine-tuned model back into Marker

### Next Steps

To further improve OCR performance, consider:

1. **Data augmentation**: Apply transforms like rotation, scaling, and noise to training data
2. **Hyperparameter tuning**: Experiment with learning rates, batch sizes, and LoRA parameters
3. **Domain-specific training**: Fine-tune on more documents from your specific domain
4. **Ensemble methods**: Combine predictions from multiple models
5. **Post-processing**: Add domain-specific correction rules for common errors

By fine-tuning the OCR model for your specific use case, you can significantly improve the accuracy of text extraction in the Marker pipeline.