Skip to content

Commit

Permalink
Merge pull request #620 from materialsproject/enhancement/add_extra_p…
Browse files Browse the repository at this point in the history
…agination_params

Changes to core query operators and API
  • Loading branch information
munrojm committed Apr 14, 2022
2 parents facbc20 + 988bf33 commit 7aefd8b
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 45 deletions.
3 changes: 3 additions & 0 deletions src/maggma/api/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from monty.json import MSONable
from starlette.responses import RedirectResponse

Expand Down Expand Up @@ -88,6 +89,8 @@ def app(self):

app.include_router(main_resource.router, prefix=f"/{prefix}")

app.add_middleware(GZipMiddleware, minimum_size=1000)

@app.get("/heartbeat", include_in_schema=False)
def heartbeat():
"""API Heartbeat for Load Balancing"""
Expand Down
67 changes: 51 additions & 16 deletions src/maggma/api/query_operator/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,74 @@
class PaginationQuery(QueryOperator):
"""Query opertators to provides Pagination"""

def __init__(
self, default_skip: int = 0, default_limit: int = 100, max_limit: int = 1000
):
def __init__(self, default_limit: int = 100, max_limit: int = 1000):
"""
Args:
default_skip: the default number of documents to skip
default_limit: the default number of documents to return
max_limit: max number of documents to return
"""
self.default_skip = default_skip

self.default_limit = default_limit
self.max_limit = max_limit

