Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 17 additions & 50 deletions mindee/client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from mindee.error.mindee_http_error_v2 import handle_error_v2
from mindee.input.inference_predict_options import InferencePredictOptions
from mindee.input.local_response import LocalResponse
from mindee.input.page_options import PageOptions
from mindee.input.polling_options_v2 import PollingOptionsV2
from mindee.input.polling_options import PollingOptions
from mindee.input.sources.local_input_source import LocalInputSource
from mindee.logger import logger
from mindee.mindee_http.mindee_api_v2 import MindeeApiV2
Expand Down Expand Up @@ -39,11 +38,7 @@ def __init__(self, api_key: Optional[str] = None) -> None:
self.mindee_api = MindeeApiV2(api_key)

def enqueue(
self,
input_source: LocalInputSource,
options: InferencePredictOptions,
page_options: Optional[PageOptions] = None,
close_file: bool = True,
self, input_source: LocalInputSource, options: InferencePredictOptions
) -> PollingResponse:
"""
Enqueues a document to a given model.
Expand All @@ -52,28 +47,19 @@ def enqueue(
Has to be created beforehand.

:param options: Options for the prediction.

:param close_file: Whether to ``close()`` the file after parsing it.
Set to ``False`` if you need to access the file after this operation.

:param page_options: If set, remove pages from the document as specified.
This is done before sending the file to the server.
It is useful to avoid page limitations.
:return: A valid inference response.
"""
logger.debug("Enqueuing document to '%s'", options.model_id)

if page_options and input_source.is_pdf():
if options.page_options and input_source.is_pdf():
input_source.process_pdf(
page_options.operation,
page_options.on_min_pages,
page_options.page_indexes,
options.page_options.operation,
options.page_options.on_min_pages,
options.page_options.page_indexes,
)

response = self.mindee_api.predict_async_req_post(
input_source=input_source,
options=options,
close_file=close_file,
input_source=input_source, options=options
)
dict_response = response.json()

