Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/7475-dataset-validation-error-handling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type: Fixed
description: Replaced 500 errors with 422 responses for dataset validation failures and added skip_validation query param for troubleshooting
pr: 7475
labels: []
84 changes: 72 additions & 12 deletions src/fides/api/api/v1/endpoints/generic_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate as async_paginate
from fideslang.models import Dataset as FideslangDataset
from loguru import logger
from pydantic import ValidationError as PydanticValidationError
from sqlalchemy import not_, select
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -71,6 +72,32 @@
data_subject_router = APIRouter(tags=["DataSubject"], prefix=V1_URL_PREFIX)


def _dataset_validation_error_response(
exc: PydanticValidationError,
) -> JSONResponse:
"""Return a structured 422 response for dataset validation failures.

Scoped to dataset endpoints to avoid masking unrelated pydantic errors
as 422s across the application.
"""
errors = exc.errors()
field_errors = [
{"loc": e.get("loc"), "msg": e.get("msg"), "type": e.get("type")}
for e in errors
]
logger.error(
"Dataset validation error: "
f"{len(field_errors)} validation error(s): {field_errors}"
)
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": "The requested dataset contains data that fails validation.",
"errors": field_errors,
},
)


@dataset_router.post(
"/dataset",
dependencies=[Security(verify_oauth_client, scopes=[CTL_DATASET_CREATE])],
Expand Down Expand Up @@ -147,7 +174,12 @@ async def list_dataset_paginated(
only_unlinked_datasets: Optional[bool] = Query(False),
connection_type: Optional[ConnectionType] = Query(None),
minimal: Optional[bool] = Query(False),
) -> Union[Page[DatasetResponse], List[DatasetResponse]]:
skip_validation: bool = Query(
False,
description="[Troubleshooting only] Skip pydantic response validation. "
"Use this to retrieve datasets that contain data failing validation.",
),
) -> Union[Page[DatasetResponse], List[DatasetResponse], Response]:
"""
Get a list of all of the Datasets.
If any pagination parameters (size or page) are provided, then the response will be paginated.
Expand Down Expand Up @@ -212,21 +244,41 @@ async def list_dataset_paginated(

if not page and not size:
results = await list_resource_query(db, filtered_query, CtlDataset)
response = [
DatasetResponse.model_validate(result.__dict__) for result in results
]
if skip_validation:
return JSONResponse(content=jsonable_encoder(results))
try:
response = [
DatasetResponse.model_validate(result.__dict__) for result in results
]
except PydanticValidationError as exc:
return _dataset_validation_error_response(exc)
return response

pagination_params = Params(page=page or 1, size=size or 50)
results = await async_paginate(db, filtered_query, pagination_params)

validated_items = []
for result in results.items: # type: ignore[attr-defined]
# run pydantic validation in a threadpool to avoid blocking the main thread
validated_item = await run_in_threadpool(
partial(DatasetResponse.model_validate, result.__dict__)
if skip_validation:
return JSONResponse(
content={
"items": jsonable_encoder(
list(results.items) # type: ignore[attr-defined]
),
"total": results.total, # type: ignore[attr-defined]
"page": results.page, # type: ignore[attr-defined]
"size": results.size, # type: ignore[attr-defined]
"pages": results.pages, # type: ignore[attr-defined]
}
)
validated_items.append(validated_item)

try:
validated_items = []
for result in results.items: # type: ignore[attr-defined]
validated_item = await run_in_threadpool(
partial(DatasetResponse.model_validate, result.__dict__)
)
validated_items.append(validated_item)
except PydanticValidationError as exc:
return _dataset_validation_error_response(exc)

results.items = validated_items # type: ignore[attr-defined]

Expand All @@ -242,15 +294,23 @@ async def list_dataset_paginated(
async def get_dataset(
fides_key: str,
dataset_service: DatasetService = Depends(get_dataset_service),
) -> Dict:
skip_validation: bool = Query(
False,
description="[Troubleshooting only] Skip pydantic response validation. "
"Use this to retrieve a dataset that contains data failing validation.",
),
) -> Union[Dict, Response]:
"""Get a single dataset by fides key"""
try:
return dataset_service.get_dataset(fides_key)
result = dataset_service.get_dataset(fides_key)
except DatasetNotFoundException as e:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=str(e),
)
if skip_validation:
return JSONResponse(content=jsonable_encoder(result))
return result


@dataset_router.delete(
Expand Down
39 changes: 38 additions & 1 deletion src/fides/api/api/v1/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
from typing import Callable, List

from fastapi import Request
from fastapi.exceptions import ResponseValidationError
from fastapi.responses import JSONResponse
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from loguru import logger
from starlette.status import (
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_500_INTERNAL_SERVER_ERROR,
)

from fides.api.common_exceptions import RedisNotConfigured


async def response_validation_error_handler(
request: Request, exc: ResponseValidationError
) -> JSONResponse:
"""Handle ResponseValidationError raised during FastAPI response serialization.

