diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index a770a45..dba12d8 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -3,7 +3,7 @@ from fastapi import Depends from loguru import logger -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncConnection from core.errors import AuthenticationFailedError, AuthenticationRequiredError @@ -57,6 +57,10 @@ def fetch_user_or_raise( return user +LIMIT_DEFAULT = 100 +LIMIT_MAX = 1000 + + class Pagination(BaseModel): - offset: int = 0 - limit: int = 100 + offset: int = Field(default=0, ge=0) + limit: int = Field(default=LIMIT_DEFAULT, gt=0, le=LIMIT_MAX) diff --git a/tests/dependencies/pagination_test.py b/tests/dependencies/pagination_test.py new file mode 100644 index 0000000..0de2d69 --- /dev/null +++ b/tests/dependencies/pagination_test.py @@ -0,0 +1,38 @@ +from typing import Any + +import pytest +from pydantic import ValidationError + +from routers.dependencies import Pagination + + +def test_pagination_defaults() -> None: + """Pagination has expected defaults when no values are provided.""" + pagination = Pagination() + assert pagination.offset == 0 + assert pagination.limit == 100 # noqa: PLR2004 + + +@pytest.mark.parametrize( + ("kwargs", "expected_field"), + [ + ({"limit": "abc", "offset": 0}, "limit"), + ({"limit": -5, "offset": 0}, "limit"), + ({"limit": 2000, "offset": 0}, "limit"), + ({"limit": 5, "offset": "xyz"}, "offset"), + ({"limit": 5, "offset": -5}, "offset"), + ], + ids=[ + "bad_limit_type", + "negative_limit", + "limit_too_large", + "bad_offset_type", + "negative_offset", + ], +) +def test_pagination_invalid_type(kwargs: dict[str, Any], expected_field: str) -> None: + """Non-integer values for limit or offset raise a ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + Pagination(**kwargs) + errors = exc_info.value.errors() + assert any(error["loc"] == (expected_field,) for error in errors) diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index be08927..7460fb5 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -11,7 +11,7 @@ from core.errors import NoResultsError from database.users import User -from routers.dependencies import Pagination +from routers.dependencies import LIMIT_DEFAULT, Pagination from routers.openml.datasets import DatasetStatusFilter, list_datasets from tests import constants from tests.users import ADMIN_USER, DATASET_130_OWNER, SOME_USER, ApiKey @@ -57,12 +57,6 @@ async def test_list_data_identical( api_key = kwargs.pop("api_key") api_key_query = f"?api_key={api_key}" if api_key else "" - # Pagination parameters are nested in the new query style - # The old style has no `limit` by default, so we mimic this with a high default - new_style = kwargs | {"pagination": {"limit": limit or 1_000_000}} - if offset is not None: - new_style["pagination"]["offset"] = offset - # old style `/data/filter` encodes all filters as a path query = [ [filter_, value if not isinstance(value, list) else ",".join(str(v) for v in value)] @@ -74,13 +68,21 @@ async def test_list_data_identical( uri += f"/{'/'.join([str(v) for q in query for v in q])}" uri += api_key_query + # new style just takes the values directly in a JSON body, + # except that the limit and offset parameters are under a pagination field. + if limit is not None: + kwargs.setdefault("pagination", {})["limit"] = limit + if offset is not None: + kwargs.setdefault("pagination", {})["offset"] = offset + py_response, php_response = await asyncio.gather( - py_api.post(f"/datasets/list{api_key_query}", json=new_style), + py_api.post(f"/datasets/list{api_key_query}", json=kwargs), php_api.get(uri), ) # Note: RFC 9457 changed some status codes (PRECONDITION_FAILED -> NOT_FOUND for no results) # and the error response format, so we can't compare error responses directly. + # Validation errors shouldn't occur since the search space doesn't include invalid values php_is_error = php_response.status_code == HTTPStatus.PRECONDITION_FAILED py_is_error = py_response.status_code == HTTPStatus.NOT_FOUND @@ -105,6 +107,9 @@ async def test_list_data_identical( # PHP API has a double nested dictionary that never has other entries php_json = php_response.json()["data"]["dataset"] + # The default limit changed from unbound to 100. + if limit is None: + php_json = php_json[:LIMIT_DEFAULT] assert len(py_json) == len(php_json) assert py_json == php_json return None diff --git a/tests/routers/openml/migration/tasks_migration_test.py b/tests/routers/openml/migration/tasks_migration_test.py index a11f1a5..ea3226b 100644 --- a/tests/routers/openml/migration/tasks_migration_test.py +++ b/tests/routers/openml/migration/tasks_migration_test.py @@ -11,6 +11,7 @@ nested_remove_single_element_list, nested_remove_values, ) +from routers.dependencies import LIMIT_MAX @pytest.mark.parametrize( @@ -141,8 +142,7 @@ async def test_list_tasks_equal( - PHP error status is 412 PRECONDITION_FAILED; Python uses 404 NOT_FOUND. """ php_path = _build_php_task_list_path(php_params) - # Use a very large limit on Python side to match PHP's unbounded default result count - py_body = {**py_extra, "pagination": {"limit": 1_000_000, "offset": 0}} + py_body = {**py_extra, "pagination": {"limit": LIMIT_MAX, "offset": 0}} py_response, php_response = await asyncio.gather( py_api.post("/tasks/list", json=py_body), php_api.get(php_path), @@ -163,6 +163,7 @@ async def test_list_tasks_equal( php_tasks: list[dict[str, Any]] = ( php_tasks_raw if isinstance(php_tasks_raw, list) else [php_tasks_raw] ) + php_tasks = php_tasks[:LIMIT_MAX] py_tasks: list[dict[str, Any]] = [_normalize_py_task(t) for t in py_response.json()] php_ids = {int(t["task_id"]) for t in php_tasks} diff --git a/tests/routers/openml/task_list_test.py b/tests/routers/openml/task_list_test.py index 67d8539..78eb5ec 100644 --- a/tests/routers/openml/task_list_test.py +++ b/tests/routers/openml/task_list_test.py @@ -74,72 +74,6 @@ async def test_list_tasks_api_happy_path(py_api: httpx.AsyncClient) -> None: assert "OpenML100" in task["tag"] -@pytest.mark.parametrize( - ("limit", "offset", "expected_status", "expected_max_results"), - [ - (-10, 0, HTTPStatus.NOT_FOUND, 0), # negative limit clamped to 0 -> No results - (5, -10, HTTPStatus.OK, 5), # negative offset clamped to 0 -> First 5 results - ], - ids=["negative_limit", "negative_offset"], -) -async def test_list_tasks_negative_pagination_safely_clamped( - limit: int, - offset: int, - expected_status: int, - expected_max_results: int, - py_api: httpx.AsyncClient, -) -> None: - """Negative pagination values are safely clamped to 0 instead of causing 500 errors. - - A limit clamped to 0 raises NoResultsError, which the API maps to HTTP 404. - An offset clamped to 0 simply returns the first page of results (200 OK). - - Note: This remains an HTTP-level (py_api) test to ensure end-to-end safety is - preserved. - """ - response = await py_api.post( - "/tasks/list", - json={"pagination": {"limit": limit, "offset": offset}}, - ) - assert response.status_code == expected_status - if expected_status == HTTPStatus.OK: - body = response.json() - assert len(body) <= expected_max_results - # Compare to a baseline with offset=0 to prove it was correctly clamped - baseline = await py_api.post( - "/tasks/list", - json={"pagination": {"limit": limit, "offset": 0}}, - ) - assert baseline.status_code == HTTPStatus.OK - assert [t["task_id"] for t in body] == [t["task_id"] for t in baseline.json()] - else: - error = response.json() - assert error["type"] == NoResultsError.uri - - -@pytest.mark.parametrize( - ("pagination_override", "expected_field"), - [ - ({"limit": "abc", "offset": 0}, "limit"), # Invalid type - ({"limit": 5, "offset": "xyz"}, "offset"), # Invalid type - ], - ids=["bad_limit_type", "bad_offset_type"], -) -async def test_list_tasks_invalid_pagination_type( - pagination_override: dict[str, Any], expected_field: str, py_api: httpx.AsyncClient -) -> None: - """Invalid pagination types return 422 Unprocessable Entity.""" - response = await py_api.post( - "/tasks/list", - json={"pagination": pagination_override}, - ) - assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY - # Verify that the error points to the correct field - error = response.json()["errors"][0] - assert error["loc"][-2:] == ["pagination", expected_field] - assert error["type"] in {"type_error.integer", "int_parsing", "int_type"} - - @pytest.mark.parametrize( "value", ["1...2", "abc"],