diff --git a/changelog/7475-dataset-validation-error-handling.yaml b/changelog/7475-dataset-validation-error-handling.yaml new file mode 100644 index 0000000000..fc035274e9 --- /dev/null +++ b/changelog/7475-dataset-validation-error-handling.yaml @@ -0,0 +1,4 @@ +type: Fixed +description: Replaced 500 errors with 422 responses for dataset validation failures and added skip_validation query param for troubleshooting +pr: 7475 +labels: [] diff --git a/src/fides/api/api/v1/endpoints/generic_overrides.py b/src/fides/api/api/v1/endpoints/generic_overrides.py index 55249adf7d..3cf79d44a4 100644 --- a/src/fides/api/api/v1/endpoints/generic_overrides.py +++ b/src/fides/api/api/v1/endpoints/generic_overrides.py @@ -8,6 +8,7 @@ from fastapi_pagination import Page, Params from fastapi_pagination.ext.sqlalchemy import paginate as async_paginate from fideslang.models import Dataset as FideslangDataset +from loguru import logger from pydantic import ValidationError as PydanticValidationError from sqlalchemy import not_, select from sqlalchemy.ext.asyncio import AsyncSession @@ -71,6 +72,32 @@ data_subject_router = APIRouter(tags=["DataSubject"], prefix=V1_URL_PREFIX) +def _dataset_validation_error_response( + exc: PydanticValidationError, +) -> JSONResponse: + """Return a structured 422 response for dataset validation failures. + + Scoped to dataset endpoints to avoid masking unrelated pydantic errors + as 422s across the application. + """ + errors = exc.errors() + field_errors = [ + {"loc": e.get("loc"), "msg": e.get("msg"), "type": e.get("type")} + for e in errors + ] + logger.error( + "Dataset validation error: " + f"{len(field_errors)} validation error(s): {field_errors}" + ) + return JSONResponse( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "detail": "The requested dataset contains data that fails validation.", + "errors": field_errors, + }, + ) + + @dataset_router.post( "/dataset", dependencies=[Security(verify_oauth_client, scopes=[CTL_DATASET_CREATE])], @@ -147,7 +174,12 @@ async def list_dataset_paginated( only_unlinked_datasets: Optional[bool] = Query(False), connection_type: Optional[ConnectionType] = Query(None), minimal: Optional[bool] = Query(False), -) -> Union[Page[DatasetResponse], List[DatasetResponse]]: + skip_validation: bool = Query( + False, + description="[Troubleshooting only] Skip pydantic response validation. " + "Use this to retrieve datasets that contain data failing validation.", + ), +) -> Union[Page[DatasetResponse], List[DatasetResponse], Response]: """ Get a list of all of the Datasets. If any pagination parameters (size or page) are provided, then the response will be paginated. @@ -212,21 +244,41 @@ async def list_dataset_paginated( if not page and not size: results = await list_resource_query(db, filtered_query, CtlDataset) - response = [ - DatasetResponse.model_validate(result.__dict__) for result in results - ] + if skip_validation: + return JSONResponse(content=jsonable_encoder(results)) + try: + response = [ + DatasetResponse.model_validate(result.__dict__) for result in results + ] + except PydanticValidationError as exc: + return _dataset_validation_error_response(exc) return response pagination_params = Params(page=page or 1, size=size or 50) results = await async_paginate(db, filtered_query, pagination_params) - validated_items = [] - for result in results.items: # type: ignore[attr-defined] - # run pydantic validation in a threadpool to avoid blocking the main thread - validated_item = await run_in_threadpool( - partial(DatasetResponse.model_validate, result.__dict__) + if skip_validation: + return JSONResponse( + content={ + "items": jsonable_encoder( + list(results.items) # type: ignore[attr-defined] + ), + "total": results.total, # type: ignore[attr-defined] + "page": results.page, # type: ignore[attr-defined] + "size": results.size, # type: ignore[attr-defined] + "pages": results.pages, # type: ignore[attr-defined] + } ) - validated_items.append(validated_item) + + try: + validated_items = [] + for result in results.items: # type: ignore[attr-defined] + validated_item = await run_in_threadpool( + partial(DatasetResponse.model_validate, result.__dict__) + ) + validated_items.append(validated_item) + except PydanticValidationError as exc: + return _dataset_validation_error_response(exc) results.items = validated_items # type: ignore[attr-defined] @@ -242,15 +294,23 @@ async def list_dataset_paginated( async def get_dataset( fides_key: str, dataset_service: DatasetService = Depends(get_dataset_service), -) -> Dict: + skip_validation: bool = Query( + False, + description="[Troubleshooting only] Skip pydantic response validation. " + "Use this to retrieve a dataset that contains data failing validation.", + ), +) -> Union[Dict, Response]: """Get a single dataset by fides key""" try: - return dataset_service.get_dataset(fides_key) + result = dataset_service.get_dataset(fides_key) except DatasetNotFoundException as e: raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail=str(e), ) + if skip_validation: + return JSONResponse(content=jsonable_encoder(result)) + return result @dataset_router.delete( diff --git a/src/fides/api/api/v1/exception_handlers.py b/src/fides/api/api/v1/exception_handlers.py index c2ecd0f1a2..16f019bb03 100644 --- a/src/fides/api/api/v1/exception_handlers.py +++ b/src/fides/api/api/v1/exception_handlers.py @@ -1,12 +1,49 @@ from typing import Callable, List from fastapi import Request +from fastapi.exceptions import ResponseValidationError from fastapi.responses import JSONResponse -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from loguru import logger +from starlette.status import ( + HTTP_422_UNPROCESSABLE_ENTITY, + HTTP_500_INTERNAL_SERVER_ERROR, +) from fides.api.common_exceptions import RedisNotConfigured +async def response_validation_error_handler( + request: Request, exc: ResponseValidationError +) -> JSONResponse: + """Handle ResponseValidationError raised during FastAPI response serialization. + + This occurs when data read from the database no longer passes the response + model's validators — for example, a dataset field with data_type='string' + that has sub-fields (should be 'object'). + """ + errors = exc.errors() + field_errors = [ + { + "loc": error.get("loc"), + "msg": error.get("msg"), + "type": error.get("type"), + } + for error in errors + ] + logger.error( + f"Response validation error for {request.method} {request.url.path}: " + f"{len(field_errors)} validation error(s): {field_errors}" + ) + return JSONResponse( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "detail": "The requested resource contains data that fails validation " + "when serializing the response.", + "errors": field_errors, + }, + ) + + class ExceptionHandlers: @staticmethod def redis_not_configured_handler( diff --git a/src/fides/api/app_setup.py b/src/fides/api/app_setup.py index 40da829b9b..b7b5e976d9 100644 --- a/src/fides/api/app_setup.py +++ b/src/fides/api/app_setup.py @@ -7,6 +7,7 @@ from typing import AsyncGenerator, List from fastapi import FastAPI +from fastapi.exceptions import ResponseValidationError from fastapi.middleware.gzip import GZipMiddleware from fastapi.routing import APIRoute from loguru import logger @@ -26,7 +27,10 @@ from fides.api.api.v1.endpoints.admin import ADMIN_ROUTER from fides.api.api.v1.endpoints.generic_overrides import GENERIC_OVERRIDES_ROUTER from fides.api.api.v1.endpoints.health import HEALTH_ROUTER -from fides.api.api.v1.exception_handlers import ExceptionHandlers +from fides.api.api.v1.exception_handlers import ( + ExceptionHandlers, + response_validation_error_handler, +) from fides.api.asgi_middleware import ( AnalyticsLoggingMiddleware, AuditLogMiddleware, @@ -104,6 +108,11 @@ def create_fides_app( # Starlette bug causing this to fail mypy fastapi_app.add_exception_handler(RedisNotConfigured, handler) # type: ignore + fastapi_app.add_exception_handler( + ResponseValidationError, + response_validation_error_handler, # type: ignore[arg-type] + ) + if is_rate_limit_enabled: # Validate header before SlowAPI processes the request fastapi_app.add_middleware(RateLimitIPValidationMiddleware) diff --git a/tests/ops/api/v1/endpoints/test_generic_overrides.py b/tests/ops/api/v1/endpoints/test_generic_overrides.py index 86017b239b..17a16d9855 100644 --- a/tests/ops/api/v1/endpoints/test_generic_overrides.py +++ b/tests/ops/api/v1/endpoints/test_generic_overrides.py @@ -7,6 +7,28 @@ from fides.api.models.sql_models import Dataset as CtlDataset from fides.common.api.scope_registry import CTL_DATASET_READ +INVALID_FIELD_COLLECTIONS = [ + { + "name": "test_collection", + "fields": [ + { + "name": "valid_field", + "fides_meta": {"data_type": "string"}, + }, + { + "name": "bad_field", + "fides_meta": {"data_type": "string"}, + "fields": [ + { + "name": "child", + "fides_meta": {"data_type": "string"}, + } + ], + }, + ], + } +] + @pytest.fixture def minimal_dataset(db: Session) -> Generator[CtlDataset, None, None]: @@ -233,3 +255,142 @@ def test_list_dataset_paginated_connection_type( assert ( response.status_code == 422 ) # Unprocessable Entity for invalid enum value + + +@pytest.fixture +def dataset_with_invalid_field(db: Session) -> Generator[CtlDataset, None, None]: + """A dataset whose field has data_type='string' but also has subfields. + + This is invalid per fideslang validation and will trigger a ValidationError + when deserialized through pydantic, but can exist in the DB because the + collections column is unvalidated JSON. + """ + dataset = CtlDataset( + fides_key="invalid_field_dataset", + name="Invalid Field Dataset", + organization_fides_key="default_organization", + collections=INVALID_FIELD_COLLECTIONS, + ) + db.add(dataset) + db.commit() + yield dataset + + db.delete(dataset) + db.commit() + + +class TestDatasetSkipValidation: + """Tests for the skip_validation query parameter on dataset endpoints.""" + + def test_list_datasets_skip_validation_returns_invalid_data( + self, + api_client: TestClient, + dataset_with_invalid_field: CtlDataset, + generate_auth_header, + ): + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get( + "/api/v1/dataset?skip_validation=true", headers=auth_header + ) + assert response.status_code == 200 + data = response.json() + keys = [d["fides_key"] for d in data] + assert "invalid_field_dataset" in keys + + def test_list_datasets_paginated_skip_validation( + self, + api_client: TestClient, + dataset_with_invalid_field: CtlDataset, + generate_auth_header, + ): + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get( + "/api/v1/dataset?page=1&size=50&skip_validation=true", + headers=auth_header, + ) + assert response.status_code == 200 + data = response.json() + keys = [item["fides_key"] for item in data["items"]] + assert "invalid_field_dataset" in keys + + def test_get_dataset_skip_validation_returns_invalid_data( + self, + api_client: TestClient, + dataset_with_invalid_field: CtlDataset, + generate_auth_header, + ): + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get( + "/api/v1/dataset/invalid_field_dataset?skip_validation=true", + headers=auth_header, + ) + assert response.status_code == 200 + data = response.json() + assert data["fides_key"] == "invalid_field_dataset" + bad_field = data["collections"][0]["fields"][1] + assert bad_field["name"] == "bad_field" + assert bad_field["fides_meta"]["data_type"] == "string" + assert len(bad_field["fields"]) == 1 + + def test_skip_validation_defaults_to_false( + self, + api_client: TestClient, + minimal_dataset: CtlDataset, + generate_auth_header, + ): + """Normal requests without skip_validation still validate.""" + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get("/api/v1/dataset", headers=auth_header) + assert response.status_code == 200 + + response = api_client.get( + f"/api/v1/dataset/{minimal_dataset.fides_key}", headers=auth_header + ) + assert response.status_code == 200 + + +class TestDatasetValidationErrorHandlers: + """Tests that validation errors on dataset retrieval return 422, not 500.""" + + def test_list_datasets_returns_422_for_invalid_data( + self, + api_client: TestClient, + dataset_with_invalid_field: CtlDataset, + generate_auth_header, + ): + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get("/api/v1/dataset", headers=auth_header) + assert response.status_code == 422 + body = response.json() + assert "detail" in body + assert "errors" in body + assert any("bad_field" in err.get("msg", "") for err in body["errors"]) + + def test_list_datasets_paginated_returns_422_for_invalid_data( + self, + api_client: TestClient, + dataset_with_invalid_field: CtlDataset, + generate_auth_header, + ): + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get("/api/v1/dataset?page=1&size=50", headers=auth_header) + assert response.status_code == 422 + body = response.json() + assert "detail" in body + assert "errors" in body + + def test_get_dataset_returns_422_for_invalid_data( + self, + api_client: TestClient, + dataset_with_invalid_field: CtlDataset, + generate_auth_header, + ): + auth_header = generate_auth_header([CTL_DATASET_READ]) + response = api_client.get( + "/api/v1/dataset/invalid_field_dataset", headers=auth_header + ) + assert response.status_code == 422 + body = response.json() + assert "detail" in body + assert "errors" in body + assert any("bad_field" in err.get("msg", "") for err in body["errors"])