diff --git a/mindee/documents/custom/line_items.py b/mindee/documents/custom/line_items.py new file mode 100644 index 00000000..3989614e --- /dev/null +++ b/mindee/documents/custom/line_items.py @@ -0,0 +1,111 @@ +from typing import Dict, List, Sequence + +from mindee.documents.custom.custom_v1_fields import ListField, ListFieldValue +from mindee.geometry import ( + Quadrilateral, + get_bounding_box, + get_min_max_y, + is_point_in_y, + merge_polygons, +) + + +def _array_product(array: Sequence[float]) -> float: + """ + Get the product of a sequence of floats. + + :array: List of floats + """ + product = 1.0 + for k in array: + product = product * k + return product + + +def _find_best_anchor(anchors: Sequence[str], fields: Dict[str, ListField]) -> str: + """ + Find the anchor with the most rows, in the order specified by `anchors`. + + Anchor will be the name of the field. + """ + anchor = "" + anchor_rows = 0 + for field in anchors: + values = fields[field].values + if len(values) > anchor_rows: + anchor_rows = len(values) + anchor = field + return anchor + + +def _get_empty_field() -> ListFieldValue: + """Return sample field with empty values.""" + return ListFieldValue({"content": "", "polygon": [], "confidence": 0.0}) + + +class Line: + """Represent a single line.""" + + row_number: int + fields: Dict[str, ListFieldValue] + bounding_box: Quadrilateral + + +def get_line_items( + anchors: Sequence[str], columns: Sequence[str], fields: Dict[str, ListField] +) -> List[Line]: + """ + Reconstruct line items from fields. + + :anchors: Possible fields to use as an anchor + :columns: All fields which are columns + :fields: List of field names to reconstruct table with + """ + line_items: List[Line] = [] + anchor = _find_best_anchor(anchors, fields) + if not anchor: + print(Warning("Could not find an anchor!")) + return line_items + + # Loop on anchor items and create an item for each anchor item. + # This will create all rows with just the anchor column value. + for item in fields[anchor].values: + line_item = Line() + line_item.fields = {f: _get_empty_field() for f in columns} + line_item.fields[anchor] = item + line_items.append(line_item) + + # Loop on all created rows + for idx, line in enumerate(line_items): + # Compute sliding window between anchor item and the next + min_y, _ = get_min_max_y(line.fields[anchor].polygon) + if idx != len(line_items) - 1: + max_y, _ = get_min_max_y(line_items[idx + 1].fields[anchor].polygon) + else: + max_y = 1.0 # bottom of page + # Get candidates of each field included in sliding window and add it in line item + for field in columns: + field_words = [ + word + for word in fields[field].values + if is_point_in_y(word.polygon.centroid, min_y, max_y) + ] + line.fields[field].content = " ".join([v.content for v in field_words]) + try: + line.fields[field].polygon = merge_polygons( + [v.polygon for v in field_words] + ) + except ValueError: + pass + line.fields[field].confidence = _array_product( + [v.confidence for v in field_words] + ) + all_polygons = [line.fields[anchor].polygon] + for field in columns: + try: + all_polygons.append(line.fields[field].polygon) + except IndexError: + pass + line.bounding_box = get_bounding_box(merge_polygons(all_polygons)) + line.row_number = idx + return line_items diff --git a/mindee/geometry.py b/mindee/geometry.py index b1b16a85..7f76149b 100644 --- a/mindee/geometry.py +++ b/mindee/geometry.py @@ -28,6 +28,11 @@ class Quadrilateral(NamedTuple): bottom_left: Point """Bottom left Point""" + @property + def centroid(self) -> Point: + """The central point (centroid) of the quadrilateral.""" + return get_centroid(self) + class BBox(NamedTuple): """Contains exactly 4 coordinates.""" @@ -73,6 +78,11 @@ class Polygon(list): Inherits from base class ``list`` so is compatible with type ``Points``. """ + @property + def centroid(self) -> Point: + """The central point (centroid) of the polygon.""" + return get_centroid(self) + Points = Sequence[Point] @@ -132,9 +142,9 @@ def get_bbox(points: Points) -> BBox: return BBox(x_min, y_min, x_max, y_max) -def get_bounding_box_for_polygons(vertices: Sequence[Polygon]) -> Quadrilateral: +def merge_polygons(vertices: Sequence[Polygon]) -> Polygon: """ - Given a sequence of polygons, calculate a bounding box that encompasses all polygons. + Given a sequence of polygons, calculate a polygon box that encompasses all polygons. :param vertices: List of polygons :return: A bounding box that encompasses all polygons @@ -143,11 +153,13 @@ def get_bounding_box_for_polygons(vertices: Sequence[Polygon]) -> Quadrilateral: y_max = max(y for v in vertices for _, y in v) x_min = min(x for v in vertices for x, _ in v) x_max = max(x for v in vertices for x, _ in v) - return Quadrilateral( - Point(x_min, y_min), - Point(x_max, y_min), - Point(x_max, y_max), - Point(x_min, y_max), + return Polygon( + [ + Point(x_min, y_min), + Point(x_max, y_min), + Point(x_max, y_max), + Point(x_min, y_max), + ] ) diff --git a/tests/documents/test_custom_v1_line_items.py b/tests/documents/test_custom_v1_line_items.py new file mode 100644 index 00000000..d592ffe8 --- /dev/null +++ b/tests/documents/test_custom_v1_line_items.py @@ -0,0 +1,31 @@ +import json + +from mindee.documents import CustomV1 +from mindee.documents.custom.line_items import get_line_items +from tests import CUSTOM_DATA_DIR + + +def test_single_table_01(): + json_data_path = f"{CUSTOM_DATA_DIR}/response_v1/line_items/single_table_01.json" + json_data = json.load(open(json_data_path, "r")) + doc = CustomV1( + "field_test", api_prediction=json_data["document"]["inference"], page_n=None + ) + anchors = ["beneficiary_birth_date"] + columns = [ + "beneficiary_name", + "beneficiary_birth_date", + "beneficiary_rank", + "beneficiary_number", + ] + line_items = get_line_items(anchors, columns, doc.fields) + assert len(line_items) == 3 + assert line_items[0].fields["beneficiary_name"].content == "JAMES BOND 007" + assert line_items[0].fields["beneficiary_birth_date"].content == "1970-11-11" + assert line_items[0].row_number == 0 + assert line_items[1].fields["beneficiary_name"].content == "HARRY POTTER" + assert line_items[1].fields["beneficiary_birth_date"].content == "2010-07-18" + assert line_items[1].row_number == 1 + assert line_items[2].fields["beneficiary_name"].content == "DRAGO MALFOY" + assert line_items[2].fields["beneficiary_birth_date"].content == "2015-07-05" + assert line_items[2].row_number == 2 diff --git a/tests/test_geometry.py b/tests/test_geometry.py index b6aee4ca..4c98d19b 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -87,9 +87,19 @@ def test_get_centroid(rectangle_a): def test_bounding_box_several_polygons(rectangle_b, quadrangle_a): - assert geometry.get_bounding_box_for_polygons((rectangle_b, quadrangle_a)) == ( + merged = geometry.merge_polygons((rectangle_b, quadrangle_a)) + assert geometry.get_bounding_box(merged) == ( (0.124, 0.407), (0.381, 0.407), (0.381, 0.546), (0.124, 0.546), ) + + +def test_polygon_merge(rectangle_b, quadrangle_a): + assert geometry.merge_polygons((rectangle_b, quadrangle_a)) == [ + (0.124, 0.407), + (0.381, 0.407), + (0.381, 0.546), + (0.124, 0.546), + ]