diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 1a932694..7fa5ca51 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -945,6 +945,40 @@ def as_table( ) return output_table + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + pipeline_config: PipelineConfig | None = None, + ) -> None: + """Streaming async execution for FunctionNode.""" + try: + pipeline_config = pipeline_config or PipelineConfig() + node_config = ( + self._function_pod.node_config + if hasattr(self._function_pod, "node_config") + else NodeConfig() + ) + max_concurrency = resolve_concurrency(node_config, pipeline_config) + sem = asyncio.Semaphore(max_concurrency) if max_concurrency is not None else None + + async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: + try: + result_packet = self._packet_function.call(packet) + if result_packet is not None: + await output.send((tag, result_packet)) + finally: + if sem is not None: + sem.release() + + async with asyncio.TaskGroup() as tg: + async for tag, packet in inputs[0]: + if sem is not None: + await sem.acquire() + tg.create_task(process_one(tag, packet)) + finally: + await output.close() + def __repr__(self) -> str: return ( f"{type(self).__name__}(packet_function={self._packet_function!r}, " diff --git a/src/orcapod/core/operator_node.py b/src/orcapod/core/operator_node.py index 5aaab461..3bf87485 100644 --- a/src/orcapod/core/operator_node.py +++ b/src/orcapod/core/operator_node.py @@ -1,9 +1,11 @@ from __future__ import annotations import logging -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel + from orcapod import contexts from orcapod.config import Config from orcapod.core.streams.base import StreamBase @@ -156,6 +158,14 @@ def as_table( assert self._cached_output_stream is not None return self._cached_output_stream.as_table(columns=columns, all_info=all_info) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + ) -> None: + """Delegate to the wrapped operator's async_execute.""" + await self._operator.async_execute(inputs, output) + def __repr__(self) -> str: return ( f"{type(self).__name__}(operator={self._operator!r}, " diff --git a/src/orcapod/core/tracker.py b/src/orcapod/core/tracker.py index 7ed70a7f..d7b221b1 100644 --- a/src/orcapod/core/tracker.py +++ b/src/orcapod/core/tracker.py @@ -1,10 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Sequence from contextlib import contextmanager from typing import TYPE_CHECKING, Any, TypeAlias +from orcapod.channels import ReadableChannel, WritableChannel + from orcapod import contexts from orcapod.config import Config from orcapod.core.streams import StreamBase @@ -209,6 +211,21 @@ def as_table( def iter_packets(self) -> Iterator[tuple[cp.TagProtocol, cp.PacketProtocol]]: return self.stream.iter_packets() + def run(self) -> None: + """No-op for source nodes — data is already available.""" + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]]], + output: WritableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]], + ) -> None: + """Push all (tag, packet) pairs from the wrapped stream to the output channel.""" + try: + for tag, packet in self.stream.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() + GraphNode: TypeAlias = "SourceNode | FunctionNode | OperatorNode" # Full type once FunctionNode/OperatorNode are imported: diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 472ee287..7495be39 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,7 +1,9 @@ from .graph import Pipeline from .nodes import PersistentSourceNode +from .orchestrator import AsyncPipelineOrchestrator __all__ = [ + "AsyncPipelineOrchestrator", "Pipeline", "PersistentSourceNode", ] diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index 71207a7d..1764e297 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -267,15 +267,38 @@ def compile(self) -> None: # Execution # ------------------------------------------------------------------ - def run(self) -> None: - """Execute all compiled nodes in topological order.""" + def run(self, config: "PipelineConfig | None" = None) -> None: + """Execute all compiled nodes. + + Args: + config: Pipeline configuration. When ``config.executor`` is + ``ExecutorType.ASYNC_CHANNELS``, the pipeline runs + asynchronously via the orchestrator. Otherwise nodes are + executed synchronously in topological order. + """ + from orcapod.types import ExecutorType, PipelineConfig + + config = config or PipelineConfig() + if not self._compiled: self.compile() - assert self._node_graph is not None - for node in nx.topological_sort(self._node_graph): - node.run() + + if config.executor == ExecutorType.ASYNC_CHANNELS: + self._run_async(config) + else: + assert self._node_graph is not None + for node in nx.topological_sort(self._node_graph): + node.run() + self.flush() + def _run_async(self, config: "PipelineConfig") -> None: + """Run the pipeline asynchronously using the orchestrator.""" + from orcapod.pipeline.orchestrator import AsyncPipelineOrchestrator + + orchestrator = AsyncPipelineOrchestrator() + orchestrator.run(self, config) + def flush(self) -> None: """Flush all databases.""" self._pipeline_database.flush() diff --git a/src/orcapod/pipeline/orchestrator.py b/src/orcapod/pipeline/orchestrator.py index e69de29b..17de6743 100644 --- a/src/orcapod/pipeline/orchestrator.py +++ b/src/orcapod/pipeline/orchestrator.py @@ -0,0 +1,174 @@ +"""Async pipeline orchestrator for push-based channel execution. + +Compiles a ``GraphTracker``'s DAG into channels and launches all nodes +concurrently via ``asyncio.TaskGroup``. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +from orcapod.channels import BroadcastChannel, Channel +from orcapod.core.static_output_pod import StaticOutputPod +from orcapod.core.tracker import GraphTracker, SourceNode +from orcapod.types import PipelineConfig + +if TYPE_CHECKING: + import networkx as nx + + from orcapod.core.streams.arrow_table_stream import ArrowTableStream + from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol + +logger = logging.getLogger(__name__) + + +class AsyncPipelineOrchestrator: + """Executes a compiled DAG asynchronously using channels and TaskGroup. + + After ``GraphTracker.compile()``, the orchestrator: + + 1. Identifies source, intermediate, and terminal nodes. + 2. Creates bounded channels (or broadcast channels for fan-out) between + connected nodes. + 3. Launches every node's ``async_execute`` concurrently. + 4. Collects the terminal node's output and materializes it as a stream. + """ + + def run( + self, + tracker: GraphTracker, + config: PipelineConfig | None = None, + ) -> StreamProtocol: + """Synchronous entry point — runs the async pipeline and returns the result. + + Args: + tracker: A compiled ``GraphTracker`` whose ``_node_lut`` and + ``_graph_edges`` describe the DAG. + config: Pipeline configuration (buffer sizes, concurrency). + + Returns: + An ``ArrowTableStream`` containing all (tag, packet) pairs + produced by the terminal node. + """ + config = config or PipelineConfig() + return asyncio.run(self._run_async(tracker, config)) + + async def run_async( + self, + tracker: GraphTracker, + config: PipelineConfig | None = None, + ) -> StreamProtocol: + """Async entry point for callers already inside an event loop. + + Args: + tracker: A compiled ``GraphTracker``. + config: Pipeline configuration. + + Returns: + An ``ArrowTableStream`` of the terminal node's output. + """ + config = config or PipelineConfig() + return await self._run_async(tracker, config) + + async def _run_async( + self, + tracker: GraphTracker, + config: PipelineConfig, + ) -> StreamProtocol: + """Core async logic: wire channels, launch tasks, collect results.""" + import networkx as nx + + # Build directed graph from edges + G = nx.DiGraph() + for upstream_hash, downstream_hash in tracker._graph_edges: + G.add_edge(upstream_hash, downstream_hash) + + # Add isolated nodes (sources with no downstream edges) + for node_hash in tracker._node_lut: + if node_hash not in G: + G.add_node(node_hash) + + topo_order = list(nx.topological_sort(G)) + + # Identify terminal nodes (no outgoing edges) + terminal_hashes = [h for h in topo_order if G.out_degree(h) == 0] + if not terminal_hashes: + raise ValueError("DAG has no terminal nodes") + + # For multiple terminals, we use the last one in topological order + # (the one furthest downstream) + terminal_hash = terminal_hashes[-1] + + buf = config.channel_buffer_size + + # Build channel mapping: + # For each edge (upstream_hash → downstream_hash), create a channel. + # If an upstream feeds multiple downstreams (fan-out), use BroadcastChannel. + + # Count outgoing edges per node + out_edges: dict[str, list[str]] = defaultdict(list) + for upstream_hash, downstream_hash in tracker._graph_edges: + out_edges[upstream_hash].append(downstream_hash) + + # Count incoming edges per node (to know how many input channels) + in_edges: dict[str, list[str]] = defaultdict(list) + for upstream_hash, downstream_hash in tracker._graph_edges: + in_edges[downstream_hash].append(upstream_hash) + + # For each upstream node, create either a Channel or BroadcastChannel + # upstream_hash → Channel or BroadcastChannel + node_output_channels: dict[str, Channel | BroadcastChannel] = {} + + # edge (upstream, downstream) → reader + edge_readers: dict[tuple[str, str], Any] = {} + + for upstream_hash, downstreams in out_edges.items(): + if len(downstreams) == 1: + # Simple channel + ch = Channel(buffer_size=buf) + node_output_channels[upstream_hash] = ch + edge_readers[(upstream_hash, downstreams[0])] = ch.reader + else: + # Fan-out: use BroadcastChannel + bch = BroadcastChannel(buffer_size=buf) + node_output_channels[upstream_hash] = bch + for ds_hash in downstreams: + edge_readers[(upstream_hash, ds_hash)] = bch.add_reader() + + # Terminal node output channel + terminal_ch = Channel(buffer_size=buf) + node_output_channels[terminal_hash] = terminal_ch + + # Now launch all nodes + async with asyncio.TaskGroup() as tg: + for node_hash in topo_order: + node = tracker._node_lut[node_hash] + + # Gather input readers for this node (from its upstream edges) + input_readers = [] + for upstream_hash in in_edges.get(node_hash, []): + reader = edge_readers[(upstream_hash, node_hash)] + input_readers.append(reader) + + # Get the output writer + output_channel = node_output_channels.get(node_hash) + if output_channel is None: + # Node with no downstream and not the terminal — still needs + # an output channel (it will just be discarded) + output_channel = Channel(buffer_size=buf) + node_output_channels[node_hash] = output_channel + + writer = output_channel.writer + + tg.create_task( + node.async_execute(input_readers, writer) + ) + + # Collect terminal output + terminal_rows = await terminal_ch.reader.collect() + + # Materialize into a stream + return StaticOutputPod._materialize_to_stream(terminal_rows) diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py new file mode 100644 index 00000000..f6b5fb35 --- /dev/null +++ b/tests/test_pipeline/test_orchestrator.py @@ -0,0 +1,396 @@ +""" +Tests for the async pipeline orchestrator. + +Covers: +- Linear pipeline: Source → Operator → FunctionPod +- Diamond DAG: Source → [Op1, Op2] → Join +- Fan-out: one source feeds multiple downstream nodes +- Results match synchronous execution +- SourceNode.async_execute pushes all rows +- OperatorNode.async_execute delegates correctly +- FunctionNode.async_execute works in streaming mode +- Error propagation cancels other tasks +""" + +from __future__ import annotations + +import asyncio + +import pyarrow as pa +import pytest + +from orcapod.channels import Channel +from orcapod.core.function_pod import FunctionNode, FunctionPod +from orcapod.core.operator_node import OperatorNode +from orcapod.core.operators import SelectPacketColumns +from orcapod.core.operators.filters import PolarsFilter +from orcapod.core.operators.join import Join +from orcapod.core.operators.mappers import MapPackets +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.core.tracker import GraphTracker, SourceNode +from orcapod.pipeline.orchestrator import AsyncPipelineOrchestrator +from orcapod.types import ExecutorType, PipelineConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source( + tag_col: str, + packet_col: str, + data: dict, +) -> ArrowTableSource: + table = pa.table( + { + tag_col: pa.array(data[tag_col], type=pa.large_string()), + packet_col: pa.array(data[packet_col], type=pa.int64()), + } + ) + return ArrowTableSource(table, tag_columns=[tag_col]) + + +def _make_two_sources(): + src_a = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + src_b = _make_source("key", "score", {"key": ["a", "b"], "score": [100, 200]}) + return src_a, src_b + + +def double_value(value: int) -> int: + return value * 2 + + +def add_values(value: int, score: int) -> int: + return value + score + + +# =========================================================================== +# 1. SourceNode.async_execute +# =========================================================================== + + +class TestSourceNodeAsyncExecute: + @pytest.mark.asyncio + async def test_pushes_all_rows_to_output(self): + src = _make_source("key", "value", {"key": ["a", "b", "c"], "value": [1, 2, 3]}) + node = SourceNode(src) + + output_ch = Channel(buffer_size=16) + await node.async_execute([], output_ch.writer) + + rows = await output_ch.reader.collect() + assert len(rows) == 3 + + @pytest.mark.asyncio + async def test_closes_channel_on_completion(self): + src = _make_source("key", "value", {"key": ["a"], "value": [1]}) + node = SourceNode(src) + + output_ch = Channel(buffer_size=4) + await node.async_execute([], output_ch.writer) + + rows = await output_ch.reader.collect() + assert len(rows) == 1 + + +# =========================================================================== +# 2. OperatorNode.async_execute +# =========================================================================== + + +class TestOperatorNodeAsyncExecute: + @pytest.mark.asyncio + async def test_delegates_to_operator(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + op = SelectPacketColumns(columns=["value"]) + op_node = OperatorNode(op, input_streams=[src]) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + # Feed source rows into input channel + for tag, packet in src.iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + await op_node.async_execute([input_ch.reader], output_ch.writer) + + rows = await output_ch.reader.collect() + assert len(rows) == 2 + + +# =========================================================================== +# 3. FunctionNode.async_execute +# =========================================================================== + + +class TestFunctionNodeAsyncExecute: + @pytest.mark.asyncio + async def test_processes_packets(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + node = FunctionNode(pod, src) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + for tag, packet in src.iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + await node.async_execute([input_ch.reader], output_ch.writer) + + rows = await output_ch.reader.collect() + assert len(rows) == 2 + + values = sorted([pkt.as_dict()["result"] for _, pkt in rows]) + assert values == [20, 40] + + +# =========================================================================== +# 4. Orchestrator: linear pipeline +# =========================================================================== + + +class TestOrchestratorLinearPipeline: + """Source → FunctionPod (linear pipeline).""" + + def test_linear_source_to_function_pod(self): + src = _make_source("key", "value", {"key": ["a", "b", "c"], "value": [1, 2, 3]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + tracker = GraphTracker() + with tracker: + result_stream = pod(src) + + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + result = orchestrator.run(tracker) + + rows = list(result.iter_packets()) + assert len(rows) == 3 + + values = sorted([pkt.as_dict()["result"] for _, pkt in rows]) + assert values == [2, 4, 6] + + def test_matches_sync_execution(self): + """Async results should match synchronous execution.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + # Sync execution + sync_result = pod.process(src) + sync_rows = list(sync_result.iter_packets()) + sync_values = sorted([pkt.as_dict()["result"] for _, pkt in sync_rows]) + + # Async execution + tracker = GraphTracker() + with tracker: + _ = pod(src) + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + async_result = orchestrator.run(tracker) + async_rows = list(async_result.iter_packets()) + async_values = sorted([pkt.as_dict()["result"] for _, pkt in async_rows]) + + assert sync_values == async_values + + +# =========================================================================== +# 5. Orchestrator: operator pipeline +# =========================================================================== + + +class TestOrchestratorOperatorPipeline: + """Source → Operator → FunctionPod.""" + + def test_source_to_operator_to_function_pod(self): + src = _make_source("key", "value", {"key": ["a", "b", "c"], "value": [1, 2, 3]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + op = MapPackets(name_map={"value": "val"}) + + # Create a function that takes 'val' instead of 'value' + def double_val(val: int) -> int: + return val * 2 + + pf2 = PythonPacketFunction(double_val, output_keys="result") + pod2 = FunctionPod(pf2) + + tracker = GraphTracker() + with tracker: + mapped = op(src) + result_stream = pod2(mapped) + + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + result = orchestrator.run(tracker) + + rows = list(result.iter_packets()) + assert len(rows) == 3 + values = sorted([pkt.as_dict()["result"] for _, pkt in rows]) + assert values == [2, 4, 6] + + +# =========================================================================== +# 6. Orchestrator: diamond DAG (fan-out + join) +# =========================================================================== + + +class TestOrchestratorDiamondDag: + """Two sources → Join → FunctionPod.""" + + def test_two_sources_join_function_pod(self): + src_a, src_b = _make_two_sources() + + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(pf) + + tracker = GraphTracker() + with tracker: + joined = Join()(src_a, src_b) + result_stream = pod(joined) + + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + result = orchestrator.run(tracker) + + rows = list(result.iter_packets()) + assert len(rows) == 2 + + values = sorted([pkt.as_dict()["total"] for _, pkt in rows]) + assert values == [110, 220] + + def test_diamond_matches_sync(self): + """Diamond DAG async results should match sync execution.""" + src_a, src_b = _make_two_sources() + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(pf) + + # Sync + sync_joined = Join()(src_a, src_b) + sync_result = pod.process(sync_joined) + sync_values = sorted([pkt.as_dict()["total"] for _, pkt in sync_result.iter_packets()]) + + # Async + tracker = GraphTracker() + with tracker: + joined = Join()(src_a, src_b) + _ = pod(joined) + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + async_result = orchestrator.run(tracker) + async_values = sorted( + [pkt.as_dict()["total"] for _, pkt in async_result.iter_packets()] + ) + + assert sync_values == async_values + + +# =========================================================================== +# 7. Orchestrator: fan-out (one source feeds multiple nodes) +# =========================================================================== + + +class TestOrchestratorFanOut: + """One source feeds two different function pods via fan-out.""" + + def test_fan_out_source_feeds_two_branches(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + + # Two function pods: one doubles, one triples + def double(value: int) -> int: + return value * 2 + + def triple(value: int) -> int: + return value * 3 + + pf_double = PythonPacketFunction(double, output_keys="doubled") + pf_triple = PythonPacketFunction(triple, output_keys="tripled") + pod_double = FunctionPod(pf_double) + pod_triple = FunctionPod(pf_triple) + + tracker = GraphTracker() + with tracker: + doubled = pod_double(src) + tripled = pod_triple(src) + result = Join()(doubled, tripled) + + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + result_stream = orchestrator.run(tracker) + + rows = list(result_stream.iter_packets()) + assert len(rows) == 2 + + for _, pkt in rows: + d = pkt.as_dict() + assert "doubled" in d + assert "tripled" in d + + +# =========================================================================== +# 8. run_async entry point (for callers inside event loop) +# =========================================================================== + + +class TestOrchestratorRunAsync: + @pytest.mark.asyncio + async def test_run_async_from_event_loop(self): + """run_async should work when called from inside an event loop.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + tracker = GraphTracker() + with tracker: + _ = pod(src) + tracker.compile() + + orchestrator = AsyncPipelineOrchestrator() + result = await orchestrator.run_async(tracker) + + rows = list(result.iter_packets()) + assert len(rows) == 2 + values = sorted([pkt.as_dict()["result"] for _, pkt in rows]) + assert values == [2, 4] + + +# =========================================================================== +# 9. PipelineConfig integration +# =========================================================================== + + +class TestPipelineConfigIntegration: + def test_custom_buffer_size(self): + """Pipeline should work with custom buffer sizes.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + tracker = GraphTracker() + with tracker: + _ = pod(src) + tracker.compile() + + config = PipelineConfig( + executor=ExecutorType.ASYNC_CHANNELS, + channel_buffer_size=4, + ) + + orchestrator = AsyncPipelineOrchestrator() + result = orchestrator.run(tracker, config=config) + + rows = list(result.iter_packets()) + assert len(rows) == 2