From b59a04839eacea9632ce8d50e3b247d4f59c0794 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 2 Apr 2025 22:51:29 -0700 Subject: [PATCH] Make `@main_fn` decorator support async functions. --- examples/gdrive_text_embedding/main.py | 7 ++-- python/cocoindex/lib.py | 52 ++++++++++++++++++-------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/examples/gdrive_text_embedding/main.py b/examples/gdrive_text_embedding/main.py index 9c78e0f8..ce0353be 100644 --- a/examples/gdrive_text_embedding/main.py +++ b/examples/gdrive_text_embedding/main.py @@ -49,9 +49,9 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY) @cocoindex.main_fn() -def _run(): +async def _run(): # Use a `FlowLiveUpdater` to keep the flow data updated. - with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow): + async with cocoindex.FlowLiveUpdater(gdrive_text_embedding_flow): # Run queries in a loop to demonstrate the query capabilities. while True: try: @@ -70,4 +70,5 @@ def _run(): if __name__ == "__main__": load_dotenv(override=True) - _run() + import asyncio + asyncio.run(_run()) diff --git a/python/cocoindex/lib.py b/python/cocoindex/lib.py index d385e677..44525f23 100644 --- a/python/cocoindex/lib.py +++ b/python/cocoindex/lib.py @@ -1,10 +1,12 @@ """ Library level functions and states. """ -import json import os import sys -from typing import Callable, Self +import functools +import inspect +import asyncio +from typing import Callable, Self, Any from dataclasses import dataclass from . import _engine @@ -78,20 +80,40 @@ def main_fn( If the settings are not provided, they are loaded from the environment variables. """ - def _main_wrapper(fn: Callable) -> Callable: - def _inner(*args, **kwargs): - effective_settings = settings or Settings.from_env() - init(effective_settings) - try: - if len(sys.argv) > 1 and sys.argv[1] == cocoindex_cmd: - return cli.cli.main(sys.argv[2:], prog_name=f"{sys.argv[0]} {sys.argv[1]}") - else: - return fn(*args, **kwargs) - finally: - stop() + def _pre_init() -> None: + effective_settings = settings or Settings.from_env() + init(effective_settings) + + def _should_run_cli() -> bool: + return len(sys.argv) > 1 and sys.argv[1] == cocoindex_cmd - _inner.__name__ = fn.__name__ - return _inner + def _run_cli(): + return cli.cli.main(sys.argv[2:], prog_name=f"{sys.argv[0]} {sys.argv[1]}") + + def _main_wrapper(fn: Callable) -> Callable: + if inspect.iscoroutinefunction(fn): + @functools.wraps(fn) + async def _inner(*args, **kwargs): + _pre_init() + try: + if _should_run_cli(): + # Schedule to a separate thread as it invokes nested event loop. + return await asyncio.to_thread(_run_cli) + return await fn(*args, **kwargs) + finally: + stop() + return _inner + else: + @functools.wraps(fn) + def _inner(*args, **kwargs): + _pre_init() + try: + if _should_run_cli(): + return _run_cli() + return fn(*args, **kwargs) + finally: + stop() + return _inner return _main_wrapper