Skip to content

Commit

Permalink
Basic search query with str fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Mar 11, 2020
1 parent d40ccda commit 7ace284
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lupyne/engine/documents.py
Expand Up @@ -298,7 +298,7 @@ class Hits:

def __init__(self, searcher, scoredocs: Sequence, count=0, fields=None):
self.searcher, self.scoredocs = searcher, scoredocs
if hasattr(count, 'relation'): # pragma: no cover
if hasattr(count, 'relation'):
cls = int if count.relation == search.TotalHits.Relation.EQUAL_TO else float
count = cls(count.value)
self.count, self.fields = count, fields
Expand Down
28 changes: 27 additions & 1 deletion lupyne/server/graphql.py
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
import lucene
import strawberry.asgi
from starlette.applications import Starlette
Expand Down Expand Up @@ -41,6 +41,26 @@ class Terms:
counts: Counts


@strawberry.type
class Document:
__annotations__ = {name: List[str] for name in root.searcher.fieldinfos}
locals().update(dict.fromkeys(__annotations__, ()))


@strawberry.type
class Hit:
id: int
score: Optional[float]
sortkeys: List[str]
doc: Document


@strawberry.type
class Hits:
count: int
hits: List[Hit]


@strawberry.type
class Query:
@strawberry.field
Expand All @@ -61,6 +81,12 @@ def terms(self, info) -> Terms:
values[name], counts[name] = zip(*root.searcher.terms(name, counts=True))
return Terms(Values(**values), Counts(**counts))

@strawberry.field
def search(self, info, q: str, count: int = None) -> Hits:
"""Run query and return htis."""
hits = root.searcher.search(q, count)
return Hits(hits.count, hits=[Hit(hit.id, hit.score, hit.sortkeys, Document(**hit)) for hit in hits])


@strawberry.type
class Mutation:
Expand Down
10 changes: 10 additions & 0 deletions lupyne/server/rest.py
Expand Up @@ -21,6 +21,16 @@ def terms(name: str, *, counts: bool = False) -> Union[list, dict]:
return (dict if counts else list)(terms)


@app.get('/search')
def search(q: str, count: int = None) -> dict:
"""Run query and return htis."""
hits = root.searcher.search(q, count)
return {
'count': hits.count,
'hits': [{'id': hit.id, 'score': hit.score, 'sortkeys': hit.sortkeys, 'doc': hit} for hit in hits],
}


@app.middleware('http')
async def headers(request, call_next):
start = time.time()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_graphql.py
Expand Up @@ -38,3 +38,15 @@ def test_terms(client):
data = client.execute(query='''{ terms { counts { date } } }''')
counts = data['terms']['counts']['date']
assert counts[0] == 10


def test_search(client):
data = client.execute(
query='''{ search(q: "text:right", count: 1) { count hits { id score sortkeys doc { amendment } } } }'''
)
assert data['search']['count'] == 13
(hit,) = data['search']['hits']
assert hit['id'] == 9
assert hit['score'] > 0
assert hit['sortkeys'] == []
assert hit['doc'] == {'amendment': ['2']}
10 changes: 10 additions & 0 deletions tests/test_rest.py
Expand Up @@ -28,3 +28,13 @@ def test_terms(client):
assert min(result) == result[0] == '1791-12-15'
result = client.get('/terms/date', params={'counts': True}).json()
assert result['1791-12-15'] == 10


def test_search(client):
result = client.get('/search', params={'q': "text:right", 'count': 1}).json()
assert result['count'] == 13
(hit,) = result['hits']
assert hit['id'] == 9
assert hit['score'] > 0
assert hit['sortkeys'] == []
assert hit['doc'] == {'amendment': ['2'], 'date': ['1791-12-15']}

0 comments on commit 7ace284

Please sign in to comment.