# Fine-tune dots.ocr for Custom OCR Tasks

This notebook shows how to fine-tune the dots.ocr model for custom OCR tasks on Google Colab.

> **Note:** This notebook makes use of wjbmattingly's [dots.ocr training repo](https://github.com/wjbmattingly/dots.ocr).  

## What is in this notebook:
- Autolabel images using the base dots.ocr model
- Prepare training data from your custom images
- Finetune the model for better OCR on your specific content
- Test and evaluate your finetuned model

## Requirements:
- A100/L4 GPU recommended
- Images you want to train on (upload to Google Drive)

## Workflow:
1. Setup - Install dependencies and download base model
2. Auto-label - Generate initial OCR predictions (skip if you have prepared data)
3. Correct - Manually fix the generated labels (skip if you have prepared data)
4. Train- Finetune the model on your corrected data
- GPU: T4, L4, A100, or similar (T4 compatible!)


In [None]:
## Environment Setup

import os
import subprocess

# Set memory allocation for better GPU usage
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Check GPU availability
print("Checking GPU...")
!nvidia-smi

# Clone repository
if not os.path.exists("dots.ocr"):
    !git clone https://github.com/knightofcookies/dots.ocr.git

%cd dots.ocr

# Install base dependencies first
!pip install -q -r requirements.txt

# Install training requirements (skip flash-attn for now, will install conditionally)
!pip install -q torch>=2.0.0 transformers>=4.37.0 accelerate>=0.25.0 datasets>=2.16.0
!pip install -q peft>=0.8.0 wandb Pillow>=8.0.0 tqdm qwen-vl-utils
!pip install -q deepspeed bitsandbytes

# Detect GPU and install flash-attn only if supported
import torch
if torch.cuda.is_available():
    capability = torch.cuda.get_device_capability(0)
    compute = float(f'{capability[0]}.{capability[1]}')
    gpu_name = torch.cuda.get_device_name(0)
    print(f'\nDetected GPU: {gpu_name}')
    print(f'Compute Capability: {compute}')
    
    if compute >= 8.0:
        print('Flash Attention 2 is supported - installing...')
        !pip install -q flash-attn --no-build-isolation
    else:
        print('⚠️  Flash Attention 2 NOT supported on this GPU (requires compute capability >= 8.0)')
        print('Training will use standard PyTorch attention (slower but compatible with T4)')
        print('This is fine for fine-tuning! Just expect longer training times.')

!pip install -e .

# Download base model weights (~6GB)
!python tools/download_model.py

print("\n✅ Setup complete")


In [None]:
## Mount Google Drive

from google.colab import drive

# Mount Google Drive
drive.mount("/content/drive")

# Set up paths
IMAGES_DIR = "/content/drive/MyDrive/temp/my_pdf_pages/"  # Images
AUTOLABEL_DIR = "/content/autolabel"  # Autolabeled results

print(f"Images directory: {IMAGES_DIR}")
print(f"Autolabel directory: {AUTOLABEL_DIR}")

Mounted at /content/drive
Images directory: /content/drive/MyDrive/images
Autolabel directory: /content/autolabel


## Upload Your Images

**Before proceeding:**
1. Go to Google Drive and create a folder called 'images' in your My Drive
2. Upload the images you want to train on
3. Supported formats: .jpg, .jpeg, .png, .pdf



Note: If you already have prepared training data, skip to the training section.


In [None]:
## Autolabel Your Images

import os
import glob
from tqdm import tqdm

# Create output directory
os.makedirs(AUTOLABEL_DIR, exist_ok=True)

# Find all image files
image_files = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.pdf"]:
    image_files.extend(glob.glob(os.path.join(IMAGES_DIR, ext)))

print(f"Found {len(image_files)} images to process")

if len(image_files) == 0:
    print("No images found! Please upload images to Google Drive first.")
    print(f"Expected location: {IMAGES_DIR}")
else:
    # Process each image
    successful = 0
    failed = 0

    for img_path in tqdm(image_files, desc="Autolabeling"):
        try:
            !python -m dots_ocr.parser "{img_path}" --output "{AUTOLABEL_DIR}" --prompt "prompt_ocr" --use_hf true
            successful += 1
        except Exception as e:
            print(f"Failed to process {os.path.basename(img_path)}: {e}")
            failed += 1

    print(f"\nAuto-labeling completed!")
    print(f"Successful: {successful}")
    print(f"Failed: {failed}")
    print(f"Results saved to: {AUTOLABEL_DIR}")


## Manual Correction


### How to correct your labels:

Since autolabel results are saved directly to your Google Drive, you can edit them there:

1. Open Google Drive and navigate to 'MyDrive/autolabel'
2. For each sample folder, open the '.md' file
3. Edit the text to correct any OCR errors
4. Save the file directly in Google Drive

That's it! No need to download/upload. The training script will read the corrected files from Drive.

Skip this section if you already have prepared training data.

## Prepare Training Data

Expected data format: Your data should be in '/content/drive/MyDrive/autolabel/' with this structure:

    autolabel/
    ├── sample1/
    │   ├── sample1.jpg
    │   ├── sample1.md
    │   └── sample1.json
    ├── sample2/
    │   ├── sample2.jpg
    │   ├── sample2.md
    │   └── sample2.json
    └── ...


### **Training JSONL Format**

