Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/orcapod/core/function_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,40 @@ def as_table(
)
return output_table

async def async_execute(
self,
inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]],
output: WritableChannel[tuple[TagProtocol, PacketProtocol]],
pipeline_config: PipelineConfig | None = None,
) -> None:
"""Streaming async execution for FunctionNode."""
try:
pipeline_config = pipeline_config or PipelineConfig()
node_config = (
self._function_pod.node_config
if hasattr(self._function_pod, "node_config")
else NodeConfig()
)
max_concurrency = resolve_concurrency(node_config, pipeline_config)
sem = asyncio.Semaphore(max_concurrency) if max_concurrency is not None else None

async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None:
try:
result_packet = self._packet_function.call(packet)
if result_packet is not None:
await output.send((tag, result_packet))
finally:
if sem is not None:
sem.release()

async with asyncio.TaskGroup() as tg:
async for tag, packet in inputs[0]:
if sem is not None:
await sem.acquire()
tg.create_task(process_one(tag, packet))
finally:
await output.close()

def __repr__(self) -> str:
return (
f"{type(self).__name__}(packet_function={self._packet_function!r}, "
Expand Down
12 changes: 11 additions & 1 deletion src/orcapod/core/operator_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import logging
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING, Any

from orcapod.channels import ReadableChannel, WritableChannel

from orcapod import contexts
from orcapod.config import Config
from orcapod.core.streams.base import StreamBase
Expand Down Expand Up @@ -156,6 +158,14 @@ def as_table(
assert self._cached_output_stream is not None
return self._cached_output_stream.as_table(columns=columns, all_info=all_info)

async def async_execute(
self,
inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]],
output: WritableChannel[tuple[TagProtocol, PacketProtocol]],
) -> None:
"""Delegate to the wrapped operator's async_execute."""
await self._operator.async_execute(inputs, output)

def __repr__(self) -> str:
return (
f"{type(self).__name__}(operator={self._operator!r}, "
Expand Down
19 changes: 18 additions & 1 deletion src/orcapod/core/tracker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Generator, Iterator
from collections.abc import Generator, Iterator, Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, TypeAlias

from orcapod.channels import ReadableChannel, WritableChannel

from orcapod import contexts
from orcapod.config import Config
from orcapod.core.streams import StreamBase
Expand Down Expand Up @@ -209,6 +211,21 @@ def as_table(
def iter_packets(self) -> Iterator[tuple[cp.TagProtocol, cp.PacketProtocol]]:
return self.stream.iter_packets()

def run(self) -> None:
"""No-op for source nodes — data is already available."""

async def async_execute(
self,
inputs: Sequence[ReadableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]]],
output: WritableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]],
) -> None:
"""Push all (tag, packet) pairs from the wrapped stream to the output channel."""
try:
for tag, packet in self.stream.iter_packets():
await output.send((tag, packet))
finally:
await output.close()


GraphNode: TypeAlias = "SourceNode | FunctionNode | OperatorNode"
# Full type once FunctionNode/OperatorNode are imported:
Expand Down
2 changes: 2 additions & 0 deletions src/orcapod/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .graph import Pipeline
from .nodes import PersistentSourceNode
from .orchestrator import AsyncPipelineOrchestrator

__all__ = [
"AsyncPipelineOrchestrator",
"Pipeline",
"PersistentSourceNode",
]
33 changes: 28 additions & 5 deletions src/orcapod/pipeline/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,38 @@ def compile(self) -> None:
# Execution
# ------------------------------------------------------------------

def run(self) -> None:
"""Execute all compiled nodes in topological order."""
def run(self, config: "PipelineConfig | None" = None) -> None:
"""Execute all compiled nodes.

Args:
config: Pipeline configuration. When ``config.executor`` is
``ExecutorType.ASYNC_CHANNELS``, the pipeline runs
asynchronously via the orchestrator. Otherwise nodes are
executed synchronously in topological order.
"""
from orcapod.types import ExecutorType, PipelineConfig

config = config or PipelineConfig()

if not self._compiled:
self.compile()
assert self._node_graph is not None
for node in nx.topological_sort(self._node_graph):
node.run()

if config.executor == ExecutorType.ASYNC_CHANNELS:
self._run_async(config)
else:
assert self._node_graph is not None
for node in nx.topological_sort(self._node_graph):
node.run()

self.flush()

def _run_async(self, config: "PipelineConfig") -> None:
"""Run the pipeline asynchronously using the orchestrator."""
from orcapod.pipeline.orchestrator import AsyncPipelineOrchestrator

orchestrator = AsyncPipelineOrchestrator()
orchestrator.run(self, config)

def flush(self) -> None:
"""Flush all databases."""
self._pipeline_database.flush()
Expand Down
174 changes: 174 additions & 0 deletions src/orcapod/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Async pipeline orchestrator for push-based channel execution.

