Skip to content

Commit

Permalink
set_handlers: changed type of models_to_fetch, removed "models_down…
Browse files Browse the repository at this point in the history
…load_params" (#184)

* set_handlers: `models_to_fetch` and `models_download_params` united in
one more flexible parameter.

Signed-off-by: Alexander Piskun <bigcat88@icloud.com>
  • Loading branch information
bigcat88 committed Dec 16, 2023
1 parent 16f44f8 commit 4d27356
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file.
### Changed

- set_handlers: `enabled_handler`, `heartbeat_handler`, `init_handler` now can be async(Coroutines). #175 #181
- set_handlers: `models_to_fetch` and `models_download_params` united in one more flexible parameter. #184
- drop Python 3.9 support. #180
- internal code refactoring and clean-up #177

Expand Down
2 changes: 1 addition & 1 deletion docs/NextcloudTalkBotTransformers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ This library also provides an additional functionality over this endpoint for ea
@asynccontextmanager
async def lifespan(_app: FastAPI):
set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME:{}})
yield
This will automatically download models specified in ``models_to_fetch`` parameter to the application persistent storage.
Expand Down
2 changes: 1 addition & 1 deletion examples/as_app/talk_bot_ai/lib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@asynccontextmanager
async def lifespan(_app: FastAPI):
set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
yield


Expand Down
18 changes: 6 additions & 12 deletions nc_py_api/ex_app/integration_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def set_handlers(
enabled_handler: typing.Callable[[bool, AsyncNextcloudApp | NextcloudApp], typing.Awaitable[str] | str],
heartbeat_handler: typing.Callable[[], typing.Awaitable[str] | str] | None = None,
init_handler: typing.Callable[[AsyncNextcloudApp | NextcloudApp], typing.Awaitable[None] | None] | None = None,
models_to_fetch: list[str] | None = None,
models_download_params: dict | None = None,
models_to_fetch: dict[str, dict] | None = None,
map_app_static: bool = True,
):
"""Defines handlers for the application.
Expand All @@ -92,7 +91,6 @@ def set_handlers(
.. note:: ```huggingface_hub`` package should be present for automatic models fetching.
:param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
:param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
.. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
Expand Down Expand Up @@ -140,8 +138,7 @@ async def init_callback(
background_tasks.add_task(
__fetch_models_task,
nc,
models_to_fetch if models_to_fetch else [],
models_download_params if models_download_params else {},
models_to_fetch if models_to_fetch else {},
)
return responses.JSONResponse(content={}, status_code=200)

Expand Down Expand Up @@ -181,8 +178,7 @@ def __map_app_static_folders(fast_api_app: FastAPI):

def __fetch_models_task(
nc: NextcloudApp,
models: list[str],
params: dict[str, typing.Any],
models: dict[str, dict],
) -> None:
if models:
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
Expand All @@ -193,10 +189,8 @@ def display(self, msg=None, pos=None):
nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100))
return super().display(msg, pos)

if "max_workers" not in params:
params["max_workers"] = 2
if "cache_dir" not in params:
params["cache_dir"] = persistent_storage()
for model in models:
snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa
workers = models[model].pop("max_workers", 2)
cache = models[model].pop("cache_dir", persistent_storage())
snapshot_download(model, tqdm_class=TqdmProgress, **models[model], max_workers=workers, cache_dir=cache)
nc.set_init_status(100)
2 changes: 1 addition & 1 deletion tests/_install_init_handler_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@asynccontextmanager
async def lifespan(_app: FastAPI):
ex_app.set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
ex_app.set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
yield


Expand Down
2 changes: 1 addition & 1 deletion tests/actual_tests/nc_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ async def test_set_user_same_value_async(anc_app):

def test_set_handlers_invalid_param(nc_any):
with pytest.raises(ValueError):
set_handlers(None, None, init_handler=set_handlers, models_to_fetch=["some"]) # noqa
set_handlers(None, None, init_handler=set_handlers, models_to_fetch={"some": {}}) # noqa

0 comments on commit 4d27356

Please sign in to comment.