This occurs when data read from the database no longer passes the response
model's validators — for example, a dataset field with data_type='string'
that has sub-fields (should be 'object').
"""
errors = exc.errors()
field_errors = [
{
"loc": error.get("loc"),
"msg": error.get("msg"),
"type": error.get("type"),
}
for error in errors
]
logger.error(
f"Response validation error for {request.method} {request.url.path}: "
f"{len(field_errors)} validation error(s): {field_errors}"
)
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": "The requested resource contains data that fails validation "
"when serializing the response.",
"errors": field_errors,
},
)


class ExceptionHandlers:
@staticmethod
def redis_not_configured_handler(
Expand Down
11 changes: 10 additions & 1 deletion src/fides/api/app_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import AsyncGenerator, List

from fastapi import FastAPI
from fastapi.exceptions import ResponseValidationError
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.routing import APIRoute
from loguru import logger
Expand All @@ -26,7 +27,10 @@
from fides.api.api.v1.endpoints.admin import ADMIN_ROUTER
from fides.api.api.v1.endpoints.generic_overrides import GENERIC_OVERRIDES_ROUTER
from fides.api.api.v1.endpoints.health import HEALTH_ROUTER
from fides.api.api.v1.exception_handlers import ExceptionHandlers
from fides.api.api.v1.exception_handlers import (
ExceptionHandlers,
response_validation_error_handler,
)
from fides.api.asgi_middleware import (
AnalyticsLoggingMiddleware,
AuditLogMiddleware,
Expand Down Expand Up @@ -104,6 +108,11 @@ def create_fides_app(
# Starlette bug causing this to fail mypy
fastapi_app.add_exception_handler(RedisNotConfigured, handler) # type: ignore

fastapi_app.add_exception_handler(
ResponseValidationError,
response_validation_error_handler, # type: ignore[arg-type]
)

if is_rate_limit_enabled:
# Validate header before SlowAPI processes the request
fastapi_app.add_middleware(RateLimitIPValidationMiddleware)
Expand Down
161 changes: 161 additions & 0 deletions tests/ops/api/v1/endpoints/test_generic_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,28 @@
from fides.api.models.sql_models import Dataset as CtlDataset
from fides.common.api.scope_registry import CTL_DATASET_READ

INVALID_FIELD_COLLECTIONS = [
{
"name": "test_collection",
"fields": [
{
"name": "valid_field",
"fides_meta": {"data_type": "string"},
},
{
"name": "bad_field",
"fides_meta": {"data_type": "string"},
"fields": [
{
"name": "child",
"fides_meta": {"data_type": "string"},
}
],
},
],
}
]


@pytest.fixture
def minimal_dataset(db: Session) -> Generator[CtlDataset, None, None]:
Expand Down Expand Up @@ -233,3 +255,142 @@ def test_list_dataset_paginated_connection_type(
assert (
response.status_code == 422
) # Unprocessable Entity for invalid enum value


@pytest.fixture
def dataset_with_invalid_field(db: Session) -> Generator[CtlDataset, None, None]:
"""A dataset whose field has data_type='string' but also has subfields.

