<a href="https://colab.research.google.com/github/fabiodr/colabs/blob/main/finetune_LightOnOCR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune LightOnOCR on OCR Tasks

In this notebook, we will fine-tune [LightOnOCR-1B](https://huggingface.co/lightonai/LightOnOCR-1B-1025) on a custom OCR dataset. LightOnOCR is a vision-language model specifically designed for OCR tasks.

## Why Fine-tune LightOnOCR? üéØ

LightOnOCR is an **end-to-end trainable model**, making it incredibly flexible for adaptation to specific use cases. Unlike traditional OCR pipelines that require complex multi-stage processing, LightOnOCR can be easily fine-tuned on:

- **Specific domains** üìÑ ‚Äì Medical records, legal documents, receipts, forms, etc.
- **Different languages** üåç ‚Äì Enhance performance on low-resource languages or specialized scripts
- **Custom writing styles** ‚úçÔ∏è ‚Äì Historical documents, handwriting, or stylized fonts
- **Domain-specific vocabulary** üè¢ ‚Äì Technical jargon, product names, or industry terminology

This end-to-end approach means you can optimize the entire model for your specific task with just a few training examples, without needing to retrain separate detection, recognition, or layout analysis components.

## Getting Started

We use the IAM handwritten text dataset from [HuggingFaceM4/FineVision](https://huggingface.co/datasets/HuggingFaceM4/FineVision) as an example, but you can easily adapt this notebook to your own dataset.

**Note:** This notebook supports multiple FineVision subsets (olmOCR-mix-0225-books, olmOCR-mix-0225-documents, and iam) üìö‚úçÔ∏è

For more details about the model, see the [LightOnOCR blog post](https://huggingface.co/blog/lightonai/lightonocr).

## Installation

First, let's install the necessary libraries including the transformers fork with LightOnOCR support.

In [None]:
!pip install -q -U datasets accelerate
!pip install -q -U "pillow>=12.0.0"
!pip install -q -U git+https://github.com/baptiste-aubertin/transformers.git@main
!pip install -q huggingface-hub==1.0.0
!pip install -U bitsandbytes>=0.46.1
!pip install -q jiwer

## Load Dataset

For this example, we'll use the **IAM handwriting dataset** from [HuggingFaceM4/FineVision](https://huggingface.co/datasets/HuggingFaceM4/FineVision) to finetune the model on handwritten text recognition ‚úçÔ∏è.

**FineVision Dataset Structure:**
- `images`: List of PIL Images
- `texts`: List of conversation dicts with:
  - `user`: Question/prompt(we ignore this one)
  - `assistant`: OCR ground truth text
- `source`: Dataset source identifier

**Other available subsets:**
- `olmOCR-mix-0225-books` - Books subset of olmOCR-mix-0225 üìö
- `olmOCR-mix-0225-documents` - Document subset olmOCR-mix-0225 üìÑ
- And many more! See the [FineVision dataset page](https://huggingface.co/datasets/HuggingFaceM4/FineVision) for all available datasets.

In [None]:
import torch
from datasets import load_dataset

# choose dataset subset
finevision_subset = "iam"
# finevision_subset = "olmOCR-mix-0225-books"
# finevision_subset = "olmOCR-mix-0225-documents"

train_ds = load_dataset('HuggingFaceM4/FineVision', finevision_subset, split='train[:85%]')
val_ds = load_dataset('HuggingFaceM4/FineVision', finevision_subset, split='train[85%:95%]')
test_ds = load_dataset('HuggingFaceM4/FineVision', finevision_subset, split='train[95%:]')

print(f"Training samples: {len(train_ds)}, Validation samples: {len(val_ds)}, Test samples: {len(test_ds)}")

## Load Model and Processor

We'll load the LightOnOCR model with full fine-tuning (optionally freezing parts of the model to reduce memory requirements).

In [None]:
from transformers import AutoProcessor
import torch

model_id = "lightonai/LightOnOCR-1B-1025"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = 'left'

print(f"Using device: {device}")



> Note: this cell fails sometimes! but it's enough to restart the notebook and it works!



In [None]:
from transformers import LightOnOCRForConditionalGeneration

model = LightOnOCRForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    device_map="auto",
).to(device)

# # freeze vision encoder, projector or language model to reduce memory requirements on Colab
# for param in model.model.vision_encoder.parameters():
#     param.requires_grad = False
# print(f"Vision encoder frozen: {param.requires_grad}")
# for param in model.model.vision_projection.parameters():
#     param.requires_grad = False
# print(f"Vision projection frozen: {param.requires_grad}")
for param in model.model.language_model.parameters():
    param.requires_grad = False
print(f"Language model frozen: {param.requires_grad}")

### Option 2: LoRA Fine-tuning

Uncomment this cell to use LoRA instead of full fine-tuning. This is more memory efficient and recommended for limited GPU resources.

In [None]:
# from transformers import LightOnOCRForConditionalGeneration, BitsAndBytesConfig
# from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
# import torch

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4"
# )

# model = LightOnOCRForConditionalGeneration.from_pretrained(
#     model_id,
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
#     quantization_config=bnb_config
# )

# model = prepare_model_for_kbit_training(model)

# # # convert vision encoder to bfloat16 to match input dtype
# # if hasattr(model, 'model') and hasattr(model.model, 'vision_encoder'):
# #     model.model.vision_encoder = model.model.vision_encoder.to(torch.bfloat16)

# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=16,
#     target_modules=["o_proj", "gate_proj", "up_proj", "down_proj"],
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM"
# )

# model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()

# print("Model loaded with 4-bit quantization + LoRA")

## Prepare Data Collator

The data collator prepares batches for training. It:
1. Formats prompts with image tokens using chat template
2. Processes images and text through the processor
3. Creates labels for training (masking prompt tokens, only training on assistant response)

In [None]:
# assistant start pattern: <|im_end|>\n<|im_start|>assistant\n
ASSISTANT_START_PATTERN = [151645, 1699, 151644, 77091, 1699]
MAX_LENGTH = 1024
LONGEST_EDGE = 700

def collate_fn(examples):
    batch_messages = []
    batch_images = []

    for example in examples:
        example_images = example["images"]
        example_texts = example["texts"]

        assert len(example_images) == 1, f"Expected 1 image per sample, got {len(example_images)}"
        assert len(example_texts) == 1, f"Expected 1 text per sample, got {len(example_texts)}"

        image = example_images[0].convert("RGB")
        batch_images.append(image)

        conversation = example_texts[0]
        # strip extra whitespaces and newlines to avoid tokenization issues
        assistant_text = conversation.get("assistant", "").strip()

        messages = [
            {"role": "user", "content": [{"type": "image"}]},
            {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}
        ]
        batch_messages.append(messages)

    if len(batch_images) == 0:
        return None

    texts = [
        processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        for messages in batch_messages
    ]

    inputs = processor(
        text=texts,
        images=batch_images,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        size={"longest_edge": LONGEST_EDGE} # reduce due to memory requirements
    )

    labels = inputs["input_ids"].clone()
    pad_token_id = processor.tokenizer.pad_token_id

    for i in range(len(labels)):
        full_ids = inputs["input_ids"][i].tolist()

        # find where assistant content starts (after the assistant marker)
        assistant_content_start = None
        for idx in range(len(full_ids) - len(ASSISTANT_START_PATTERN)):
            if full_ids[idx:idx+len(ASSISTANT_START_PATTERN)] == ASSISTANT_START_PATTERN:
                assistant_content_start = idx + len(ASSISTANT_START_PATTERN)
                break

        if assistant_content_start is None:
            print(f"Warning: Could not find assistant marker in sample {i}")
            print(f"Sample {i} failed. Text: {texts[i]}")
            labels[i, :] = -100
        else:
            # mask everything first
            labels[i, :] = -100

            # unmask from assistant content start to end
            # this trains on: assistant text + EOS
            for idx in range(assistant_content_start, len(full_ids)):
                if full_ids[idx] == pad_token_id:
                    break
                labels[i, idx] = inputs["input_ids"][i, idx]

        # mask padding tokens
        labels[i, inputs["input_ids"][i] == pad_token_id] = -100

    inputs["labels"] = labels

    # convert tensors to device with proper dtype
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)

    return inputs

## Test the Collator

Let's test the collator with a sample batch to ensure everything works correctly.

In [None]:
# test with a small batch
test_batch = collate_fn([train_ds[0], train_ds[1]])
print("Input shape:", test_batch["input_ids"].shape)
print("Labels shape:", test_batch["labels"].shape)
print("Pixel values shape:", test_batch["pixel_values"].shape)

## Test Model Before Fine-tuning

Let's run inference with the base model first to see how it performs on our dataset before fine-tuning.

In [None]:
def run_inference(image):
    """run inference on a single image"""
    messages = [
        {"role": "user", "content": [{"type": "image"}]}
    ]

    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = processor(
        text=[text],
        images=[[image]],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        size={"longest_edge": LONGEST_EDGE},
    ).to(device)
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)

    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
    )

    input_length = inputs['input_ids'].shape[1]
    generated_ids = outputs[0, input_length:]
    generated_text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)

    return generated_text.strip()

