-
Notifications
You must be signed in to change notification settings - Fork 203
Initial BenchKit backend implementation #1018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| .*/ | ||
| build | ||
| dist | ||
| venv*/ | ||
| .coverage | ||
| tests/ | ||
| testkit/ | ||
| testkitbackend/ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| FROM python:3.12 | ||
|
|
||
| WORKDIR /driver | ||
|
|
||
| COPY . /driver | ||
|
|
||
| # Install dependencies | ||
| RUN pip install -U pip && \ | ||
| pip install -Ur requirements-dev.txt | ||
|
|
||
| ENTRYPOINT ["python", "-m", "benchkit"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from __future__ import annotations |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from sanic import Sanic | ||
| from sanic.worker.loader import AppLoader | ||
|
|
||
| from .app import create_app | ||
| from .env import env | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| loader = AppLoader(factory=create_app) | ||
| app = loader.load() | ||
|
|
||
| # For local development: | ||
| # app.prepare(port=env.backend_port, debug=True, workers=1, dev=True) | ||
|
|
||
| # For production: | ||
| app.prepare(host="0.0.0.0", port=env.backend_port, workers=1) | ||
|
|
||
| Sanic.serve(primary=app, app_loader=loader) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from contextlib import contextmanager | ||
| from multiprocessing import Semaphore | ||
|
|
||
| import typing_extensions as te | ||
| from sanic import Sanic | ||
| from sanic.config import Config | ||
| from sanic.exceptions import ( | ||
| BadRequest, | ||
| NotFound, | ||
| ) | ||
| from sanic.request import Request | ||
| from sanic.response import ( | ||
| empty, | ||
| HTTPResponse, | ||
| text, | ||
| ) | ||
|
|
||
| from .context import BenchKitContext | ||
| from .env import env | ||
| from .workloads import Workload | ||
|
|
||
|
|
||
| T_App: te.TypeAlias = "Sanic[Config, BenchKitContext]" | ||
|
|
||
|
|
||
| def create_app() -> T_App: | ||
| app: T_App = Sanic("Python_BenchKit", ctx=BenchKitContext()) | ||
|
|
||
| @app.main_process_start | ||
| async def main_process_start(app: T_App) -> None: | ||
| app.shared_ctx.running = Semaphore(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious about the function of this semaphore.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to make sure that only one worker process is started. The backend shouldn't use multiple drivers and each process would get its own driver instance. That's basically just a sanity check that the server is started with |
||
|
|
||
| @app.before_server_start | ||
| async def before_server_start(app: T_App) -> None: | ||
| if env.driver_debug: | ||
| from neo4j.debug import watch | ||
| watch("neo4j") | ||
|
|
||
| running = app.shared_ctx.running | ||
| acquired = running.acquire(block=False) | ||
| if not acquired: | ||
| raise RuntimeError( | ||
| "The server does not support multiple worker processes" | ||
| ) | ||
|
|
||
| @app.after_server_stop | ||
| async def after_server_stop(app: T_App) -> None: | ||
| await app.ctx.shutdown() | ||
| running = app.shared_ctx.running | ||
| running.release() | ||
|
|
||
| @contextmanager | ||
| def _loading_workload(): | ||
| try: | ||
| yield | ||
| except (ValueError, TypeError) as e: | ||
| print(e) | ||
| raise BadRequest(str(e)) | ||
|
|
||
| def _get_workload(app: T_App, name: str) -> Workload: | ||
| try: | ||
| workload = app.ctx.workloads[name] | ||
| except KeyError: | ||
| raise NotFound(f"Workload {name} not found") | ||
| return workload | ||
|
|
||
| @app.get("/ready") | ||
| async def ready(_: Request) -> HTTPResponse: | ||
| await app.ctx.get_db() # check that the database is available | ||
| return empty() | ||
|
|
||
| @app.post("/workload") | ||
| async def post_workload(request: Request) -> HTTPResponse: | ||
| data = request.json | ||
| with _loading_workload(): | ||
| name = app.ctx.workloads.store_workload(data) | ||
| location = f"/workload/{name}" | ||
| return text(f"created at {location}", | ||
| status=204, | ||
| headers={"location": location}) | ||
|
|
||
| @app.put("/workload") | ||
| async def put_workload(request: Request) -> HTTPResponse: | ||
| data = request.json | ||
| with _loading_workload(): | ||
| workload = app.ctx.workloads.parse_workload(data) | ||
| driver = await app.ctx.get_db() | ||
| await workload(driver) | ||
| return empty() | ||
|
|
||
| @app.get("/workload/<name>") | ||
| async def get_workload(_: Request, name: str) -> HTTPResponse: | ||
| workload = _get_workload(app, name) | ||
| driver = await app.ctx.get_db() | ||
| await workload(driver) | ||
| return empty() | ||
|
|
||
| @app.patch("/workload/<name>") | ||
| async def patch_workload(request: Request, name: str) -> HTTPResponse: | ||
| data = request.json | ||
| workload = _get_workload(app, name) | ||
| with _loading_workload(): | ||
| workload.patch(data) | ||
| return empty() | ||
|
|
||
| @app.delete("/workload/<name>") | ||
| async def delete_workload(_: Request, name: str) -> HTTPResponse: | ||
| _get_workload(app, name) | ||
| del app.ctx.workloads[name] | ||
| return empty() | ||
|
|
||
| return app | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import typing as t | ||
|
|
||
| import neo4j | ||
| from neo4j import ( | ||
| AsyncDriver, | ||
| AsyncGraphDatabase, | ||
| ) | ||
|
|
||
| from .env import env | ||
| from .workloads import Workloads | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "BenchKitContext", | ||
| ] | ||
|
|
||
|
|
||
| class BenchKitContext: | ||
| _db: t.Optional[AsyncDriver] | ||
| workloads: Workloads | ||
|
|
||
| def __init__(self) -> None: | ||
| self._db = None | ||
| self.workloads = Workloads() | ||
|
|
||
| async def get_db(self) -> AsyncDriver: | ||
| if self._db is None: | ||
| url = f"{env.neo4j_scheme}://{env.neo4j_host}:{env.neo4j_port}" | ||
| auth = (env.neo4j_user, env.neo4j_pass) | ||
| self._db = AsyncGraphDatabase.driver(url, auth=auth) | ||
| try: | ||
| await self._db.verify_connectivity() | ||
| except Exception: | ||
| db = self._db | ||
| self._db = None | ||
| await db.close() | ||
| raise | ||
| return self._db | ||
|
|
||
| async def shutdown(self) -> None: | ||
| if self._db is not None: | ||
| await self._db.close() | ||
| self._db = None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import typing as t | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Env", | ||
| "env", | ||
| ] | ||
|
|
||
|
|
||
| class Env(t.NamedTuple): | ||
| backend_port: int | ||
| neo4j_host: str | ||
| neo4j_port: int | ||
| neo4j_scheme: str | ||
| neo4j_user: str | ||
| neo4j_pass: str | ||
| driver_debug: bool | ||
|
|
||
|
|
||
| env = Env( | ||
| backend_port=int(os.environ.get("TEST_BACKEND_PORT", "9000")), | ||
| neo4j_host=os.environ.get("TEST_NEO4J_HOST", "localhost"), | ||
| neo4j_port=int(os.environ.get("TEST_NEO4J_PORT", "7687")), | ||
| neo4j_scheme=os.environ.get("TEST_NEO4J_SCHEME", "neo4j"), | ||
| neo4j_user=os.environ.get("TEST_NEO4J_USER", "neo4j"), | ||
| neo4j_pass=os.environ.get("TEST_NEO4J_PASS", "password"), | ||
| driver_debug=os.environ.get("TEST_DRIVER_DEBUG", "").lower() in ( | ||
| "y", "yes", "true", "1", "on" | ||
| ) | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.