Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/stac_auth_proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
import uvicorn
from uvicorn.config import LOGGING_CONFIG

LOGGING_CONFIG["loggers"][__package__] = {
"level": "DEBUG",
"handlers": ["default"],
}

uvicorn.run(
f"{__package__}.app:create_app",
host="0.0.0.0",
port=8000,
log_config=LOGGING_CONFIG,
log_config={
**LOGGING_CONFIG,
"loggers": {
**LOGGING_CONFIG["loggers"],
__package__: {
"level": "DEBUG",
"handlers": ["default"],
},
},
},
reload=True,
factory=True,
)
103 changes: 74 additions & 29 deletions src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import re
from dataclasses import dataclass
from typing import Any, Optional
from urllib.parse import urlparse, urlunparse
from urllib.parse import ParseResult, urlparse, urlunparse

from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import ASGIApp, Scope

from ..utils.middleware import JsonResponseMiddleware
from ..utils.requests import get_base_url
from ..utils.stac import get_links

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,37 +41,81 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:

def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Update links in the response to include root_path."""
for link in get_links(data):
href = link.get("href")
if not href:
continue
# Get the client's actual base URL (accounting for load balancers/proxies)
req_base_url = get_base_url(request)
parsed_req_url = urlparse(req_base_url)
parsed_upstream_url = urlparse(self.upstream_url)

for link in get_links(data):
try:
parsed_link = urlparse(href)

# Ignore links that are not for this proxy
if parsed_link.netloc != request.headers.get("host"):
continue

# Remove the upstream_url path from the link if it exists
parsed_upstream_url = urlparse(self.upstream_url)
if parsed_upstream_url.path != "/" and parsed_link.path.startswith(
parsed_upstream_url.path
):
parsed_link = parsed_link._replace(
path=parsed_link.path[len(parsed_upstream_url.path) :]
)

# Add the root_path to the link if it exists
if self.root_path:
parsed_link = parsed_link._replace(
path=f"{self.root_path}{parsed_link.path}"
)

link["href"] = urlunparse(parsed_link)
self._update_link(link, parsed_req_url, parsed_upstream_url)
except Exception as e:
logger.error(
"Failed to parse link href %r, (ignoring): %s", href, str(e)
"Failed to parse link href %r, (ignoring): %s",
link.get("href"),
str(e),
)

return data

def _update_link(
self, link: dict[str, Any], request_url: ParseResult, upstream_url: ParseResult
) -> None:
"""
Ensure that link hrefs that are local to upstream url are rewritten as local to
the proxy.
"""
if "href" not in link:
logger.warning("Link %r has no href", link)
return

parsed_link = urlparse(link["href"])

if parsed_link.netloc not in [
request_url.netloc,
upstream_url.netloc,
]:
logger.debug(
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
link["href"],
request_url.netloc,
upstream_url.netloc,
)
return

# If the link path is not a descendant of the upstream path, don't transform it
if upstream_url.path != "/" and not parsed_link.path.startswith(
upstream_url.path
):
logger.debug(
"Ignoring link %s because it is not descendant of upstream path (%s)",
link["href"],
upstream_url.path,
)
return

# Replace the upstream host with the client's host
if parsed_link.netloc == upstream_url.netloc:
parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace(
scheme=request_url.scheme
)

# Rewrite the link path
if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path):
parsed_link = parsed_link._replace(
path=parsed_link.path[len(upstream_url.path) :]
)

# Add the root_path to the link if it exists
if self.root_path:
parsed_link = parsed_link._replace(
path=f"{self.root_path}{parsed_link.path}"
)

logger.debug(
"Rewriting %r link %r to %r",
link.get("rel"),
link["href"],
urlunparse(parsed_link),
)

link["href"] = urlunparse(parsed_link)
114 changes: 113 additions & 1 deletion src/stac_auth_proxy/utils/requests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Utility functions for working with HTTP requests."""

import json
import logging
import re
from dataclasses import dataclass, field
from typing import Optional, Sequence
from typing import Dict, Optional, Sequence
from urllib.parse import urlparse

from starlette.requests import Request

from ..config import EndpointMethods

logger = logging.getLogger(__name__)


def extract_variables(url: str) -> dict:
"""
Expand Down Expand Up @@ -90,3 +95,110 @@ def build_server_timing_header(
if current_value:
return f"{current_value}, {metric}"
return metric


def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]:
"""
Parse the Forwarded header according to RFC 7239.

Args:
forwarded_header: The Forwarded header value

Returns:
Dictionary containing parsed forwarded information (proto, host, for, by, etc.)

Example:
>>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com")
{'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'}

"""
# Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com"
# The format is: for=value1, for=value2; by=value; proto=value; host=value
# We need to parse all the key=value pairs, taking the first 'for' value
forwarded_info = {}

try:
# Parse all key=value pairs separated by semicolons
for pair in forwarded_header.split(";"):
pair = pair.strip()
if "=" in pair:
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip().strip('"')

# For 'for' field, only take the first value if there are multiple
if key == "for" and key not in forwarded_info:
# Extract the first for value (before comma if present)
first_for_value = value.split(",")[0].strip()
forwarded_info[key] = first_for_value
elif key != "for":
# For other fields, just use the value as-is
forwarded_info[key] = value
except Exception as e:
logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}")
return {}

return forwarded_info


def get_base_url(request: Request) -> str:
"""
Get the request's base URL, accounting for forwarded headers from load balancers/proxies.

This function handles both the standard Forwarded header (RFC 7239) and legacy
X-Forwarded-* headers to reconstruct the original client URL when the service
is deployed behind load balancers or reverse proxies.

Args:
request: The Starlette request object

Returns:
The reconstructed client base URL

Example:
>>> # With Forwarded header
>>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"}
>>> get_base_url(request)
"https://api.example.com/"

>>> # With X-Forwarded-* headers
>>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"}
>>> get_base_url(request)
"https://api.example.com/"

"""
# Check for standard Forwarded header first (RFC 7239)
forwarded_header = request.headers.get("Forwarded")
if forwarded_header:
try:
forwarded_info = parse_forwarded_header(forwarded_header)
# Only use Forwarded header if we successfully parsed it and got useful info
if forwarded_info and (
"proto" in forwarded_info or "host" in forwarded_info
):
scheme = forwarded_info.get("proto", request.url.scheme)
host = forwarded_info.get("host", request.url.netloc)
# Note: Forwarded header doesn't include path, so we use request.base_url.path
path = request.base_url.path
return f"{scheme}://{host}{path}"
except Exception as e:
logger.warning(f"Failed to parse Forwarded header: {e}")

# Fall back to legacy X-Forwarded-* headers
forwarded_host = request.headers.get("X-Forwarded-Host")
forwarded_proto = request.headers.get("X-Forwarded-Proto")
forwarded_path = request.headers.get("X-Forwarded-Path")

if forwarded_host:
# Use forwarded headers to reconstruct the original client URL
scheme = forwarded_proto or request.url.scheme
netloc = forwarded_host
# Use forwarded path if available, otherwise use request base URL path
path = forwarded_path or request.base_url.path
else:
# Fall back to the request's base URL if no forwarded headers
scheme = request.url.scheme
netloc = request.url.netloc
path = request.base_url.path

return f"{scheme}://{netloc}{path}"
Loading
Loading