Skip to content

Commit f69dd4d

Browse files
committed
✨ add support for handling OCR return
1 parent c5316f8 commit f69dd4d

File tree

4 files changed

+159
-3
lines changed

4 files changed

+159
-3
lines changed

mindee/fields/ocr.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import List, Optional
2+
3+
from mindee.documents.base import TypeApiPrediction
4+
from mindee.geometry import (
5+
Polygon,
6+
get_centroid,
7+
get_min_max_x,
8+
get_min_max_y,
9+
is_point_in_polygon_y,
10+
polygon_from_prediction,
11+
)
12+
13+
14+
class Word:
15+
"""A single word."""
16+
17+
confidence: float
18+
polygon: Polygon
19+
text: str
20+
21+
def __init__(self, prediction: TypeApiPrediction):
22+
self.confidence = prediction["confidence"]
23+
self.polygon = polygon_from_prediction(prediction["polygon"])
24+
self.text = prediction["text"]
25+
26+
def __str__(self) -> str:
27+
return self.text
28+
29+
30+
class OcrLine(List[Word]):
31+
"""A list of words which are on the same line."""
32+
33+
def sort_on_x(self) -> None:
34+
"""Sort the words on the line from left to right."""
35+
self.sort(key=lambda item: get_min_max_x(item.polygon).min)
36+
37+
def __str__(self) -> str:
38+
return " ".join([word.text for word in self])
39+
40+
41+
class OcrPage:
42+
"""OCR extraction for a single page."""
43+
44+
all_words: List[Word]
45+
"""All the words on the page, in semi-random order."""
46+
_lines: List[OcrLine]
47+
48+
def __init__(self, prediction: TypeApiPrediction):
49+
self.all_words = [
50+
Word(word_prediction) for word_prediction in prediction["all_words"]
51+
]
52+
self._lines = []
53+
54+
@staticmethod
55+
def _are_words_on_same_line(current_word: Word, next_word: Word) -> bool:
56+
"""Determine if two words are on the same line."""
57+
current_in_next = is_point_in_polygon_y(
58+
get_centroid(current_word.polygon),
59+
next_word.polygon,
60+
)
61+
next_in_current = is_point_in_polygon_y(
62+
get_centroid(next_word.polygon), current_word.polygon
63+
)
64+
# We need to check both to eliminate any issues due to word order.
65+
return current_in_next or next_in_current
66+
67+
def _to_lines(self) -> List[OcrLine]:
68+
"""Order all the words on the page into lines."""
69+
current: Optional[Word] = None
70+
indexes: List[int] = []
71+
lines: List[OcrLine] = []
72+
73+
# make sure words are sorted from top to bottom
74+
self.all_words.sort(
75+
key=lambda item: get_min_max_y(item.polygon).min, reverse=False
76+
)
77+
78+
for _ in self.all_words:
79+
line: OcrLine = OcrLine()
80+
for idx, word in enumerate(self.all_words):
81+
if idx in indexes:
82+
continue
83+
if current is None:
84+
current = word
85+
indexes.append(idx)
86+
line = OcrLine()
87+
line.append(word)
88+
else:
89+
if self._are_words_on_same_line(current, word):
90+
line.append(word)
91+
indexes.append(idx)
92+
current = None
93+
if line:
94+
line.sort_on_x()
95+
lines.append(line)
96+
return lines
97+
98+
@property
99+
def all_lines(self) -> List[OcrLine]:
100+
"""All the words on the page, ordered in lines."""
101+
if not self._lines:
102+
self._lines = self._to_lines()
103+
return self._lines
104+
105+
def __str__(self) -> str:
106+
return "\n".join(str(line) for line in self.all_lines) + "\n"
107+
108+
109+
class MVisionV1:
110+
"""Mindee Vision V1."""
111+
112+
pages: List[OcrPage]
113+
114+
def __init__(self, prediction: TypeApiPrediction):
115+
self.pages = [
116+
OcrPage(page_prediction) for page_prediction in prediction["pages"]
117+
]
118+
119+
def __str__(self) -> str:
120+
return "\n".join([str(page) for page in self.pages])
121+
122+
123+
class Ocr:
124+
"""OCR extraction from the entire document."""
125+
126+
mvision_v1: MVisionV1
127+
128+
def __init__(self, prediction: TypeApiPrediction):
129+
self.mvision_v1 = MVisionV1(prediction["mvision-v1"])
130+
131+
def __str__(self) -> str:
132+
return str(self.mvision_v1)

mindee/response.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from enum import Enum
44
from typing import Any, Dict, Generic, List, Optional, Union
55

6-
from mindee.documents.base import TypeDocument
6+
from mindee.documents.base import TypeApiPrediction, TypeDocument
77
from mindee.documents.config import DocumentConfig
8+
from mindee.fields.ocr import Ocr
89
from mindee.input.sources import LocalInputSource, UrlInputSource
910
from mindee.logger import logger
1011

@@ -75,7 +76,7 @@ class PredictResponse(Generic[TypeDocument]):
7576
This is a generic class, so certain class properties depend on the document type.
7677
"""
7778

78-
http_response: Dict[str, Any]
79+
http_response: TypeApiPrediction
7980
"""Raw HTTP response JSON"""
8081
document_type: Optional[str] = None
8182
"""Document type"""
@@ -89,6 +90,8 @@ class PredictResponse(Generic[TypeDocument]):
8990
"""An instance of the ``Document`` class, according to the type given."""
9091
pages: List[TypeDocument]
9192
"""A list of instances of the ``Document`` class, according to the type given."""
93+
ocr: Optional[Ocr]
94+
"""Full OCR operation results."""
9295

9396
def __init__(
9497
self,
@@ -116,8 +119,17 @@ def __init__(
116119

117120
if not response_ok:
118121
self.document = None
122+
self.ocr = None
119123
else:
120124
self._load_response(doc_config, input_source)
125+
self.ocr = self._load_ocr(http_response)
126+
127+
@staticmethod
128+
def _load_ocr(http_response: TypeApiPrediction):
129+
ocr_prediction = http_response["document"].get("ocr", None)
130+
if not ocr_prediction or not ocr_prediction.get("mvision-v1", None):
131+
return None
132+
return Ocr(ocr_prediction)
121133

122134
def _load_response(
123135
self,

tests/fields/test_ocr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import json
2+
3+
from mindee.fields.ocr import Ocr
4+
5+
6+
def test_response():
7+
json_data = json.load(open("./tests/data/ocr/complete_with_ocr.json"))
8+
with open("./tests/data/ocr/ocr.txt") as file_handle:
9+
expected_text = file_handle.read()
10+
ocr = Ocr(json_data["document"]["ocr"])
11+
assert str(ocr) == expected_text
12+
assert str(ocr.mvision_v1.pages[0]) == expected_text

0 commit comments

Comments
 (0)