In [None]:
# --- Imports and Environment Setup ---
import sys
import math
import random
import re
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import time

import torch
import torch.nn.functional as F
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
import pytesseract
import matplotlib.pyplot as plt
from tqdm import tqdm
import psutil

In [None]:
# --- Tweakable Parameters ---
MIN_AREA_FRAC = 0.08
CANNY_LOW = 50
CANNY_HIGH = 150
GAUSSIAN_BLUR_KERNEL = (3, 3)
DILATE_KERNEL_SIZE = (3, 3)
DILATE_ITERATIONS = 1

# --- OCR preprocessing ---
BILATERAL_D = 5
BILATERAL_SIGMA_COLOR = 50
BILATERAL_SIGMA_SPACE = 30
CLAHE_CLIP_LIMIT = 3.0
CLAHE_TILE_GRID = (8, 8)
TARGET_SHORT_SIDE = 1320

DO_INVERT = False
DO_BINARIZE = False
ERODE_KERNEL_SIZE = (2, 2)
ERODE_ITERATIONS = 3
DILATE_KERNEL_SIZE_TEXT = (3, 3)
DILATE_ITERATIONS_TEXT = 0
DO_DESKEW = False
REMOVE_BORDER_FRAC = 0.02
ADD_BORDER_PADDING = 10
ADD_BORDER_COLOR = (255, 255, 255)

OCR_LANG = "Latin+osd"
OCR_CONFIG = "--psm 6"

In [None]:
# --- Utility Functions ---
def load_img(path):
    img = cv2.imread(str(path))
    if img is None:
        raise FileNotFoundError(f"Cannot read {path}")
    return img

def save_img(path, img):
    ext = Path(path).suffix or ".png"
    ok, buf = cv2.imencode(ext, img)
    if not ok:
        raise IOError("Could not encode image")
    buf.tofile(str(path))

def order_points(pts):
    rect = np.zeros((4,2), dtype="float32")
    s = pts.sum(axis=1)
    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]
    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    return rect

def four_point_transform(image, pts, dst_size=None):
    rect = order_points(pts)
    (tl, tr, br, bl) = rect
    widthA = np.linalg.norm(br-bl)
    widthB = np.linalg.norm(tr-tl)
    maxWidth = int(max(widthA, widthB))
    heightA = np.linalg.norm(tr-br)
    heightB = np.linalg.norm(tl-bl)
    maxHeight = int(max(heightA, heightB))
    if dst_size is not None:
        (maxWidth, maxHeight) = dst_size
    dst = np.array([[0,0],[maxWidth-1,0],[maxWidth-1,maxHeight-1],[0,maxHeight-1]], dtype="float32")
    M = cv2.getPerspectiveTransform(rect, dst)
    return cv2.warpPerspective(image, M, (maxWidth, maxHeight), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(255,255,255))