Expand Down Expand Up @@ -103,12 +89,7 @@ def parse_queued(
return InferenceResponse(dict_response)

def enqueue_and_parse(
self,
input_source: LocalInputSource,
options: InferencePredictOptions,
polling_options: Optional[PollingOptionsV2] = None,
page_options: Optional[PageOptions] = None,
close_file: bool = True,
self, input_source: LocalInputSource, options: InferencePredictOptions
) -> InferenceResponse:
"""
Enqueues to an asynchronous endpoint and automatically polls for a response.
Expand All @@ -118,39 +99,25 @@ def enqueue_and_parse(

:param options: Options for the prediction.

:param polling_options: Options for polling.

:param close_file: Whether to ``close()`` the file after parsing it.
Set to ``False`` if you need to access the file after this operation.

:param page_options: If set, remove pages from the document as specified.
This is done before sending the file to the server.
It is useful to avoid page limitations.

:return: A valid inference response.
"""
if not polling_options:
polling_options = PollingOptionsV2()
if not options.polling_options:
options.polling_options = PollingOptions()
self._validate_async_params(
polling_options.initial_delay_sec,
polling_options.delay_sec,
polling_options.max_retries,
)
queue_result = self.enqueue(
input_source,
options,
page_options,
close_file,
options.polling_options.initial_delay_sec,
options.polling_options.delay_sec,
options.polling_options.max_retries,
)
queue_result = self.enqueue(input_source, options)
logger.debug(
"Successfully enqueued document with job id: %s", queue_result.job.id
)
sleep(polling_options.initial_delay_sec)
sleep(options.polling_options.initial_delay_sec)
retry_counter = 1
poll_results = self.parse_queued(
queue_result.job.id,
)
while retry_counter < polling_options.max_retries:
while retry_counter < options.polling_options.max_retries:
if not isinstance(poll_results, PollingResponse):
break
if poll_results.job.status == "Failed":
Expand All @@ -160,7 +127,7 @@ def enqueue_and_parse(
queue_result.job.id,
)
retry_counter += 1
sleep(polling_options.delay_sec)
sleep(options.polling_options.delay_sec)
poll_results = self.parse_queued(queue_result.job.id)

if not isinstance(poll_results, InferenceResponse):
Expand Down
2 changes: 1 addition & 1 deletion mindee/input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mindee.input.inference_predict_options import InferencePredictOptions
from mindee.input.local_response import LocalResponse
from mindee.input.page_options import PageOptions
from mindee.input.polling_options_v2 import PollingOptionsV2
from mindee.input.polling_options import PollingOptions
from mindee.input.sources.base_64_input import Base64Input
from mindee.input.sources.bytes_input import BytesInput
from mindee.input.sources.file_input import FileInput
Expand Down
9 changes: 9 additions & 0 deletions mindee/input/inference_predict_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import List, Optional

from mindee.input.page_options import PageOptions
from mindee.input.polling_options import PollingOptions


@dataclass
class InferencePredictOptions:
Expand All @@ -19,3 +22,9 @@ class InferencePredictOptions:
"""Optional alias for the file."""
webhook_ids: Optional[List[str]] = None
"""IDs of webhooks to propagate the API response to."""
page_options: Optional[PageOptions] = None
"""Options for page-level inference."""
polling_options: Optional[PollingOptions] = None
"""Options for polling."""
close_file: bool = True
"""Whether to close the file after parsing."""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class PollingOptionsV2:
class PollingOptions:
"""Options for asynchronous polling."""

initial_delay_sec: float
Expand Down
8 changes: 2 additions & 6 deletions mindee/mindee_http/mindee_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,13 @@ def set_from_env(self) -> None:
logger.debug("Value was set from env: %s", name)

def predict_async_req_post(
self,
input_source: LocalInputSource,
options: InferencePredictOptions,
close_file: bool = True,
self, input_source: LocalInputSource, options: InferencePredictOptions
) -> requests.Response:
"""
Make an asynchronous request to POST a document for prediction on the V2 API.

:param input_source: Input object.
:param options: Options for the enqueueing of the document.
:param close_file: Whether to `close()` the file after parsing it.
:return: requests response.
"""
data = {"model_id": options.model_id}
Expand All @@ -93,7 +89,7 @@ def predict_async_req_post(
if options.alias and len(options.alias):
data["alias"] = options.alias

files = {"file": input_source.read_contents(close_file)}
files = {"file": input_source.read_contents(options.close_file)}
response = requests.post(
url=url,
files=files,
Expand Down
8 changes: 6 additions & 2 deletions mindee/parsing/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from mindee.parsing.v2.base_field import ListField, ObjectField, SimpleField
from mindee.parsing.v2.base_field import (
InferenceFields,
ListField,
ObjectField,
SimpleField,
)
from mindee.parsing.v2.common_response import CommonResponse
from mindee.parsing.v2.error_response import ErrorResponse
from mindee.parsing.v2.inference import Inference
from mindee.parsing.v2.inference_fields import InferenceFields
from mindee.parsing.v2.inference_file import InferenceFile
from mindee.parsing.v2.inference_model import InferenceModel
from mindee.parsing.v2.inference_options import InferenceOptions
Expand Down
33 changes: 24 additions & 9 deletions mindee/parsing/v2/base_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ def create_field(
raise MindeeApiV2Error(f"Unrecognized field format {raw_response}.")


class InferenceFields(Dict[str, Union["SimpleField", "ObjectField", "ListField"]]):
"""Inference fields dict."""

def __init__(self, raw_response: StringDict, indent_level: int = 0) -> None:
super().__init__()
for key, value in raw_response.items():
field_obj = BaseField.create_field(value, indent_level)
self[key] = field_obj

def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item) from None

def __str__(self) -> str:
str_fields = ""
for field_key, field_value in self.items():
str_fields += f":{field_key}: {field_value}"
return str_fields


class ListField(BaseField):
"""List field containing multiple fields."""

Expand All @@ -55,21 +77,14 @@ def __str__(self) -> str:
class ObjectField(BaseField):
"""Object field containing multiple fields."""

fields: Dict[str, Union[ListField, "ObjectField", "SimpleField"]]
fields: InferenceFields
"""Fields contained in the object."""

def __init__(self, raw_response: StringDict, indent_level: int = 0):
super().__init__(indent_level)
inner_fields = raw_response.get("fields", raw_response)

self.fields: Dict[str, Union["ListField", "ObjectField", "SimpleField"]] = {}
for field_key, field_value in inner_fields.items():
if isinstance(field_value, dict):
self.fields[field_key] = BaseField.create_field(
field_value, self._indent_level + 1
)
else:
raise MindeeApiV2Error(f"Unrecognized field format '{field_value}'.")
self.fields = InferenceFields(inner_fields, self._indent_level + 1)

def __str__(self) -> str:
out_str = ""
Expand Down
28 changes: 0 additions & 28 deletions mindee/parsing/v2/inference_fields.py

This file was deleted.

2 changes: 1 addition & 1 deletion mindee/parsing/v2/inference_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from mindee.parsing.common.string_dict import StringDict
from mindee.parsing.v2.inference_fields import InferenceFields
from mindee.parsing.v2.base_field import InferenceFields
from mindee.parsing.v2.inference_options import InferenceOptions


Expand Down
Empty file.
Empty file.
Empty file.
4 changes: 4 additions & 0 deletions tests/v2/test_inference_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def test_full_inference_response():
assert load_response.inference.result.fields.date.value == "2019-11-02"
assert isinstance(load_response.inference.result.fields.taxes, ListField)
assert isinstance(load_response.inference.result.fields.taxes.items[0], ObjectField)
assert (
load_response.inference.result.fields.customer_address.fields.city.value
== "New York"
)
assert (
load_response.inference.result.fields.taxes.items[0].fields["base"].value
== 31.5
Expand Down