# layout detection

In [1]:
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
from PIL import Image
from surya.settings import settings


def load_det_cached():
    checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
    return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)


# def load_rec_cached():
#     return load_rec_model(), load_rec_processor()


def load_layout_cached():
    return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)

def load_order_cached():
    return load_order_model(), load_order_processor()


det_model, det_processor = load_det_cached()
# rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()




Loaded detection model vikp/surya_det2 on device cuda with dtype torch.float16
Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16
Loaded reading order model vikp/surya_order on device cuda with dtype torch.float16


In [2]:
import pymupdf as fitz
from pymupdf import Document
def get_image_from_page(doc, page):
    page = doc[page]
    dpi = 72
    zoom = dpi / 72  # 72 is the default DPI
    mat = fitz.Matrix(zoom, zoom)
    pix = page.get_pixmap(matrix=mat)
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    return img

In [4]:
from typing import List, Tuple
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
from PIL import Image
from surya.ordering import batch_ordering


def texts_detection(images: List[Image.Image], batch_size=10) -> List[TextDetectionResult]:
    preds = batch_text_detection(images, det_model, det_processor, batch_size=batch_size)
    return preds

def layout_detection(images: List[Image.Image], batch_size=10) -> List[LayoutResult]:
    _det_preds = texts_detection(images, batch_size=batch_size)
    preds = batch_layout_detection(images, layout_model, layout_processor, _det_preds, batch_size=batch_size)
    return preds

def order_detection(images: List[Image.Image], batch_size=10) -> List[Tuple[Image.Image, OrderResult]]:
    layout_preds = layout_detection(images, batch_size=batch_size)
    all_bboxes = []
    all_labels = []
    for layout_pred in layout_preds:
        bboxes = [l.bbox for l in layout_pred.bboxes]
        labels = [l.label for l in layout_pred.bboxes]
        all_labels.append(labels)
        all_bboxes.append(bboxes)
    preds = batch_ordering(images, all_bboxes, order_model, order_processor, batch_size=batch_size)
    all_polys = []
    all_positions = []
    for pred in preds:
        polys = [l.polygon for l in pred.bboxes]
        all_polys.append(polys)
        positions = [str(l.position) for l in pred.bboxes]
        all_positions.append(positions)

    order_imgs = []
    for idx, img in enumerate(images):
        order_img = draw_polys_on_image(all_polys[idx], img.copy(), labels=all_positions[idx], label_font_size=20)
        order_imgs.append(order_img)
    return order_imgs, preds, all_labels

def process_pdf_texts(doc: type[Document], batch_size=2):
    # images = [get_image_from_page(doc, page_number) for page_number in range(20)]
    images = [get_image_from_page(doc, page_number) for page_number in range(len(doc))]
    # print(images)
    order_imgs, preds, labels = order_detection(images, batch_size=batch_size)
    return order_imgs, preds, labels


pdf_path = "./data/UniversScience/revue-decouverte/decouverte_436.pdf"
doc = fitz.open(pdf_path)
result_imgs, result_preds, result_labels = process_pdf_texts(doc, batch_size=8)
# max 8 on my laptop, increse on server


Detecting bboxes: 100%|██████████| 13/13 [00:16<00:00,  1.23s/it]
Detecting bboxes: 100%|██████████| 13/13 [00:12<00:00,  1.01it/s]
Finding reading order: 100%|██████████| 13/13 [00:09<00:00,  1.36it/s]


In [5]:
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import io
import pymupdf as fitz

for page_index in range(len(doc)):
    page = doc[page_index]
    preds = result_preds[page_index]
    labels = result_labels[page_index]
    img = result_imgs[page_index]


    # display(img)
    # Sort bounding boxes by position
    sorted_indices = sorted(range(len(preds.bboxes)), key=lambda i: preds.bboxes[i].position)
    # Zip and iterate over sorted bounding boxes, labels, and indices
    for i in sorted_indices:
        box = preds.bboxes[i]
        label = labels[i]
        
        left, top, right, bottom = box.bbox
        # clean text
        text = page.get_text("text", clip=box.bbox)
        text = text.replace('\n', ' ')
        text = text.strip()

        width = img.width
        page_rect = page.rect
        page_width = page_rect.width
        page_height = page_rect.height
        # if (len(text) > 2):
        #     print("pos  :", box.position)
        #     print("label:", label)
        #     print("text :", text)
        #     print("\n===================\n")