Skip to content

Commit

Permalink
Fix serialization of numpy arrays and pandas dataframes in REST API (#…
Browse files Browse the repository at this point in the history
…2838)

* correct serialization of numpy arrays and pandas dataframes

* Update Documentation & Code Style

* set additional json_encoders globally

* Update Documentation & Code Style

* add tests for non primitive return types

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tstadel and github-actions[bot] committed Aug 2, 2022
1 parent 86d56b4 commit 2c56305
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 66 deletions.
13 changes: 3 additions & 10 deletions docs/_src/api/openapi/openapi-1.6.1rc0.json
Expand Up @@ -123,15 +123,7 @@
"content": {
"application/json": {
"schema": {
"title": "Feedback",
"anyOf": [
{
"$ref": "#/components/schemas/Label"
},
{
"$ref": "#/components/schemas/CreateLabelSerialized"
}
]
"$ref": "#/components/schemas/CreateLabelSerialized"
}
}
},
Expand Down Expand Up @@ -698,7 +690,8 @@
]
}
}
}
},
"additionalProperties": false
},
"HTTPValidationError": {
"title": "HTTPValidationError",
Expand Down
13 changes: 3 additions & 10 deletions docs/_src/api/openapi/openapi.json
Expand Up @@ -123,15 +123,7 @@
"content": {
"application/json": {
"schema": {
"title": "Feedback",
"anyOf": [
{
"$ref": "#/components/schemas/Label"
},
{
"$ref": "#/components/schemas/CreateLabelSerialized"
}
]
"$ref": "#/components/schemas/CreateLabelSerialized"
}
}
},
Expand Down Expand Up @@ -698,7 +690,8 @@
]
}
}
}
},
"additionalProperties": false
},
"HTTPValidationError": {
"title": "HTTPValidationError",
Expand Down
5 changes: 3 additions & 2 deletions rest_api/controller/document.py
Expand Up @@ -4,10 +4,11 @@

from fastapi import FastAPI, APIRouter
from haystack.document_stores import BaseDocumentStore
from haystack.schema import Document

from rest_api.utils import get_app, get_pipelines
from rest_api.config import LOG_LEVEL
from rest_api.schema import FilterRequest, DocumentSerialized
from rest_api.schema import FilterRequest


logging.getLogger("haystack").setLevel(LOG_LEVEL)
Expand All @@ -19,7 +20,7 @@
document_store: BaseDocumentStore = get_pipelines().get("document_store", None)


@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized], response_model_exclude_none=True)
@router.post("/documents/get_by_filters", response_model=List[Document], response_model_exclude_none=True)
def get_documents(filters: FilterRequest):
"""
This endpoint allows you to retrieve documents contained in your document store.
Expand Down
6 changes: 3 additions & 3 deletions rest_api/controller/feedback.py
Expand Up @@ -6,7 +6,7 @@
from fastapi import FastAPI, APIRouter
from haystack.schema import Label
from haystack.document_stores import BaseDocumentStore
from rest_api.schema import FilterRequest, LabelSerialized, CreateLabelSerialized
from rest_api.schema import FilterRequest, CreateLabelSerialized
from rest_api.utils import get_app, get_pipelines


Expand All @@ -18,7 +18,7 @@


@router.post("/feedback")
def post_feedback(feedback: Union[LabelSerialized, CreateLabelSerialized]):
def post_feedback(feedback: CreateLabelSerialized):
"""
This endpoint allows the API user to submit feedback on an answer for a particular query.
Expand All @@ -35,7 +35,7 @@ def post_feedback(feedback: Union[LabelSerialized, CreateLabelSerialized]):
document_store.write_labels([label])


