diff --git a/mindee/client_v2.py b/mindee/client_v2.py index 41e10139..6d1e6890 100644 --- a/mindee/client_v2.py +++ b/mindee/client_v2.py @@ -6,8 +6,7 @@ from mindee.error.mindee_http_error_v2 import handle_error_v2 from mindee.input.inference_predict_options import InferencePredictOptions from mindee.input.local_response import LocalResponse -from mindee.input.page_options import PageOptions -from mindee.input.polling_options_v2 import PollingOptionsV2 +from mindee.input.polling_options import PollingOptions from mindee.input.sources.local_input_source import LocalInputSource from mindee.logger import logger from mindee.mindee_http.mindee_api_v2 import MindeeApiV2 @@ -39,11 +38,7 @@ def __init__(self, api_key: Optional[str] = None) -> None: self.mindee_api = MindeeApiV2(api_key) def enqueue( - self, - input_source: LocalInputSource, - options: InferencePredictOptions, - page_options: Optional[PageOptions] = None, - close_file: bool = True, + self, input_source: LocalInputSource, options: InferencePredictOptions ) -> PollingResponse: """ Enqueues a document to a given model. @@ -52,28 +47,19 @@ def enqueue( Has to be created beforehand. :param options: Options for the prediction. - - :param close_file: Whether to ``close()`` the file after parsing it. - Set to ``False`` if you need to access the file after this operation. - - :param page_options: If set, remove pages from the document as specified. - This is done before sending the file to the server. - It is useful to avoid page limitations. :return: A valid inference response. """ logger.debug("Enqueuing document to '%s'", options.model_id) - if page_options and input_source.is_pdf(): + if options.page_options and input_source.is_pdf(): input_source.process_pdf( - page_options.operation, - page_options.on_min_pages, - page_options.page_indexes, + options.page_options.operation, + options.page_options.on_min_pages, + options.page_options.page_indexes, ) response = self.mindee_api.predict_async_req_post( - input_source=input_source, - options=options, - close_file=close_file, + input_source=input_source, options=options ) dict_response = response.json() @@ -103,12 +89,7 @@ def parse_queued( return InferenceResponse(dict_response) def enqueue_and_parse( - self, - input_source: LocalInputSource, - options: InferencePredictOptions, - polling_options: Optional[PollingOptionsV2] = None, - page_options: Optional[PageOptions] = None, - close_file: bool = True, + self, input_source: LocalInputSource, options: InferencePredictOptions ) -> InferenceResponse: """ Enqueues to an asynchronous endpoint and automatically polls for a response. @@ -118,39 +99,25 @@ def enqueue_and_parse( :param options: Options for the prediction. - :param polling_options: Options for polling. - - :param close_file: Whether to ``close()`` the file after parsing it. - Set to ``False`` if you need to access the file after this operation. - - :param page_options: If set, remove pages from the document as specified. - This is done before sending the file to the server. - It is useful to avoid page limitations. - :return: A valid inference response. """ - if not polling_options: - polling_options = PollingOptionsV2() + if not options.polling_options: + options.polling_options = PollingOptions() self._validate_async_params( - polling_options.initial_delay_sec, - polling_options.delay_sec, - polling_options.max_retries, - ) - queue_result = self.enqueue( - input_source, - options, - page_options, - close_file, + options.polling_options.initial_delay_sec, + options.polling_options.delay_sec, + options.polling_options.max_retries, ) + queue_result = self.enqueue(input_source, options) logger.debug( "Successfully enqueued document with job id: %s", queue_result.job.id ) - sleep(polling_options.initial_delay_sec) + sleep(options.polling_options.initial_delay_sec) retry_counter = 1 poll_results = self.parse_queued( queue_result.job.id, ) - while retry_counter < polling_options.max_retries: + while retry_counter < options.polling_options.max_retries: if not isinstance(poll_results, PollingResponse): break if poll_results.job.status == "Failed": @@ -160,7 +127,7 @@ def enqueue_and_parse( queue_result.job.id, ) retry_counter += 1 - sleep(polling_options.delay_sec) + sleep(options.polling_options.delay_sec) poll_results = self.parse_queued(queue_result.job.id) if not isinstance(poll_results, InferenceResponse): diff --git a/mindee/input/__init__.py b/mindee/input/__init__.py index b10dece5..ebeffdf2 100644 --- a/mindee/input/__init__.py +++ b/mindee/input/__init__.py @@ -1,7 +1,7 @@ from mindee.input.inference_predict_options import InferencePredictOptions from mindee.input.local_response import LocalResponse from mindee.input.page_options import PageOptions -from mindee.input.polling_options_v2 import PollingOptionsV2 +from mindee.input.polling_options import PollingOptions from mindee.input.sources.base_64_input import Base64Input from mindee.input.sources.bytes_input import BytesInput from mindee.input.sources.file_input import FileInput diff --git a/mindee/input/inference_predict_options.py b/mindee/input/inference_predict_options.py index df8dce82..6d224168 100644 --- a/mindee/input/inference_predict_options.py +++ b/mindee/input/inference_predict_options.py @@ -1,6 +1,9 @@ from dataclasses import dataclass from typing import List, Optional +from mindee.input.page_options import PageOptions +from mindee.input.polling_options import PollingOptions + @dataclass class InferencePredictOptions: @@ -19,3 +22,9 @@ class InferencePredictOptions: """Optional alias for the file.""" webhook_ids: Optional[List[str]] = None """IDs of webhooks to propagate the API response to.""" + page_options: Optional[PageOptions] = None + """Options for page-level inference.""" + polling_options: Optional[PollingOptions] = None + """Options for polling.""" + close_file: bool = True + """Whether to close the file after parsing.""" diff --git a/mindee/input/polling_options_v2.py b/mindee/input/polling_options.py similarity index 95% rename from mindee/input/polling_options_v2.py rename to mindee/input/polling_options.py index bf7ef142..5ffc4c76 100644 --- a/mindee/input/polling_options_v2.py +++ b/mindee/input/polling_options.py @@ -1,4 +1,4 @@ -class PollingOptionsV2: +class PollingOptions: """Options for asynchronous polling.""" initial_delay_sec: float diff --git a/mindee/mindee_http/mindee_api_v2.py b/mindee/mindee_http/mindee_api_v2.py index 36d35fb1..deb17cba 100644 --- a/mindee/mindee_http/mindee_api_v2.py +++ b/mindee/mindee_http/mindee_api_v2.py @@ -68,17 +68,13 @@ def set_from_env(self) -> None: logger.debug("Value was set from env: %s", name) def predict_async_req_post( - self, - input_source: LocalInputSource, - options: InferencePredictOptions, - close_file: bool = True, + self, input_source: LocalInputSource, options: InferencePredictOptions ) -> requests.Response: """ Make an asynchronous request to POST a document for prediction on the V2 API. :param input_source: Input object. :param options: Options for the enqueueing of the document. - :param close_file: Whether to `close()` the file after parsing it. :return: requests response. """ data = {"model_id": options.model_id} @@ -93,7 +89,7 @@ def predict_async_req_post( if options.alias and len(options.alias): data["alias"] = options.alias - files = {"file": input_source.read_contents(close_file)} + files = {"file": input_source.read_contents(options.close_file)} response = requests.post( url=url, files=files, diff --git a/mindee/parsing/v2/__init__.py b/mindee/parsing/v2/__init__.py index 2c855c52..812b3865 100644 --- a/mindee/parsing/v2/__init__.py +++ b/mindee/parsing/v2/__init__.py @@ -1,8 +1,12 @@ -from mindee.parsing.v2.base_field import ListField, ObjectField, SimpleField +from mindee.parsing.v2.base_field import ( + InferenceFields, + ListField, + ObjectField, + SimpleField, +) from mindee.parsing.v2.common_response import CommonResponse from mindee.parsing.v2.error_response import ErrorResponse from mindee.parsing.v2.inference import Inference -from mindee.parsing.v2.inference_fields import InferenceFields from mindee.parsing.v2.inference_file import InferenceFile from mindee.parsing.v2.inference_model import InferenceModel from mindee.parsing.v2.inference_options import InferenceOptions diff --git a/mindee/parsing/v2/base_field.py b/mindee/parsing/v2/base_field.py index f199265b..ee75498c 100644 --- a/mindee/parsing/v2/base_field.py +++ b/mindee/parsing/v2/base_field.py @@ -29,6 +29,28 @@ def create_field( raise MindeeApiV2Error(f"Unrecognized field format {raw_response}.") +class InferenceFields(Dict[str, Union["SimpleField", "ObjectField", "ListField"]]): + """Inference fields dict.""" + + def __init__(self, raw_response: StringDict, indent_level: int = 0) -> None: + super().__init__() + for key, value in raw_response.items(): + field_obj = BaseField.create_field(value, indent_level) + self[key] = field_obj + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(item) from None + + def __str__(self) -> str: + str_fields = "" + for field_key, field_value in self.items(): + str_fields += f":{field_key}: {field_value}" + return str_fields + + class ListField(BaseField): """List field containing multiple fields.""" @@ -55,21 +77,14 @@ def __str__(self) -> str: class ObjectField(BaseField): """Object field containing multiple fields.""" - fields: Dict[str, Union[ListField, "ObjectField", "SimpleField"]] + fields: InferenceFields """Fields contained in the object.""" def __init__(self, raw_response: StringDict, indent_level: int = 0): super().__init__(indent_level) inner_fields = raw_response.get("fields", raw_response) - self.fields: Dict[str, Union["ListField", "ObjectField", "SimpleField"]] = {} - for field_key, field_value in inner_fields.items(): - if isinstance(field_value, dict): - self.fields[field_key] = BaseField.create_field( - field_value, self._indent_level + 1 - ) - else: - raise MindeeApiV2Error(f"Unrecognized field format '{field_value}'.") + self.fields = InferenceFields(inner_fields, self._indent_level + 1) def __str__(self) -> str: out_str = "" diff --git a/mindee/parsing/v2/inference_fields.py b/mindee/parsing/v2/inference_fields.py deleted file mode 100644 index dfbcfb9a..00000000 --- a/mindee/parsing/v2/inference_fields.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Union - -from mindee.parsing.common.string_dict import StringDict -from mindee.parsing.v2.base_field import BaseField, ListField, ObjectField, SimpleField - - -class InferenceFields(Dict[str, Union[SimpleField, ObjectField, ListField]]): - """Inference fields dict.""" - - def __init__(self, raw_response: StringDict) -> None: - super().__init__() - for key, value in raw_response.items(): - field_obj = BaseField.create_field(value, 0) - self[key] = field_obj - - def __getattr__(self, item): - try: - return self[item] - except KeyError: - raise AttributeError(item) from None - - def __str__(self) -> str: - str_fields = "" - for field_key, field_value in self.items(): - str_fields += f":{field_key}: {field_value}" - return str_fields diff --git a/mindee/parsing/v2/inference_result.py b/mindee/parsing/v2/inference_result.py index 2f6ad911..9fdac6f4 100644 --- a/mindee/parsing/v2/inference_result.py +++ b/mindee/parsing/v2/inference_result.py @@ -1,7 +1,7 @@ from typing import Optional from mindee.parsing.common.string_dict import StringDict -from mindee.parsing.v2.inference_fields import InferenceFields +from mindee.parsing.v2.base_field import InferenceFields from mindee.parsing.v2.inference_options import InferenceOptions diff --git a/mindee/tests/product/fr/__init__.py b/mindee/tests/product/fr/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mindee/tests/product/ind/__init__.py b/mindee/tests/product/ind/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mindee/tests/product/us/__init__.py b/mindee/tests/product/us/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/v2/test_inference_response.py b/tests/v2/test_inference_response.py index 0c10219f..924ecd01 100644 --- a/tests/v2/test_inference_response.py +++ b/tests/v2/test_inference_response.py @@ -184,6 +184,10 @@ def test_full_inference_response(): assert load_response.inference.result.fields.date.value == "2019-11-02" assert isinstance(load_response.inference.result.fields.taxes, ListField) assert isinstance(load_response.inference.result.fields.taxes.items[0], ObjectField) + assert ( + load_response.inference.result.fields.customer_address.fields.city.value + == "New York" + ) assert ( load_response.inference.result.fields.taxes.items[0].fields["base"].value == 31.5