Skip to content

Commit

Permalink
feat: add bounding box info to IngredientPredictionAggregatedEntity
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Nov 13, 2023
1 parent 2b91a29 commit f45cd39
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
12 changes: 1 addition & 11 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,17 +647,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
],
model_version=model_version,
)

output_dict = dataclasses.asdict(output)

if aggregation_strategy != "NONE":
# Add bounding boxes to entities
for entity in output_dict["entities"]:
entity["bounding_boxes"] = ocr_result.get_match_bounding_box(
entity["start"], entity["end"]
)

resp.media = output_dict
resp.media = dataclasses.asdict(output)


class UpdateDatasetResource:
Expand Down
18 changes: 15 additions & 3 deletions robotoff/prediction/ingredient_list/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import functools
from pathlib import Path
from typing import Optional, Union
from typing import Union

import numpy as np
from openfoodfacts.ocr import OCRResult
Expand Down Expand Up @@ -39,7 +39,10 @@ class IngredientPredictionAggregatedEntity:
# entity text (without organic or allergen mentions)
text: str
# language prediction of the entity text
lang: Optional[LanguagePrediction] = None
lang: LanguagePrediction | None = None
# the bounding box of the entity in absolute coordinates
# (y_min, x_min, y_max, x_max), or None if not available
bounding_box: tuple[int, int, int, int] | None = None


@dataclasses.dataclass
Expand Down Expand Up @@ -102,7 +105,16 @@ def predict_from_ocr(
predictions = predict_batch(
[text], aggregation_strategy, predict_lang, model_version
)
return predictions[0]
prediction = predictions[0]

for entity in prediction.entities:
if isinstance(entity, IngredientPredictionAggregatedEntity):
# Add the bounding box to the entity
entity.bounding_box = ocr_result.get_match_bounding_box(
entity.start, entity.end
)

return prediction


@functools.cache
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/workers/tasks/test_import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_extract_ingredients_job(mocker, peewee_db):
score=0.9,
text="water, salt, sugar.",
lang=LanguagePrediction(lang="en", confidence=0.9),
bounding_box=(0, 0, 100, 100),
)
]
parsed_ingredients = [
Expand Down Expand Up @@ -112,6 +113,7 @@ def test_extract_ingredients_job(mocker, peewee_db):
{"in_taxonomy": True, **ingredient}
for ingredient in parsed_ingredients
],
"bounding_box": [0, 0, 100, 100],
}
],
}
Expand Down

0 comments on commit f45cd39

Please sign in to comment.