This is invalid per fideslang validation and will trigger a ValidationError
when deserialized through pydantic, but can exist in the DB because the
collections column is unvalidated JSON.
"""
dataset = CtlDataset(
fides_key="invalid_field_dataset",
name="Invalid Field Dataset",
organization_fides_key="default_organization",
collections=INVALID_FIELD_COLLECTIONS,
)
db.add(dataset)
db.commit()
yield dataset

db.delete(dataset)
db.commit()
Comment on lines +276 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Manual record deletion in fixture teardown

The dataset_with_invalid_field fixture manually deletes records in its teardown. Per repository convention, the database is automatically cleared between test runs, making this cleanup unnecessary and potentially error-prone if the test fails before reaching the yield.

Suggested change
yield dataset
db.delete(dataset)
db.commit()
yield dataset

Context Used: Rule from dashboard - Do not manually delete database records in test fixtures or at the end of tests, as the database is ... (source)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!



class TestDatasetSkipValidation:
"""Tests for the skip_validation query parameter on dataset endpoints."""

def test_list_datasets_skip_validation_returns_invalid_data(
self,
api_client: TestClient,
dataset_with_invalid_field: CtlDataset,
generate_auth_header,
):
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get(
"/api/v1/dataset?skip_validation=true", headers=auth_header
)
assert response.status_code == 200
data = response.json()
keys = [d["fides_key"] for d in data]
assert "invalid_field_dataset" in keys

def test_list_datasets_paginated_skip_validation(
self,
api_client: TestClient,
dataset_with_invalid_field: CtlDataset,
generate_auth_header,
):
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get(
"/api/v1/dataset?page=1&size=50&skip_validation=true",
headers=auth_header,
)
assert response.status_code == 200
data = response.json()
keys = [item["fides_key"] for item in data["items"]]
assert "invalid_field_dataset" in keys

def test_get_dataset_skip_validation_returns_invalid_data(
self,
api_client: TestClient,
dataset_with_invalid_field: CtlDataset,
generate_auth_header,
):
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get(
"/api/v1/dataset/invalid_field_dataset?skip_validation=true",
headers=auth_header,
)
assert response.status_code == 200
data = response.json()
assert data["fides_key"] == "invalid_field_dataset"
bad_field = data["collections"][0]["fields"][1]
assert bad_field["name"] == "bad_field"
assert bad_field["fides_meta"]["data_type"] == "string"
assert len(bad_field["fields"]) == 1

def test_skip_validation_defaults_to_false(
self,
api_client: TestClient,
minimal_dataset: CtlDataset,
generate_auth_header,
):
"""Normal requests without skip_validation still validate."""
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get("/api/v1/dataset", headers=auth_header)
assert response.status_code == 200

response = api_client.get(
f"/api/v1/dataset/{minimal_dataset.fides_key}", headers=auth_header
)
assert response.status_code == 200


class TestDatasetValidationErrorHandlers:
"""Tests that validation errors on dataset retrieval return 422, not 500."""

def test_list_datasets_returns_422_for_invalid_data(
self,
api_client: TestClient,
dataset_with_invalid_field: CtlDataset,
generate_auth_header,
):
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get("/api/v1/dataset", headers=auth_header)
assert response.status_code == 422
body = response.json()
assert "detail" in body
assert "errors" in body
assert any("bad_field" in err.get("msg", "") for err in body["errors"])

def test_list_datasets_paginated_returns_422_for_invalid_data(
self,
api_client: TestClient,
dataset_with_invalid_field: CtlDataset,
generate_auth_header,
):
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get("/api/v1/dataset?page=1&size=50", headers=auth_header)
assert response.status_code == 422
body = response.json()
assert "detail" in body
assert "errors" in body

def test_get_dataset_returns_422_for_invalid_data(
self,
api_client: TestClient,
dataset_with_invalid_field: CtlDataset,
generate_auth_header,
):
auth_header = generate_auth_header([CTL_DATASET_READ])
response = api_client.get(
"/api/v1/dataset/invalid_field_dataset", headers=auth_header
)
assert response.status_code == 422
body = response.json()
assert "detail" in body
assert "errors" in body
assert any("bad_field" in err.get("msg", "") for err in body["errors"])
Loading