Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ v4.31.0 (2024-06-??)

Added
^^^^^
- Support async functions and methods in ``CLI`` (`#531
<https://github.com/omni-us/jsonargparse/pull/531>`__).
- Support for ``Protocol`` types only accepting exact matching signature of
public methods (`#526
<https://github.com/omni-us/jsonargparse/pull/526>`__).
Expand Down
21 changes: 11 additions & 10 deletions jsonargparse/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def CLI(
deprecation_warning_cli_return_parser()
return parser
cfg = parser.parse_args(args)
cfg_init = parser.instantiate_classes(cfg)
return _run_component(components, cfg_init)
init = parser.instantiate_classes(cfg)
return _run_component(components, init)

elif isinstance(components, list):
components = {c.__name__: c for c in components}
Expand Down Expand Up @@ -192,12 +192,13 @@ def _add_component_to_parser(

def _run_component(component, cfg):
cfg.pop("config", None)
if not inspect.isclass(component):
return component(**cfg)
subcommand = cfg.pop("subcommand")
if not subcommand:
return component(**cfg)
subcommand_cfg = cfg.pop(subcommand, {})
subcommand_cfg.pop("config", None)
component_obj = component(**cfg)
return getattr(component_obj, subcommand)(**subcommand_cfg)
if inspect.isclass(component) and subcommand:
subcommand_cfg = cfg.pop(subcommand, {})
subcommand_cfg.pop("config", None)
component_obj = component(**cfg)
component = getattr(component_obj, subcommand)
cfg = subcommand_cfg
if inspect.iscoroutinefunction(component):
return __import__("asyncio").run(component(**cfg))
return component(**cfg)
44 changes: 43 additions & 1 deletion jsonargparse_tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import asyncio
import os
import sys
from contextlib import redirect_stderr, redirect_stdout, suppress
from dataclasses import asdict, dataclass
from io import StringIO
from pathlib import Path
from typing import Optional
from typing import Callable, Optional
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -535,3 +536,44 @@ def test_final_and_subclass_type_config_file(tmp_cwd):

out = CLI(run_bf, args=["--config=config.yaml"])
assert "a yaml" == out


# async tests


async def run_async(time: float = 0.1):
await asyncio.sleep(time)
return "done"


def test_async_function():
assert "done" == CLI(run_async, args=["--time=0.0"])


class AsyncMethod:
def __init__(self, time: float = 0.1, require_async: bool = False):
self.time = time
if require_async:
self.loop = asyncio.get_event_loop()

async def run(self):
await asyncio.sleep(self.time)
return "done"


def test_async_method():
assert "done" == CLI(AsyncMethod, args=["--time=0.0", "run"])


async def run_async_instance(cls: Callable[[], AsyncMethod]):
return await cls().run()


def test_async_instance():
config = {
"cls": {
"class_path": f"{__name__}.AsyncMethod",
"init_args": {"time": 0.0, "require_async": True},
}
}
assert "done" == CLI(run_async_instance, args=[f"--config={config}"])