From 9a900d16793092b388030dc7b247e98247821ab7 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 13 Mar 2025 22:27:03 -0700 Subject: [PATCH] feat: check upstream API health at startup --- src/stac_auth_proxy/app.py | 53 ++++++++++++----- src/stac_auth_proxy/config.py | 3 + .../lifespan/LifespanManager.py | 37 ++++++++++++ .../lifespan/ServerHealthCheck.py | 57 +++++++++++++++++++ src/stac_auth_proxy/lifespan/__init__.py | 9 +++ 5 files changed, 146 insertions(+), 13 deletions(-) create mode 100644 src/stac_auth_proxy/lifespan/LifespanManager.py create mode 100644 src/stac_auth_proxy/lifespan/ServerHealthCheck.py create mode 100644 src/stac_auth_proxy/lifespan/__init__.py diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index b9b7a793..4f2096c2 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -12,6 +12,7 @@ from .config import Settings from .handlers import HealthzHandler, ReverseProxyHandler +from .lifespan import LifespanManager, ServerHealthCheck from .middleware import ( AddProcessTimeHeaderMiddleware, ApplyCql2FilterMiddleware, @@ -27,12 +28,44 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI: """FastAPI Application Factory.""" settings = settings or Settings() + # + # Application + # + upstream_urls = [ + settings.upstream_url, + settings.oidc_discovery_internal_url or settings.oidc_discovery_url, + ] + lifespan = LifespanManager( + on_startup=( + [ServerHealthCheck(url=url) for url in upstream_urls] + if settings.wait_for_upstream + else [] + ) + ) + app = FastAPI( openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema + lifespan=lifespan, ) - app.add_middleware(AddProcessTimeHeaderMiddleware) + # + # Handlers (place catch-all proxy handler last) + # + if settings.healthz_prefix: + app.include_router( + HealthzHandler(upstream_url=str(settings.upstream_url)).router, + prefix=settings.healthz_prefix, + ) + + app.add_api_route( + "/{path:path}", + ReverseProxyHandler(upstream=str(settings.upstream_url)).stream, + methods=["GET", "POST", "PUT", "PATCH", "DELETE"], + ) + # + # Middleware (order is important, last added = first to run) + # if settings.openapi_spec_endpoint: app.add_middleware( OpenApiMiddleware, @@ -44,10 +77,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI: ) if settings.items_filter: - app.add_middleware(ApplyCql2FilterMiddleware) + app.add_middleware( + ApplyCql2FilterMiddleware, + ) app.add_middleware( BuildCql2FilterMiddleware, - # collections_filter=settings.collections_filter, items_filter=settings.items_filter(), ) @@ -57,18 +91,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI: private_endpoints=settings.private_endpoints, default_public=settings.default_public, oidc_config_url=settings.oidc_discovery_url, + oidc_config_internal_url=settings.oidc_discovery_internal_url, ) - if settings.healthz_prefix: - healthz_handler = HealthzHandler(upstream_url=str(settings.upstream_url)) - app.include_router(healthz_handler.router, prefix="/healthz") - - # Catchall for any endpoint - proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url)) - app.add_api_route( - "/{path:path}", - proxy_handler.stream, - methods=["GET", "POST", "PUT", "PATCH", "DELETE"], + app.add_middleware( + AddProcessTimeHeaderMiddleware, ) return app diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index 3738b037..d37ec5b7 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -34,6 +34,9 @@ class Settings(BaseSettings): # External URLs upstream_url: HttpUrl oidc_discovery_url: HttpUrl + oidc_discovery_internal_url: Optional[HttpUrl] = None + + wait_for_upstream: bool = True # Endpoints healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") diff --git a/src/stac_auth_proxy/lifespan/LifespanManager.py b/src/stac_auth_proxy/lifespan/LifespanManager.py new file mode 100644 index 00000000..725af986 --- /dev/null +++ b/src/stac_auth_proxy/lifespan/LifespanManager.py @@ -0,0 +1,37 @@ +"""Lifespan manager for FastAPI applications.""" + +import logging +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import AsyncGenerator, Awaitable, Callable, List + +from fastapi import FastAPI + +logger = logging.getLogger(__name__) + + +@dataclass +class LifespanManager: + """Manager for FastAPI lifespan events.""" + + on_startup: List[Callable[[], Awaitable[None]]] = field(default_factory=list) + on_teardown: List[Callable[[], Awaitable[None]]] = field(default_factory=list) + + @asynccontextmanager + async def __call__(self, app: FastAPI) -> AsyncGenerator[None, None]: + """FastAPI lifespan event handler.""" + for i, task in enumerate(self.on_startup): + logger.debug(f"Executing startup task {i+1}/{len(self.on_startup)}") + await task() + + logger.debug("All startup tasks completed successfully") + + yield + + # Execute teardown tasks + for i, task in enumerate(self.on_teardown): + try: + logger.debug(f"Executing teardown task {i+1}/{len(self.on_teardown)}") + await task() + except Exception as e: + logger.error(f"Teardown task failed: {e}") diff --git a/src/stac_auth_proxy/lifespan/ServerHealthCheck.py b/src/stac_auth_proxy/lifespan/ServerHealthCheck.py new file mode 100644 index 00000000..196f1eeb --- /dev/null +++ b/src/stac_auth_proxy/lifespan/ServerHealthCheck.py @@ -0,0 +1,57 @@ +"""Health check implementations for lifespan events.""" + +import asyncio +import logging +from dataclasses import dataclass + +import httpx +from pydantic import HttpUrl + +logger = logging.getLogger(__name__) + + +@dataclass +class ServerHealthCheck: + """Health check for upstream API.""" + + url: str | HttpUrl + max_retries: int = 5 + retry_delay: float = 0.25 + retry_delay_max: float = 10.0 + timeout: float = 5.0 + + def __post_init__(self): + """Convert url to string if it's a HttpUrl.""" + if isinstance(self.url, HttpUrl): + self.url = str(self.url) + + async def _check_health(self) -> bool: + """Check if upstream API is responding.""" + try: + async with httpx.AsyncClient() as client: + response = await client.get( + self.url, timeout=self.timeout, follow_redirects=True + ) + response.raise_for_status() + return True + except Exception as e: + logger.warning(f"Upstream health check for {self.url!r} failed: {e}") + return False + + async def __call__(self) -> None: + """Wait for upstream API to become available.""" + for attempt in range(self.max_retries): + if await self._check_health(): + logger.info(f"Upstream API {self.url!r} is healthy") + return + + retry_in = min(self.retry_delay * (2**attempt), self.retry_delay_max) + logger.warning( + f"Upstream API {self.url!r} not healthy, retrying in {retry_in:.1f}s " + f"(attempt {attempt + 1}/{self.max_retries})" + ) + await asyncio.sleep(retry_in) + + raise RuntimeError( + f"Upstream API {self.url!r} failed to respond after {self.max_retries} attempts" + ) diff --git a/src/stac_auth_proxy/lifespan/__init__.py b/src/stac_auth_proxy/lifespan/__init__.py new file mode 100644 index 00000000..40b1c8a3 --- /dev/null +++ b/src/stac_auth_proxy/lifespan/__init__.py @@ -0,0 +1,9 @@ +"""Lifespan event handlers for the STAC Auth Proxy.""" + +from .LifespanManager import LifespanManager +from .ServerHealthCheck import ServerHealthCheck + +__all__ = [ + "ServerHealthCheck", + "LifespanManager", +]