Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Executor] Convert sync generator to async generator #3144

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/promptflow-core/promptflow/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import signal
import threading
from typing import Iterator

from promptflow.tracing import ThreadPoolExecutorWithContext

Expand Down Expand Up @@ -109,3 +110,15 @@ async def wrapper(*args, **kwargs):
return await asyncio.get_event_loop().run_in_executor(executor, partial_func)

return wrapper


async def sync_iterator_to_async(g: Iterator):
with ThreadPoolExecutorWithContext(max_workers=1) as pool:
loop = asyncio.get_running_loop()
# Use object() as a default value to distinguish from None
default_value = object()
while True:
resp = await loop.run_in_executor(pool, next, g, default_value)
if resp is default_value:
return
yield resp
39 changes: 25 additions & 14 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from contextlib import contextmanager
from pathlib import Path
from threading import current_thread
from types import AsyncGeneratorType, GeneratorType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union

import opentelemetry.trace as otel_trace
from opentelemetry.trace.span import Span, format_trace_id
Expand All @@ -29,7 +28,7 @@
from promptflow._core.run_tracker import RunTracker
from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR
from promptflow._core.tools_manager import ToolsManager
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.async_utils import async_run_allowing_running_loop, sync_iterator_to_async
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.execution_utils import (
apply_default_value_for_input,
Expand Down Expand Up @@ -724,6 +723,7 @@ def exec_line(
node_concurrency,
allow_generator_output,
line_timeout_sec,
sync_iterator_to_async=False,
)
# TODO: Call exec_line_async in exec_line when async is mature.
self._node_concurrency = node_concurrency
Expand Down Expand Up @@ -754,6 +754,7 @@ async def exec_line_async(
node_concurrency=DEFAULT_CONCURRENCY_FLOW,
allow_generator_output: bool = False,
line_timeout_sec: Optional[int] = None,
sync_iterator_to_async: bool = True,
) -> LineResult:
"""Execute a single line of the flow.

Expand All @@ -769,6 +770,8 @@ async def exec_line_async(
:type node_concurrency: int
:param allow_generator_output: Whether to allow generator output.
:type allow_generator_output: bool
:param sync_iterator_to_async: Whether to convert sync iterator output to async iterator.
:type sync_iterator_to_async: bool
:return: The result of executing the line.
:rtype: ~promptflow.executor._result.LineResult
"""
Expand All @@ -786,6 +789,8 @@ async def exec_line_async(
validate_inputs=validate_inputs,
allow_generator_output=allow_generator_output,
)
if sync_iterator_to_async:
line_result.output = self._convert_iterators_to_async(line_result.output)
# Return line result with index
if index is not None and isinstance(line_result.output, dict):
line_result.output[LINE_NUMBER_KEY] = index
Expand Down Expand Up @@ -890,6 +895,12 @@ def _start_flow_span(self, inputs: Mapping[str, Any]):
enrich_span_with_input(span, inputs)
yield span

def _convert_iterators_to_async(self, output: dict):
for k, v in output.items():
if isinstance(v, Iterator):
output[k] = sync_iterator_to_async(v)
return output

async def _exec_inner_with_trace_async(
self,
inputs: Mapping[str, Any],
Expand Down Expand Up @@ -942,7 +953,7 @@ def _exec_post_process(
generator_output_nodes = [
nodename
for nodename, output in nodes_outputs.items()
if isinstance(output, GeneratorType) or isinstance(output, AsyncGeneratorType)
if isinstance(output, Iterator) or isinstance(output, AsyncIterator)
]
# When stream is True, we allow generator output in the flow output
run_tracker.allow_generator_types = stream
Expand Down Expand Up @@ -1202,34 +1213,34 @@ async def _traverse_nodes_async(self, inputs, context: FlowExecutionContext) ->
return outputs, nodes_outputs

@staticmethod
async def _merge_async_generator(async_gen: AsyncGeneratorType, outputs: dict, key: str):
async def _merge_async_iterator(async_it: AsyncIterator, outputs: dict, key: str):
items = []
async for item in async_gen:
async for item in async_it:
items.append(item)
outputs[key] = "".join(str(item) for item in items)

async def _stringify_generator_output_async(self, outputs: dict):
pool = ThreadPoolExecutorWithContext()
tasks = []
for k, v in outputs.items():
if isinstance(v, AsyncGeneratorType):
tasks.append(asyncio.create_task(self._merge_async_generator(v, outputs, k)))
elif isinstance(v, GeneratorType):
if isinstance(v, AsyncIterator):
tasks.append(asyncio.create_task(self._merge_async_iterator(v, outputs, k)))
elif isinstance(v, Iterator):
loop = asyncio.get_event_loop()
task = loop.run_in_executor(pool, self._merge_generator, v, outputs, k)
task = loop.run_in_executor(pool, self._merge_iterator, v, outputs, k)
tasks.append(task)
if tasks:
await asyncio.wait(tasks)
return outputs

@staticmethod
def _merge_generator(gen: GeneratorType, outputs: dict, key: str):
def _merge_iterator(gen: Iterator, outputs: dict, key: str):
outputs[key] = "".join(str(item) for item in gen)

def _stringify_generator_output(self, outputs: dict):
for k, v in outputs.items():
if isinstance(v, GeneratorType):
self._merge_generator(v, outputs, k)
if isinstance(v, Iterator):
self._merge_iterator(v, outputs, k)

return outputs

Expand Down Expand Up @@ -1366,7 +1377,7 @@ def _ensure_node_result_is_serializable(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
result = f(*args, **kwargs)
if isinstance(result, GeneratorType):
if isinstance(result, Iterator):
result = "".join(str(trunk) for trunk in result)
return result

Expand Down