Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions fireflyframework_agentic/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ class Send:
payload: dict[str, Any]


def _resolve_node_id(ref: Any) -> str:
"""Turn either a string node id or a function reference into a node id.

Function references use ``fn.__name__``. Anything else raises
:class:`PipelineError`.
"""
if isinstance(ref, str):
return ref
name = getattr(ref, "__name__", None)
if not name:
raise PipelineError(f"Cannot derive node id from {ref!r}")
return name


def _is_send_payload(value: Any) -> bool:
"""True when a node's return value is a single :class:`Send` or a
non-empty ``list[Send]``. Drives the runtime fan-out branch in
Expand Down Expand Up @@ -300,6 +314,7 @@ async def run(
state: BaseModel | None = None,
run_id: str | None = None,
approve_pause: bool = False,
start_at: str | Any = None,
) -> PipelineResult:
"""Execute the pipeline.

Expand Down Expand Up @@ -345,6 +360,15 @@ async def run(
pre_completed = set()
sequence_start = 0
all_results = {}
# Mid-pipeline start: pretend everything not reachable from
# `start_at` already ran. The scheduler then starts at start_at
# because its upstream nodes appear "completed".
if start_at is not None:
start_id = _resolve_node_id(start_at)
if start_id not in self._dag.nodes:
raise PipelineError(f"start_at='{start_id}' not in DAG")
forward = {start_id} | self._dag.transitive_successors(start_id)
pre_completed = {nid for nid in self._dag.nodes if nid not in forward}

if run_id is None:
run_id = uuid.uuid4().hex[:12]
Expand Down
123 changes: 123 additions & 0 deletions tests/unit/pipeline/test_pipeline_engine_start_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Layer 6 of the unification (#245): start_at kwarg for mid-pipeline entry.

PipelineEngine.run() accepts ``start_at=`` (string node id or callable
reference). Execution begins at the named node; everything not reachable
from it is treated as pre-completed and skipped. This is the unified
equivalent of StatePipeline.invoke(state=..., start_at=...).
"""

from __future__ import annotations

import pytest

from fireflyframework_agentic.exceptions import PipelineError
from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode
from fireflyframework_agentic.pipeline.engine import PipelineEngine


class _Counting:
def __init__(self, name: str):
self.name = name
self.calls = 0

async def execute(self, ctx, inputs):
self.calls += 1
return self.name


def _chain(*ids: str) -> tuple[DAG, dict[str, _Counting]]:
dag = DAG("chain")
steps: dict[str, _Counting] = {}
for nid in ids:
s = _Counting(nid)
steps[nid] = s
dag.add_node(DAGNode(node_id=nid, step=s))
for i in range(len(ids) - 1):
dag.add_edge(DAGEdge(source=ids[i], target=ids[i + 1]))
return dag, steps


# ---- baseline -------------------------------------------------------------


async def test_no_start_at_runs_every_node():
dag, steps = _chain("a", "b", "c")
result = await PipelineEngine(dag).run(inputs="x")
assert result.success
assert all(s.calls == 1 for s in steps.values())


# ---- start_at: skip upstream ----------------------------------------------


async def test_start_at_skips_upstream_nodes():
dag, steps = _chain("a", "b", "c", "d")
result = await PipelineEngine(dag).run(inputs="x", start_at="c")
assert result.success
assert steps["a"].calls == 0
assert steps["b"].calls == 0
assert steps["c"].calls == 1
assert steps["d"].calls == 1


async def test_start_at_first_node_is_like_no_start_at():
dag, steps = _chain("a", "b", "c")
result = await PipelineEngine(dag).run(inputs="x", start_at="a")
assert result.success
assert all(s.calls == 1 for s in steps.values())


async def test_start_at_terminal_runs_only_that_node():
dag, steps = _chain("a", "b", "c")
result = await PipelineEngine(dag).run(inputs="x", start_at="c")
assert result.success
assert steps["a"].calls == 0
assert steps["b"].calls == 0
assert steps["c"].calls == 1


# ---- start_at via callable ------------------------------------------------


async def test_start_at_accepts_callable_reference():
async def deploy(ctx, inputs):
return "deployed"

from fireflyframework_agentic.pipeline.steps import CallableStep

dag = DAG("callable")
dag.add_node(DAGNode(node_id="build", step=_Counting("build")))
dag.add_node(DAGNode(node_id="deploy", step=CallableStep(deploy)))
dag.add_edge(DAGEdge(source="build", target="deploy"))
result = await PipelineEngine(dag).run(inputs="x", start_at=deploy)
assert result.success
# Resolves deploy.__name__ -> 'deploy' -> only deploy ran.
assert result.outputs["deploy"].output == "deployed"


# ---- invalid start_at -----------------------------------------------------


async def test_unknown_start_at_raises():
dag, _ = _chain("a", "b")
with pytest.raises(PipelineError, match="start_at"):
await PipelineEngine(dag).run(inputs="x", start_at="ghost")


# ---- branching dag --------------------------------------------------------


async def test_start_at_in_branching_dag():
"""In a branching DAG, start_at picks one branch; the other is skipped entirely."""
dag = DAG("branchy")
for nid in ("root", "left", "right", "leftchild"):
dag.add_node(DAGNode(node_id=nid, step=_Counting(nid)))
dag.add_edge(DAGEdge(source="root", target="left"))
dag.add_edge(DAGEdge(source="root", target="right"))
dag.add_edge(DAGEdge(source="left", target="leftchild"))
result = await PipelineEngine(dag).run(inputs="x", start_at="left")
assert result.success
# 'right' is not downstream of 'left' — it should not run.
assert "right" not in result.outputs or result.outputs["right"].skipped
# 'leftchild' is reachable from 'left'.
assert result.outputs["leftchild"].success