Skip to content

Commit

Permalink
feat: add language predictor for product
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 1, 2023
1 parent 619c477 commit c77f049
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 1 deletion.
79 changes: 79 additions & 0 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,85 @@ paths:
"400":
description: "An HTTP 400 is returned if the provided parameters are invalid"

/predict/lang/product:
get:
tags:
- Predict
summary: Predict the languages of the product
description: |
Return the most common languages present on the product images, based on word-level
language detection from product images.
Language detection is not performed on the fly, but is based on predictions of type
`image_lang` stored in the `prediction` table.
parameters:
- $ref: "#/components/parameters/barcode"
- $ref: "#/components/parameters/server_type"
in: query
required: false
description: |
the minimum probability for a language to be returned
schema:
type: number
default: 0.01
minimum: 0
maximum: 1
responses:
"200":
description: |
The predicted languages, sorted by descending probability.
content:
application/json:
schema:
type: object
properties:
counts:
type: array
description: |
the number of words detected for each language, over all images,
sorted by descending count
items:
type: object
properties:
lang:
type: string
description: the predicted language (2-letter code). `null` if the language could not be detected.
example: "en"
count:
type: number
description: the number of words for which this language was detected over all images
example: 10
percent:
type: array
description: |
the percentage of words detected for each language, over all images,
sorted by descending percentage
items:
type: object
properties:
lang:
type: string
description: the predicted language (2-letter code). `null` if the language could not be detected.
example: "en"
percent:
type: number
description: the percentage of words for which the language was detected over all images
minimum: 0
maximum: 100
example: 80.5
image_ids:
type: array
description: |
the IDs of the images that were used to generate the predictions
items:
type: number
example: 1
description: the ID of an image
"400":
description: "An HTTP 400 is returned if the provided parameters are invalid"


components:
schemas:
LogoANNSearchResponse:
Expand Down
40 changes: 40 additions & 0 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import re
import tempfile
import uuid
from collections import defaultdict
from pathlib import Path
from typing import Literal, Optional, cast

import falcon
Expand Down Expand Up @@ -685,6 +687,43 @@ def on_post(self, req: falcon.Request, resp: falcon.Response):
self._on_get_post(req, resp)


class ProductLanguagePredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
"""Predict the languages displayed on the product images, using
`image_lang` predictions as input."""
barcode = req.get_param("barcode", required=True)
server_type = get_server_type_from_req(req)
counts: dict[str, int] = defaultdict(int)
image_ids: list[int] = []

for prediction_data, source_image in (
Prediction.select(Prediction.data, Prediction.source_image)
.where(
Prediction.barcode == barcode,
Prediction.server_type == server_type.name,
Prediction.type == PredictionType.image_lang.name,
)
.tuples()
.iterator()
):
image_ids.append(int(Path(source_image).stem))
for lang, lang_count in prediction_data["count"].items():
counts[lang] += lang_count

words_n = counts.pop("words")
sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True)
counts_list = [{"count": count, "lang": lang} for lang, count in sorted_counts]
percent_list = [
{"percent": (count * 100 / words_n), "lang": lang}
for lang, count in sorted_counts
]
resp.media = {
"counts": counts_list,
"percent": percent_list,
"image_ids": sorted(image_ids),
}


class UpdateDatasetResource:
def on_post(self, req: falcon.Request, resp: falcon.Response):
"""Re-import the Product Opener product dump."""
Expand Down Expand Up @@ -1796,6 +1835,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
api.add_route("/api/v1/predict/category", CategoryPredictorResource())
api.add_route("/api/v1/predict/ingredient_list", IngredientListPredictorResource())
api.add_route("/api/v1/predict/lang", LanguagePredictorResource())
api.add_route("/api/v1/predict/lang/product", ProductLanguagePredictorResource())
api.add_route("/api/v1/products/dataset", UpdateDatasetResource())
api.add_route("/api/v1/webhook/product", WebhookProductResource())
api.add_route("/api/v1/images", ImageCollection())
Expand Down
43 changes: 42 additions & 1 deletion tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from robotoff.models import AnnotationVote, LogoAnnotation, ProductInsight
from robotoff.off import OFFAuthentication
from robotoff.prediction.langid import LanguagePrediction
from robotoff.types import ProductIdentifier, ServerType
from robotoff.types import PredictionType, ProductIdentifier, ServerType

from .models_utils import (
AnnotationVoteFactory,
Expand Down Expand Up @@ -1250,3 +1250,44 @@ def test_predict_lang(client, mocker):
)
assert result.status_code == 200
assert result.json == {"predictions": expected_predictions}


def test_predict_product_language(client, peewee_db):
barcode = "123456789"
prediction_data_1 = {"count": {"en": 10, "fr": 5, "es": 3, "words": 18}}
prediction_data_2 = {"count": {"en": 2, "fr": 3, "words": 5}}

with peewee_db:
PredictionFactory(
barcode=barcode,
server_type=ServerType.off.name,
type=PredictionType.image_lang.name,
data=prediction_data_1,
source_image="/123/45678/2.jpg",
)
PredictionFactory(
barcode=barcode,
server_type=ServerType.off.name,
type=PredictionType.image_lang.name,
data=prediction_data_2,
source_image="/123/45678/4.jpg",
)

# Send GET request to the API endpoint
result = client.simulate_get(f"/api/v1/predict/lang/product?barcode={barcode}")

# Assert the response
assert result.status_code == 200
assert result.json == {
"counts": [
{"count": 12, "lang": "en"},
{"count": 8, "lang": "fr"},
{"count": 3, "lang": "es"},
],
"percent": [
{"percent": 12 * 100 / 23, "lang": "en"},
{"percent": 8 * 100 / 23, "lang": "fr"},
{"percent": 3 * 100 / 23, "lang": "es"},
],
"image_ids": [2, 4],
}

0 comments on commit c77f049

Please sign in to comment.