Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions mindee/fields/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import List, Optional

from mindee.documents.base import TypeApiPrediction
from mindee.geometry import (
Polygon,
get_centroid,
get_min_max_x,
get_min_max_y,
is_point_in_polygon_y,
polygon_from_prediction,
)


class Word:
"""A single word."""

confidence: float
polygon: Polygon
text: str

def __init__(self, prediction: TypeApiPrediction):
self.confidence = prediction["confidence"]
self.polygon = polygon_from_prediction(prediction["polygon"])
self.text = prediction["text"]

def __str__(self) -> str:
return self.text


class OcrLine(List[Word]):
"""A list of words which are on the same line."""

def sort_on_x(self) -> None:
"""Sort the words on the line from left to right."""
self.sort(key=lambda item: get_min_max_x(item.polygon).min)

def __str__(self) -> str:
return " ".join([word.text for word in self])


class OcrPage:
"""OCR extraction for a single page."""

all_words: List[Word]
"""All the words on the page, in semi-random order."""
_lines: List[OcrLine]

def __init__(self, prediction: TypeApiPrediction):
self.all_words = [
Word(word_prediction) for word_prediction in prediction["all_words"]
]
self._lines = []

@staticmethod
def _are_words_on_same_line(current_word: Word, next_word: Word) -> bool:
"""Determine if two words are on the same line."""
current_in_next = is_point_in_polygon_y(
get_centroid(current_word.polygon),
next_word.polygon,
)
next_in_current = is_point_in_polygon_y(
get_centroid(next_word.polygon), current_word.polygon
)
# We need to check both to eliminate any issues due to word order.
return current_in_next or next_in_current

def _to_lines(self) -> List[OcrLine]:
"""Order all the words on the page into lines."""
current: Optional[Word] = None
indexes: List[int] = []
lines: List[OcrLine] = []

# make sure words are sorted from top to bottom
self.all_words.sort(
key=lambda item: get_min_max_y(item.polygon).min, reverse=False
)

for _ in self.all_words:
line: OcrLine = OcrLine()
for idx, word in enumerate(self.all_words):
if idx in indexes:
continue
if current is None:
current = word
indexes.append(idx)
line = OcrLine()
line.append(word)
else:
if self._are_words_on_same_line(current, word):
line.append(word)
indexes.append(idx)
current = None
if line:
line.sort_on_x()
lines.append(line)
return lines

@property
def all_lines(self) -> List[OcrLine]:
"""All the words on the page, ordered in lines."""
if not self._lines:
self._lines = self._to_lines()
return self._lines

def __str__(self) -> str:
return "\n".join(str(line) for line in self.all_lines) + "\n"


class MVisionV1:
"""Mindee Vision V1."""

pages: List[OcrPage]

def __init__(self, prediction: TypeApiPrediction):
self.pages = [
OcrPage(page_prediction) for page_prediction in prediction["pages"]
]

def __str__(self) -> str:
return "\n".join([str(page) for page in self.pages])


class Ocr:
"""OCR extraction from the entire document."""

mvision_v1: MVisionV1

def __init__(self, prediction: TypeApiPrediction):
self.mvision_v1 = MVisionV1(prediction["mvision-v1"])

def __str__(self) -> str:
return str(self.mvision_v1)
16 changes: 14 additions & 2 deletions mindee/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, Union

from mindee.documents.base import TypeDocument
from mindee.documents.base import TypeApiPrediction, TypeDocument
from mindee.documents.config import DocumentConfig
from mindee.fields.ocr import Ocr
from mindee.input.sources import LocalInputSource, UrlInputSource
from mindee.logger import logger

Expand Down Expand Up @@ -75,7 +76,7 @@ class PredictResponse(Generic[TypeDocument]):
This is a generic class, so certain class properties depend on the document type.
"""

http_response: Dict[str, Any]
http_response: TypeApiPrediction
"""Raw HTTP response JSON"""
document_type: Optional[str] = None
"""Document type"""
Expand All @@ -89,6 +90,8 @@ class PredictResponse(Generic[TypeDocument]):
"""An instance of the ``Document`` class, according to the type given."""
pages: List[TypeDocument]
"""A list of instances of the ``Document`` class, according to the type given."""
ocr: Optional[Ocr]
"""Full OCR operation results."""

def __init__(
self,
Expand Down Expand Up @@ -116,8 +119,17 @@ def __init__(

if not response_ok:
self.document = None
self.ocr = None
else:
self._load_response(doc_config, input_source)
self.ocr = self._load_ocr(http_response)

@staticmethod
def _load_ocr(http_response: TypeApiPrediction):
ocr_prediction = http_response["document"].get("ocr", None)
if not ocr_prediction or not ocr_prediction.get("mvision-v1", None):
return None
return Ocr(ocr_prediction)

def _load_response(
self,
Expand Down
12 changes: 12 additions & 0 deletions tests/fields/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import json

from mindee.fields.ocr import Ocr


def test_response():
json_data = json.load(open("./tests/data/ocr/complete_with_ocr.json"))
with open("./tests/data/ocr/ocr.txt") as file_handle:
expected_text = file_handle.read()
ocr = Ocr(json_data["document"]["ocr"])
assert str(ocr) == expected_text
assert str(ocr.mvision_v1.pages[0]) == expected_text