Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi import Depends
from loguru import logger
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import AuthenticationFailedError, AuthenticationRequiredError
Expand Down Expand Up @@ -57,6 +57,10 @@ def fetch_user_or_raise(
return user


LIMIT_DEFAULT = 100
LIMIT_MAX = 1000


class Pagination(BaseModel):
offset: int = 0
limit: int = 100
offset: int = Field(default=0, ge=0)
limit: int = Field(default=LIMIT_DEFAULT, gt=0, le=LIMIT_MAX)
38 changes: 38 additions & 0 deletions tests/dependencies/pagination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any

import pytest
from pydantic import ValidationError

from routers.dependencies import Pagination


def test_pagination_defaults() -> None:
"""Pagination has expected defaults when no values are provided."""
pagination = Pagination()
assert pagination.offset == 0
assert pagination.limit == 100 # noqa: PLR2004


@pytest.mark.parametrize(
("kwargs", "expected_field"),
[
({"limit": "abc", "offset": 0}, "limit"),
({"limit": -5, "offset": 0}, "limit"),
({"limit": 2000, "offset": 0}, "limit"),
({"limit": 5, "offset": "xyz"}, "offset"),
({"limit": 5, "offset": -5}, "offset"),
],
ids=[
"bad_limit_type",
"negative_limit",
"limit_too_large",
"bad_offset_type",
"negative_offset",
],
)
def test_pagination_invalid_type(kwargs: dict[str, Any], expected_field: str) -> None:
"""Non-integer values for limit or offset raise a ValidationError."""
with pytest.raises(ValidationError) as exc_info:
Pagination(**kwargs)
errors = exc_info.value.errors()
assert any(error["loc"] == (expected_field,) for error in errors)
21 changes: 13 additions & 8 deletions tests/routers/openml/datasets_list_datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from core.errors import NoResultsError
from database.users import User
from routers.dependencies import Pagination
from routers.dependencies import LIMIT_DEFAULT, Pagination
from routers.openml.datasets import DatasetStatusFilter, list_datasets
from tests import constants
from tests.users import ADMIN_USER, DATASET_130_OWNER, SOME_USER, ApiKey
Expand Down Expand Up @@ -57,12 +57,6 @@ async def test_list_data_identical(
api_key = kwargs.pop("api_key")
api_key_query = f"?api_key={api_key}" if api_key else ""

# Pagination parameters are nested in the new query style
# The old style has no `limit` by default, so we mimic this with a high default
new_style = kwargs | {"pagination": {"limit": limit or 1_000_000}}
if offset is not None:
new_style["pagination"]["offset"] = offset

# old style `/data/filter` encodes all filters as a path
query = [
[filter_, value if not isinstance(value, list) else ",".join(str(v) for v in value)]
Expand All @@ -74,13 +68,21 @@ async def test_list_data_identical(
uri += f"/{'/'.join([str(v) for q in query for v in q])}"
uri += api_key_query

# new style just takes the values directly in a JSON body,
# except that the limit and offset parameters are under a pagination field.
if limit is not None:
kwargs.setdefault("pagination", {})["limit"] = limit
if offset is not None:
kwargs.setdefault("pagination", {})["offset"] = offset

py_response, php_response = await asyncio.gather(
py_api.post(f"/datasets/list{api_key_query}", json=new_style),
py_api.post(f"/datasets/list{api_key_query}", json=kwargs),
php_api.get(uri),
)

# Note: RFC 9457 changed some status codes (PRECONDITION_FAILED -> NOT_FOUND for no results)
# and the error response format, so we can't compare error responses directly.
# Validation errors shouldn't occur since the search space doesn't include invalid values
php_is_error = php_response.status_code == HTTPStatus.PRECONDITION_FAILED
py_is_error = py_response.status_code == HTTPStatus.NOT_FOUND

Expand All @@ -105,6 +107,9 @@ async def test_list_data_identical(

# PHP API has a double nested dictionary that never has other entries
php_json = php_response.json()["data"]["dataset"]
# The default limit changed from unbound to 100.
if limit is None:
php_json = php_json[:LIMIT_DEFAULT]
assert len(py_json) == len(php_json)
assert py_json == php_json
return None
Expand Down
5 changes: 3 additions & 2 deletions tests/routers/openml/migration/tasks_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
nested_remove_single_element_list,
nested_remove_values,
)
from routers.dependencies import LIMIT_MAX


@pytest.mark.parametrize(
Expand Down Expand Up @@ -141,8 +142,7 @@ async def test_list_tasks_equal(
- PHP error status is 412 PRECONDITION_FAILED; Python uses 404 NOT_FOUND.
"""
php_path = _build_php_task_list_path(php_params)
# Use a very large limit on Python side to match PHP's unbounded default result count
py_body = {**py_extra, "pagination": {"limit": 1_000_000, "offset": 0}}
py_body = {**py_extra, "pagination": {"limit": LIMIT_MAX, "offset": 0}}
py_response, php_response = await asyncio.gather(
py_api.post("/tasks/list", json=py_body),
php_api.get(php_path),
Expand All @@ -163,6 +163,7 @@ async def test_list_tasks_equal(
php_tasks: list[dict[str, Any]] = (
php_tasks_raw if isinstance(php_tasks_raw, list) else [php_tasks_raw]
)
php_tasks = php_tasks[:LIMIT_MAX]
py_tasks: list[dict[str, Any]] = [_normalize_py_task(t) for t in py_response.json()]

php_ids = {int(t["task_id"]) for t in php_tasks}
Expand Down
66 changes: 0 additions & 66 deletions tests/routers/openml/task_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,72 +74,6 @@ async def test_list_tasks_api_happy_path(py_api: httpx.AsyncClient) -> None:
assert "OpenML100" in task["tag"]


@pytest.mark.parametrize(
("limit", "offset", "expected_status", "expected_max_results"),
[
(-10, 0, HTTPStatus.NOT_FOUND, 0), # negative limit clamped to 0 -> No results
(5, -10, HTTPStatus.OK, 5), # negative offset clamped to 0 -> First 5 results
],
ids=["negative_limit", "negative_offset"],
)
async def test_list_tasks_negative_pagination_safely_clamped(
limit: int,
offset: int,
expected_status: int,
expected_max_results: int,
py_api: httpx.AsyncClient,
) -> None:
"""Negative pagination values are safely clamped to 0 instead of causing 500 errors.

A limit clamped to 0 raises NoResultsError, which the API maps to HTTP 404.
An offset clamped to 0 simply returns the first page of results (200 OK).

Note: This remains an HTTP-level (py_api) test to ensure end-to-end safety is
preserved.
"""
response = await py_api.post(
"/tasks/list",
json={"pagination": {"limit": limit, "offset": offset}},
)
assert response.status_code == expected_status
if expected_status == HTTPStatus.OK:
body = response.json()
assert len(body) <= expected_max_results
# Compare to a baseline with offset=0 to prove it was correctly clamped
baseline = await py_api.post(
"/tasks/list",
json={"pagination": {"limit": limit, "offset": 0}},
)
assert baseline.status_code == HTTPStatus.OK
assert [t["task_id"] for t in body] == [t["task_id"] for t in baseline.json()]
else:
error = response.json()
assert error["type"] == NoResultsError.uri


@pytest.mark.parametrize(
("pagination_override", "expected_field"),
[
({"limit": "abc", "offset": 0}, "limit"), # Invalid type
({"limit": 5, "offset": "xyz"}, "offset"), # Invalid type
],
ids=["bad_limit_type", "bad_offset_type"],
)
async def test_list_tasks_invalid_pagination_type(
pagination_override: dict[str, Any], expected_field: str, py_api: httpx.AsyncClient
) -> None:
"""Invalid pagination types return 422 Unprocessable Entity."""
response = await py_api.post(
"/tasks/list",
json={"pagination": pagination_override},
)
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
# Verify that the error points to the correct field
error = response.json()["errors"][0]
assert error["loc"][-2:] == ["pagination", expected_field]
assert error["type"] in {"type_error.integer", "int_parsing", "int_type"}


@pytest.mark.parametrize(
"value",
["1...2", "abc"],
Expand Down
Loading