# test on a few samples
print("Testing base model on validation samples:\n")
for idx in range(3):
    sample = test_ds[idx]
    image = sample["images"][0]
    ground_truth = sample["texts"][0].get("assistant", "").strip()

    prediction = run_inference(image)

    print(f"Sample {idx + 1}:")
    print(f"Prediction  : {prediction}")
    print(f"Ground truth: {ground_truth}")
    print("-" * 50)

    # display image
    display(image)

The model performs well overall but still makes some mistakes on handwritten crops ‚úçÔ∏è since it was primarily trained on full-page PDF documents üìÑ, not paragraph-level crops like these. However, we can finetune it to boost handwritten text recognition! üöÄ‚ú®

## Lets evaluate the model before training

In [None]:
from jiwer import cer, wer
import torch
from tqdm import tqdm

def evaluate_model(model, dataset, num_samples=50, batch_size=8, description="Model"):
    model.eval()

    predictions = []
    ground_truths = []

    print(f"\nEvaluating {description} on {num_samples} samples...")

    for start_idx in tqdm(range(0, min(num_samples, len(dataset)), batch_size)):
        end_idx = min(start_idx + batch_size, num_samples, len(dataset))
        batch_samples = [dataset[i] for i in range(start_idx, end_idx)]

        batch_images = [[s["images"][0]] for s in batch_samples]
        batch_ground_truths = [s["texts"][0]["assistant"].strip() for s in batch_samples]

        messages = [{"role": "user", "content": [{"type": "image"}]}]
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        texts = [text] * len(batch_images)

        inputs = processor(text=texts,
                           images=batch_images,
                           return_tensors="pt",
                           padding=True,
                           truncation=True,
                           max_length=MAX_LENGTH,
                           size={"longest_edge": LONGEST_EDGE},
                           ).to(device)
        inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)

        outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True)

        input_length = inputs['input_ids'].shape[1]
        generated_ids = outputs[:, input_length:]
        batch_predictions = processor.batch_decode(generated_ids, skip_special_tokens=True)
        batch_predictions = [p.strip() for p in batch_predictions]

        predictions.extend(batch_predictions)
        ground_truths.extend(batch_ground_truths)

    cer_score = cer(ground_truths, predictions) * 100
    wer_score = wer(ground_truths, predictions) * 100
    perfect_matches = sum(1 for pred, gt in zip(predictions, ground_truths) if pred == gt)

    print(f"CER: {cer_score:.2f}% | WER: {wer_score:.2f}% | Perfect: {perfect_matches}/{num_samples}")

    for i in range(min(3, len(predictions))):
        match = "‚úÖ" if predictions[i] == ground_truths[i] else "‚ùå"
        print(f"{match} Sample {i+1}: '{predictions[i]}' vs '{ground_truths[i]}'")

    return {"cer": cer_score, "wer": wer_score, "perfect_matches": perfect_matches}

