diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f23c37ae..c77c5100 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,8 @@ v4.31.0 (2024-06-??) Added ^^^^^ +- Support async functions and methods in ``CLI`` (`#531 + `__). - Support for ``Protocol`` types only accepting exact matching signature of public methods (`#526 `__). diff --git a/jsonargparse/_cli.py b/jsonargparse/_cli.py index 43a08bab..18053476 100644 --- a/jsonargparse/_cli.py +++ b/jsonargparse/_cli.py @@ -92,8 +92,8 @@ def CLI( deprecation_warning_cli_return_parser() return parser cfg = parser.parse_args(args) - cfg_init = parser.instantiate_classes(cfg) - return _run_component(components, cfg_init) + init = parser.instantiate_classes(cfg) + return _run_component(components, init) elif isinstance(components, list): components = {c.__name__: c for c in components} @@ -192,12 +192,13 @@ def _add_component_to_parser( def _run_component(component, cfg): cfg.pop("config", None) - if not inspect.isclass(component): - return component(**cfg) subcommand = cfg.pop("subcommand") - if not subcommand: - return component(**cfg) - subcommand_cfg = cfg.pop(subcommand, {}) - subcommand_cfg.pop("config", None) - component_obj = component(**cfg) - return getattr(component_obj, subcommand)(**subcommand_cfg) + if inspect.isclass(component) and subcommand: + subcommand_cfg = cfg.pop(subcommand, {}) + subcommand_cfg.pop("config", None) + component_obj = component(**cfg) + component = getattr(component_obj, subcommand) + cfg = subcommand_cfg + if inspect.iscoroutinefunction(component): + return __import__("asyncio").run(component(**cfg)) + return component(**cfg) diff --git a/jsonargparse_tests/test_cli.py b/jsonargparse_tests/test_cli.py index 06d48791..ae518307 100644 --- a/jsonargparse_tests/test_cli.py +++ b/jsonargparse_tests/test_cli.py @@ -1,12 +1,13 @@ from __future__ import annotations +import asyncio import os import sys from contextlib import redirect_stderr, redirect_stdout, suppress from dataclasses import asdict, dataclass from io import StringIO from pathlib import Path -from typing import Optional +from typing import Callable, Optional from unittest.mock import patch import pytest @@ -535,3 +536,44 @@ def test_final_and_subclass_type_config_file(tmp_cwd): out = CLI(run_bf, args=["--config=config.yaml"]) assert "a yaml" == out + + +# async tests + + +async def run_async(time: float = 0.1): + await asyncio.sleep(time) + return "done" + + +def test_async_function(): + assert "done" == CLI(run_async, args=["--time=0.0"]) + + +class AsyncMethod: + def __init__(self, time: float = 0.1, require_async: bool = False): + self.time = time + if require_async: + self.loop = asyncio.get_event_loop() + + async def run(self): + await asyncio.sleep(self.time) + return "done" + + +def test_async_method(): + assert "done" == CLI(AsyncMethod, args=["--time=0.0", "run"]) + + +async def run_async_instance(cls: Callable[[], AsyncMethod]): + return await cls().run() + + +def test_async_instance(): + config = { + "cls": { + "class_path": f"{__name__}.AsyncMethod", + "init_args": {"time": 0.0, "require_async": True}, + } + } + assert "done" == CLI(run_async_instance, args=[f"--config={config}"])