In [None]:
# --- Page Detection & Cropping ---
def detect_page_and_crop(img_bgr):
    h, w = img_bgr.shape[:2]
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, GAUSSIAN_BLUR_KERNEL, 0)
    edges = cv2.Canny(blur, CANNY_LOW, CANNY_HIGH)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, DILATE_KERNEL_SIZE)
    edges = cv2.dilate(edges, kernel, iterations=DILATE_ITERATIONS)
    contours, _ = cv2.findContours(edges.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    for cnt in contours:
        peri = cv2.arcLength(cnt, True)
        approx = cv2.approxPolyDP(cnt, 0.02*peri, True)
        if len(approx) == 4 and cv2.contourArea(approx) > MIN_AREA_FRAC*h*w:
            return four_point_transform(img_bgr, approx.reshape(4,2).astype("float32"))
    # Fallback: central crop
    m = int(REMOVE_BORDER_FRAC * min(w,h))
    return img_bgr[m:h-m, m:w-m].copy()

In [None]:
# --- Deskew CPU ---
def deskew_image(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape)==3 else image
    gray = cv2.bitwise_not(gray)

    # 🔧 Ensure dtype is uint8 (OpenCV requirement for OTSU)
    if gray.dtype != np.uint8:
        gray = cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
    coords = np.column_stack(np.where(thresh > 0))
    angle = cv2.minAreaRect(coords)[-1]

    if angle < -45:
        angle = -(90 + angle)
    else:
        angle = -angle

    if abs(angle) < 0.5:
        return image

    (h, w) = image.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
    return cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

In [None]:
# --- CPU Preprocessing for OCR ---
def preprocess_for_ocr(img_bgr):
    if DO_INVERT:
        img_bgr = cv2.bitwise_not(img_bgr)
    if DO_DESKEW:
        img_bgr = deskew_image(img_bgr)

    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_rgb = cv2.bilateralFilter(img_rgb, BILATERAL_D, BILATERAL_SIGMA_COLOR, BILATERAL_SIGMA_SPACE)
    
    lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
    l,a,b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=CLAHE_TILE_GRID)
    l2 = clahe.apply(l)
    img_rgb = cv2.cvtColor(cv2.merge((l2,a,b)), cv2.COLOR_LAB2RGB)

    if DO_BINARIZE:
        gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
        # Make sure black text is foreground (0) and background is 255
        thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 15, 10)
        # thresh is now single-channel: 0 for black text, 255 for white background
        img_gray = thresh  # Keep single-channel for morphological ops
    else:
        img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)

    if ERODE_ITERATIONS>0:
        kernel = np.ones(ERODE_KERNEL_SIZE, np.uint8)
        img_gray = cv2.dilate(img_gray, kernel, iterations=DILATE_ITERATIONS_TEXT)
        # img_rgb = cv2.erode(img_rgb, kernel, iterations=ERODE_ITERATIONS)
        
    if DILATE_ITERATIONS_TEXT>0:
        kernel = np.ones(DILATE_KERNEL_SIZE_TEXT, np.uint8)
        img_gray = cv2.dilate(img_gray, kernel, iterations=DILATE_ITERATIONS_TEXT)
        # img_rgb = cv2.dilate(img_rgb, kernel, iterations=DILATE_ITERATIONS_TEXT)

    if TARGET_SHORT_SIDE>0:
        h,w = img_rgb.shape[:2]
        scale = TARGET_SHORT_SIDE / min(h,w)
        new_h,new_w = int(h*scale), int(w*scale)
        img_gray = cv2.resize(img_gray, (new_w,new_h), interpolation=cv2.INTER_LINEAR)
        # img_rgb = cv2.resize(img_rgb, (new_w,new_h), interpolation=cv2.INTER_LINEAR)

    if ADD_BORDER_PADDING>0:
        img_gray = cv2.copyMakeBorder(img_gray, ADD_BORDER_PADDING, ADD_BORDER_PADDING, ADD_BORDER_PADDING, ADD_BORDER_PADDING, cv2.BORDER_CONSTANT, value=ADD_BORDER_COLOR)
        # img_rgb = cv2.copyMakeBorder(img_rgb, ADD_BORDER_PADDING, ADD_BORDER_PADDING, ADD_BORDER_PADDING, ADD_BORDER_PADDING, cv2.BORDER_CONSTANT, value=ADD_BORDER_COLOR)
    return Image.fromarray(img_gray)

