Skip to content

Commit

Permalink
Detection module : Preprocessor (#20)
Browse files Browse the repository at this point in the history
* feat: ✨ pdf reader

* feat: ✨ add doc_to_string function

* feat ✨ add inference_utilities + inference for DBnet

* save: saving work before switching to doc reader

* feat: ✨ add model meta class

* add: postprocessor

* feat ✨ preprocessor

* test: passed test

* test: passed all tests except unitest

* refacto: remove deprecated file

* test: passed all tests

* test: passed all tests

* test: passed tests

* test: passed tests

* test: passed tests

* test: passed tests
  • Loading branch information
charlesmindee committed Jan 18, 2021
1 parent 9abdf11 commit b8eb11b
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions doctr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import utils
from . import preprocessor
106 changes: 106 additions & 0 deletions doctr/models/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (C) 2021, Mindee.
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import math
import cv2
import json
import os
import numpy as np
from typing import Union, List, Tuple, Optional, Any, Dict


__all__ = ['Preprocessor']


class Preprocessor:
"""
class to preprocess documents
a processor can perform noramization, resizing and batching
a processor is called on a document
"""

def __init__(
self,
out_size: Tuple[int, int],
normalization: bool = True,
mode: str = 'symmetric',
batch_size: int = 1
) -> None:

self.out_size = out_size
self.normalization = normalization
self.mode = mode
self.batch_size = batch_size

def normalize_documents_imgs(
self,
documents_imgs: List[List[np.ndarray]]
) -> List[List[np.ndarray]]:
"""
normalize documents imgs according to mode
"""

if self.mode == 'symmetric':
normalized = [[(img - 128) / 128 for img in doc] for doc in documents_imgs]
else:
normalized = documents_imgs
return normalized

def resize_documents_imgs(
self,
documents_imgs: List[List[np.ndarray]]
) -> List[List[np.ndarray]]:
"""
Resize documents img to the out_size : size for the model inputs
The nested structure documents/pages is preserved
returns resized documents img
"""
return [[cv2.resize(img, self.out_size, cv2.INTER_LINEAR) for img in doc] for doc in documents_imgs]

def batch_documents(
self,
documents: Tuple[List[List[np.ndarray]], List[List[str]], List[List[Tuple[int, int]]]]
) -> Tuple[List[Tuple[List[np.ndarray], List[str], List[Tuple[int, int]]]], List[int], List[int]]:
"""
function to batch a list of read documents
:param documents: documents read by documents.reader.read_documents
:param batch_size: batch_size to use during inference, default goes to 1
"""

images, names, shapes = documents

# keep track of both documents and pages indexes
docs_indexes = [i for i, doc in enumerate(images) for _ in doc]
pages_indexes = [i for doc in images for i, page in enumerate(doc)]

# flatten structure
flat_images = [image for doc in images for image in doc]
flat_names = [name for doc in names for name in doc]
flat_shapes = [shape for doc in shapes for shape in doc]

range_batch = range((len(flat_shapes) + self.batch_size - 1) // self.batch_size)

b_images = [flat_images[i * self.batch_size:(i + 1) * self.batch_size] for i in range_batch]
b_names = [flat_names[i * self.batch_size:(i + 1) * self.batch_size] for i in range_batch]
b_shapes = [flat_shapes[i * self.batch_size:(i + 1) * self.batch_size] for i in range_batch]

b_docs = [(b_i, b_n, b_s) for b_i, b_n, b_s in zip(b_images, b_names, b_shapes)]

return b_docs, docs_indexes, pages_indexes

def __call__(
self,
documents: Tuple[List[List[np.ndarray]], List[List[str]], List[List[Tuple[int, int]]]]
) -> Tuple[List[Tuple[List[np.ndarray], List[str], List[Tuple[int, int]]]], List[int], List[int]]:
"""
perform resizing, normalization and batching on documents
"""
images, names, shapes = documents
images = self.resize_documents_imgs(images)
if self.normalization:
images = self.normalize_documents_imgs(images)
norm_and_sized_docs = images, names, shapes
b_docs, docs_indexes, pages_indexes = self.batch_documents(norm_and_sized_docs)

return b_docs, docs_indexes, pages_indexes
17 changes: 17 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import pytest
import requests
from io import BytesIO

from doctr import models

import sys
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

from doctr.models.preprocessor import Preprocessor
from doctr import documents
from test_documents import mock_pdf


@pytest.fixture(scope="module")
def mock_model():
Expand Down Expand Up @@ -42,3 +48,14 @@ def test_quantize_model(mock_model):
def test_export_sizes(test_convert_to_tflite, test_convert_to_fp16, test_quantize_model):
assert sys.getsizeof(test_convert_to_tflite) > sys.getsizeof(test_convert_to_fp16)
assert sys.getsizeof(test_convert_to_fp16) > sys.getsizeof(test_quantize_model)


def test_preprocess_documents(mock_pdf, num_docs=10, batch_size=3): # noqa: F811
docs = documents.reader.read_documents(
filepaths=[mock_pdf for _ in range(num_docs)])
preprocessor = Preprocessor(out_size=(600, 600), normalization=True, mode='symmetric', batch_size=batch_size)
batched_docs, docs_indexes, pages_indexes = preprocessor(docs)
assert len(docs_indexes) == len(pages_indexes)
assert docs_indexes[-1] + 1 == num_docs
if num_docs > batch_size:
assert all(len(batch) == batch_size for batches in batched_docs[:-1] for batch in batches)

0 comments on commit b8eb11b

Please sign in to comment.