In [None]:
# -*- coding: utf-8 -*-
"""
03_model_comparison.ipynb

Comparison of OCR Pipelines:
- YOLO + CRNN (Ours)
- YOLO + TrOCR
- EasyOCR (End-to-End)
"""

USE_COLAB = True

if USE_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    PROJECT_ROOT = "/content/drive/MyDrive/Information-Extraction-from-Image"
else:
    PROJECT_ROOT = os.path.abspath(".")

print("PROJECT_ROOT:", PROJECT_ROOT)


In [None]:
# ===============================
# Install dependencies
# ===============================
!pip install -r {PROJECT_ROOT}/requirements.txt


In [None]:
# ===============================
# Standard Library
# ===============================
import os
import sys
import time
import random
import pickle
import json

# ===============================
# Third-party Libraries
# ===============================
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torchvision import transforms
from ultralytics import YOLO
import ultralytics

import easyocr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel


In [None]:
DATASET_DIR = os.path.join(PROJECT_ROOT, "datasets/SceneTrialTrain")
CACHE_DIR = os.path.join(PROJECT_ROOT, "cache")
MODEL_DIR = os.path.join(PROJECT_ROOT, "model")

if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

ultralytics.checks()

In [None]:
with open(os.path.join(CACHE_DIR, "val_data.pkl"), "rb") as f:
    cache = pickle.load(f)

val_yolo_data = cache["val_yolo_data"]
image_paths = cache["image_paths"]
image_labels = cache["image_labels"]
bounding_boxes = cache["bounding_boxes"]

In [None]:
from src.pipeline import (
    inference_yolo_crnn,
    inference_yolo_trocr,
    inference_easyocr,
)

from src.evaluation import (
    evaluate_model,
)

Comparison of 3 models on the test dataset:
- **YOLO + CRNN (Ours)**
- **YOLO + TrOCR**
- **EasyOCR (End-to-End)**

Evaluation based on:

- Character Accuracy

- Word Accuracy

- Inference Speed

With 2 confidence thresholds: **0.3**

#1. Vocabulary & CRNN Config

In [None]:
CHARS = "0123456789abcdefghijklmnopqrstuvwxyz-"
BLANK_CHAR = "-"

char_to_idx = {c: i + 1 for i, c in enumerate(sorted(CHARS))}
idx_to_char = {i: c for c, i in char_to_idx.items()}

VOCAB_SIZE = len(CHARS)

CRNN_CONFIG = {
    "hidden_size": 256,
    "n_layers": 3,
    "dropout": 0.2,
    "unfreeze_layers": 3,
}


#2. Load YOLO + CRNN Models

In [None]:
from src.recognition import CRNN

yolo_model_path = os.path.join(MODEL_DIR, "yolo/best.pt")
crnn_model_path = os.path.join(MODEL_DIR, "cnn/ocr_crnn.pt")

yolo_det = YOLO(yolo_model_path)
print("YOLO loaded")

crnn_model = CRNN(
    vocab_size=VOCAB_SIZE,
    **CRNN_CONFIG,
).to(DEVICE)

crnn_model.load_state_dict(
    torch.load(crnn_model_path, map_location=DEVICE)
)
crnn_model.eval()
print("CRNN loaded")


#3. Load TrOCR

In [None]:
print("Loading TrOCR...")
trocr_processor = TrOCRProcessor.from_pretrained(
    "microsoft/trocr-base-printed"
)
trocr_model = VisionEncoderDecoderModel.from_pretrained(
    "microsoft/trocr-base-printed"
).to(DEVICE)
trocr_model.eval()
print("TrOCR loaded")

#4. Load EasyOCR


In [None]:
easyocr_reader = easyocr.Reader(
    ["en"], gpu=torch.cuda.is_available()
)
print("EasyOCR loaded")

#5. Build Test Dataset (from YOLO Val)

In [None]:
test_samples = val_yolo_data[:50]
test_data = []

for img_rel_path, _ in test_samples:
    img_path = os.path.join(DATASET_DIR, img_rel_path)

    for p, labels, bbs in zip(image_paths, image_labels, bounding_boxes):
        if p != img_rel_path:
            continue

        for bb, label in zip(bbs, labels):
            x, y, w, h = map(int, bb)
            test_data.append(
                {
                    "image_path": img_path,
                    "bbox": (x, y, w, h),
                    "label": label,
                }
            )
        break