Your corrected data will be converted into a `.jsonl` file for training (`train_ocr_resized.jsonl`).  
Each line in this file is one training sample in JSON format.

**Expected JSONL structure:**

```json
{"messages":[
  {"role":"user","content":[
    {"type":"image","image":"/content/resized_images/sample1.jpg"},
    {"type":"text","text":"prompt_ocr"}
  ]},
  {"role":"assistant","content":"This is the corrected OCR text for sample 1."}
]}
{"messages":[
  {"role":"user","content":[
    {"type":"image","image":"/content/resized_images/sample2.jpg"},
    {"type":"text","text":"prompt_ocr"}
  ]},
  {"role":"assistant","content":"This is the corrected OCR text for sample 2."}
]}


In [None]:
## 7. Finetune the Model

TRAINING_JSONL = "/content/drive/MyDrive/train_ocr_resized.jsonl"

# Check if training data exists
if not os.path.exists(TRAINING_JSONL):
    print(f"Training data not found: {TRAINING_JSONL}")
    print("Please run the previous step to prepare training data first.")
else:
    print("Starting finetuning...")
    
    # Detect GPU for Flash Attention compatibility
    import torch
    flash_attn_flag = ""
    if torch.cuda.is_available():
        capability = torch.cuda.get_device_capability(0)
        compute = float(f'{capability[0]}.{capability[1]}')
        if compute < 8.0:
            print(f"⚠️  Detected GPU with compute capability {compute} (T4 or older)")
            print("Flash Attention will be automatically disabled for compatibility.")
            # No need to add --no_flash_attention, the script auto-detects
    
    # Train the model
    # Note: Flash Attention is auto-detected. Use --use_flash_attention to force enable
    # or --no_flash_attention to force disable
    !python train_simple.py \
        --data "{TRAINING_JSONL}" \
        --epochs 15 \
        --batch_size 1 \
        --learning_rate 3e-4 \
        --max_length 1024 \
        --gradient_checkpointing \
        --output_dir "/content/local_checkpoints"

    print("Training completed")


In [None]:
## 7. Finetune the Model

TRAINING_JSONL = "/content/drive/MyDrive/train_ocr_resized.jsonl"

# Check if training data exists
if not os.path.exists(TRAINING_JSONL):
    print(f"Training data not found: {TRAINING_JSONL}")
    print("Please run the previous step to prepare training data first.")
else:
    print("Starting finetuning...")

    # Train the model
    !python train_simple.py \
        --data "{TRAINING_JSONL}" \
        --epochs 15 \
        --batch_size 1 \
        --learning_rate 3e-4 \
        --max_length 1024 \
        --gradient_checkpointing \
        --output_dir "/content/local_checkpoints"

    print("Training completed")


In [None]:
## 8. Setup Finetuned Model

print("Copying configuration files...")

# Ensure we have base model files
!python tools/download_model.py

# Copy missing configuration files from base model
!cp ./weights/DotsOCR/configuration_dots.py /content/local_checkpoints/final_model/
!cp ./weights/DotsOCR/modeling_*.py /content/local_checkpoints/final_model/

print("Replacing base model with finetuned model...")
!rm -rf ./weights/DotsOCR
!cp -r /content/local_checkpoints/final_model ./weights/DotsOCR

print("Verifying model setup...")
!python -c "import json; json.load(open('./weights/DotsOCR/config.json')); print('Model setup complete')"

print("\nYour finetuned model is ready for inference")



In [None]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def sanitize(img_path, dst_dir):
    try:
        with Image.open(img_path) as im:
            im = im.convert("RGB")
            w, h = im.size
            if max(w, h) > 1024:
                s = 1024 / max(w, h)
                im = im.resize((int(w*s), int(h*s)), Image.Resampling.LANCZOS)
            clean_path = os.path.join(dst_dir, os.path.basename(img_path))
            im.save(clean_path, format="JPEG", quality=92, optimize=True)
            return clean_path
    except Exception as e:
        print(f"[skip bad image] {img_path} -> {e}")
        return None


In [None]:
# Test OCR
# Set your file path here:
IMAGE_PATH = "/content/resized_images/SCR-20250715-iho.jpg"

from dots_ocr.parser import DotsOCRParser
import os

parser = DotsOCRParser(use_hf=True, max_completion_tokens=128)
result = parser.parse_file(IMAGE_PATH, prompt_mode="prompt_ocr")

if not result:
    print("[no result]")
else:
    info = result[0]
    text = None

    md_path = info.get("md_content_path")
    if md_path and os.path.exists(md_path):
        with open(md_path, "r", encoding="utf-8") as f:
            text = f.read()

    if text is None and isinstance(info.get("content"), str):
        text = info["content"]

    print(text or "[empty]")


In [None]:
## 10. Save the Fine-tuned Model
print("Saving fine-tuned model to Google Drive...")

!cp -r /content/local_checkpoints/final_model /content/drive/MyDrive/dots_ocr_finetuned

print("\nFine-tuned model saved to Google Drive")
print("Location: /content/drive/MyDrive/dots_ocr_finetuned")





### Resources:
- [dots.ocr GitHub](https://github.com/rednote-hilab/dots.ocr)
- [Training Documentation](https://github.com/wjbmattingly/dots.ocr/blob/main/README_model_training.md)

Happy training!