Skip to content

Commit

Permalink
Refactor package structure, names of things
Browse files Browse the repository at this point in the history
- Define "router" in package init, register view functions defined in views
  module
- Add test for check with custom name
- use better name for check operations
  • Loading branch information
grahamalama committed Aug 2, 2023
1 parent 7ef60d5 commit 40e67c3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
17 changes: 14 additions & 3 deletions src/dockerflow/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from .router import dockerflow_router
from .middleware import MozlogRequestSummaryLogger
from .checks import check
from fastapi import APIRouter
from fastapi.routing import APIRoute

from .checks import register_heartbeat_check # noqa
from .views import heartbeat, lbheartbeat, version

router = APIRouter(
tags=["Dockerflow"],
routes=[
APIRoute("/__lbheartbeat__", endpoint=lbheartbeat, methods=["GET", "HEAD"]),
APIRoute("/__heartbeat__", endpoint=heartbeat, methods=["GET", "HEAD"]),
APIRoute("/__version__", endpoint=version, methods=["GET"]),
],
)
9 changes: 6 additions & 3 deletions src/dockerflow/fastapi/checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
import functools
import logging
from dataclasses import dataclass
from typing import Dict, List, Tuple

from ..checks import level_to_text
Expand All @@ -18,7 +18,10 @@ class CheckDetail:
registered_checks = dict()


def check(func, name=None):
def register_heartbeat_check(func=None, *, name=None):
if func is None:
return functools.partial(register_heartbeat_check, name=name)

if name is None:
name = func.__name__

Expand All @@ -41,7 +44,7 @@ def _heartbeat_check_detail(check):
)


def run_checks():
def run_heartbeat_checks():
check_details: List[Tuple[str, CheckDetail]] = []
for name, check in registered_checks.items():
detail = _heartbeat_check_detail(check)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
import os

from fastapi import Request, Response
from fastapi.routing import APIRouter


from ..version import get_version
from .checks import run_checks
from dockerflow import checks

dockerflow_router = APIRouter(tags=["Dockerflow"])
from ..version import get_version
from .checks import run_heartbeat_checks


@dockerflow_router.get("/__lbheartbeat__")
@dockerflow_router.head("/__lbheartbeat__")
def lbheartbeat():
return {"status": "ok"}


@dockerflow_router.get("/__heartbeat__")
@dockerflow_router.head("/__heartbeat__")
def heartbeat(response: Response):
check_results = run_checks()
check_results = run_heartbeat_checks()
details = {}
statuses = {}
level = 0
Expand All @@ -43,7 +36,6 @@ def heartbeat(response: Response):
}


@dockerflow_router.get("/__version__")
def version(request: Request):
if getattr(request.app.state, "APP_DIR", None):
root = request.app.state.APP_DIR
Expand Down
19 changes: 15 additions & 4 deletions tests/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient

from dockerflow.fastapi import MozlogRequestSummaryLogger, check, dockerflow_router
from dockerflow.checks import Error
from dockerflow.fastapi import register_heartbeat_check
from dockerflow.fastapi import router as dockerflow_router
from dockerflow.fastapi.middleware import MozlogRequestSummaryLogger


def create_app():
Expand Down Expand Up @@ -85,7 +87,7 @@ def test_version_env_var(client, tmp_path, monkeypatch):

def test_version_default(client, mocker):
mock_get_version = mocker.MagicMock(return_value=VERSION_CONTENT)
mocker.patch("dockerflow.fastapi.router.get_version", mock_get_version)
mocker.patch("dockerflow.fastapi.views.get_version", mock_get_version)

response = client.get("/__version__")
assert response.status_code == 200
Expand All @@ -94,7 +96,7 @@ def test_version_default(client, mocker):


def test_heartbeat_get(client):
@check
@register_heartbeat_check
def return_error():
return [Error("BOOM", id="foo")]

Expand All @@ -114,9 +116,18 @@ def return_error():


def test_heartbeat_head(client):
@check
@register_heartbeat_check
def return_error():
return [Error("BOOM", id="foo")]

response = client.head("/__heartbeat__")
assert response.content == b""


def test_heartbeat_custom_name(client):
@register_heartbeat_check(name="my_check_name")
def return_error():
return [Error("BOOM", id="foo")]

response = client.get("/__heartbeat__")
assert response.json()["checks"]["my_check_name"]

0 comments on commit 40e67c3

Please sign in to comment.