Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detection module : Preprocessor #20

Merged
merged 18 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doctr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import utils
from . import preprocessor
93 changes: 93 additions & 0 deletions doctr/models/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (C) 2021, Mindee.
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import math
import cv2
import json
import os
import numpy as np
from typing import Union, List, Tuple, Optional, Any, Dict

fg-mindee marked this conversation as resolved.
Show resolved Hide resolved

class Preprocessor():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the brackets here


fg-mindee marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
out_size: Tuple[int, int],
normalization: bool = True,
batch_size: int = 1
) -> None:

self.out_size = out_size
self.normalization = normalization
self.batch_size = batch_size

def __call__(
documents: Tuple[List[List[np.ndarray]], List[List[str]], List[List[Tuple[int, int]]]]
) -> List[Tuple[List[np.ndarray], List[str], List[Tuple[int, int]]]]:
"""
perform resizing, normalization and batching on documents
"""
images, names, shapes = documents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input signature can take several arguments, it would be way cleaner than having a tuple of 3 different type of objects

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a self in argument btw

images = resize_documents_imgs(images, out_size=self.out_size)
if self.normalization:
images = normalize_documents_imgs(images)
norm_and_sized_docs = images, names, shapes
b_docs, docs_indexes, pages_indexes = batch_documents(
norm_and_sized_docs, batch_size=self.batch_size)

return b_docs, docs_indexes, pages_indexes

def normalize_documents_imgs(
documents_imgs: List[List[np.ndarray]],
mode: str = 'symmetric'
) -> List[List[np.ndarray]]:
"""
normalize documents imgs according to mode
"""

if mode == 'symmetric':
return [[(img - 128) / 128 for img in doc] for doc in documents_imgs]

def resize_documents_imgs(
documents_imgs: List[List[np.ndarray]],
out_size: Tuple[int, int]
) -> List[List[np.ndarray]]:
"""
Resize documents img to the out_size : size for the model inputs
The nested structure documents/pages is preserved
returns resized documents img
"""
return [[cv2.resize(img, out_size, cv2.INTER_LINEAR) for img in doc] for doc in documents_imgs]

def batch_documents(
documents: Tuple[List[List[np.ndarray]], List[List[str]], List[List[Tuple[int, int]]]],
batch_size: int = 1
) -> List[Tuple[List[np.ndarray], List[str], List[Tuple[int, int]]]]:
"""
function to batch a list of read documents
:param documents: documents read by documents.reader.read_documents
:param batch_size: batch_size to use during inference, default goes to 1
"""

images, names, shapes = documents

# keep track of both documents and pages indexes
docs_indexes = [images.index(doc) for doc in images for _ in doc]
pages_indexes = [doc.index(page) for doc in images for page in doc]

# flatten structure
flat_images = [image for doc in images for raw in doc]
flat_names = [name for doc in names for name in doc]
flat_shapes = [shape for doc in shapes for shape in doc]

range_batch = range((len(flat_shapes) + batch_size - 1) // batch_size)

b_images = [flat_images[i * batch_size:(i + 1) * batch_size] for i in range_batch]
b_names = [flat_names[i * batch_size:(i + 1) * batch_size] for i in range_batch]
b_shapes = [flat_shapes[i * batch_size:(i + 1) * batch_size] for i in range_batch]

b_docs = [[b_i, b_n, b_s] for b_i, b_n, b_s in zip(b_images, b_names, b_shapes)]

return b_docs, docs_indexes, pages_indexes
27 changes: 27 additions & 0 deletions test/test_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import requests
from io import BytesIO

from doctr.models.preprocessor import Preprocessor
from doctr import documents

@pytest.fixture(scope="session")
def mock_pdf(tmpdir_factory):
url = 'https://arxiv.org/pdf/1911.08947.pdf'
file = BytesIO(requests.get(url).content)
fn = tmpdir_factory.mktemp("data").join("mock_pdf_file.pdf")
with open(fn, 'wb') as f:
f.write(file.getbuffer())
return fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check if this works after removing this definition please?
The fixture is defined at the session level, so I guess this will work as is (or we can import it)



def test_preprocess_documents(num_docs=10, batch_size=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a mock_pdf argument for the fixture to work here!
You can move the others function arguments inside its definition

docs = documents.reader.read_documents(
filepaths=[mock_pdf for _ in range(num_docs)])
preprocessor = Preprocessor(out_size=(600, 600), normalization=True, batch_size=batch_size)
batched_docs, docs_indexes, pages_indexes = preprocessor(docs)
assert len(docs_indexes) == len(pages_indexes)
if num_docs > batch_size:
for batch in batched_docs[:-1]:
for i in range(len(batch)):
assert len(batch[i]) == batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a specific test checking for actual values here since you now the doc + num of pages?