From c447ac662db6a73d7532de10b7cb9bc4f4d119bf Mon Sep 17 00:00:00 2001 From: Felix Delattre Date: Thu, 11 Dec 2025 19:59:57 +0100 Subject: [PATCH] Fixed link rewriting for localhost:PORT. --- .../middleware/ProcessLinksMiddleware.py | 61 +++++- tests/test_process_links.py | 177 ++++++++++++++++++ 2 files changed, 230 insertions(+), 8 deletions(-) diff --git a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py index 79b0f21f..2faf8710 100644 --- a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py +++ b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py @@ -17,6 +17,43 @@ logger = logging.getLogger(__name__) +def _extract_hostname(netloc: str) -> str: + """ + Extract hostname from netloc, ignoring port number. + + Args: + netloc: Network location string (e.g., "localhost:8080" or "example.com") + + Returns: + Hostname without port (e.g., "localhost" or "example.com") + + """ + if ":" in netloc: + if netloc.startswith("["): + # IPv6 with port: [::1]:8080 + end_bracket = netloc.rfind("]") + if end_bracket != -1: + return netloc[: end_bracket + 1] + # Regular hostname with port: localhost:8080 + return netloc.split(":", 1)[0] + return netloc + + +def _hostnames_match(hostname1: str, hostname2: str) -> bool: + """ + Check if two hostnames match, ignoring case and port. + + Args: + hostname1: First hostname (may include port) + hostname2: Second hostname (may include port) + + Returns: + True if hostnames match (case-insensitive, ignoring port) + + """ + return _extract_hostname(hostname1).lower() == _extract_hostname(hostname2).lower() + + @dataclass class ProcessLinksMiddleware(JsonResponseMiddleware): """ @@ -70,10 +107,14 @@ def _update_link( parsed_link = urlparse(link["href"]) - if parsed_link.netloc not in [ - request_url.netloc, - upstream_url.netloc, - ]: + link_hostname = _extract_hostname(parsed_link.netloc) + request_hostname = _extract_hostname(request_url.netloc) + upstream_hostname = _extract_hostname(upstream_url.netloc) + + if not ( + _hostnames_match(link_hostname, request_hostname) + or _hostnames_match(link_hostname, upstream_hostname) + ): logger.debug( "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", link["href"], @@ -94,10 +135,14 @@ def _update_link( return # Replace the upstream host with the client's host - if parsed_link.netloc == upstream_url.netloc: - parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace( - scheme=request_url.scheme - ) + link_matches_upstream = _hostnames_match( + parsed_link.netloc, upstream_url.netloc + ) + parsed_link = parsed_link._replace(netloc=request_url.netloc) + if link_matches_upstream: + # Link hostname matches upstream: also replace scheme with request URL's scheme + parsed_link = parsed_link._replace(scheme=request_url.scheme) + # If link matches request hostname, scheme is preserved (handles https://localhost:443 -> http://localhost) # Remove the upstream prefix from the link path if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path): diff --git a/tests/test_process_links.py b/tests/test_process_links.py index 02dce87c..748a82d6 100644 --- a/tests/test_process_links.py +++ b/tests/test_process_links.py @@ -597,3 +597,180 @@ def test_transform_with_forwarded_headers(headers, expected_base_url): # but not include the forwarded path in the response URLs assert transformed["links"][0]["href"] == f"{expected_base_url}/proxy/collections" assert transformed["links"][1]["href"] == f"{expected_base_url}/proxy" + + +@pytest.mark.parametrize( + "upstream_url,root_path,request_host,input_links,expected_links", + [ + # Basic localhost:PORT rewriting (common port 8080) + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "data", "href": "http://localhost:8080/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Standard HTTP port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:80/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # HTTPS port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "https://localhost:443/collections"}, + ], + [ + "https://localhost/stac/collections", + ], + ), + # Arbitrary port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:3000/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Multiple links with different ports + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:8080/collections"}, + {"rel": "root", "href": "http://localhost:80/"}, + { + "rel": "items", + "href": "https://localhost:443/collections/test/items", + }, + ], + [ + "http://localhost/stac/collections", + "http://localhost/stac/", + "https://localhost/stac/collections/test/items", + ], + ), + # localhost:PORT with upstream path + ( + "http://eoapi-stac:8080/api", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:8080/api/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Request host with port should still work (port removed in rewrite) + ( + "http://eoapi-stac:8080", + "/stac", + "localhost:80", + [ + {"rel": "self", "href": "http://localhost:8080/collections"}, + ], + [ + "http://localhost:80/stac/collections", + ], + ), + ], +) +def test_transform_localhost_with_port( + upstream_url, root_path, request_host, input_links, expected_links +): + """Test transforming links with localhost:PORT (any port number).""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url=upstream_url, root_path=root_path + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", request_host.encode()), + (b"content-type", b"application/json"), + ], + } + + data = {"links": input_links} + transformed = middleware.transform_json(data, Request(request_scope)) + + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected + + +def test_localhost_with_port_preserves_other_hostnames(): + """Test that links with other hostnames are not transformed.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://eoapi-stac:8080", + root_path="/stac", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"localhost"), + (b"content-type", b"application/json"), + ], + } + + data = { + "links": [ + {"rel": "external", "href": "http://example.com:8080/collections"}, + {"rel": "other", "href": "http://other-host:3000/collections"}, + ] + } + + transformed = middleware.transform_json(data, Request(request_scope)) + + # External hostnames should remain unchanged + assert transformed["links"][0]["href"] == "http://example.com:8080/collections" + assert transformed["links"][1]["href"] == "http://other-host:3000/collections" + + +def test_localhost_with_port_upstream_service_name_still_works(): + """Test that upstream service name matching still works.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://eoapi-stac:8080", + root_path="/stac", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"localhost"), + (b"content-type", b"application/json"), + ], + } + + data = { + "links": [ + {"rel": "self", "href": "http://eoapi-stac:8080/collections"}, + ] + } + + transformed = middleware.transform_json(data, Request(request_scope)) + + # Upstream service name should be rewritten to request hostname + assert transformed["links"][0]["href"] == "http://localhost/stac/collections"