diff --git a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py index 79b0f21f..adc9a274 100644 --- a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py +++ b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py @@ -17,6 +17,45 @@ logger = logging.getLogger(__name__) +def _extract_hostname(netloc: str) -> str: + """Extract hostname from netloc.""" + if ":" in netloc: + if netloc.startswith("["): + # IPv6 with port: [::1]:8080 + end_bracket = netloc.rfind("]") + if end_bracket != -1: + return netloc[: end_bracket + 1] + return netloc.split(":", 1)[0] + return netloc + + +def _netlocs_match(netloc1: str, scheme1: str, netloc2: str, scheme2: str) -> bool: + """ + Check if two netlocs match. Ports must match exactly, but missing ports + are assumed to be standard ports (80 for http, 443 for https). + """ + if _extract_hostname(netloc1).lower() != _extract_hostname(netloc2).lower(): + return False + + def _get_port(netloc: str, scheme: str) -> int: + if ":" in netloc: + if netloc.startswith("["): + end_bracket = netloc.rfind("]") + if end_bracket != -1 and end_bracket + 1 < len(netloc): + try: + return int(netloc[end_bracket + 2 :]) + except ValueError: + pass + else: + try: + return int(netloc.split(":", 1)[1]) + except ValueError: + pass + return 443 if scheme == "https" else 80 + + return _get_port(netloc1, scheme1) == _get_port(netloc2, scheme2) + + @dataclass class ProcessLinksMiddleware(JsonResponseMiddleware): """ @@ -70,10 +109,20 @@ def _update_link( parsed_link = urlparse(link["href"]) - if parsed_link.netloc not in [ - request_url.netloc, - upstream_url.netloc, - ]: + if not ( + _netlocs_match( + parsed_link.netloc, + parsed_link.scheme, + request_url.netloc, + request_url.scheme, + ) + or _netlocs_match( + parsed_link.netloc, + parsed_link.scheme, + upstream_url.netloc, + upstream_url.scheme, + ) + ): logger.debug( "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", link["href"], @@ -94,10 +143,17 @@ def _update_link( return # Replace the upstream host with the client's host - if parsed_link.netloc == upstream_url.netloc: - parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace( - scheme=request_url.scheme - ) + link_matches_upstream = _netlocs_match( + parsed_link.netloc, + parsed_link.scheme, + upstream_url.netloc, + upstream_url.scheme, + ) + parsed_link = parsed_link._replace(netloc=request_url.netloc) + if link_matches_upstream: + # Link hostname matches upstream: also replace scheme with request URL's scheme + parsed_link = parsed_link._replace(scheme=request_url.scheme) + # If link matches request hostname, scheme is preserved (handles https://localhost:443 -> http://localhost) # Remove the upstream prefix from the link path if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path): diff --git a/tests/test_process_links.py b/tests/test_process_links.py index 02dce87c..748a82d6 100644 --- a/tests/test_process_links.py +++ b/tests/test_process_links.py @@ -597,3 +597,180 @@ def test_transform_with_forwarded_headers(headers, expected_base_url): # but not include the forwarded path in the response URLs assert transformed["links"][0]["href"] == f"{expected_base_url}/proxy/collections" assert transformed["links"][1]["href"] == f"{expected_base_url}/proxy" + + +@pytest.mark.parametrize( + "upstream_url,root_path,request_host,input_links,expected_links", + [ + # Basic localhost:PORT rewriting (common port 8080) + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "data", "href": "http://localhost:8080/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Standard HTTP port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:80/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # HTTPS port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "https://localhost:443/collections"}, + ], + [ + "https://localhost/stac/collections", + ], + ), + # Arbitrary port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:3000/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Multiple links with different ports + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:8080/collections"}, + {"rel": "root", "href": "http://localhost:80/"}, + { + "rel": "items", + "href": "https://localhost:443/collections/test/items", + }, + ], + [ + "http://localhost/stac/collections", + "http://localhost/stac/", + "https://localhost/stac/collections/test/items", + ], + ), + # localhost:PORT with upstream path + ( + "http://eoapi-stac:8080/api", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:8080/api/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Request host with port should still work (port removed in rewrite) + ( + "http://eoapi-stac:8080", + "/stac", + "localhost:80", + [ + {"rel": "self", "href": "http://localhost:8080/collections"}, + ], + [ + "http://localhost:80/stac/collections", + ], + ), + ], +) +def test_transform_localhost_with_port( + upstream_url, root_path, request_host, input_links, expected_links +): + """Test transforming links with localhost:PORT (any port number).""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url=upstream_url, root_path=root_path + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", request_host.encode()), + (b"content-type", b"application/json"), + ], + } + + data = {"links": input_links} + transformed = middleware.transform_json(data, Request(request_scope)) + + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected + + +def test_localhost_with_port_preserves_other_hostnames(): + """Test that links with other hostnames are not transformed.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://eoapi-stac:8080", + root_path="/stac", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"localhost"), + (b"content-type", b"application/json"), + ], + } + + data = { + "links": [ + {"rel": "external", "href": "http://example.com:8080/collections"}, + {"rel": "other", "href": "http://other-host:3000/collections"}, + ] + } + + transformed = middleware.transform_json(data, Request(request_scope)) + + # External hostnames should remain unchanged + assert transformed["links"][0]["href"] == "http://example.com:8080/collections" + assert transformed["links"][1]["href"] == "http://other-host:3000/collections" + + +def test_localhost_with_port_upstream_service_name_still_works(): + """Test that upstream service name matching still works.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://eoapi-stac:8080", + root_path="/stac", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"localhost"), + (b"content-type", b"application/json"), + ], + } + + data = { + "links": [ + {"rel": "self", "href": "http://eoapi-stac:8080/collections"}, + ] + } + + transformed = middleware.transform_json(data, Request(request_scope)) + + # Upstream service name should be rewritten to request hostname + assert transformed["links"][0]["href"] == "http://localhost/stac/collections"