Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions mindee/documents/custom/line_items.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Dict, List, Sequence

from mindee.documents.custom.custom_v1_fields import ListField, ListFieldValue
from mindee.geometry import (
Quadrilateral,
get_bounding_box,
get_min_max_y,
is_point_in_y,
merge_polygons,
)


def _array_product(array: Sequence[float]) -> float:
"""
Get the product of a sequence of floats.

:array: List of floats
"""
product = 1.0
for k in array:
product = product * k
return product


def _find_best_anchor(anchors: Sequence[str], fields: Dict[str, ListField]) -> str:
"""
Find the anchor with the most rows, in the order specified by `anchors`.

Anchor will be the name of the field.
"""
anchor = ""
anchor_rows = 0
for field in anchors:
values = fields[field].values
if len(values) > anchor_rows:
anchor_rows = len(values)
anchor = field
return anchor


def _get_empty_field() -> ListFieldValue:
"""Return sample field with empty values."""
return ListFieldValue({"content": "", "polygon": [], "confidence": 0.0})


class Line:
"""Represent a single line."""

row_number: int
fields: Dict[str, ListFieldValue]
bounding_box: Quadrilateral


def get_line_items(
anchors: Sequence[str], columns: Sequence[str], fields: Dict[str, ListField]
) -> List[Line]:
"""
Reconstruct line items from fields.

:anchors: Possible fields to use as an anchor
:columns: All fields which are columns
:fields: List of field names to reconstruct table with
"""
line_items: List[Line] = []
anchor = _find_best_anchor(anchors, fields)
if not anchor:
print(Warning("Could not find an anchor!"))
return line_items

# Loop on anchor items and create an item for each anchor item.
# This will create all rows with just the anchor column value.
for item in fields[anchor].values:
line_item = Line()
line_item.fields = {f: _get_empty_field() for f in columns}
line_item.fields[anchor] = item
line_items.append(line_item)

# Loop on all created rows
for idx, line in enumerate(line_items):
# Compute sliding window between anchor item and the next
min_y, _ = get_min_max_y(line.fields[anchor].polygon)
if idx != len(line_items) - 1:
max_y, _ = get_min_max_y(line_items[idx + 1].fields[anchor].polygon)
else:
max_y = 1.0 # bottom of page
# Get candidates of each field included in sliding window and add it in line item
for field in columns:
field_words = [
word
for word in fields[field].values
if is_point_in_y(word.polygon.centroid, min_y, max_y)
]
line.fields[field].content = " ".join([v.content for v in field_words])
try:
line.fields[field].polygon = merge_polygons(
[v.polygon for v in field_words]
)
except ValueError:
pass
line.fields[field].confidence = _array_product(
[v.confidence for v in field_words]
)
all_polygons = [line.fields[anchor].polygon]
for field in columns:
try:
all_polygons.append(line.fields[field].polygon)
except IndexError:
pass
line.bounding_box = get_bounding_box(merge_polygons(all_polygons))
line.row_number = idx
return line_items
26 changes: 19 additions & 7 deletions mindee/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class Quadrilateral(NamedTuple):
bottom_left: Point
"""Bottom left Point"""

@property
def centroid(self) -> Point:
"""The central point (centroid) of the quadrilateral."""
return get_centroid(self)


class BBox(NamedTuple):
"""Contains exactly 4 coordinates."""
Expand Down Expand Up @@ -73,6 +78,11 @@ class Polygon(list):
Inherits from base class ``list`` so is compatible with type ``Points``.
"""

@property
def centroid(self) -> Point:
"""The central point (centroid) of the polygon."""
return get_centroid(self)


Points = Sequence[Point]

Expand Down Expand Up @@ -132,9 +142,9 @@ def get_bbox(points: Points) -> BBox:
return BBox(x_min, y_min, x_max, y_max)


def get_bounding_box_for_polygons(vertices: Sequence[Polygon]) -> Quadrilateral:
def merge_polygons(vertices: Sequence[Polygon]) -> Polygon:
"""
Given a sequence of polygons, calculate a bounding box that encompasses all polygons.
Given a sequence of polygons, calculate a polygon box that encompasses all polygons.

:param vertices: List of polygons
:return: A bounding box that encompasses all polygons
Expand All @@ -143,11 +153,13 @@ def get_bounding_box_for_polygons(vertices: Sequence[Polygon]) -> Quadrilateral:
y_max = max(y for v in vertices for _, y in v)
x_min = min(x for v in vertices for x, _ in v)
x_max = max(x for v in vertices for x, _ in v)
return Quadrilateral(
Point(x_min, y_min),
Point(x_max, y_min),
Point(x_max, y_max),
Point(x_min, y_max),
return Polygon(
[
Point(x_min, y_min),
Point(x_max, y_min),
Point(x_max, y_max),
Point(x_min, y_max),
]
)


Expand Down
31 changes: 31 additions & 0 deletions tests/documents/test_custom_v1_line_items.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import json

from mindee.documents import CustomV1
from mindee.documents.custom.line_items import get_line_items
from tests import CUSTOM_DATA_DIR


def test_single_table_01():
json_data_path = f"{CUSTOM_DATA_DIR}/response_v1/line_items/single_table_01.json"
json_data = json.load(open(json_data_path, "r"))
doc = CustomV1(
"field_test", api_prediction=json_data["document"]["inference"], page_n=None
)
anchors = ["beneficiary_birth_date"]
columns = [
"beneficiary_name",
"beneficiary_birth_date",
"beneficiary_rank",
"beneficiary_number",
]
line_items = get_line_items(anchors, columns, doc.fields)
assert len(line_items) == 3
assert line_items[0].fields["beneficiary_name"].content == "JAMES BOND 007"
assert line_items[0].fields["beneficiary_birth_date"].content == "1970-11-11"
assert line_items[0].row_number == 0
assert line_items[1].fields["beneficiary_name"].content == "HARRY POTTER"
assert line_items[1].fields["beneficiary_birth_date"].content == "2010-07-18"
assert line_items[1].row_number == 1
assert line_items[2].fields["beneficiary_name"].content == "DRAGO MALFOY"
assert line_items[2].fields["beneficiary_birth_date"].content == "2015-07-05"
assert line_items[2].row_number == 2
12 changes: 11 additions & 1 deletion tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,19 @@ def test_get_centroid(rectangle_a):


def test_bounding_box_several_polygons(rectangle_b, quadrangle_a):
assert geometry.get_bounding_box_for_polygons((rectangle_b, quadrangle_a)) == (
merged = geometry.merge_polygons((rectangle_b, quadrangle_a))
assert geometry.get_bounding_box(merged) == (
(0.124, 0.407),
(0.381, 0.407),
(0.381, 0.546),
(0.124, 0.546),
)


def test_polygon_merge(rectangle_b, quadrangle_a):
assert geometry.merge_polygons((rectangle_b, quadrangle_a)) == [
(0.124, 0.407),
(0.381, 0.407),
(0.381, 0.546),
(0.124, 0.546),
]