Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindee/fields/amount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
:param amount_prediction: Amount prediction object from HTTP response
:param value_key: Key to use in the amount_prediction dict
:param reconstructed: Bool for reconstructed object (not extracted in the API)
:param page_n: Page number for multi pages pdf
:param page_n: Page number for multi-page PDF
"""
super().__init__(
amount_prediction,
Expand Down
40 changes: 24 additions & 16 deletions mindee/fields/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from typing import Any, Dict, List, Optional, TypeVar

from mindee.geometry import Polygon, get_bbox_as_polygon

TypePrediction = Dict[str, Any]


class Field:
value: Optional[Any] = None
"""Raw field value"""
confidence: float = 0.0
"""Confidence score"""
bbox: List[List[float]] = []
bbox: Polygon = []
"""Bounding box coordinates containing the field"""
polygon: Polygon = []
"""coordinates of the field"""

def __init__(
self,
abstract_prediction: Dict[str, Any],
abstract_prediction: TypePrediction,
value_key: str = "value",
reconstructed: bool = False,
page_n: Optional[int] = None,
Expand All @@ -25,26 +31,28 @@ def __init__(
:param page_n: Page number for multi-page PDF
"""
self.page_n = page_n
self.reconstructed = reconstructed

if (
value_key not in abstract_prediction
or abstract_prediction[value_key] == "N/A"
):
self.value = None
self.confidence = 0.0
self.bbox = []
else:
self.value = abstract_prediction[value_key]
try:
self.confidence = float(abstract_prediction["confidence"])
except (KeyError, TypeError):
self.confidence = 0.0
try:
self.bbox = abstract_prediction["polygon"]
except KeyError:
self.bbox = []
return

self.reconstructed = reconstructed
self.value = abstract_prediction[value_key]
try:
self.confidence = float(abstract_prediction["confidence"])
except (KeyError, TypeError):
pass
self._set_bbox(abstract_prediction)

def _set_bbox(self, abstract_prediction: TypePrediction) -> None:
try:
self.polygon = abstract_prediction["polygon"]
except KeyError:
pass
if self.polygon:
self.bbox = get_bbox_as_polygon(self.polygon)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Field):
Expand Down
2 changes: 1 addition & 1 deletion mindee/fields/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
:param date_prediction: Date prediction object from HTTP response
:param value_key: Key to use in the date_prediction dict
:param reconstructed: Bool for reconstructed object (not extracted in the API)
:param page_n: Page number for multi pages pdf
:param page_n: Page number for multi-page PDF
"""
super().__init__(
date_prediction,
Expand Down
2 changes: 1 addition & 1 deletion mindee/fields/locale.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
:param locale_prediction: Locale prediction object from HTTP response
:param value_key: Key to use in the locale_prediction dict
:param reconstructed: Bool for reconstructed object (not extracted in the API)
:param page_n: Page number for multi pages pdf
:param page_n: Page number for multi-page PDF
"""
super().__init__(
locale_prediction,
Expand Down
2 changes: 1 addition & 1 deletion mindee/fields/orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
:param orientation_prediction: Orientation prediction object from HTTP response
:param value_key: Key to use in the orientation_prediction dict
:param reconstructed: Bool for reconstructed object (not extracted in the API)
:param page_n: Page number for multi pages pdf
:param page_n: Page number for multi-page PDF
"""
super().__init__(
orientation_prediction,
Expand Down
2 changes: 1 addition & 1 deletion mindee/fields/payment_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
payment_details_prediction dict
:param swift_key: Key to use for getting the SWIFT in the payment_details_prediction dict
:param reconstructed: Bool for reconstructed object (not extracted in the API)
:param page_n: Page number for multi pages pdf
:param page_n: Page number for multi-page PDF
"""
super().__init__(
payment_details_prediction,
Expand Down
33 changes: 33 additions & 0 deletions mindee/geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Pure Python geometry functions for working with polygons."""

from typing import Sequence, Tuple

Point = Tuple[float, float]
Polygon = Sequence[Point]
BoundingBox = Tuple[float, float, float, float]
Quadrilateral = Tuple[Point, Point, Point, Point]


def get_bbox_as_polygon(polygon: Polygon) -> Quadrilateral:
"""
Given a sequence of points, calculate a polygon that encompasses all points.

:param polygon: Sequence of ``Point``
:return: Quadrilateral
"""
x_min, y_min, x_max, y_max = get_bbox(polygon)
return (x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)


def get_bbox(polygon: Polygon) -> BoundingBox:
"""
Given a list of points, calculate a bounding box that encompasses all points.

:param polygon: Sequence of ``Point``
:return: BoundingBox
"""
y_min = min(v[1] for v in polygon)
y_max = max(v[1] for v in polygon)
x_min = min(v[0] for v in polygon)
x_max = max(v[0] for v in polygon)
return x_min, y_min, x_max, y_max
48 changes: 48 additions & 0 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from mindee import geometry


@pytest.fixture
def polygon_a():
"""90° rectangle, overlaps polygon_b"""
return [(0.123, 0.53), (0.175, 0.53), (0.175, 0.546), (0.123, 0.546)]


@pytest.fixture
def polygon_b():
"""90° rectangle, overlaps polygon_a"""
return [(0.124, 0.535), (0.190, 0.535), (0.190, 0.546), (0.124, 0.546)]


@pytest.fixture
def polygon_c():
"""not 90° rectangle, doesn't overlap any polygons"""
return [(0.205, 0.407), (0.379, 0.407), (0.381, 0.43), (0.207, 0.43)]


def test_bbox(polygon_a, polygon_b, polygon_c):
assert geometry.get_bbox(polygon_a) == (0.123, 0.53, 0.175, 0.546)
assert geometry.get_bbox(polygon_b) == (0.124, 0.535, 0.19, 0.546)
assert geometry.get_bbox(polygon_c) == (0.205, 0.407, 0.381, 0.43)


def test_bbox_polygon(polygon_a, polygon_b, polygon_c):
assert geometry.get_bbox_as_polygon(polygon_a) == (
(0.123, 0.53),
(0.175, 0.53),
(0.175, 0.546),
(0.123, 0.546),
)
assert geometry.get_bbox_as_polygon(polygon_b) == (
(0.124, 0.535),
(0.19, 0.535),
(0.19, 0.546),
(0.124, 0.546),
)
assert geometry.get_bbox_as_polygon(polygon_c) == (
(0.205, 0.407),
(0.381, 0.407),
(0.381, 0.43),
(0.205, 0.43),
)