Compiles a ``GraphTracker``'s DAG into channels and launches all nodes
concurrently via ``asyncio.TaskGroup``.
"""

from __future__ import annotations

import asyncio
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any

from orcapod.channels import BroadcastChannel, Channel
from orcapod.core.static_output_pod import StaticOutputPod
from orcapod.core.tracker import GraphTracker, SourceNode
from orcapod.types import PipelineConfig

if TYPE_CHECKING:
import networkx as nx

from orcapod.core.streams.arrow_table_stream import ArrowTableStream
from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol

logger = logging.getLogger(__name__)


class AsyncPipelineOrchestrator:
"""Executes a compiled DAG asynchronously using channels and TaskGroup.

After ``GraphTracker.compile()``, the orchestrator:

1. Identifies source, intermediate, and terminal nodes.
2. Creates bounded channels (or broadcast channels for fan-out) between
connected nodes.
3. Launches every node's ``async_execute`` concurrently.
4. Collects the terminal node's output and materializes it as a stream.
"""

def run(
self,
tracker: GraphTracker,
config: PipelineConfig | None = None,
) -> StreamProtocol:
"""Synchronous entry point — runs the async pipeline and returns the result.

Args:
tracker: A compiled ``GraphTracker`` whose ``_node_lut`` and
``_graph_edges`` describe the DAG.
config: Pipeline configuration (buffer sizes, concurrency).

Returns:
An ``ArrowTableStream`` containing all (tag, packet) pairs
produced by the terminal node.
"""
config = config or PipelineConfig()
return asyncio.run(self._run_async(tracker, config))

async def run_async(
self,
tracker: GraphTracker,
config: PipelineConfig | None = None,
) -> StreamProtocol:
"""Async entry point for callers already inside an event loop.

Args:
tracker: A compiled ``GraphTracker``.
config: Pipeline configuration.

Returns:
An ``ArrowTableStream`` of the terminal node's output.
"""
config = config or PipelineConfig()
return await self._run_async(tracker, config)

async def _run_async(
self,
tracker: GraphTracker,
config: PipelineConfig,
) -> StreamProtocol:
"""Core async logic: wire channels, launch tasks, collect results."""
import networkx as nx

# Build directed graph from edges
G = nx.DiGraph()
for upstream_hash, downstream_hash in tracker._graph_edges:
G.add_edge(upstream_hash, downstream_hash)

# Add isolated nodes (sources with no downstream edges)
for node_hash in tracker._node_lut:
if node_hash not in G:
G.add_node(node_hash)

topo_order = list(nx.topological_sort(G))

# Identify terminal nodes (no outgoing edges)
terminal_hashes = [h for h in topo_order if G.out_degree(h) == 0]
if not terminal_hashes:
raise ValueError("DAG has no terminal nodes")

# For multiple terminals, we use the last one in topological order
# (the one furthest downstream)
terminal_hash = terminal_hashes[-1]

buf = config.channel_buffer_size

# Build channel mapping:
# For each edge (upstream_hash → downstream_hash), create a channel.
# If an upstream feeds multiple downstreams (fan-out), use BroadcastChannel.

# Count outgoing edges per node
out_edges: dict[str, list[str]] = defaultdict(list)
for upstream_hash, downstream_hash in tracker._graph_edges:
out_edges[upstream_hash].append(downstream_hash)

# Count incoming edges per node (to know how many input channels)
in_edges: dict[str, list[str]] = defaultdict(list)
for upstream_hash, downstream_hash in tracker._graph_edges:
in_edges[downstream_hash].append(upstream_hash)

# For each upstream node, create either a Channel or BroadcastChannel
# upstream_hash → Channel or BroadcastChannel
node_output_channels: dict[str, Channel | BroadcastChannel] = {}

# edge (upstream, downstream) → reader
edge_readers: dict[tuple[str, str], Any] = {}

for upstream_hash, downstreams in out_edges.items():
if len(downstreams) == 1:
# Simple channel
ch = Channel(buffer_size=buf)
node_output_channels[upstream_hash] = ch
edge_readers[(upstream_hash, downstreams[0])] = ch.reader
else:
# Fan-out: use BroadcastChannel
bch = BroadcastChannel(buffer_size=buf)
node_output_channels[upstream_hash] = bch
for ds_hash in downstreams:
edge_readers[(upstream_hash, ds_hash)] = bch.add_reader()

# Terminal node output channel
terminal_ch = Channel(buffer_size=buf)
node_output_channels[terminal_hash] = terminal_ch

# Now launch all nodes
async with asyncio.TaskGroup() as tg:
for node_hash in topo_order:
node = tracker._node_lut[node_hash]

# Gather input readers for this node (from its upstream edges)
input_readers = []
for upstream_hash in in_edges.get(node_hash, []):
reader = edge_readers[(upstream_hash, node_hash)]
input_readers.append(reader)

# Get the output writer
output_channel = node_output_channels.get(node_hash)
if output_channel is None:
# Node with no downstream and not the terminal — still needs
# an output channel (it will just be discarded)
output_channel = Channel(buffer_size=buf)
node_output_channels[node_hash] = output_channel

writer = output_channel.writer

tg.create_task(
node.async_execute(input_readers, writer)
)

# Collect terminal output
terminal_rows = await terminal_ch.reader.collect()

# Materialize into a stream
return StaticOutputPod._materialize_to_stream(terminal_rows)
Loading