# Text Detection Comparison with MMOCR

This notebook uses [MMOCR](https://github.com/open-mmlab/mmocr) to run and compare three text detectors—**CRAFT**, **DBNet**, and **PSENet**—in pure Python on Windows VSCode.

Steps:
1. Install dependencies
2. Setup paths & imports
3. Utility functions
4. Initialize MMOCR detectors
5. Inference & qualitative comparison
6. Save side-by-side canvases



In [None]:
# 1. Install required packages (run once)
# Uncomment and run if needed:
# !pip install mmocr gevent-websocket munch anyconfig polygon3
# !pip install torch torchvision torchaudio

# 2. Imports & Global Paths
import os
import cv2
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
from mmocr.apis import init_detector, inference_detector

# Data folders
base_dir    = os.getcwd()
data_dir    = os.path.join(base_dir, 'data')
image_dir   = os.path.join(data_dir, 'imgsForAllPages')
gt_dir      = os.path.join(data_dir, 'annotations')
output_dir  = os.path.join(base_dir, 'output')
compare_dir = os.path.join(output_dir, 'comparisons')

os.makedirs(compare_dir, exist_ok=True)


: 

## 3. Utility Functions

Load images and ground-truth polygons.


In [None]:

def load_image(path):
    """Load an image (BGR) as numpy array."""
    return cv2.imread(path)


def load_ground_truth(path):
    """Load GT polygons from TXT (8 coords per line)."""
    polys = []
    if os.path.exists(path):
        with open(path, 'r') as f:
            for line in f:
                coords = list(map(float, line.strip().split(',')))
                polys.append(np.array(coords).reshape(-1,2).tolist())
    return polys



## 4. Initialize MMOCR Detectors

We point to MMOCR's config files and local checkpoints.



In [None]:

configs = {
    'CRAFT':  'mmocr/configs/textdet/craft/craft_mlt_25k.py',
    'DBNet':  'mmocr/configs/textdet/dbnet/dbnetpp_resnet50_fpnc_1200e_icdar2015.py',
    'PSENet': 'mmocr/configs/textdet/psenet/psenet_resnet50_fpnf_600e_ctw1500.py',
}
checkpoints = {
    'CRAFT':  'CRAFTModel/weights/craft_mlt_25k.pth',
    'DBNet':  'DB/weights/dbnetpp_resnet50_fpnc_1200e_icdar2015.pth',
    'PSENet':'PSENet/weights/psenet_resnet50_fpnf_600e_ctw1500.pth',
}

models = {}
for name, cfg in configs.items():
    print(f"Loading {name} model…")
    models[name] = init_detector(cfg, checkpoints[name], device='cpu')
print("All detectors loaded!")



## 5. Sample Inference & Visualization

Run all three detectors on one example and display side-by-side.



In [None]:

# pick first image
img_paths = glob(os.path.join(image_dir, '*.png')) + glob(os.path.join(image_dir, '*.jpg'))
if not img_paths:
    raise RuntimeError(f"No images found in {image_dir}")
img_path = img_paths[0]
img = load_image(img_path)
gt  = load_ground_truth(img_path.replace(image_dir, gt_dir).rsplit('.',1)[0] + '.txt')

# collect predictions
preds = {}
for name, model in models.items():
    result = inference_detector(model, img)
    polys  = result.get('boundary_result', [])
    preds[name] = [np.array(p).reshape(-1,2).tolist() for p in polys]

# draw canvas
y, x = img.shape[:2]
canvas = np.zeros((y, x*4, 3), dtype=np.uint8)
# GT
vis = img.copy()
for poly in gt:
    cv2.polylines(vis, [np.array(poly, dtype=np.int32).reshape(-1,1,2)], True, (0,255,0), 2)
canvas[:, :x] = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
# each model
for i,(name, boxes) in enumerate(preds.items(), start=1):
    vis = img.copy()
    for poly in boxes:
        cv2.polylines(vis, [np.array(poly, dtype=np.int32).reshape(-1,1,2)], True, (255,0,0), 2)
    canvas[:, i*x:(i+1)*x] = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)

# display
from matplotlib import pyplot as plt
plt.figure(figsize=(16,8))
plt.imshow(canvas)
plt.axis('off')



## 6. Batch Inference & Save Comparisons

Loop through all images and save side-by-side comparisons.


In [None]:

for img_path in img_paths:
    img = load_image(img_path)
    base = os.path.splitext(os.path.basename(img_path))[0]
    gt   = load_ground_truth(img_path.replace(image_dir, gt_dir).rsplit('.',1)[0] + '.txt')

    # predictions
    preds = {}
    for name, model in models.items():
        res   = inference_detector(model, img)
        polys = res.get('boundary_result', [])
        preds[name] = [np.array(p).reshape(-1,2).tolist() for p in polys]

    # build canvas
    h,w = img.shape[:2]
    canvas = np.zeros((h, w*4, 3), dtype=np.uint8)
    # GT
    vis_gt = img.copy()
    for poly in gt:
        cv2.polylines(vis_gt, [np.array(poly, dtype=np.int32).reshape(-1,1,2)], True, (0,255,0), 2)
    canvas[:, :w] = cv2.cvtColor(vis_gt, cv2.COLOR_BGR2RGB)
    # models
    for i,(name,boxes) in enumerate(preds.items(), start=1):
        vis = img.copy()
        for poly in boxes:
            cv2.polylines(vis, [np.array(poly, dtype=np.int32).reshape(-1,1,2)], True, (255,0,0), 2)
        canvas[:, i*w:(i+1)*w] = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)

    save_path = os.path.join(compare_dir, f"{base}_compare.png")
    Image.fromarray(canvas).save(save_path)
    print(f"Saved comparison: {save_path}")
