In [19]:
# =============================================================================
# TrOCR Test - IMPORTANT: TrOCR is for SINGLE LINE text, not full documents!
# =============================================================================
# 
# TrOCR expects cropped text lines, not full document images.
# For documents, you must:
# 1. Detect text regions/lines first (using OpenCV or PaddleOCR detection)
# 2. Crop each line
# 3. OCR each cropped line with TrOCR
#
# This notebook demonstrates both the WRONG and RIGHT way to use TrOCR.
# =============================================================================

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import numpy as np

# Load TrOCR model
print("Loading TrOCR model...")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
print("Model loaded!")

Loading TrOCR model...


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded!


In [24]:
# =============================================================================
# ‚ùå WRONG WAY: Pass entire document to TrOCR (this will fail!)
# =============================================================================

path = "../data/samples/prescriptions/ab36c7-20061128-oldprescrip.jpg"
image = Image.open(path).convert("RGB")

print(f"Image size: {image.size}")
print("Passing ENTIRE document to TrOCR (WRONG!)...")

pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(f"Result: '{generated_text}'")
print("^ This fails because TrOCR expects SINGLE LINE text, not full documents!")

Image size: (852, 602)
Passing ENTIRE document to TrOCR (WRONG!)...
Result: '0 1'
^ This fails because TrOCR expects SINGLE LINE text, not full documents!


In [25]:
# =============================================================================
# ‚úÖ RIGHT WAY: Detect text lines first, then OCR each line with TrOCR
# =============================================================================

def detect_text_lines(image_path):
    """
    Detect text lines using OpenCV morphological operations.
    Returns list of bounding boxes (x, y, w, h) for each text line.
    """
    # Read image
    img = cv2.imread(image_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # Binarize
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Dilate horizontally to connect text into lines
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (50, 3))
    dilated = cv2.dilate(binary, kernel, iterations=1)
    
    # Find contours (text lines)
    contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Get bounding boxes, filter by size, sort by y-position
    boxes = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        # Filter: minimum width and height
        if w > 50 and h > 10:
            boxes.append((x, y, w, h))
    
    # Sort by y-position (top to bottom)
    boxes.sort(key=lambda b: b[1])
    
    return boxes, img


def ocr_with_trocr(image, boxes, processor, model):
    """
    OCR each detected text line with TrOCR.
    """
    results = []
    
    for i, (x, y, w, h) in enumerate(boxes):
        # Crop the text line with padding
        pad = 5
        y1 = max(0, y - pad)
        y2 = min(image.shape[0], y + h + pad)
        x1 = max(0, x - pad)
        x2 = min(image.shape[1], x + w + pad)
        
        line_img = image[y1:y2, x1:x2]
        
        # Convert to PIL
        pil_img = Image.fromarray(cv2.cvtColor(line_img, cv2.COLOR_BGR2RGB))
        
        # OCR with TrOCR
        pixel_values = processor(pil_img, return_tensors="pt").pixel_values
        generated_ids = model.generate(pixel_values, max_length=64)
        text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        results.append({
            'bbox': (x, y, w, h),
            'text': text.strip()
        })
        
        print(f"Line {i+1}: '{text.strip()}'")
    
    return results


# Run detection and OCR
print("Detecting text lines...")
boxes, img = detect_text_lines(path)
print(f"Found {len(boxes)} text lines\n")

print("OCR each line with TrOCR:")
print("-" * 50)
results = ocr_with_trocr(img, boxes, processor, model)

Detecting text lines...
Found 20 text lines

OCR each line with TrOCR:
--------------------------------------------------
Line 1: '0 1'
Line 2: '2 Primary Care Center .'
Line 3: '6 Clinic 3A. Phillips-Wangensteen Building'
Line 4: '516 Delaware Street Southeast . Minneapolis , MN 55455.'
Line 5: 'Ritzen " Mickey Mouse -'
Line 6: 'assis relief #'
Line 7: 'Address that has been'
Line 8: 'Attendal 25mg.'
Line 9: 'ago'
Line 10: 'r.'
Line 11: 'VOID'
Line 12: 'Although it'
Line 13: 'ii'
Line 14: '" ( ( see a refill phone numbers on tracks'
Line 15: 'ex #'
Line 16: 'exemptment'
Line 17: 'threat #'
Line 18: 'Health-administration Estimates .0000008'
Line 19: 'staff physician disappointment by'
Line 20: '166103. 3. 4.4'


In [26]:
# =============================================================================
# Visualize detected text lines
# =============================================================================
import matplotlib.pyplot as plt

def visualize_detections(image_path, boxes):
    """Draw bounding boxes on the image."""
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    for i, (x, y, w, h) in enumerate(boxes):
        cv2.rectangle(img_rgb, (x, y), (x+w, y+h), (255, 0, 0), 2)
        cv2.putText(img_rgb, str(i+1), (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
    
    plt.figure(figsize=(12, 10))
    plt.imshow(img_rgb)
    plt.title(f"Detected {len(boxes)} text lines")
    plt.axis('off')
    plt.show()

visualize_detections(path, boxes)

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
# =============================================================================
# üîë KEY INSIGHT: TrOCR vs PaddleOCR
# =============================================================================
#
# | Model     | Detection | Recognition | Best For                        |
# |-----------|-----------|-------------|----------------------------------|
# | PaddleOCR | ‚úÖ Yes    | ‚úÖ Yes      | Clean printed text, fast         |
# | TrOCR     | ‚ùå No     | ‚úÖ Yes      | Handwriting, degraded/noisy docs |
#
# For best results on prescriptions:
# 1. Use PaddleOCR for text DETECTION (bounding boxes)
# 2. Use TrOCR for text RECOGNITION (especially handwriting)
#
# This is what our ocr_router.py does:
# - region_detector.py detects text regions
# - ocr_router.py sends each region to TrOCR
# =============================================================================

print("Summary:")
print("-" * 50)
print("‚ùå TrOCR on full document ‚Üí '0 1' (fails)")
print("‚úÖ TrOCR on detected lines ‚Üí works!")
print()
print("The preprocessing pipeline handles this automatically:")
print("  1. document_pipeline.py preprocesses image")
print("  2. region_detector.py detects text regions") 
print("  3. ocr_router.py sends each region to TrOCR")

In [None]:
# =============================================================================
# ‚úÖ BEST WAY: Use the DocumentPipeline (production approach)
# =============================================================================

import sys
sys.path.insert(0, '..')

from src.medical_ingestion.core.document_pipeline import DocumentPipeline
import asyncio

async def process_with_pipeline(image_path):
    """Use the full document pipeline with region detection + TrOCR."""
    pipeline = DocumentPipeline({
        'enable_preprocessing': True,
        'enable_region_detection': True,
        'use_vlm_classification': False,  # Don't need PaliGemma
    })
    
    # Process the image
    results = await pipeline.process_image(image_path)
    
    if results:
        result = results[0]
        print(f"Full text extracted:\n{result.full_text}")
        print(f"\nRegions detected: {result.total_regions}")
        print(f"Average confidence: {result.average_confidence:.2f}")
        
        if result.region_texts:
            print("\nText by region type:")
            for region_type, text in result.region_texts.items():
                print(f"  {region_type}: {text[:100]}...")
    
    return results

# Run the pipeline
# results = asyncio.run(process_with_pipeline(path))
print("To use the full pipeline, run:")
print("  results = asyncio.run(process_with_pipeline(path))")