Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .config import Settings
from .handlers import HealthzHandler, ReverseProxyHandler
from .lifespan import LifespanManager, ServerHealthCheck
from .middleware import (
AddProcessTimeHeaderMiddleware,
ApplyCql2FilterMiddleware,
Expand All @@ -27,12 +28,44 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
"""FastAPI Application Factory."""
settings = settings or Settings()

#
# Application
#
upstream_urls = [
settings.upstream_url,
settings.oidc_discovery_internal_url or settings.oidc_discovery_url,
]
lifespan = LifespanManager(
on_startup=(
[ServerHealthCheck(url=url) for url in upstream_urls]
if settings.wait_for_upstream
else []
)
)

app = FastAPI(
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
lifespan=lifespan,
)

app.add_middleware(AddProcessTimeHeaderMiddleware)
#
# Handlers (place catch-all proxy handler last)
#
if settings.healthz_prefix:
app.include_router(
HealthzHandler(upstream_url=str(settings.upstream_url)).router,
prefix=settings.healthz_prefix,
)

app.add_api_route(
"/{path:path}",
ReverseProxyHandler(upstream=str(settings.upstream_url)).stream,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
)

#
# Middleware (order is important, last added = first to run)
#
if settings.openapi_spec_endpoint:
app.add_middleware(
OpenApiMiddleware,
Expand All @@ -44,10 +77,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
)

if settings.items_filter:
app.add_middleware(ApplyCql2FilterMiddleware)
app.add_middleware(
ApplyCql2FilterMiddleware,
)
app.add_middleware(
BuildCql2FilterMiddleware,
# collections_filter=settings.collections_filter,
items_filter=settings.items_filter(),
)

Expand All @@ -57,18 +91,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
private_endpoints=settings.private_endpoints,
default_public=settings.default_public,
oidc_config_url=settings.oidc_discovery_url,
oidc_config_internal_url=settings.oidc_discovery_internal_url,
)

if settings.healthz_prefix:
healthz_handler = HealthzHandler(upstream_url=str(settings.upstream_url))
app.include_router(healthz_handler.router, prefix="/healthz")

# Catchall for any endpoint
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
app.add_api_route(
"/{path:path}",
proxy_handler.stream,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
app.add_middleware(
AddProcessTimeHeaderMiddleware,
)

return app
3 changes: 3 additions & 0 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Settings(BaseSettings):
# External URLs
upstream_url: HttpUrl
oidc_discovery_url: HttpUrl
oidc_discovery_internal_url: Optional[HttpUrl] = None

wait_for_upstream: bool = True

# Endpoints
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
Expand Down
37 changes: 37 additions & 0 deletions src/stac_auth_proxy/lifespan/LifespanManager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Lifespan manager for FastAPI applications."""

import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import AsyncGenerator, Awaitable, Callable, List

from fastapi import FastAPI

logger = logging.getLogger(__name__)


@dataclass
class LifespanManager:
"""Manager for FastAPI lifespan events."""

on_startup: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
on_teardown: List[Callable[[], Awaitable[None]]] = field(default_factory=list)

@asynccontextmanager
async def __call__(self, app: FastAPI) -> AsyncGenerator[None, None]:
"""FastAPI lifespan event handler."""
for i, task in enumerate(self.on_startup):
logger.debug(f"Executing startup task {i+1}/{len(self.on_startup)}")
await task()

logger.debug("All startup tasks completed successfully")

yield

# Execute teardown tasks
for i, task in enumerate(self.on_teardown):
try:
logger.debug(f"Executing teardown task {i+1}/{len(self.on_teardown)}")
await task()
except Exception as e:
logger.error(f"Teardown task failed: {e}")
57 changes: 57 additions & 0 deletions src/stac_auth_proxy/lifespan/ServerHealthCheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Health check implementations for lifespan events."""

import asyncio
import logging
from dataclasses import dataclass

import httpx
from pydantic import HttpUrl

logger = logging.getLogger(__name__)


@dataclass
class ServerHealthCheck:
"""Health check for upstream API."""

url: str | HttpUrl
max_retries: int = 5
retry_delay: float = 0.25
retry_delay_max: float = 10.0
timeout: float = 5.0

def __post_init__(self):
"""Convert url to string if it's a HttpUrl."""
if isinstance(self.url, HttpUrl):
self.url = str(self.url)

async def _check_health(self) -> bool:
"""Check if upstream API is responding."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
self.url, timeout=self.timeout, follow_redirects=True
)
response.raise_for_status()
return True
except Exception as e:
logger.warning(f"Upstream health check for {self.url!r} failed: {e}")
return False

async def __call__(self) -> None:
"""Wait for upstream API to become available."""
for attempt in range(self.max_retries):
if await self._check_health():
logger.info(f"Upstream API {self.url!r} is healthy")
return

retry_in = min(self.retry_delay * (2**attempt), self.retry_delay_max)
logger.warning(
f"Upstream API {self.url!r} not healthy, retrying in {retry_in:.1f}s "
f"(attempt {attempt + 1}/{self.max_retries})"
)
await asyncio.sleep(retry_in)

raise RuntimeError(
f"Upstream API {self.url!r} failed to respond after {self.max_retries} attempts"
)
9 changes: 9 additions & 0 deletions src/stac_auth_proxy/lifespan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Lifespan event handlers for the STAC Auth Proxy."""

from .LifespanManager import LifespanManager
from .ServerHealthCheck import ServerHealthCheck

__all__ = [
"ServerHealthCheck",
"LifespanManager",
]