diff --git a/.github/workflows/analysis-coverage.yml b/.github/workflows/analysis-coverage.yml index 7cd080d7..4346eaae 100644 --- a/.github/workflows/analysis-coverage.yml +++ b/.github/workflows/analysis-coverage.yml @@ -725,7 +725,7 @@ jobs: run: | php occ app:enable app_api cd nc_py_api - coverage run --data-file=.coverage.ci_install tests/_install.py & + coverage run --data-file=.coverage.ci_install tests/_install_async.py & echo $! > /tmp/_install.pid python3 tests/_install_wait.py http://127.0.0.1:$APP_PORT/heartbeat "\"status\":\"ok\"" 15 0.5 python3 tests/_app_security_checks.py http://127.0.0.1:$APP_PORT diff --git a/CHANGELOG.md b/CHANGELOG.md index 52a35425..02092849 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to this project will be documented in this file. +## [0.6.1 - 202x-xx-xx] + +### Added + +- set_handlers: `enabled_handler`, `heartbeat_handler` now can be async(Coroutines). #175 + ## [0.6.0 - 2023-12-06] ### Added diff --git a/nc_py_api/ex_app/integration_fastapi.py b/nc_py_api/ex_app/integration_fastapi.py index c64498a2..0a60199b 100644 --- a/nc_py_api/ex_app/integration_fastapi.py +++ b/nc_py_api/ex_app/integration_fastapi.py @@ -54,8 +54,8 @@ def talk_bot_app(request: Request) -> TalkBotMessage: def set_handlers( fast_api_app: FastAPI, - enabled_handler: typing.Callable[[bool, NextcloudApp], str], - heartbeat_handler: typing.Optional[typing.Callable[[], str]] = None, + enabled_handler: typing.Callable[[bool, NextcloudApp], typing.Union[str, typing.Awaitable[str]]], + heartbeat_handler: typing.Optional[typing.Callable[[], typing.Union[str, typing.Awaitable[str]]]] = None, init_handler: typing.Optional[typing.Callable[[NextcloudApp], None]] = None, models_to_fetch: typing.Optional[list[str]] = None, models_download_params: typing.Optional[dict] = None, @@ -81,50 +81,40 @@ def set_handlers( .. note:: First, presence of these directories in the current working dir is checked, then one directory higher. """ - def fetch_models_task(nc: NextcloudApp, models: list[str]) -> None: - if models: - from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401 - from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401 - - class TqdmProgress(tqdm): - def display(self, msg=None, pos=None): - if init_handler is None: - nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100)) - return super().display(msg, pos) - - params = models_download_params if models_download_params else {} - if "max_workers" not in params: - params["max_workers"] = 2 - if "cache_dir" not in params: - params["cache_dir"] = persistent_storage() - for model in models: - snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa - if init_handler is None: - nc.set_init_status(100) - else: - init_handler(nc) - @fast_api_app.put("/enabled") - def enabled_callback( + async def enabled_callback( enabled: bool, nc: typing.Annotated[NextcloudApp, Depends(nc_app)], ): - r = enabled_handler(enabled, nc) + if asyncio.iscoroutinefunction(heartbeat_handler): + r = await enabled_handler(enabled, nc) # type: ignore + else: + r = enabled_handler(enabled, nc) return responses.JSONResponse(content={"error": r}, status_code=200) @fast_api_app.get("/heartbeat") - def heartbeat_callback(): - return_status = "ok" + async def heartbeat_callback(): if heartbeat_handler is not None: - return_status = heartbeat_handler() + if asyncio.iscoroutinefunction(heartbeat_handler): + return_status = await heartbeat_handler() + else: + return_status = heartbeat_handler() + else: + return_status = "ok" return responses.JSONResponse(content={"status": return_status}, status_code=200) @fast_api_app.post("/init") - def init_callback( + async def init_callback( background_tasks: BackgroundTasks, nc: typing.Annotated[NextcloudApp, Depends(nc_app)], ): - background_tasks.add_task(fetch_models_task, nc, models_to_fetch if models_to_fetch else []) + background_tasks.add_task( + __fetch_models_task, + nc, + init_handler, + models_to_fetch if models_to_fetch else [], + models_download_params if models_download_params else {}, + ) return responses.JSONResponse(content={}, status_code=200) if map_app_static: @@ -139,3 +129,31 @@ def __map_app_static_folders(fast_api_app: FastAPI): mnt_dir_path = os.path.join(os.path.dirname(os.getcwd()), mnt_dir) if os.path.exists(mnt_dir_path): fast_api_app.mount(f"/{mnt_dir}", staticfiles.StaticFiles(directory=mnt_dir_path), name=mnt_dir) + + +def __fetch_models_task( + nc: NextcloudApp, + init_handler: typing.Optional[typing.Callable[[NextcloudApp], None]], + models: list[str], + params: dict[str, typing.Any], +) -> None: + if models: + from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401 + from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401 + + class TqdmProgress(tqdm): + def display(self, msg=None, pos=None): + if init_handler is None: + nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100)) + return super().display(msg, pos) + + if "max_workers" not in params: + params["max_workers"] = 2 + if "cache_dir" not in params: + params["cache_dir"] = persistent_storage() + for model in models: + snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa + if init_handler is None: + nc.set_init_status(100) + else: + init_handler(nc) diff --git a/tests/_install.py b/tests/_install.py index fd60cc5a..cb069961 100644 --- a/tests/_install.py +++ b/tests/_install.py @@ -19,10 +19,9 @@ async def lifespan(_app: FastAPI): @APP.put("/sec_check") def sec_check( value: int, - nc: Annotated[NextcloudApp, Depends(ex_app.nc_app)], + _nc: Annotated[NextcloudApp, Depends(ex_app.nc_app)], ): - print(value) - _ = nc + print(value, flush=True) return JSONResponse(content={"error": ""}, status_code=200) diff --git a/tests/_install_async.py b/tests/_install_async.py new file mode 100644 index 00000000..5bd603d6 --- /dev/null +++ b/tests/_install_async.py @@ -0,0 +1,46 @@ +from contextlib import asynccontextmanager +from typing import Annotated + +from fastapi import Depends, FastAPI +from fastapi.responses import JSONResponse + +from nc_py_api import NextcloudApp, ex_app + + +@asynccontextmanager +async def lifespan(_app: FastAPI): + ex_app.set_handlers(APP, enabled_handler, heartbeat_callback, init_handler=init_handler) + yield + + +APP = FastAPI(lifespan=lifespan) + + +@APP.put("/sec_check") +async def sec_check( + value: int, + _nc: Annotated[NextcloudApp, Depends(ex_app.nc_app)], +): + print(value, flush=True) + return JSONResponse(content={"error": ""}, status_code=200) + + +async def enabled_handler(enabled: bool, nc: NextcloudApp) -> str: + print(f"enabled_handler: enabled={enabled}", flush=True) + if enabled: + nc.log(ex_app.LogLvl.WARNING, f"Hello from {nc.app_cfg.app_name} :)") + else: + nc.log(ex_app.LogLvl.WARNING, f"Bye bye from {nc.app_cfg.app_name} :(") + return "" + + +def init_handler(nc: NextcloudApp): + nc.set_init_status(100) + + +async def heartbeat_callback(): + return "ok" + + +if __name__ == "__main__": + ex_app.run_app("_install_async:APP", log_level="trace")