In [None]:
from peft import PeftModel

print("\n" + "="*80)
print("BEFORE TRAINING")
print("="*80)

if isinstance(model, PeftModel):
    with model.disable_adapter():
        base_results = evaluate_model(model, test_ds, num_samples=100, batch_size=4, description="Base")
else:
    base_results = evaluate_model(model, test_ds, num_samples=100, batch_size=4, description="Base")

torch.cuda.empty_cache()

## Configure Training Arguments

Set up the training configuration. Adjust based on your hardware and requirements.

In [None]:
from transformers import TrainingArguments
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision('high')

output_dir = f"lightonocr-ft-{finevision_subset}"
use_bf16 = torch.cuda.is_available()

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    # max_steps=100,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=6,
    gradient_accumulation_steps=4,
    learning_rate=6e-5,
    weight_decay=0.0,
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    bf16=use_bf16,
    fp16=False,
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    warmup_steps=10,
    lr_scheduler_type="linear",
)

print(f"Output directory: {output_dir}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")

## Initialize Trainer and Start Training

In [None]:
from transformers import Trainer

# use a smaller validation set on Colab
val_ds_small = val_ds.select(range(100))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds_small,
    data_collator=collate_fn,
)

print("Starting training...")
print(f"Number of training samples: {len(train_ds)}")
print(f"Number of validation samples: {len(val_ds_small)}")

In [None]:
trainer.train()

In [None]:
import matplotlib.pyplot as plt