print(
    f"Prepared {len(test_data)} text regions "
    f"from {len(test_samples)} images"
)

#6. CRNN Transform

In [None]:
crnn_transform = transforms.Compose(
    [
        transforms.Resize((100, 420)),
        transforms.Grayscale(1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

#7. Evaluation (Confidence = 0.3)

In [None]:
from functools import partial

# Test with confidence threshold 0.3
conf_threshold = 0.3
print(f"\n\n{'#'*80}")
print(f"# EVALUATION WITH CONFIDENCE THRESHOLD = {conf_threshold}")
print(f"{'#'*80}\n")

# Evaluate YOLO + CRNN
yolo_crnn_infer = partial(
    inference_yolo_crnn,
    yolo_det=yolo_det,
    crnn_transform=crnn_transform,
    crnn_inference=crnn_model,
    idx_to_char=idx_to_char,
)

results_yolo_crnn = evaluate_model(
    yolo_crnn_infer,
    test_data,
    "YOLO + CRNN (Ours)",
    conf_threshold=conf_threshold
)

# Evaluate YOLO + TrOCR
yolo_trocr_infer = partial(
    inference_yolo_trocr,
    yolo_det=yolo_det,
    trocr_processor=trocr_processor,
    trocr_model=trocr_model,
)

results_yolo_trocr = evaluate_model(
    yolo_trocr_infer,
    test_data,
    "YOLO + TrOCR",
    conf_threshold=conf_threshold
)

# Evaluate EasyOCR (end-to-end)
easyocr_infer = partial(
    inference_easyocr,
    easyocr_reader=easyocr_reader,
)

results_easyocr = evaluate_model(
    easyocr_infer,
    test_data,
    "EasyOCR (End-to-End)",
    conf_threshold=conf_threshold
)

# Store results for comparison
results_03 = [results_yolo_crnn, results_yolo_trocr, results_easyocr]

#8. Results Table

In [None]:
import pandas as pd

# Display comparison results for confidence threshold 0.3
print("\n" + "="*100)
print("COMPARISON RESULTS - CONFIDENCE THRESHOLD = 0.3")
print("="*100)

results = [r for r in results_03 if r]
if len(results) > 0:
    df_03 = pd.DataFrame(results)
    df_03 = df_03[['model', 'char_acc', 'word_acc', 'avg_time', 'matched_regions']]
    df_03.columns = ['Model', 'Char Acc (%)', 'Word Acc (%)', 'Speed (s/img)', 'Matched Regions']
    print(df_03.to_string(index=False))
    best_char_03 = df_03.loc[df_03['Char Acc (%)'].idxmax(), 'Model']
    best_word_03 = df_03.loc[df_03['Word Acc (%)'].idxmax(), 'Model']

    print("\n" + "="*100)
    print("SUMMARY")
    print("="*100)
    print(f"Confidence 0.3 - Best Char Acc: {best_char_03} | Best Word Acc: {best_word_03}")

In [None]:
if len(results_03) > 0:
    # Visualize results for confidence threshold 0.3
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    colors = ['#2ecc71', '#3498db', '#e74c3c']
    df = df_03

    # Character Accuracy
    axes[0].bar(df['Model'], df['Char Acc (%)'], color=colors)
    axes[0].set_ylabel('Accuracy (%)', fontsize=12)
    axes[0].set_title('Character Accuracy (Conf=0.3)', fontsize=14, fontweight='bold')
    axes[0].set_ylim([0, 100])
    axes[0].grid(axis='y', alpha=0.3)
    axes[0].tick_params(axis='x', rotation=45)
    for i, v in enumerate(df['Char Acc (%)']):
        axes[0].text(i, v + 2, f'{v:.1f}%', ha='center', fontweight='bold', fontsize=10)

    # Word Accuracy
    axes[1].bar(df['Model'], df['Word Acc (%)'], color=colors)
    axes[1].set_ylabel('Accuracy (%)', fontsize=12)
    axes[1].set_title('Word Accuracy (Conf=0.3)', fontsize=14, fontweight='bold')
    axes[1].set_ylim([0, 100])
    axes[1].grid(axis='y', alpha=0.3)
    axes[1].tick_params(axis='x', rotation=45)
    for i, v in enumerate(df['Word Acc (%)']):
        axes[1].text(i, v + 2, f'{v:.1f}%', ha='center', fontweight='bold', fontsize=10)

    # Speed
    axes[2].bar(df['Model'], df['Speed (s/img)'], color=colors)
    axes[2].set_ylabel('Time (seconds)', fontsize=12)
    axes[2].set_title('Speed (Conf=0.3)', fontsize=14, fontweight='bold')
    axes[2].grid(axis='y', alpha=0.3)
    axes[2].tick_params(axis='x', rotation=45)

    max_speed = float(df['Speed (s/img)'].max())
    axes[2].set_ylim(0, max_speed * 1.3)

    for i, v in enumerate(df['Speed (s/img)']):
        axes[2].text(i, v + v * 0.05, f'{v:.3f}s', ha='center', fontweight='bold', fontsize=10)

    plt.tight_layout()
    plt.show()
else:
    print("No results to visualize!")

In [None]:
# Visualize predictions from 3 models on 2 sample images
print("\n" + "="*100)
print("SAMPLE PREDICTIONS VISUALIZATION (2 Sample Images)")
print("="*100)

# Select 2 random sample images
sample_indices = random.sample(range(len(test_data)), min(2, len(test_data)))

for sample_idx in sample_indices:
    sample_img_path = test_data[sample_idx]['image_path']

    print(f"\n{'─'*100}")
    print(f"SAMPLE IMAGE {sample_idx + 1}: {sample_img_path}")
    print(f"{'─'*100}")

    # Get ground truth labels for this image
    gt_labels = [item['label'] for item in test_data if item['image_path'] == sample_img_path]
    print(f"Ground Truth Labels: {gt_labels}")
    print(f"Total regions in this image: {len(gt_labels)}\n")

    # Get predictions with confidence threshold 0.3
    preds_crnn = inference_yolo_crnn(sample_img_path, yolo_det, crnn_transform, crnn_model, idx_to_char, conf_threshold=0.3)
    preds_trocr = inference_yolo_trocr(sample_img_path, yolo_det, trocr_processor, trocr_model, conf_threshold=0.3)
    preds_easyocr = inference_easyocr(sample_img_path, easyocr_reader, conf_threshold=0.3)

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(20, 8))

    for idx, (preds, title) in enumerate([
        (preds_crnn, 'YOLO + CRNN (Ours)'),
        (preds_trocr, 'YOLO + TrOCR'),
        (preds_easyocr, 'EasyOCR (End-to-End)')
    ]):
        img = cv2.imread(sample_img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        for pred in preds:
            x, y, w, h = pred['bbox']
            cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 2)

            # Add text
            text = pred['text']
            conf = pred['confidence']
            label = f"{text} ({conf:.2f})"

            # Background for text
            (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
            cv2.rectangle(img, (x, y-text_h-5), (x+text_w, y), (0, 255, 0), -1)
            cv2.putText(img, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)

        axes[idx].imshow(img)
        axes[idx].set_title(title, fontsize=14, fontweight='bold')
        axes[idx].axis('off')

    plt.tight_layout()
    plt.show()

    # Print detection results
    print(f"  YOLO+CRNN detected: {len(preds_crnn)} regions")
    for i, pred in enumerate(preds_crnn):
        print(f"    {i+1}. {pred['text']} (conf: {pred['confidence']:.3f})")

    print(f"\n  YOLO+TrOCR detected: {len(preds_trocr)} regions")
    for i, pred in enumerate(preds_trocr):
        print(f"    {i+1}. {pred['text']} (conf: {pred['confidence']:.3f})")

    print(f"\n  EasyOCR detected: {len(preds_easyocr)} regions")
    for i, pred in enumerate(preds_easyocr):
        print(f"    {i+1}. {pred['text']} (conf: {pred['confidence']:.3f})")

    print("\n")