From 7ace2843df2bf503b414804cf56dde9d02ceb7c6 Mon Sep 17 00:00:00 2001 From: Aric Coady Date: Tue, 10 Mar 2020 21:05:36 -0700 Subject: [PATCH] Basic search query with str fields. --- lupyne/engine/documents.py | 2 +- lupyne/server/graphql.py | 28 +++++++++++++++++++++++++++- lupyne/server/rest.py | 10 ++++++++++ tests/test_graphql.py | 12 ++++++++++++ tests/test_rest.py | 10 ++++++++++ 5 files changed, 60 insertions(+), 2 deletions(-) diff --git a/lupyne/engine/documents.py b/lupyne/engine/documents.py index 3452559..b6c63ea 100644 --- a/lupyne/engine/documents.py +++ b/lupyne/engine/documents.py @@ -298,7 +298,7 @@ class Hits: def __init__(self, searcher, scoredocs: Sequence, count=0, fields=None): self.searcher, self.scoredocs = searcher, scoredocs - if hasattr(count, 'relation'): # pragma: no cover + if hasattr(count, 'relation'): cls = int if count.relation == search.TotalHits.Relation.EQUAL_TO else float count = cls(count.value) self.count, self.fields = count, fields diff --git a/lupyne/server/graphql.py b/lupyne/server/graphql.py index c52ed1a..02a12ee 100644 --- a/lupyne/server/graphql.py +++ b/lupyne/server/graphql.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import lucene import strawberry.asgi from starlette.applications import Starlette @@ -41,6 +41,26 @@ class Terms: counts: Counts +@strawberry.type +class Document: + __annotations__ = {name: List[str] for name in root.searcher.fieldinfos} + locals().update(dict.fromkeys(__annotations__, ())) + + +@strawberry.type +class Hit: + id: int + score: Optional[float] + sortkeys: List[str] + doc: Document + + +@strawberry.type +class Hits: + count: int + hits: List[Hit] + + @strawberry.type class Query: @strawberry.field @@ -61,6 +81,12 @@ def terms(self, info) -> Terms: values[name], counts[name] = zip(*root.searcher.terms(name, counts=True)) return Terms(Values(**values), Counts(**counts)) + @strawberry.field + def search(self, info, q: str, count: int = None) -> Hits: + """Run query and return htis.""" + hits = root.searcher.search(q, count) + return Hits(hits.count, hits=[Hit(hit.id, hit.score, hit.sortkeys, Document(**hit)) for hit in hits]) + @strawberry.type class Mutation: diff --git a/lupyne/server/rest.py b/lupyne/server/rest.py index 39f50d9..768078a 100644 --- a/lupyne/server/rest.py +++ b/lupyne/server/rest.py @@ -21,6 +21,16 @@ def terms(name: str, *, counts: bool = False) -> Union[list, dict]: return (dict if counts else list)(terms) +@app.get('/search') +def search(q: str, count: int = None) -> dict: + """Run query and return htis.""" + hits = root.searcher.search(q, count) + return { + 'count': hits.count, + 'hits': [{'id': hit.id, 'score': hit.score, 'sortkeys': hit.sortkeys, 'doc': hit} for hit in hits], + } + + @app.middleware('http') async def headers(request, call_next): start = time.time() diff --git a/tests/test_graphql.py b/tests/test_graphql.py index a6cea1d..bd8ec68 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -38,3 +38,15 @@ def test_terms(client): data = client.execute(query='''{ terms { counts { date } } }''') counts = data['terms']['counts']['date'] assert counts[0] == 10 + + +def test_search(client): + data = client.execute( + query='''{ search(q: "text:right", count: 1) { count hits { id score sortkeys doc { amendment } } } }''' + ) + assert data['search']['count'] == 13 + (hit,) = data['search']['hits'] + assert hit['id'] == 9 + assert hit['score'] > 0 + assert hit['sortkeys'] == [] + assert hit['doc'] == {'amendment': ['2']} diff --git a/tests/test_rest.py b/tests/test_rest.py index 7bc83f8..dd50c30 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -28,3 +28,13 @@ def test_terms(client): assert min(result) == result[0] == '1791-12-15' result = client.get('/terms/date', params={'counts': True}).json() assert result['1791-12-15'] == 10 + + +def test_search(client): + result = client.get('/search', params={'q': "text:right", 'count': 1}).json() + assert result['count'] == 13 + (hit,) = result['hits'] + assert hit['id'] == 9 + assert hit['score'] > 0 + assert hit['sortkeys'] == [] + assert hit['doc'] == {'amendment': ['2'], 'date': ['1791-12-15']}