In [None]:
# --- GPU Preprocessing for OCR ---
def preprocess_for_ocr_gpu(img_bgr, device):
    # Convert to grayscale and normalize
    tensor = torch.from_numpy(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)).unsqueeze(0).unsqueeze(0).float() / 255.0
    tensor = tensor.to(device)

    # Deskew if enabled
    if DO_DESKEW:
        tensor_cpu = tensor.squeeze().cpu().numpy()
        tensor_cpu = deskew_image(tensor_cpu)
        tensor = torch.from_numpy(tensor_cpu).unsqueeze(0).unsqueeze(0).float().to(device) / 255.0

    # Resize
    if TARGET_SHORT_SIDE > 0:
        _, _, h, w = tensor.shape
        scale = TARGET_SHORT_SIDE / min(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        tensor = F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False)

    # Binarization (adaptive thresholding)
    if DO_BINARIZE:  
        tensor_cpu = tensor.squeeze().cpu().numpy() * 255
        tensor_cpu = cv2.adaptiveThreshold(tensor_cpu.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 15, 10)
        tensor = torch.from_numpy(tensor_cpu.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(device)

    # Morphology (dilate to make text more prominent)
    if DILATE_ITERATIONS_TEXT > 0:
        kernel = torch.ones(1, 1, *DILATE_KERNEL_SIZE_TEXT, device=device)
        for _ in range(DILATE_ITERATIONS_TEXT):
            tensor = F.conv2d(tensor, kernel, padding=0)
            tensor = torch.clamp(tensor, 0, 1)

    # Add border if needed
    if ADD_BORDER_PADDING > 0:
        tensor = F.pad(tensor, (ADD_BORDER_PADDING,) * 4, mode='constant', value=1.0)

    # Convert back to PIL Image
    tensor = (tensor.squeeze(0).squeeze(0) * 255).byte().cpu().numpy()
    return Image.fromarray(tensor)

In [None]:
# --- OCR ---
def ocr_infer_pil(image_pil):
    return pytesseract.image_to_string(image_pil, lang=OCR_LANG, config=OCR_CONFIG)

In [None]:
def get_processing_profile():
    profile = {"device": "cpu", "batch_size": 1, "resize": 1200}
    total_ram_gb = psutil.virtual_memory().total / (1024**3)

    if torch.cuda.is_available():
        props = torch.cuda.get_device_properties(0)
        total_vram_gb = props.total_memory / (1024**3)

        if total_vram_gb >= 8:
            profile.update({"device": "cuda", "batch_size": 8, "resize": 1800})
        elif total_vram_gb >= 4:
            profile.update({"device": "cuda", "batch_size": 4, "resize": 1500})
        else:
            profile.update({"device": "cuda", "batch_size": 2, "resize": 1200})
    else:
        if total_ram_gb >= 16:
            profile.update({"batch_size": 4, "resize": 1500})
        elif total_ram_gb >= 8:
            profile.update({"batch_size": 2, "resize": 1200})
        else:
            profile.update({"batch_size": 1, "resize": 1000})

    return profile

In [None]:
import random
import matplotlib.pyplot as plt
from IPython.display import display

def detect_device(prefer_gpu=True):
    if prefer_gpu and torch.cuda.is_available():
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        return torch.device("cuda")
    print("Using CPU")
    return torch.device("cpu")

def preview_random_page(files, device):
    """Pick a random file, preprocess, and display comparison inline."""
    preview_file = random.choice(files)
    print(f"[Preview] Showing random file: {preview_file.name}")
    
    img = load_img(preview_file)
    cropped = detect_page_and_crop(img)
    preprocessed = preprocess_for_ocr_gpu(cropped, device) if device.type == 'cuda' else preprocess_for_ocr(cropped)

    # --- Show original vs preprocessed ---
    fig, axes = plt.subplots(1, 2, figsize=(12, 8))
    axes[0].imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
    axes[0].set_title("Original Cropped Page")
    axes[0].axis("off")

    axes[1].imshow(preprocessed, cmap='gray')
    axes[1].set_title("Preprocessed for OCR")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()

    # OCR preview text
    text = ocr_infer_pil(preprocessed)
    print(f"--- Preview OCR ---\n{text[:1000]}")

    # Ask user if they want to proceed
    proceed = input("Proceed with OCR using these settings? [y/n]: ").strip().lower()
    return proceed == "y"

def ocr_worker(path, device):
    try:
        img = load_img(path)
        cropped = detect_page_and_crop(img)
        preprocessed = preprocess_for_ocr_gpu(cropped, device) if device.type == 'cuda' else preprocess_for_ocr(cropped)
        return ocr_infer_pil(preprocessed)
    except Exception as e:
        return f"Error processing {path}: {e}"

def batch_process_book(book_dir, out_dir, max_workers=4, preview=True):
    device = detect_device()
    profile = get_processing_profile()
    batch_size = profile["batch_size"]
    devi_ce = torch.device(profile["device"])
    print(f"Running on {devi_ce}, batch size {profile['batch_size']}, resize {profile['resize']}")

    while True:
        try:
            book_dir = Path(book_dir)
            out_dir = Path(out_dir)
            out_dir.mkdir(parents=True, exist_ok=True)
            files = sorted([p for p in book_dir.iterdir() if p.suffix.lower() in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']])
            
            # ✅ Show preview before starting
            if preview and files:
                approved = preview_random_page(files, device)
                if not approved:
                    print("[INFO] Aborted by user. Please tweak preprocessing parameters and re-run.")
                    return  # exit before OCR

            results = []
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                for text in tqdm(executor.map(lambda p: ocr_worker(p, device), files), total=len(files)):
                    results.append(text)

            out_path = out_dir / f"{book_dir.name}_ocr.txt"
            with open(out_path, 'w', encoding='utf-8') as f:
                f.write("\n".join(results))   
            print("OCR complete:", out_path)
            break  # success
        except RuntimeError as e:
            if "CUDA out of memory" in str(e) and batch_size > 1:
                batch_size = max(1, batch_size // 2)
                profile["batch_size"] = batch_size
                torch.cuda.empty_cache()
                print(f"[WARN] CUDA OOM - reducing batch size to {batch_size} and retrying...")
                continue
            else:
                raise

In [None]:
# [11] --- Main Execution ---
if __name__ == "__main__":
    # -----🔧 Change these paths to match your book images folder and output folder ------
    # Enter the name of the folder in which the pages are in
    BOOK_NAME = "Bhala_Mabhalana" 
    # Make sure that you put the files you want to process inside 'data/'raw data'
    book_dir = Path.cwd() / "data" / "raw data" / BOOK_NAME
    # Once the OCR is complete navigate to 'data/output' to find the output file
    out_dir = Path.cwd() / "data" / "output" 

    print(f"Starting OCR pipeline on: {book_dir}")
    start_time = time.time()

    profile = get_processing_profile()
    batch_process_book(book_dir, out_dir, max_workers=profile["batch_size"], preview=True)

    print(f"OCR Pipeline complete in {time.time() - start_time:.2f} sec")