Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,43 @@
logger = logging.getLogger(__name__)


def _extract_hostname(netloc: str) -> str:
"""
Extract hostname from netloc, ignoring port number.

Args:
netloc: Network location string (e.g., "localhost:8080" or "example.com")

Returns:
Hostname without port (e.g., "localhost" or "example.com")

"""
if ":" in netloc:
if netloc.startswith("["):
# IPv6 with port: [::1]:8080
end_bracket = netloc.rfind("]")
if end_bracket != -1:
return netloc[: end_bracket + 1]
# Regular hostname with port: localhost:8080
return netloc.split(":", 1)[0]
return netloc


def _hostnames_match(hostname1: str, hostname2: str) -> bool:
"""
Check if two hostnames match, ignoring case and port.

Args:
hostname1: First hostname (may include port)
hostname2: Second hostname (may include port)

Returns:
True if hostnames match (case-insensitive, ignoring port)

"""
return _extract_hostname(hostname1).lower() == _extract_hostname(hostname2).lower()


@dataclass
class ProcessLinksMiddleware(JsonResponseMiddleware):
"""
Expand Down Expand Up @@ -70,10 +107,14 @@ def _update_link(

parsed_link = urlparse(link["href"])

if parsed_link.netloc not in [
request_url.netloc,
upstream_url.netloc,
]:
link_hostname = _extract_hostname(parsed_link.netloc)
request_hostname = _extract_hostname(request_url.netloc)
upstream_hostname = _extract_hostname(upstream_url.netloc)

if not (
_hostnames_match(link_hostname, request_hostname)
or _hostnames_match(link_hostname, upstream_hostname)
):
logger.debug(
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
link["href"],
Expand All @@ -94,10 +135,14 @@ def _update_link(
return

# Replace the upstream host with the client's host
if parsed_link.netloc == upstream_url.netloc:
parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace(
scheme=request_url.scheme
)
link_matches_upstream = _hostnames_match(
parsed_link.netloc, upstream_url.netloc
)
parsed_link = parsed_link._replace(netloc=request_url.netloc)
if link_matches_upstream:
# Link hostname matches upstream: also replace scheme with request URL's scheme
parsed_link = parsed_link._replace(scheme=request_url.scheme)
# If link matches request hostname, scheme is preserved (handles https://localhost:443 -> http://localhost)

# Remove the upstream prefix from the link path
if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path):
Expand Down
177 changes: 177 additions & 0 deletions tests/test_process_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,180 @@ def test_transform_with_forwarded_headers(headers, expected_base_url):
# but not include the forwarded path in the response URLs
assert transformed["links"][0]["href"] == f"{expected_base_url}/proxy/collections"
assert transformed["links"][1]["href"] == f"{expected_base_url}/proxy"


@pytest.mark.parametrize(
"upstream_url,root_path,request_host,input_links,expected_links",
[
# Basic localhost:PORT rewriting (common port 8080)
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "data", "href": "http://localhost:8080/collections"},
],
[
"http://localhost/stac/collections",
],
),
# Standard HTTP port
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:80/collections"},
],
[
"http://localhost/stac/collections",
],
),
# HTTPS port
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "https://localhost:443/collections"},
],
[
"https://localhost/stac/collections",
],
),
# Arbitrary port
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:3000/collections"},
],
[
"http://localhost/stac/collections",
],
),
# Multiple links with different ports
(
"http://eoapi-stac:8080",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:8080/collections"},
{"rel": "root", "href": "http://localhost:80/"},
{
"rel": "items",
"href": "https://localhost:443/collections/test/items",
},
],
[
"http://localhost/stac/collections",
"http://localhost/stac/",
"https://localhost/stac/collections/test/items",
],
),
# localhost:PORT with upstream path
(
"http://eoapi-stac:8080/api",
"/stac",
"localhost",
[
{"rel": "self", "href": "http://localhost:8080/api/collections"},
],
[
"http://localhost/stac/collections",
],
),
# Request host with port should still work (port removed in rewrite)
(
"http://eoapi-stac:8080",
"/stac",
"localhost:80",
[
{"rel": "self", "href": "http://localhost:8080/collections"},
],
[
"http://localhost:80/stac/collections",
],
),
],
)
def test_transform_localhost_with_port(
upstream_url, root_path, request_host, input_links, expected_links
):
"""Test transforming links with localhost:PORT (any port number)."""
middleware = ProcessLinksMiddleware(
app=None, upstream_url=upstream_url, root_path=root_path
)
request_scope = {
"type": "http",
"path": "/test",
"headers": [
(b"host", request_host.encode()),
(b"content-type", b"application/json"),
],
}

data = {"links": input_links}
transformed = middleware.transform_json(data, Request(request_scope))

for i, expected in enumerate(expected_links):
assert transformed["links"][i]["href"] == expected


def test_localhost_with_port_preserves_other_hostnames():
"""Test that links with other hostnames are not transformed."""
middleware = ProcessLinksMiddleware(
app=None,
upstream_url="http://eoapi-stac:8080",
root_path="/stac",
)
request_scope = {
"type": "http",
"path": "/test",
"headers": [
(b"host", b"localhost"),
(b"content-type", b"application/json"),
],
}

data = {
"links": [
{"rel": "external", "href": "http://example.com:8080/collections"},
{"rel": "other", "href": "http://other-host:3000/collections"},
]
}

transformed = middleware.transform_json(data, Request(request_scope))

# External hostnames should remain unchanged
assert transformed["links"][0]["href"] == "http://example.com:8080/collections"
assert transformed["links"][1]["href"] == "http://other-host:3000/collections"


def test_localhost_with_port_upstream_service_name_still_works():
"""Test that upstream service name matching still works."""
middleware = ProcessLinksMiddleware(
app=None,
upstream_url="http://eoapi-stac:8080",
root_path="/stac",
)
request_scope = {
"type": "http",
"path": "/test",
"headers": [
(b"host", b"localhost"),
(b"content-type", b"application/json"),
],
}

data = {
"links": [
{"rel": "self", "href": "http://eoapi-stac:8080/collections"},
]
}

transformed = middleware.transform_json(data, Request(request_scope))

# Upstream service name should be rewritten to request hostname
assert transformed["links"][0]["href"] == "http://localhost/stac/collections"
Loading