Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/stac_auth_proxy/utils/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class JsonResponseMiddleware(ABC):

app: ASGIApp

# Expected data type for JSON responses. Only responses matching this type will be transformed.
# If None, all JSON responses will be transformed regardless of type.
expected_data_type: Optional[type] = dict

@abstractmethod
def should_transform_response(
self, request: Request, scope: Scope
Expand Down Expand Up @@ -97,8 +101,21 @@ async def transform_response(message: Message) -> None:
)
await response(scope, receive, send)
return
transformed = self.transform_json(data, request=request)
body = json.dumps(transformed).encode()

if self.expected_data_type is None or isinstance(
data, self.expected_data_type
):
transformed = self.transform_json(data, request=request)
body = json.dumps(transformed).encode()
else:
logger.warning(
"Received JSON response with unexpected data type %r from upstream server (%r %r), "
"skipping transformation (expected: %r)",
type(data).__name__,
request.method,
request.url,
self.expected_data_type.__name__,
)

# Update content-length header
headers["content-length"] = str(len(body))
Expand Down
189 changes: 187 additions & 2 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tests for middleware utilities."""

from typing import Any
from unittest.mock import patch

import pytest
from fastapi import FastAPI, Response
from starlette.datastructures import Headers
from starlette.requests import Request
Expand All @@ -17,18 +19,73 @@ class ExampleJsonResponseMiddleware(JsonResponseMiddleware):
def __init__(self, app: ASGIApp):
"""Initialize the middleware."""
self.app = app
# Use default expected_data_type (dict)

def should_transform_response(self, request: Request, scope: Scope) -> bool:
"""Transform JSON responses based on content type."""
return Headers(scope=scope).get("content-type", "") == "application/json"

def transform_json(self, data: Any, request: Request) -> Any:
"""Add a test field to the response."""
if isinstance(data, dict):
data["transformed"] = True
data["transformed"] = True
return data


class ExampleStringJsonResponseMiddleware(JsonResponseMiddleware):
"""Example implementation that expects string JSON responses."""

def __init__(self, app: ASGIApp):
"""Initialize the middleware."""
self.app = app
self.expected_data_type = str

def should_transform_response(self, request: Request, scope: Scope) -> bool:
"""Transform JSON responses based on content type."""
return Headers(scope=scope).get("content-type", "") == "application/json"

def transform_json(self, data: Any, request: Request) -> Any:
"""Transform string responses by adding a prefix."""
if isinstance(data, str):
return f"transformed: {data}"
return data


class ExampleListJsonResponseMiddleware(JsonResponseMiddleware):
"""Example implementation that expects list JSON responses."""

def __init__(self, app: ASGIApp):
"""Initialize the middleware."""
self.app = app
self.expected_data_type = list

def should_transform_response(self, request: Request, scope: Scope) -> bool:
"""Transform JSON responses based on content type."""
return Headers(scope=scope).get("content-type", "") == "application/json"

def transform_json(self, data: Any, request: Request) -> Any:
"""Transform list responses by adding a new item."""
if isinstance(data, list):
return data + ["transformed"]
return data


class ExampleAnyJsonResponseMiddleware(JsonResponseMiddleware):
"""Example implementation that transforms any JSON response type."""

def __init__(self, app: ASGIApp):
"""Initialize the middleware."""
self.app = app
self.expected_data_type = None # Transform any JSON type

def should_transform_response(self, request: Request, scope: Scope) -> bool:
"""Transform JSON responses based on content type."""
return Headers(scope=scope).get("content-type", "") == "application/json"

def transform_json(self, data: Any, request: Request) -> Any:
"""Transform any JSON response by wrapping it."""
return {"transformed": True, "data": data}


def test_json_response_middleware():
"""Test that JSON responses are properly transformed."""
app = FastAPI()
Expand Down Expand Up @@ -119,3 +176,131 @@ async def test_endpoint():
assert response.headers["content-type"] == "application/json"
data = response.json()
assert data == {"error": "Received invalid JSON from upstream server"}


@pytest.mark.parametrize(
"content,expected_data",
[
('"hello world"', "hello world"),
('[1, 2, 3, "test"]', [1, 2, 3, "test"]),
("42", 42),
("true", True),
("null", None),
],
)
def test_json_response_middleware_non_dict_json(content, expected_data):
"""Test that non-dict JSON responses are not transformed by default middleware."""
app = FastAPI()
app.add_middleware(ExampleJsonResponseMiddleware)

@app.get("/test")
async def test_endpoint():
return Response(content=content, media_type="application/json")

client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert data == expected_data # Should remain unchanged


@pytest.mark.parametrize(
"middleware_class, test_data, expected_result, should_transform",
[
# String middleware tests
(
ExampleStringJsonResponseMiddleware,
"this is a string",
"transformed: this is a string",
True,
),
(
ExampleStringJsonResponseMiddleware,
{"message": "not a string"},
{"message": "not a string"},
False,
),
# List middleware tests
(
ExampleListJsonResponseMiddleware,
[1, 2, 3],
[1, 2, 3, "transformed"],
True,
),
(
ExampleListJsonResponseMiddleware,
"not a list",
"not a list",
False,
),
# Dict middleware tests (default)
(
ExampleJsonResponseMiddleware,
{"message": "test"},
{"message": "test", "transformed": True},
True,
),
(
ExampleJsonResponseMiddleware,
"not a dict",
"not a dict",
False,
),
],
)
def test_json_response_middleware_type_specific(
middleware_class, test_data, expected_result, should_transform
):
"""Test that middleware transforms only expected data types."""
with patch.object(
middleware_class, "transform_json", return_value=expected_result
) as mock_method:
app = FastAPI()
app.add_middleware(middleware_class)

@app.get("/test")
async def test_endpoint():
return test_data

client = TestClient(app)
response = client.get("/test")

data = response.json()
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert mock_method.call_count == (1 if should_transform else 0)
if should_transform:
assert mock_method.call_args[0][0] == test_data
assert data == expected_result


@pytest.mark.parametrize(
"test_data",
[
{"message": "test"},
"hello world",
[1, 2, 3],
42,
True,
None,
],
)
def test_json_response_middleware_expected_none_type(test_data):
"""Test that middleware with expected_data_type=None transforms all JSON response types."""
app = FastAPI()
app.add_middleware(ExampleAnyJsonResponseMiddleware)

@app.get("/test")
async def test_endpoint():
return test_data

client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()

# Verify the simplified transformation behavior
assert data["transformed"] is True
assert data["data"] == test_data
Loading