@router.get("/feedback", response_model=List[LabelSerialized])
@router.get("/feedback", response_model=List[Label])
def get_feedback():
"""
This endpoint allows the API user to retrieve all the feedback that has been submitted
Expand Down
6 changes: 0 additions & 6 deletions rest_api/controller/search.py
Expand Up @@ -4,7 +4,6 @@
import logging
import time
import json
from numpy import ndarray

from pydantic import BaseConfig
from fastapi import FastAPI, APIRouter
Expand Down Expand Up @@ -84,11 +83,6 @@ def _process_request(pipeline, request) -> Dict[str, Any]:
if not "answers" in result:
result["answers"] = []

# if any of the documents contains an embedding as an ndarray the latter needs to be converted to list of float
for document in result["documents"]:
if isinstance(document.embedding, ndarray):
document.embedding = document.embedding.tolist()

logger.info(
json.dumps({"request": request, "response": result, "time": f"{(time.time() - start_time):.2f}"}, default=str)
)
Expand Down
46 changes: 17 additions & 29 deletions rest_api/schema.py
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd

try:
from typing import Literal
Expand All @@ -10,64 +12,50 @@
from pydantic import BaseModel, Field, Extra
from pydantic import BaseConfig

from haystack.schema import Answer, Document, Label
from haystack.schema import Answer, Document


BaseConfig.arbitrary_types_allowed = True
BaseConfig.json_encoders = {np.ndarray: lambda x: x.tolist(), pd.DataFrame: lambda x: x.to_dict(orient="records")}

PrimitiveType = Union[str, int, float, bool]

PrimitiveType = Union[str, int, float, bool]

class QueryRequest(BaseModel):
query: str
params: Optional[dict] = None
debug: Optional[bool] = False

class RequestBaseModel(BaseModel):
class Config:
# Forbid any extra fields in the request to avoid silent failures
extra = Extra.forbid


class FilterRequest(BaseModel):
filters: Optional[Dict[str, Union[PrimitiveType, List[PrimitiveType], Dict[str, PrimitiveType]]]] = None


class AnswerSerialized(Answer):
context: Optional[str] = None


class DocumentSerialized(Document):
content: str
embedding: Optional[List[float]] # type: ignore
class QueryRequest(RequestBaseModel):
query: str
params: Optional[dict] = None
debug: Optional[bool] = False


class LabelSerialized(Label, BaseModel):
document: DocumentSerialized
answer: Optional[AnswerSerialized] = None
class FilterRequest(RequestBaseModel):
filters: Optional[Dict[str, Union[PrimitiveType, List[PrimitiveType], Dict[str, PrimitiveType]]]] = None


class CreateLabelSerialized(BaseModel):
class CreateLabelSerialized(RequestBaseModel):
id: Optional[str] = None
query: str
document: DocumentSerialized
document: Document
is_correct_answer: bool
is_correct_document: bool
origin: Literal["user-feedback", "gold-label"]
answer: Optional[AnswerSerialized] = None
answer: Optional[Answer] = None
no_answer: Optional[bool] = None
pipeline_id: Optional[str] = None
created_at: Optional[str] = None
updated_at: Optional[str] = None
meta: Optional[dict] = None
filters: Optional[dict] = None

class Config:
# Forbid any extra fields in the request to avoid silent failures
extra = Extra.forbid


class QueryResponse(BaseModel):
query: str
answers: List[AnswerSerialized] = []
documents: List[DocumentSerialized] = []
answers: List[Answer] = []
documents: List[Document] = []
debug: Optional[Dict] = Field(None, alias="_debug")
66 changes: 60 additions & 6 deletions rest_api/test/test_rest_api.py
Expand Up @@ -5,6 +5,8 @@
from textwrap import dedent
from unittest import mock
from unittest.mock import MagicMock
import numpy as np
import pandas as pd

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -125,7 +127,7 @@ def get_all_documents_generator(self, *args, **kwargs) -> Generator[Document, No
pass

def get_all_labels(self, *args, **kwargs) -> List[Label]:
self.mocker.get_all_labels(*args, **kwargs)
return self.mocker.get_all_labels(*args, **kwargs)

def get_document_by_id(self, *args, **kwargs) -> Optional[Document]:
pass
Expand Down Expand Up @@ -176,7 +178,7 @@ def feedback():
"score": None,
"id": "fc18c987a8312e72a47fb1524f230bb0",
"meta": {},
"embedding": None,
"embedding": [0.1, 0.2, 0.3],
},
"answer": {
"answer": "Adobe Systems",
Expand Down Expand Up @@ -366,6 +368,57 @@ def test_query_with_bool_in_params(client):
assert response_json["answers"] == []


def test_query_with_embeddings(client):
with mock.patch("rest_api.controller.search.query_pipeline") as mocked_pipeline:
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {
"query": TEST_QUERY,
"documents": [
Document(
content="test",
content_type="text",
score=0.9,
meta={"test_key": "test_value"},
embedding=np.array([0.1, 0.2, 0.3]),
)
],
}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert len(response.json()["documents"]) == 1
assert response.json()["documents"][0]["content"] == "test"
assert response.json()["documents"][0]["content_type"] == "text"
assert response.json()["documents"][0]["embedding"] == [0.1, 0.2, 0.3]
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False)


def test_query_with_dataframe(client):
with mock.patch("rest_api.controller.search.query_pipeline") as mocked_pipeline:
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {
"query": TEST_QUERY,
"documents": [
Document(
content=pd.DataFrame.from_records([{"col1": "text_1", "col2": 1}, {"col1": "text_2", "col2": 2}]),
content_type="table",
score=0.9,
meta={"test_key": "test_value"},
)
],
}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert len(response.json()["documents"]) == 1
assert response.json()["documents"][0]["content"] == [
{"col1": "text_1", "col2": 1},
{"col1": "text_2", "col2": 2},
]
assert response.json()["documents"][0]["content_type"] == "table"
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False)


def test_write_feedback(client, feedback):
response = client.post(url="/feedback", json=feedback)
assert 200 == response.status_code
Expand All @@ -376,9 +429,8 @@ def test_write_feedback(client, feedback):
assert len(labels) == 1
# Ensure all the items that were in `feedback` are also part of
# the stored label (which has several more keys)
label = labels[0].to_dict()
for k, v in feedback.items():
assert label[k] == v
label = labels[0]
assert label == Label.from_dict(feedback)


def test_write_feedback_without_id(client, feedback):
Expand All @@ -395,9 +447,11 @@ def test_write_feedback_without_id(client, feedback):
assert label["id"]


def test_get_feedback(client):
def test_get_feedback(client, feedback):
MockDocumentStore.mocker.get_all_labels.return_value = [Label.from_dict(feedback)]
response = client.get("/feedback")
assert response.status_code == 200
assert Label.from_dict(response.json()[0]) == Label.from_dict(feedback)
MockDocumentStore.mocker.get_all_labels.assert_called_once()


Expand Down

0 comments on commit 2c56305

Please sign in to comment.