diff --git a/mindee/fields/ocr.py b/mindee/fields/ocr.py new file mode 100644 index 00000000..2f8a7ac8 --- /dev/null +++ b/mindee/fields/ocr.py @@ -0,0 +1,132 @@ +from typing import List, Optional + +from mindee.documents.base import TypeApiPrediction +from mindee.geometry import ( + Polygon, + get_centroid, + get_min_max_x, + get_min_max_y, + is_point_in_polygon_y, + polygon_from_prediction, +) + + +class Word: + """A single word.""" + + confidence: float + polygon: Polygon + text: str + + def __init__(self, prediction: TypeApiPrediction): + self.confidence = prediction["confidence"] + self.polygon = polygon_from_prediction(prediction["polygon"]) + self.text = prediction["text"] + + def __str__(self) -> str: + return self.text + + +class OcrLine(List[Word]): + """A list of words which are on the same line.""" + + def sort_on_x(self) -> None: + """Sort the words on the line from left to right.""" + self.sort(key=lambda item: get_min_max_x(item.polygon).min) + + def __str__(self) -> str: + return " ".join([word.text for word in self]) + + +class OcrPage: + """OCR extraction for a single page.""" + + all_words: List[Word] + """All the words on the page, in semi-random order.""" + _lines: List[OcrLine] + + def __init__(self, prediction: TypeApiPrediction): + self.all_words = [ + Word(word_prediction) for word_prediction in prediction["all_words"] + ] + self._lines = [] + + @staticmethod + def _are_words_on_same_line(current_word: Word, next_word: Word) -> bool: + """Determine if two words are on the same line.""" + current_in_next = is_point_in_polygon_y( + get_centroid(current_word.polygon), + next_word.polygon, + ) + next_in_current = is_point_in_polygon_y( + get_centroid(next_word.polygon), current_word.polygon + ) + # We need to check both to eliminate any issues due to word order. + return current_in_next or next_in_current + + def _to_lines(self) -> List[OcrLine]: + """Order all the words on the page into lines.""" + current: Optional[Word] = None + indexes: List[int] = [] + lines: List[OcrLine] = [] + + # make sure words are sorted from top to bottom + self.all_words.sort( + key=lambda item: get_min_max_y(item.polygon).min, reverse=False + ) + + for _ in self.all_words: + line: OcrLine = OcrLine() + for idx, word in enumerate(self.all_words): + if idx in indexes: + continue + if current is None: + current = word + indexes.append(idx) + line = OcrLine() + line.append(word) + else: + if self._are_words_on_same_line(current, word): + line.append(word) + indexes.append(idx) + current = None + if line: + line.sort_on_x() + lines.append(line) + return lines + + @property + def all_lines(self) -> List[OcrLine]: + """All the words on the page, ordered in lines.""" + if not self._lines: + self._lines = self._to_lines() + return self._lines + + def __str__(self) -> str: + return "\n".join(str(line) for line in self.all_lines) + "\n" + + +class MVisionV1: + """Mindee Vision V1.""" + + pages: List[OcrPage] + + def __init__(self, prediction: TypeApiPrediction): + self.pages = [ + OcrPage(page_prediction) for page_prediction in prediction["pages"] + ] + + def __str__(self) -> str: + return "\n".join([str(page) for page in self.pages]) + + +class Ocr: + """OCR extraction from the entire document.""" + + mvision_v1: MVisionV1 + + def __init__(self, prediction: TypeApiPrediction): + self.mvision_v1 = MVisionV1(prediction["mvision-v1"]) + + def __str__(self) -> str: + return str(self.mvision_v1) diff --git a/mindee/response.py b/mindee/response.py index 39886419..422dd1a1 100644 --- a/mindee/response.py +++ b/mindee/response.py @@ -3,8 +3,9 @@ from enum import Enum from typing import Any, Dict, Generic, List, Optional, Union -from mindee.documents.base import TypeDocument +from mindee.documents.base import TypeApiPrediction, TypeDocument from mindee.documents.config import DocumentConfig +from mindee.fields.ocr import Ocr from mindee.input.sources import LocalInputSource, UrlInputSource from mindee.logger import logger @@ -75,7 +76,7 @@ class PredictResponse(Generic[TypeDocument]): This is a generic class, so certain class properties depend on the document type. """ - http_response: Dict[str, Any] + http_response: TypeApiPrediction """Raw HTTP response JSON""" document_type: Optional[str] = None """Document type""" @@ -89,6 +90,8 @@ class PredictResponse(Generic[TypeDocument]): """An instance of the ``Document`` class, according to the type given.""" pages: List[TypeDocument] """A list of instances of the ``Document`` class, according to the type given.""" + ocr: Optional[Ocr] + """Full OCR operation results.""" def __init__( self, @@ -116,8 +119,17 @@ def __init__( if not response_ok: self.document = None + self.ocr = None else: self._load_response(doc_config, input_source) + self.ocr = self._load_ocr(http_response) + + @staticmethod + def _load_ocr(http_response: TypeApiPrediction): + ocr_prediction = http_response["document"].get("ocr", None) + if not ocr_prediction or not ocr_prediction.get("mvision-v1", None): + return None + return Ocr(ocr_prediction) def _load_response( self, diff --git a/tests/data b/tests/data index 7c439f63..b4f99571 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 7c439f6329029a0108eb3158bc1b1ccd7eddd541 +Subproject commit b4f99571aa7969b627cdf8ba630a39609ff11e7e diff --git a/tests/fields/test_ocr.py b/tests/fields/test_ocr.py new file mode 100644 index 00000000..0d6649a0 --- /dev/null +++ b/tests/fields/test_ocr.py @@ -0,0 +1,12 @@ +import json + +from mindee.fields.ocr import Ocr + + +def test_response(): + json_data = json.load(open("./tests/data/ocr/complete_with_ocr.json")) + with open("./tests/data/ocr/ocr.txt") as file_handle: + expected_text = file_handle.read() + ocr = Ocr(json_data["document"]["ocr"]) + assert str(ocr) == expected_text + assert str(ocr.mvision_v1.pages[0]) == expected_text