diff --git a/mindee/fields/amount.py b/mindee/fields/amount.py index 7e19ddf1..c03fd535 100644 --- a/mindee/fields/amount.py +++ b/mindee/fields/amount.py @@ -19,7 +19,7 @@ def __init__( :param amount_prediction: Amount prediction object from HTTP response :param value_key: Key to use in the amount_prediction dict :param reconstructed: Bool for reconstructed object (not extracted in the API) - :param page_n: Page number for multi pages pdf + :param page_n: Page number for multi-page PDF """ super().__init__( amount_prediction, diff --git a/mindee/fields/base.py b/mindee/fields/base.py index 455c1052..6c0e8e5b 100644 --- a/mindee/fields/base.py +++ b/mindee/fields/base.py @@ -1,17 +1,23 @@ from typing import Any, Dict, List, Optional, TypeVar +from mindee.geometry import Polygon, get_bbox_as_polygon + +TypePrediction = Dict[str, Any] + class Field: value: Optional[Any] = None """Raw field value""" confidence: float = 0.0 """Confidence score""" - bbox: List[List[float]] = [] + bbox: Polygon = [] """Bounding box coordinates containing the field""" + polygon: Polygon = [] + """coordinates of the field""" def __init__( self, - abstract_prediction: Dict[str, Any], + abstract_prediction: TypePrediction, value_key: str = "value", reconstructed: bool = False, page_n: Optional[int] = None, @@ -25,26 +31,28 @@ def __init__( :param page_n: Page number for multi-page PDF """ self.page_n = page_n + self.reconstructed = reconstructed if ( value_key not in abstract_prediction or abstract_prediction[value_key] == "N/A" ): - self.value = None - self.confidence = 0.0 - self.bbox = [] - else: - self.value = abstract_prediction[value_key] - try: - self.confidence = float(abstract_prediction["confidence"]) - except (KeyError, TypeError): - self.confidence = 0.0 - try: - self.bbox = abstract_prediction["polygon"] - except KeyError: - self.bbox = [] + return - self.reconstructed = reconstructed + self.value = abstract_prediction[value_key] + try: + self.confidence = float(abstract_prediction["confidence"]) + except (KeyError, TypeError): + pass + self._set_bbox(abstract_prediction) + + def _set_bbox(self, abstract_prediction: TypePrediction) -> None: + try: + self.polygon = abstract_prediction["polygon"] + except KeyError: + pass + if self.polygon: + self.bbox = get_bbox_as_polygon(self.polygon) def __eq__(self, other: Any) -> bool: if not isinstance(other, Field): diff --git a/mindee/fields/date.py b/mindee/fields/date.py index 7accda38..0558c213 100644 --- a/mindee/fields/date.py +++ b/mindee/fields/date.py @@ -28,7 +28,7 @@ def __init__( :param date_prediction: Date prediction object from HTTP response :param value_key: Key to use in the date_prediction dict :param reconstructed: Bool for reconstructed object (not extracted in the API) - :param page_n: Page number for multi pages pdf + :param page_n: Page number for multi-page PDF """ super().__init__( date_prediction, diff --git a/mindee/fields/locale.py b/mindee/fields/locale.py index 11584a3d..dc102bac 100644 --- a/mindee/fields/locale.py +++ b/mindee/fields/locale.py @@ -24,7 +24,7 @@ def __init__( :param locale_prediction: Locale prediction object from HTTP response :param value_key: Key to use in the locale_prediction dict :param reconstructed: Bool for reconstructed object (not extracted in the API) - :param page_n: Page number for multi pages pdf + :param page_n: Page number for multi-page PDF """ super().__init__( locale_prediction, diff --git a/mindee/fields/orientation.py b/mindee/fields/orientation.py index 6b4d345e..0d0ea127 100644 --- a/mindee/fields/orientation.py +++ b/mindee/fields/orientation.py @@ -18,7 +18,7 @@ def __init__( :param orientation_prediction: Orientation prediction object from HTTP response :param value_key: Key to use in the orientation_prediction dict :param reconstructed: Bool for reconstructed object (not extracted in the API) - :param page_n: Page number for multi pages pdf + :param page_n: Page number for multi-page PDF """ super().__init__( orientation_prediction, diff --git a/mindee/fields/payment_details.py b/mindee/fields/payment_details.py index c3d5d7d0..dacc5a3c 100644 --- a/mindee/fields/payment_details.py +++ b/mindee/fields/payment_details.py @@ -36,7 +36,7 @@ def __init__( payment_details_prediction dict :param swift_key: Key to use for getting the SWIFT in the payment_details_prediction dict :param reconstructed: Bool for reconstructed object (not extracted in the API) - :param page_n: Page number for multi pages pdf + :param page_n: Page number for multi-page PDF """ super().__init__( payment_details_prediction, diff --git a/mindee/geometry.py b/mindee/geometry.py new file mode 100644 index 00000000..e5323dad --- /dev/null +++ b/mindee/geometry.py @@ -0,0 +1,33 @@ +"""Pure Python geometry functions for working with polygons.""" + +from typing import Sequence, Tuple + +Point = Tuple[float, float] +Polygon = Sequence[Point] +BoundingBox = Tuple[float, float, float, float] +Quadrilateral = Tuple[Point, Point, Point, Point] + + +def get_bbox_as_polygon(polygon: Polygon) -> Quadrilateral: + """ + Given a sequence of points, calculate a polygon that encompasses all points. + + :param polygon: Sequence of ``Point`` + :return: Quadrilateral + """ + x_min, y_min, x_max, y_max = get_bbox(polygon) + return (x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max) + + +def get_bbox(polygon: Polygon) -> BoundingBox: + """ + Given a list of points, calculate a bounding box that encompasses all points. + + :param polygon: Sequence of ``Point`` + :return: BoundingBox + """ + y_min = min(v[1] for v in polygon) + y_max = max(v[1] for v in polygon) + x_min = min(v[0] for v in polygon) + x_max = max(v[0] for v in polygon) + return x_min, y_min, x_max, y_max diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 00000000..018cfc67 --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,48 @@ +import pytest + +from mindee import geometry + + +@pytest.fixture +def polygon_a(): + """90° rectangle, overlaps polygon_b""" + return [(0.123, 0.53), (0.175, 0.53), (0.175, 0.546), (0.123, 0.546)] + + +@pytest.fixture +def polygon_b(): + """90° rectangle, overlaps polygon_a""" + return [(0.124, 0.535), (0.190, 0.535), (0.190, 0.546), (0.124, 0.546)] + + +@pytest.fixture +def polygon_c(): + """not 90° rectangle, doesn't overlap any polygons""" + return [(0.205, 0.407), (0.379, 0.407), (0.381, 0.43), (0.207, 0.43)] + + +def test_bbox(polygon_a, polygon_b, polygon_c): + assert geometry.get_bbox(polygon_a) == (0.123, 0.53, 0.175, 0.546) + assert geometry.get_bbox(polygon_b) == (0.124, 0.535, 0.19, 0.546) + assert geometry.get_bbox(polygon_c) == (0.205, 0.407, 0.381, 0.43) + + +def test_bbox_polygon(polygon_a, polygon_b, polygon_c): + assert geometry.get_bbox_as_polygon(polygon_a) == ( + (0.123, 0.53), + (0.175, 0.53), + (0.175, 0.546), + (0.123, 0.546), + ) + assert geometry.get_bbox_as_polygon(polygon_b) == ( + (0.124, 0.535), + (0.19, 0.535), + (0.19, 0.546), + (0.124, 0.546), + ) + assert geometry.get_bbox_as_polygon(polygon_c) == ( + (0.205, 0.407), + (0.381, 0.407), + (0.381, 0.43), + (0.205, 0.43), + )