def query(
skip: int = Query(
default_skip, description="Number of entries to skip in the search"
_page: int = Query(
None,
description="Page number to request (takes precedent over _limit and _skip).",
),
_per_page: int = Query(
default_limit,
description="Number of entries to show per page (takes precedent over _limit and _skip)."
f" Limited to {max_limit}.",
),
limit: int = Query(
_skip: int = Query(
0, description="Number of entries to skip in the search.",
),
_limit: int = Query(
default_limit,
description="Max number of entries to return in a single query."
f" Limited to {max_limit}",
f" Limited to {max_limit}.",
),
) -> STORE_PARAMS:
"""
Pagination parameters for the API Endpoint
"""
if limit > max_limit:
raise HTTPException(
status_code=400,
detail="Requested more data per query than allowed by this endpoint."
f" The max limit is {max_limit} entries",
)
return {"skip": skip, "limit": limit}

if _page is not None:

if _per_page > max_limit:
raise HTTPException(
status_code=400,
detail="Requested more data per query than allowed by this endpoint."
f" The max limit is {max_limit} entries",
)

if _page < 0 or _per_page < 0:
raise HTTPException(
status_code=400,
detail="Cannot request negative _page or _per_page values",
)

return {
"skip": ((_page - 1) * _per_page) if _page >= 1 else 0,
"limit": _per_page,
}

else:
if _limit > max_limit:
raise HTTPException(
status_code=400,
detail="Requested more data per query than allowed by this endpoint."
f" The max limit is {max_limit} entries",
)

if _skip < 0 or _limit < 0:
raise HTTPException(
status_code=400,
detail="Cannot request negative _skip or _limit values",
)

return {"skip": _skip, "limit": _limit}

self.query = query # type: ignore

Expand Down
6 changes: 3 additions & 3 deletions src/maggma/api/query_operator/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SortQuery(QueryOperator):

def query(
self,
sort_fields: Optional[str] = Query(
_sort_fields: Optional[str] = Query(
None,
description="Comma delimited fields to sort with.\
Prefixing '-' to a field will force a sort in descending order.",
Expand All @@ -22,8 +22,8 @@ def query(

sort = {}

if sort_fields:
for sort_field in sort_fields.split(","):
if _sort_fields:
for sort_field in _sort_fields.split(","):
if sort_field[0] == "-":
sort.update({sort_field[1:]: -1})
else:
Expand Down
8 changes: 4 additions & 4 deletions src/maggma/api/query_operator/sparse_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ def __init__(
)

def query(
fields: str = Query(
_fields: str = Query(
None,
description=f"Fields to project from {str(model_name)} as a list of comma seperated strings.\
Fields include: `{'` `'.join(model_fields)}`",
),
all_fields: bool = Query(False, description="Include all fields."),
_all_fields: bool = Query(False, description="Include all fields."),
) -> STORE_PARAMS:
"""
Pagination parameters for the API Endpoint
"""

properties = (
fields.split(",") if isinstance(fields, str) else self.default_fields
_fields.split(",") if isinstance(_fields, str) else self.default_fields
)
if all_fields:
if _all_fields:
properties = model_fields

return {"properties": properties}
Expand Down
20 changes: 14 additions & 6 deletions src/maggma/api/resource/read_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,20 @@ async def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]:
key for key in request.query_params.keys() if key not in query_params
]
if any(overlap):
raise HTTPException(
status_code=400,
detail="Request contains query parameters which cannot be used: {}".format(
", ".join(overlap)
),
)
if "limit" in overlap or "skip" in overlap:
raise HTTPException(
status_code=400,
detail="'limit' and 'skip' parameters have been renamed. "
"Please update your API client to the newest version.",
)

else:
raise HTTPException(
status_code=400,
detail="Request contains query parameters which cannot be used: {}".format(
", ".join(overlap)
),
)

query: Dict[Any, Any] = merge_queries(list(queries.values())) # type: ignore

Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ def test_cluster_run(owner_store, pet_store):
res, data = search_helper(payload="")
assert res.status_code == 200

payload = {"name": "Person1", "limit": 10, "all_fields": True}
payload = {"name": "Person1", "_limit": 10, "_all_fields": True}
res, data = search_helper(payload=payload, base="/owners/?")
assert res.status_code == 200
assert len(data) == 1
assert data[0]["name"] == "Person1"

payload = {"name": "Pet1", "limit": 10, "all_fields": True}
payload = {"name": "Pet1", "_limit": 10, "_all_fields": True}
res, data = search_helper(payload=payload, base="/pets/?")
assert res.status_code == 200
assert len(data) == 1
Expand Down
43 changes: 36 additions & 7 deletions tests/api/test_query_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,27 @@ def test_pagination_functionality():

op = PaginationQuery()

assert op.query(skip=10, limit=20) == {"limit": 20, "skip": 10}
assert op.query(_skip=10, _limit=20, _page=None, _per_page=None) == {
"limit": 20,
"skip": 10,
}

assert op.query(_skip=None, _limit=None, _page=3, _per_page=23) == {
"limit": 23,
"skip": 46,
}

with pytest.raises(HTTPException):
op.query(_limit=10000, _skip=100, _page=None, _per_page=None)

with pytest.raises(HTTPException):
op.query(_limit=None, _skip=None, _page=5, _per_page=10000)

with pytest.raises(HTTPException):
op.query(limit=10000)
op.query(_limit=-1, _skip=100, _page=None, _per_page=None)

with pytest.raises(HTTPException):
op.query(_page=-1, _per_page=100, _skip=None, _limit=None)


def test_pagination_serialization():
Expand All @@ -42,7 +59,10 @@ def test_pagination_serialization():
with ScratchDir("."):
dumpfn(op, "temp.json")
new_op = loadfn("temp.json")
assert new_op.query(skip=10, limit=20) == {"limit": 20, "skip": 10}
assert new_op.query(_skip=10, _limit=20, _page=None, _per_page=None) == {
"limit": 20,
"skip": 10,
}


def test_sparse_query_functionality():
Expand All @@ -60,7 +80,9 @@ def test_sparse_query_serialization():
with ScratchDir("."):
dumpfn(op, "temp.json")
new_op = loadfn("temp.json")
assert new_op.query() == {"properties": ["name", "age", "weight", "last_updated"]}
assert new_op.query() == {
"properties": ["name", "age", "weight", "last_updated"]
}


def test_numeric_query_functionality():
Expand All @@ -69,7 +91,10 @@ def test_numeric_query_functionality():

assert op.meta() == {}
assert op.query(age_max=10, age_min=1, age_not_eq=[2, 3], weight_min=120) == {
"criteria": {"age": {"$lte": 10, "$gte": 1, "$ne": [2, 3]}, "weight": {"$gte": 120}}
"criteria": {
"age": {"$lte": 10, "$gte": 1, "$ne": [2, 3]},
"weight": {"$gte": 120},
}
}


Expand All @@ -87,7 +112,9 @@ def test_sort_query_functionality():

op = SortQuery()

assert op.query(sort_fields="volume,-density") == {"sort": {"volume": 1, "density": -1}}
assert op.query(_sort_fields="volume,-density") == {
"sort": {"volume": 1, "density": -1}
}


def test_sort_serialization():
Expand All @@ -97,7 +124,9 @@ def test_sort_serialization():
with ScratchDir("."):
dumpfn(op, "temp.json")
new_op = loadfn("temp.json")
assert new_op.query(sort_fields="volume,-density") == {"sort": {"volume": 1, "density": -1}}
assert new_op.query(_sort_fields="volume,-density") == {
"sort": {"volume": 1, "density": -1}
}


@pytest.fixture
Expand Down
14 changes: 7 additions & 7 deletions tests/api/test_read_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ def search_helper(payload, base: str = "/?", debug=True) -> Response:
def test_numeric_query_operator():

# Checking int
payload = {"age": 20, "all_fields": True}
payload = {"age": 20, "_all_fields": True}
res, data = search_helper(payload=payload, base="/?", debug=True)
assert res.status_code == 200
assert len(data) == 1
assert data[0]["age"] == 20

payload = {"age_not_eq": 9, "all_fields": True}
payload = {"age_not_eq": 9, "_all_fields": True}
res, data = search_helper(payload=payload, base="/?", debug=True)
assert res.status_code == 200
assert len(data) == 11
Expand All @@ -185,13 +185,13 @@ def test_numeric_query_operator():

def test_string_query_operator():

payload = {"name": "PersonAge9", "all_fields": True}
payload = {"name": "PersonAge9", "_all_fields": True}
res, data = search_helper(payload=payload, base="/?", debug=True)
assert res.status_code == 200
assert len(data) == 1
assert data[0]["name"] == "PersonAge9"

payload = {"name_not_eq": "PersonAge9", "all_fields": True}
payload = {"name_not_eq": "PersonAge9", "_all_fields": True}
res, data = search_helper(payload=payload, base="/?", debug=True)
assert res.status_code == 200
assert len(data) == 12
Expand All @@ -200,7 +200,7 @@ def test_string_query_operator():
def test_resource_compound():
payload = {
"name": "PersonAge20Weight200",
"all_fields": True,
"_all_fields": True,
"weight_min": 199.1,
"weight_max": 201.4,
"age": 20,
Expand All @@ -212,8 +212,8 @@ def test_resource_compound():

payload = {
"name": "PersonAge20Weight200",
"all_fields": False,
"fields": "name,age",
"_all_fields": False,
"_fields": "name,age",
"weight_min": 199.3,
"weight_max": 201.9,
"age": 20,
Expand Down

0 comments on commit 7aefd8b

Please sign in to comment.