train_steps = []
train_losses = []
eval_steps = []
eval_losses = []

for entry in trainer.state.log_history:
    if 'loss' in entry:
        train_steps.append(entry['step'])
        train_losses.append(entry['loss'])
    if 'eval_loss' in entry:
        eval_steps.append(entry['step'])
        eval_losses.append(entry['eval_loss'])

plt.figure(figsize=(10, 6))
plt.plot(train_steps, train_losses, label='Training Loss', marker='o', linewidth=2)
plt.plot(eval_steps, eval_losses, label='Validation Loss', marker='s', linewidth=2)
plt.xlabel('Steps', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# evaluate the model after training
print("\n" + "="*80)
print("AFTER TRAINING")
finetuned_results = evaluate_model(model, test_ds, num_samples=100, batch_size=4, description="Finetuned")

In [None]:
# comparison
print("\n" + "="*80)
print("COMPARISON")
print(f"{'Metric':<20} {'Base':<12} {'Finetuned':<12} {'Change':<12}")
print("-" * 56)
print(f"{'CER (%)':<20} {base_results['cer']:<12.2f} {finetuned_results['cer']:<12.2f} {base_results['cer']-finetuned_results['cer']:+.2f}")
print(f"{'WER (%)':<20} {base_results['wer']:<12.2f} {finetuned_results['wer']:<12.2f} {base_results['wer']-finetuned_results['wer']:+.2f}")
print(f"{'Perfect':<20} {base_results['perfect_matches']:<12} {finetuned_results['perfect_matches']:<12} {finetuned_results['perfect_matches']-base_results['perfect_matches']:+d}")
print("="*80)

With 2 epochs training, we can get even better results:

```bash
=========================================================
COMPARISON
Metric               Base         Finetuned    Change
--------------------------------------------------------
CER (%)              37.15        1.95         +35.20
WER (%)              41.89        5.07         +36.81
Perfect              54           166          +112
==========================================================
```



## Save and Push Model

Save the fine-tuned model and optionally push to Hugging Face Hub.

In [None]:
# save model and processor
trainer.save_model(output_dir)
processor.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")

## Authentication

Authenticate to access the model and push your fine-tuned model to the Hub.

In [None]:
# from huggingface_hub import notebook_login, logout
# logout()
# notebook_login()
# # optional: push to Hub
# hub_model_id = "staghado/LightOnOCR-1B-1025-ft-iam"
# trainer.push_to_hub(hub_model_id)
# processor.push_to_hub(hub_model_id)

## Inference

Test the fine-tuned model on new images.

In [None]:
# test on a validation sample
test_idx = 0
test_sample = val_ds[test_idx]
test_image = test_sample["images"][0]

print("Running inference...")
result = run_inference(test_image)

print("\n=== Generated Text ===")
print(result)

print("\n=== Ground Truth ===")
print(test_sample["texts"][0].get("assistant", ""))

display(test_image)

In [None]:
# Test on multiple validation samples
print("="*50)
print("Testing Finetuned Model")
print("="*50)

num_samples = 5
for test_idx in range(num_samples):
    test_sample = val_ds[test_idx]
    test_image = test_sample["images"][0]
    ground_truth = test_sample["texts"][0]["assistant"]

    print(f"\n{'='*50}")
    print(f"Sample {test_idx + 1}:")

    # Run inference
    result = run_inference(test_image)

    print(f"Prediction  : {result}")
    print(f"Ground truth: {ground_truth}")

    # Calculate CER
    from jiwer import cer
    error_rate = cer([ground_truth], [result]) * 100
    print(f"CER: {error_rate:.2f}%")

    # Show if improved (you'll need base model predictions to compare)
    if result.strip() == ground_truth.strip():
        print("‚úÖ Perfect match!")

    # Optional: display image
    # display(test_image)

print("\n" + "="*50)
print("Testing complete! üöÄ")

## Convert Model for vLLM Compatibility

If you want to use the model with vLLM afterwards, you need to update the config to use the the following model types beacuse the current vLLM implementation of LightOnOCR relies on them, this will be fixed soon so both Transformers and vLLM use the same names.

In [None]:
# import json

# config_path = f"{output_dir}/config.json"
# with open(config_path, 'r') as f:
#     config = json.load(f)

# # update model types for vLLM compatibility
# config['model_type'] = 'mistral3'
# config['text_config']['model_type'] = 'qwen3'
# config['vision_config']['model_type'] = 'pixtral'

# with open(config_path, 'w') as f:
#     json.dump(config, f, indent=2)