Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improved font resolution and page synthesis #472

Merged
merged 8 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,26 @@ doc = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images()
result = model(doc)
```

To make sense of your model's predictions, you can visualize them as follows:
To make sense of your model's predictions, you can visualize them interactively as follows:

```python
result.show(doc)
```

![DocTR example](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif)
![Visualization sample](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif)

The ocr_predictor returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
Or even rebuild the original document from its predictions:

```python
import matplotlib.pyplot as plt

plt.imshow(result.synthesize()); plt.axis('off'); plt.show()
```

![Synthesis sample](https://github.com/mindee/doctr/releases/download/v0.3.1/synthesized_sample.png)


The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/io.html#document-structure):

You can also export them as a nested dict, more appropriate for JSON format:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Easy-to-use functions to make sense of your model's predictions.

.. autofunction:: visualize_page

.. autofunction:: synthesize_page


.. _metrics:

Expand Down
12 changes: 2 additions & 10 deletions doctr/datasets/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform

from doctr.io.image import tensor_from_pil
from doctr.utils.fonts import get_font
from ..datasets import AbstractDataset


Expand All @@ -31,16 +32,7 @@ def synthesize_char_img(char: str, size: int = 32, font_family: Optional[str] =
d = ImageDraw.Draw(img)

# Draw the character
if font_family is None:
try:
font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", size)
except OSError:
font = ImageFont.load_default()
logging.warning("unable to load specific font families. Loading default PIL font,"
"font size issues may be expected."
"To prevent this, it is recommended to specify the value of 'font_family'.")
else:
font = ImageFont.truetype(font_family, size)
font = get_font(font_family, size)
d.text((4, 0), char, font=font, fill=(255, 255, 255))

return img
Expand Down
20 changes: 19 additions & 1 deletion doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Tuple, Dict, List, Any, Optional, Union

from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
from doctr.utils.visualization import visualize_page
from doctr.utils.visualization import visualize_page, synthesize_page
from doctr.utils.common_types import BoundingBox, RotatedBbox
from doctr.utils.repr import NestedObject

Expand Down Expand Up @@ -244,6 +244,15 @@ def show(
visualize_page(self.export(), page, interactive=interactive)
plt.show(**kwargs)

def synthesize(self, **kwargs) -> np.ndarray:
"""Synthesize the page from the predictions

Returns:
synthesized page
"""

return synthesize_page(self.export(), **kwargs)

@classmethod
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
Expand Down Expand Up @@ -280,6 +289,15 @@ def show(self, pages: List[np.ndarray], **kwargs) -> None:
for img, result in zip(pages, self.pages):
result.show(img, **kwargs)

def synthesize(self, **kwargs) -> List[np.ndarray]:
"""Synthesize all pages from their predictions

Returns:
list of synthesized pages
"""

return [page.synthesize() for page in self.pages]

@classmethod
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .sar import *
from .zoo import *

del utils
del utils # type: ignore[name-defined]
28 changes: 28 additions & 0 deletions doctr/utils/fonts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import platform
import logging
from PIL import ImageFont
from typing import Optional

__all__ = ['get_font']


def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont:

# Font selection
if font_family is None:
try:
font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size)
except OSError:
font = ImageFont.load_default()
logging.warning("unable to load recommended font family. Loading default PIL font,"
"font size issues may be expected."
"To prevent this, it is recommended to specify the value of 'font_family'.")
else:
font = ImageFont.truetype(font_family, font_size)

return font
21 changes: 10 additions & 11 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing import Tuple, List, Dict, Any, Union, Optional

from .common_types import BoundingBox, RotatedBbox
from .fonts import get_font

__all__ = ['visualize_page', 'synthetize_page', 'draw_boxes']
__all__ = ['visualize_page', 'synthesize_page', 'draw_boxes']


def rect_patch(
Expand Down Expand Up @@ -242,21 +243,24 @@ def visualize_page(
return fig


def synthetize_page(
def synthesize_page(
page: Dict[str, Any],
draw_proba: bool = False,
font_size: int = 13,
font_family: Optional[str] = None,
) -> np.ndarray:
"""Draw a the content of the element page (OCR response) on a blank page.

Args:
page: exported Page object to represent
draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0
font_size: size of the font, default font = 13
font_family: family of the font

Return:
A np array (drawn page)
the synthesized page
"""

# Draw template
h, w = page["dimensions"]
response = 255 * np.ones((h, w, 3), dtype=np.int32)
Expand All @@ -271,16 +275,11 @@ def synthetize_page(
ymin, ymax = int(h * ymin), int(h * ymax)

# White drawing context adapted to font size, 0.75 factor to convert pts --> pix
h_box, w_box = ymax - ymin, xmax - xmin
h_font, w_font = font_size, int(font_size * w_box / (h_box * 0.75))
img = Image.new('RGB', (w_font, h_font), color=(255, 255, 255))
font = get_font(font_family, int(0.75 * (ymax - ymin)))
img = Image.new('RGB', (xmax - xmin, ymax - ymin), color=(255, 255, 255))
d = ImageDraw.Draw(img)

# Draw in black the value of the word
d.text((0, 0), word["value"], font=ImageFont.load_default(), fill=(0, 0, 0))

# Resize back to box size
img = img.resize((w_box, h_box), Image.NEAREST)
d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))

# Colorize if draw_proba
if draw_proba:
Expand Down
9 changes: 9 additions & 0 deletions test/common/test_io_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def test_page():
# Show
page.show(np.zeros((256, 256, 3), dtype=np.uint8), block=False)

# Synthesize
img = page.synthesize()
assert isinstance(img, np.ndarray)
assert img.shape == (*page_size, 3)


def test_document():
pages = _mock_pages()
Expand All @@ -214,3 +219,7 @@ def test_document():

# Show
doc.show([np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(len(pages))], block=False)

# Synthesize
img_list = doc.synthesize()
assert isinstance(img_list, list) and len(img_list) == len(pages)
11 changes: 11 additions & 0 deletions test/common/test_utils_fonts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from PIL.ImageFont import ImageFont

from doctr.utils.fonts import get_font


def test_get_font():

# Attempts to load recommended OS font
font = get_font()

assert isinstance(font, ImageFont)
8 changes: 5 additions & 3 deletions test/common/test_utils_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def test_visualize_page():
visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100))


def test_draw_page():
def test_synthesize_page():
pages = _mock_pages()
visualization.synthetize_page(pages[0].export(), draw_proba=True)
visualization.synthetize_page(pages[0].export(), draw_proba=False)
visualization.synthesize_page(pages[0].export(), draw_proba=False)
render = visualization.synthesize_page(pages[0].export(), draw_proba=True)
assert isinstance(render, np.ndarray)
assert render.shape == (*pages[0].dimensions, 3)


def test_draw_boxes():
Expand Down