diff --git a/docs/pipeline.md b/docs/pipeline.md index 57b0d152..88570307 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -7,6 +7,17 @@ workflows. It supports parallel execution, conditional branching, retries, timeo and fan-out/fan-in patterns -- everything needed to model real-world enterprise processing pipelines. +`PipelineBuilder` has two modes: + +* **Port-based** (legacy, parallel) — nodes communicate via `output_key` / + `input_key` edge ports and run concurrently within each topological level. + Best for ETL-shaped DAGs. Documented in the bulk of this guide. +* **State-based** — opt-in via `PipelineBuilder("name", state=SomeModel)`. + Nodes become `async (state) -> dict` over a typed shared state. One + `.branch(source, router)` call covers conditional routing; `Send(target, payload)` + covers runtime fan-out; a `Checkpointer` enables resume after failure. Best for + agentic workflows and ReAct-style loops. See [State-Based Pipelines](#state-based-pipelines). + --- ## Concepts @@ -88,15 +99,318 @@ The framework provides these built-in executors: - **CallableStep** -- Wraps any `async` function `(context, inputs) -> output`. - **BatchLLMStep** -- Processes multiple prompts concurrently through an agent for cost optimization. See [Batch Processing](#batch-processing-batchllmstep) below. -- **BranchStep** -- Routes execution to one of several downstream paths based on - a predicate (see [Conditional Branching](#conditional-branching-branchstep) below). -- **FanOutStep** -- Splits input into a list for parallel downstream processing. +- **BranchStep** _(deprecated)_ -- Routes execution to one of several downstream paths based on + a predicate. Use `.branch(...)` in [State-Based Pipelines](#state-based-pipelines) instead. +- **FanOutStep** _(deprecated)_ -- Splits input into a list for parallel downstream processing. + Use `Send` in [State-Based Pipelines](#runtime-fan-out-via-send) instead. - **FanInStep** -- Merges outputs from multiple upstream nodes. --- +## State-Based Pipelines + +Set `state=` on `PipelineBuilder` to switch to a declarative API designed for +agentic workflows. Nodes become `async (state) -> dict | None` functions over +a typed shared-state object; the engine reduces each node's partial-update +dict back into the state. + +```python +from typing import Annotated +from pydantic import BaseModel +from fireflyframework_agentic.pipeline import PipelineBuilder, append + + +class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] # reducer: append + intent: str | None = None # default reducer: replace + answer: str | None = None + + +async def classify(state: AgentState) -> dict: + return {"intent": "complaint" if "refund" in state.messages[-1] else "general"} + + +async def answer(state: AgentState) -> dict: + return {"answer": "Here is your answer."} + + +async def escalate(state: AgentState) -> dict: + return {"answer": "Escalated to human."} + + +def route(state: AgentState) -> str: + return "escalate" if state.intent == "complaint" else "answer" + + +pipeline = ( + PipelineBuilder("support-agent", state=AgentState) + .add_node(classify) # node id derived from fn.__name__ + .add_node(answer) + .add_node(escalate) + .branch(classify, route) # router returns target node id + .build() +) +result = await pipeline.invoke(AgentState(messages=["I want a refund"])) +print(result.state.answer) +``` + +### Reducers + +Reducers are declared as `Annotated[T, reducer_fn]` on the state schema. The +built-ins live in `fireflyframework_agentic.pipeline.reducers`: + +| Reducer | Semantics | +|---------------|-------------------------------------------------| +| `replace` | Last-write-wins (the default for any field). | +| `append` | Append a single item to a list. | +| `extend` | Concatenate two iterables. | +| `merge_dict` | Shallow-merge two dicts; update wins on conflict. | + +Custom reducers are any callable `(current, update) -> merged`. + +### Branching + +`.branch(source, router, mapping=None)` registers a synchronous +`(state) -> str | Send | list[Send]` router on `source`: + +* Returning a node id (string) routes to that node directly. +* Passing `mapping={"label": target_node, ...}` lets the router return an + abstract label instead of a node id. +* Returning a `Send` or `list[Send]` triggers runtime fan-out (see below). + +### Checkpoint + Resume + +Pass a `Checkpointer` to persist state after each successful node. Three +backends ship out of the box, all conforming to the same `Checkpointer` +Protocol so they're swappable without code changes. + +| Backend | Use when | Trade-off | Install | +|---|---|---|---| +| `FileCheckpointer` | Dev, single-host, ephemeral | No cross-process / cross-host sharing | (default — no extra) | +| `RedisCheckpointer` | Multi-worker, sub-day-scale runs | TTL eviction; not durable forever | `pip install fireflyframework-agentic[redis]` | +| `PostgresCheckpointer` | Long-lived runs, compliance, audit-friendly | Operational overhead of a DB | `pip install fireflyframework-agentic[postgres]` | + +```python +from fireflyframework_agentic.pipeline import FileCheckpointer # or Redis / Postgres + +pipeline = ( + PipelineBuilder("software-factory", state=BuildState, + checkpointer=FileCheckpointer("./checkpoints")) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() +) + +# Fresh run +result = await pipeline.invoke(BuildState(requirements="user-mgmt service")) + +# Resume after crash — picks up at the failed node, skips completed ones +result = await pipeline.invoke(run_id=result.run_id) + +# Or jump into a specific node with explicit state +result = await pipeline.invoke(state=loaded_state, start_at=deployer) +``` + +Swapping backends is a one-line change. Redis uses a TTL on each checkpoint +key (default 30 days) plus a sorted-set index of run IDs; Postgres uses a +single `firefly_checkpoints` table created idempotently on first save: + +```python +from fireflyframework_agentic.pipeline import RedisCheckpointer, PostgresCheckpointer + +# Either a URL/DSN (backend constructs its own client) or a pre-built client +# (lets you share a connection pool across many pipelines). +checkpointer = RedisCheckpointer(url="redis://localhost:6379/0", ttl_seconds=86400 * 30) +checkpointer = RedisCheckpointer(client=my_existing_redis) +checkpointer = PostgresCheckpointer(dsn="postgresql://user:pw@host/db") +checkpointer = PostgresCheckpointer(connection=my_existing_psycopg_connection) +``` + +### Cycles and `recursion_limit` + +State pipelines permit cycles for ReAct loops and retry-with-critique patterns. +The builder accepts `recursion_limit` (default 25) as a safety net — a runaway +loop surfaces as `result.success=False` with a clean error, not an infinite hang. + +```python +def route(state): + return "done" if state.counter >= 3 else "step" + +PipelineBuilder("loop", state=LoopState, recursion_limit=25) + .add_node(step).add_node(done).branch(step, route).build() +``` + +### Runtime Fan-Out via `Send` + +A router may return `list[Send(target, payload)]` to dispatch multiple +invocations of the same (or different) workers concurrently. Each Send's +payload is applied to a copy of the current state before its target runs; +results reduce back into shared state. Replaces the legacy `FanOutStep`. + +```python +from fireflyframework_agentic.pipeline import Send + +def dispatch(state): + return [Send("worker", {"item": x}) for x in state.items] + +PipelineBuilder("mapreduce", state=MapReduceState) + .add_node(planner).add_node(worker).add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() +``` + +When all worker targets share a common successor, the engine continues there +once the fan-out completes; the aggregator runs once with all results in +shared state. + +### Observability + +State pipelines emit lifecycle callbacks and OTel spans so ops can see what +an agent workflow is doing in real time. + +`StatePipelineEventHandler` mirrors the legacy `PipelineEventHandler` but +every callback carries the `run_id` (so events can be correlated across +resumes) and `on_node_start` carries a per-node visit counter (so cyclic +graphs and `Send` fan-outs are distinguishable). Implement any subset of +methods; missing ones are no-ops. + +```python +from fireflyframework_agentic.pipeline import PipelineBuilder, StatePipelineEventHandler + + +class ProgressHandler: + async def on_pipeline_start(self, name, run_id): + print(f"▶ [{name}] run {run_id} starting") + + async def on_node_start(self, name, run_id, node_id, visit): + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete(self, name, run_id, node_id, latency_ms): + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error(self, name, run_id, node_id, error): + print(f" ✗ {node_id}: {error}") + + async def on_pipeline_complete(self, name, run_id, success, duration_ms): + status = "OK" if success else "FAILED" + print(f"═ [{name}] {status} in {duration_ms:.0f}ms") + + +pipeline = ( + PipelineBuilder("agent", state=AgentState, event_handler=ProgressHandler()) + .add_node(classify).add_node(answer).add_node(escalate) + .branch(classify, route) + .build() +) +``` + +In parallel, the pipeline emits OTel spans automatically when +`observability_enabled` is True and `opentelemetry` is installed: + +- One pipeline-level span `pipeline.state.` around each `invoke`, + attributes `firefly.pipeline`, `firefly.run_id`. +- One per-node span `pipeline.state.node.` for each `fn(state)` + call, parented under the pipeline span, attributes `firefly.node`, + `firefly.visit`. +- For `Send` fan-out: one per-Send span as a sibling under the pipeline span. + +Handler exceptions are swallowed — observability never breaks business logic. + +### Human-in-the-loop (Pause) + +Any node may return ``Pause(reason="...")`` instead of a state update to halt +the pipeline cleanly. The current state is checkpointed with a paused marker; +``invoke`` returns with ``result.paused=True`` and ``result.success=False``. + +```python +from fireflyframework_agentic.pipeline import Pause + +async def await_deploy_approval(state: DeployState) -> Pause: + return Pause(reason="awaiting human approval to deploy to production") +``` + +To resume after the external approval comes in, call ``invoke`` with the same +``run_id`` and ``approve_pause=True``. Without ``approve_pause=True``, the +resume raises a ``PipelineError`` — the pause is sticky until explicitly +released. The successor of the paused node runs next; the pause node itself +is not re-executed. + +```python +first = await pipeline.invoke(DeployState(...)) +assert first.paused +# ...later, after approval... +done = await pipeline.invoke(run_id=first.run_id, approve_pause=True) +assert done.success +``` + +The configured ``StatePipelineEventHandler`` receives an ``on_node_pause`` +callback when this happens (the callback is optional — partial handlers +without it continue to work). + +### Audit Log + +Distinct from the ``Checkpointer`` (which stores the *latest* state for +crash recovery), an ``AuditLog`` is an append-only record of *every* node +visit for compliance, debugging, and replay. Wire one in via the +``audit_log`` kwarg: + +```python +from fireflyframework_agentic.pipeline import ( + PipelineBuilder, FileAuditLog, PostgresAuditLog, LoggingAuditLog, OtelAuditLog, +) + +PipelineBuilder("agent", state=AgentState, audit_log=FileAuditLog("./audit")) +``` + +Four backends ship, each conforming to the ``AuditLog`` Protocol: + +| Backend | Use when | Read API | Trace-correlated | Install | +|---|---|---|---|---| +| ``FileAuditLog`` | Dev / single-host | yes | no | (default) | +| ``PostgresAuditLog`` | Compliance, retention, cross-run queries | yes | no | ``[postgres]`` | +| ``LoggingAuditLog`` | Generic log stacks (Splunk-HEC, Loki, JSON-logging) | no (write-only) | no | (default — stdlib) | +| ``OtelAuditLog`` | OTel-native stacks (Application Insights, Datadog APM, OTel Collector) | no (write-only) | **yes** | ``opentelemetry-sdk`` | + +``FileAuditLog`` and ``PostgresAuditLog`` also implement +``QueryableAuditLog`` with ``list_entries(pipeline_name, run_id)``. The +write-only backends delegate query/search to the user's existing +observability stack. + +Audit-log write failures are non-fatal — logged but never abort the +pipeline. + +### Mermaid Export + +`StatePipeline.to_mermaid()` and `DAG.to_mermaid()` render the topology as a +Mermaid flowchart. Branch edges declared with an explicit mapping show their +label; dynamic routers are noted as such. + +### When to use which mode + +| Use port-based when… | Use state-based when… | +|----------------------|------------------------| +| Pure ETL: parallel, fan-out/fan-in, no shared state | Agentic workflow: classify → branch → respond / loop / retry | +| Each step's input is a single value from the previous step | Multiple agents reading/writing different fields of a shared object | +| You want the engine to run independent nodes concurrently | You want resume-after-failure and start-from-middle semantics | +| You're happy with `BranchStep` + per-node `condition` lambdas | You want one `.branch(...)` call and inspectable routing | + +See [`examples/pipeline_state.py`](../examples/pipeline_state.py) for a +runnable demo covering branching, software-factory checkpoint/resume, and +map-reduce fan-out. + +--- + ## Parallel Execution (Fan-Out / Fan-In) +> **`FanOutStep` is deprecated.** For runtime fan-out (one dispatch per item, +> arbitrary count), prefer `Send` from [State-Based Pipelines](#runtime-fan-out-via-send). +> `FanOutStep` still works for now (it emits a `DeprecationWarning` on +> construction); `FanInStep` is not deprecated. + ```mermaid graph TD SPLIT[Fan-Out] --> W1[Worker 1] @@ -248,6 +562,12 @@ dag.add_node(DAGNode( ### Conditional Branching (BranchStep) +> **Deprecated.** Prefer [State-Based Pipelines](#state-based-pipelines) with +> `.branch(source, router)` — one call instead of `BranchStep` + per-node +> `condition` lambdas, and the topology becomes inspectable as data. +> `BranchStep` still works (it emits a `DeprecationWarning` on construction); +> removal will be tracked in a follow-up issue once internal callers migrate. + `BranchStep` provides router-based conditional branching. The router callable receives the node's input and returns a string key. Downstream nodes use condition gates to check the branch key and execute only the matching path. diff --git a/examples/README.md b/examples/README.md index cd790e4f..ef28af32 100644 --- a/examples/README.md +++ b/examples/README.md @@ -56,6 +56,8 @@ If `OPENAI_API_KEY` is not set, each script will prompt you interactively. ## Pipeline Examples - **`pipeline_branching.py`** — `BranchStep` for conditional routing in a DAG, `PipelineEventHandler` for live progress, and `DAGNode.backoff_factor` for exponential retry backoff. **No API key required.** +- **`pipeline_state.py`** — Three short scenarios with the state-based `PipelineBuilder` (`state=` mode): sentiment branching with `.branch()`, map-reduce with `Send` fan-out, and a HITL deploy gate using `Pause` plus `FileAuditLog`. **No API key required.** +- **`software_factory/`** — Self-contained example package showing a state-mode agentic SDLC pipeline (`architect → codegen → builder → qa → stable_release`) with the QA feedback loop (`recursion_limit=3`), checkpoint + resume on a transient `builder` failure, and a `StatePipelineEventHandler` printing progress. Includes plug-and-play `Checkpointer` Protocol implementations for Postgres and Redis under `checkpointers/`, and a `QueryableAuditLog` Postgres template under `audit/`. **No API key required.** ## Complex Examples diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py new file mode 100644 index 00000000..2bb3c27c --- /dev/null +++ b/examples/pipeline_state.py @@ -0,0 +1,244 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-based PipelineBuilder quick-start: branching, Send fan-out, HITL Pause. + +Three short scenarios in one file: + +1. **Branching** — sentiment-classification workflow with one ``.branch(...)`` + call (vs ``BranchStep`` + per-node ``condition`` lambdas in port-based mode). + +2. **Map-reduce via ``Send``** — a planner dispatches one ``Send`` per work + item; workers run concurrently; an aggregator runs once with all results + merged via the ``extend`` reducer. + +3. **HITL Pause + audit log** — a deploy gate that returns ``Pause(...)`` to + wait for human approval; resume with ``approve_pause=True``; a + ``FileAuditLog`` captures every node visit with its status. + +For the deeper software-factory walkthrough (QA feedback loop, checkpoint + +resume, Postgres / Redis checkpointer templates), see the self-contained +example package ``examples/software_factory/``. + +Usage:: + + uv run python examples/pipeline_state.py + +.. note:: No OpenAI API key required — all "agents" are plain Python stubs. +""" + +from __future__ import annotations + +import asyncio +import logging +import tempfile +from pathlib import Path +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + FileAuditLog, + FileCheckpointer, + Pause, + PipelineBuilder, + Send, + extend, +) + +# Quiet the pipeline's own logger.exception() when we deliberately exercise +# a node failure — the failure is part of the demo, not a bug. +logging.getLogger("fireflyframework_agentic.pipeline").setLevel(logging.CRITICAL) + + +# ============================================================================= +# Scenario 1 — Branching +# ============================================================================= + + +class SentimentState(BaseModel): + text: str + sentiment: str | None = None + response: str | None = None + + +async def classify_sentiment(state: SentimentState) -> dict: + text = state.text.lower() + positive = {"good", "great", "love", "amazing", "wonderful", "happy", "excellent"} + negative = {"bad", "terrible", "hate", "awful", "horrible", "sad", "poor"} + pos = sum(1 for w in text.split() if w in positive) + neg = sum(1 for w in text.split() if w in negative) + return {"sentiment": "positive" if pos >= neg else "negative"} + + +async def positive_reply(state: SentimentState) -> dict: + return {"response": "😊 Thank you for your kind words!"} + + +async def negative_reply(state: SentimentState) -> dict: + return {"response": "😟 We're sorry to hear that. We'll improve!"} + + +def route_by_sentiment(state: SentimentState) -> str: + # The router returns the node id directly — no mapping needed. + return "positive_reply" if state.sentiment == "positive" else "negative_reply" + + +async def run_branching() -> None: + print("=== 1. Branching (state mode) ===\n") + + pipeline = ( + PipelineBuilder("sentiment", state=SentimentState) + .add_node(classify_sentiment) + .add_node(positive_reply) + .add_node(negative_reply) + .branch(classify_sentiment, route_by_sentiment) + .build() + ) + + for text in ["This product is great and amazing!", "The service was terrible and awful."]: + result = await pipeline.invoke(SentimentState(text=text)) + print(f" input: {text!r}") + print(f" output: {result.state.response}\n") + + +# ============================================================================= +# Scenario 2 — Map-reduce via Send +# +# (The software-factory scenario that used to live here has its own folder +# now: ``examples/software_factory/``. It exercises the QA feedback loop, +# checkpoint + resume, and includes plug-and-play Postgres / Redis templates.) +# ============================================================================= + + +class MapReduceState(BaseModel): + items: list[str] = [] + processed: Annotated[list[str], extend] = [] + summary: str | None = None + # Per-Send payload field — each worker receives its own item here. + item: str | None = None + + +async def plan(state: MapReduceState) -> dict: + # No state mutation; the dispatch router below decides what runs next. + return {} + + +async def process_item(state: MapReduceState) -> dict: + assert state.item is not None + return {"processed": [f"processed:{state.item}"]} + + +async def aggregate(state: MapReduceState) -> dict: + return {"summary": f"Processed {len(state.processed)} items: {state.processed}"} + + +def dispatch(state: MapReduceState) -> list[Send]: + # One Send per item — workers run concurrently. The ``extend`` reducer on + # ``processed`` merges all worker outputs into one list. + return [Send("process_item", {"item": x}) for x in state.items] + + +async def run_map_reduce() -> None: + print("=== 2. Map-reduce via Send ===\n") + + pipeline = ( + PipelineBuilder("mapreduce", state=MapReduceState) + .add_node(plan) + .add_node(process_item) + .add_node(aggregate) + .add_edge(process_item, aggregate) + .branch(plan, dispatch) + .build() + ) + result = await pipeline.invoke(MapReduceState(items=["alpha", "beta", "gamma", "delta"])) + print(f" summary: {result.state.summary}") + + +# ============================================================================= +# Entrypoint +# ============================================================================= + + +class HitlState(BaseModel): + """State threaded through a deploy pipeline gated by human approval.""" + + target_env: str + artifact: str | None = None + deployed_to: str | None = None + + +async def build_artifact(state: HitlState) -> dict: + return {"artifact": f"build-{state.target_env}.tar.gz"} + + +async def await_approval(state: HitlState) -> Pause: + return Pause(reason=f"awaiting human approval to deploy {state.artifact} to {state.target_env}") + + +async def deploy_artifact(state: HitlState) -> dict: + return {"deployed_to": f"https://{state.target_env}.example.com"} + + +async def run_hitl_with_audit() -> None: + print("=== 3. Human-in-the-loop deploy gate with audit log ===\n") + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + ckpt = FileCheckpointer(root / "ckpt") + audit = FileAuditLog(root / "audit") + pipeline = ( + PipelineBuilder( + "hitl-deploy", + state=HitlState, + checkpointer=ckpt, + audit_log=audit, + ) + .add_node(build_artifact) + .add_node(await_approval) + .add_node(deploy_artifact) + .chain(build_artifact, await_approval, deploy_artifact) + .build() + ) + + # First run halts at the approval gate. + first = await pipeline.invoke(HitlState(target_env="prod")) + print(f" first run: paused={first.paused}, paused_node={first.paused_node}") + print(f" reason: {first.pause_reason}") + print(f" run_id: {first.run_id}\n") + + # ...time passes; a human reviews and approves... + print(" (human reviews and approves)\n") + + # Resume with explicit approval. + done = await pipeline.invoke(run_id=first.run_id, approve_pause=True) + print(f" resumed: success={done.success}, deployed_to={done.state.deployed_to}") + print(f" completed: {done.completed_nodes}\n") + + # Audit log captures every node visit with its status. + entries = audit.list_entries("hitl-deploy", first.run_id) + print(" audit trail:") + for e in entries: + extra = f" reason={e.pause_reason!r}" if e.pause_reason else "" + print(f" seq={e.sequence} node={e.node_id} status={e.status}{extra}") + + +async def main() -> None: + await run_branching() + await run_map_reduce() + await run_hitl_with_audit() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/software_factory/README.md b/examples/software_factory/README.md new file mode 100644 index 00000000..e92a53a6 --- /dev/null +++ b/examples/software_factory/README.md @@ -0,0 +1,137 @@ +# `software_factory/` — a state-based agentic SDLC pipeline + +A small, self-contained example that shows the headline features of +`PipelineBuilder` in state mode: + +- **State + reducers** — one Pydantic model carries everything the agents read or write; `extend` accumulates QA feedback across loop iterations. +- **Branching** — one `.branch("qa", qa_router)` call gives both the success terminus and the QA cycle. +- **Cycle with `recursion_limit`** — the QA fail → codegen loop is something port-based DAGs cannot express. +- **Checkpoint + resume** — `builder` raises a simulated transient error on its first call; `invoke(run_id=...)` resumes from the checkpoint. +- **Observability handler** — a `StatePipelineEventHandler` prints per-node progress. + +No LLM calls. All agents are deterministic stubs so the example runs offline and the smoke test is stable. + +## Run it + +```bash +source ~/.venvs/firefly/bin/activate +python -m examples.software_factory +``` + +Expected output: + +``` +▶ [software-factory] run abc123ef… starting + ▶ architect (visit #1) + ✔ architect (0ms) + ▶ codegen (visit #1) + ✔ codegen (0ms) + ▶ builder (visit #1) + ✗ builder: dep install timed out +═ [software-factory] FAILED in 1ms + +first run: success=False failed_node=builder run_id=abc123ef… + +▶ [software-factory] run abc123ef… starting + ▶ builder (visit #1) + ✔ builder (0ms) + ▶ qa (visit #1) + ✔ qa (0ms) + ▶ codegen (visit #2) + ✔ codegen (0ms) + ▶ builder (visit #2) + ✔ builder (0ms) + ▶ qa (visit #2) + ✔ qa (0ms) + ▶ stable_release (visit #1) + ✔ stable_release (0ms) +═ [software-factory] OK in 2ms + +resumed: success=True release=v2026.05.28 iteration=2 +qa_feedback: ['missing PSD2 strong-auth flow'] +``` + +## The DAG + +``` + ┌─────────── qa_status == 'fail' → codegen (recursion_limit=3) ─────────┐ + │ │ + ▼ │ +architect → codegen → builder → qa ──(qa_router)──▶ stable_release │ + │ │ + └───────────────────────────────────────────────────┘ +``` + +| Node | What it does | +|---|---| +| `architect` | Writes a stub ADR string into `state.adr`. | +| `codegen` | Bumps `state.iteration`, writes `state.code = "v{iteration} (addresses: ...)"`. Iteration 2+ visibly incorporates `qa_feedback`. | +| `builder` | **Transient failure** on the first call across the process (`raise RuntimeError("dep install timed out")`). Succeeds on every subsequent call. | +| `qa` | **Substantive failure** on iteration 1 (`qa_status="fail"`, appends to `qa_feedback`). Passes on iteration 2. | +| `stable_release` | Sets `release_tag`. Terminal. | + +### Why are `codegen` and `builder` separate nodes? + +In stub form they look redundant. They're kept distinct because they model **two different failure-recovery patterns** the state-mode API supports: + +| Failure mode | Meaning | How the pipeline recovers | +|---|---|---| +| `builder` raises | Transient (network blip, dep flake) — same code, just retry | The engine catches the exception, checkpoints the failure, returns `success=False`. `invoke(run_id=...)` resumes by re-running `builder` in place. **No cycle.** | +| `qa` returns `"fail"` | Substantive (tests don't pass) — the code itself needs to change | `qa_router` returns `"codegen"`; the cycle re-enters `codegen` which writes v2 informed by `qa_feedback`. | + +One pipeline, two recovery patterns. Collapsing the nodes loses one of them. + +## Swapping the checkpointer + +The example defaults to `FileCheckpointer`. To run against a real Redis or Postgres: + +```bash +FIREFLY_CKPT=postgres PG_DSN="postgresql://localhost:5432/firefly" python -m examples.software_factory +FIREFLY_CKPT=redis REDIS_URL="redis://localhost:6379/0" python -m examples.software_factory +``` + +The Postgres and Redis backends live in this folder as **plug-and-play templates**, not framework code: + +- `checkpointers/postgres.py` — implements the framework's `Checkpointer` Protocol against a caller-supplied `psycopg.Connection`. +- `checkpointers/redis.py` — same idea against a caller-supplied `redis.Redis` client. +- `audit/postgres.py` — implements `QueryableAuditLog` against a caller-supplied `psycopg.Connection`. + +Each file is a flat ~50-LOC class. The framework no longer ships these — copy whichever you need into your project, adapt the table name or key prefix, and pass your own connection. The framework's `Checkpointer` and `AuditLog` Protocols are the only contract you need to match. + +## When to use Redis vs Postgres + +Both implement the same `Checkpointer` Protocol. The choice is about durability, latency, and inspection: + +| | Redis | Postgres | +|---|---|---| +| Durability | RDB + AOF; can lose the tail on crash unless `fsync=always` (slow). | WAL-fsynced; survives crashes cleanly. | +| Latency | Sub-millisecond writes. | Single-digit ms. | +| TTL | Native per-key (`EX` on `SET`). Old checkpoints disappear automatically. | Manual (cron, partition drop). | +| Inspection | `KEYS` / `GET`; no SQL, no joins. | Full SQL — joinable with the app's domain tables. | +| Footprint | Often already in the stack as a cache. | Often already in the stack as the app DB. | + +Rule of thumb: + +- **Redis** for short-lived workflows (minutes to a few hours), high throughput, where you're OK losing the last few checkpoints on a hard crash and want automatic TTL cleanup. +- **Postgres** for long-running workflows (hours to days, anything that uses `Pause` for human approval), compliance/audit needs, or when you want to query checkpoint history with SQL. + +For most Signature client apps already running on PostgreSQL Flexible Server, Postgres is the default; Redis is the choice when latency matters more than durability. + +## File layout + +``` +software_factory/ +├── README.md +├── __main__.py # entry point — crash, then resume +├── state.py # BuildState pydantic model + extend reducer +├── agents.py # 5 stub agents (architect, codegen, builder, qa, stable_release) +├── pipeline.py # build_pipeline(); qa_router +├── progress.py # StatePipelineEventHandler implementation +├── checkpointers/ +│ ├── postgres.py # Checkpointer Protocol impl (psycopg) +│ └── redis.py # Checkpointer Protocol impl (redis-py) +└── audit/ + └── postgres.py # QueryableAuditLog Protocol impl (psycopg) +``` + +The end-to-end smoke test lives at `tests/examples/software_factory/test_pipeline.py` — same shape as the other example tests in this repo. diff --git a/examples/software_factory/__init__.py b/examples/software_factory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/__main__.py b/examples/software_factory/__main__.py new file mode 100644 index 00000000..a2edb9f0 --- /dev/null +++ b/examples/software_factory/__main__.py @@ -0,0 +1,75 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Entry point: ``python -m examples.software_factory``. + +Demonstrates the full QA-loop + checkpoint-resume flow: + +1. First ``invoke`` runs architect → codegen → builder. The builder raises on + its first call (simulated transient ``dep install`` failure); the engine + checkpoints the failure and returns ``success=False``. +2. Second ``invoke(run_id=...)`` resumes from the checkpoint. The builder + succeeds, QA fails with PSD2 feedback, the cycle re-enters codegen, the + rewritten code passes QA, ``stable_release`` runs, and the pipeline + finishes with ``success=True``. + +By default uses :class:`FileCheckpointer` over a tmp directory. Set +``FIREFLY_CKPT=postgres`` (with ``PG_DSN``) or ``FIREFLY_CKPT=redis`` (with +``REDIS_URL``) to swap in the templates under ``checkpointers/``. +""" + +from __future__ import annotations + +import asyncio +import os +import tempfile +from pathlib import Path + +from examples.software_factory.pipeline import build_pipeline +from examples.software_factory.state import BuildState +from fireflyframework_agentic.pipeline import Checkpointer, FileCheckpointer + + +def _resolve_checkpointer(default_dir: Path) -> Checkpointer: + backend = os.environ.get("FIREFLY_CKPT", "file").lower() + if backend == "postgres": + import psycopg + + from examples.software_factory.checkpointers.postgres import ( + PostgresCheckpointer, + ) + + dsn = os.environ["PG_DSN"] + return PostgresCheckpointer(psycopg.connect(dsn, autocommit=True)) + + if backend == "redis": + import redis + + from examples.software_factory.checkpointers.redis import RedisCheckpointer + + url = os.environ["REDIS_URL"] + return RedisCheckpointer(redis.Redis.from_url(url, decode_responses=True)) + + return FileCheckpointer(default_dir) + + +async def main() -> None: + with tempfile.TemporaryDirectory() as ckpt_dir: + checkpointer = _resolve_checkpointer(Path(ckpt_dir)) + pipeline = build_pipeline(checkpointer) + + first = await pipeline.invoke(BuildState(request="payments microservice")) + print(f"\nfirst run: success={first.success} failed_node={first.failed_node} run_id={first.run_id}\n") + + resumed = await pipeline.invoke(run_id=first.run_id) + print( + f"\nresumed: success={resumed.success} " + f"release={resumed.state.release_tag} iteration={resumed.state.iteration}" + ) + print(f"qa_feedback: {resumed.state.qa_feedback}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/software_factory/agents.py b/examples/software_factory/agents.py new file mode 100644 index 00000000..44f82efe --- /dev/null +++ b/examples/software_factory/agents.py @@ -0,0 +1,68 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Stub agents for the software factory example. + +No LLM calls. Each agent is a plain ``async (state) -> dict`` function that +returns the fields it wants merged into shared state. The behaviour is +deterministic so the example runs offline and the test suite is stable. + +Two failure modes are simulated to show how state mode handles each: + +* ``builder`` raises on its very first call (transient failure → recovered + via checkpoint + ``invoke(run_id=...)`` resume). +* ``qa`` returns ``qa_status='fail'`` on iteration 1 (substantive failure → + recovered via cycle back to ``codegen``). +""" + +from __future__ import annotations + +from examples.software_factory.state import BuildState + +_BUILDER_ATTEMPTS: dict[str, int] = {} + + +async def architect(state: BuildState) -> dict: + adr = ( + f"ADR for '{state.request}': split into api / domain / data modules; " + "use idiomatic Firefly patterns; PSD2 strong-auth required for payments." + ) + return {"adr": adr} + + +async def codegen(state: BuildState) -> dict: + next_iteration = state.iteration + 1 + if state.qa_feedback: + addressed = "; ".join(state.qa_feedback) + code = f"v{next_iteration} (addresses: {addressed})" + else: + code = f"v{next_iteration}" + return {"iteration": next_iteration, "code": code} + + +async def builder(state: BuildState) -> dict: + # Transient failure on the very first call across the whole process — + # exercises checkpoint + resume. Subsequent calls always succeed. + key = "global" + _BUILDER_ATTEMPTS[key] = _BUILDER_ATTEMPTS.get(key, 0) + 1 + if _BUILDER_ATTEMPTS[key] == 1: + raise RuntimeError("dep install timed out") + return {"build_status": "ok"} + + +async def qa(state: BuildState) -> dict: + # Substantive failure on iteration 1: code lacks PSD2 strong-auth flow. + # Iteration 2's codegen sees `qa_feedback` and rewrites the code, + # so QA passes on the next visit. + if state.iteration <= 1: + return { + "qa_status": "fail", + "qa_feedback": ["missing PSD2 strong-auth flow"], + } + return {"qa_status": "pass"} + + +async def stable_release(state: BuildState) -> dict: + return {"release_tag": "v2026.05.28"} diff --git a/examples/software_factory/audit/__init__.py b/examples/software_factory/audit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/audit/postgres.py b/examples/software_factory/audit/postgres.py new file mode 100644 index 00000000..9858faab --- /dev/null +++ b/examples/software_factory/audit/postgres.py @@ -0,0 +1,112 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Plug-and-play Postgres audit log for fireflyframework-agentic. + +This is **example code**, not framework code. Implements the framework's +:class:`QueryableAuditLog` Protocol (write + read-back) against a +caller-supplied ``psycopg.Connection``. + +Distinct from the checkpointer — the checkpointer stores the latest state +for crash recovery; the audit log stores every node visit for compliance +and replay. +""" + +from __future__ import annotations + +import json +from typing import Any + +from fireflyframework_agentic.pipeline import AuditEntry + + +class PostgresAuditLog: + """Append-only audit log backed by a single ``firefly_audit`` table. + + Implements the :class:`fireflyframework_agentic.pipeline.QueryableAuditLog` + Protocol — :meth:`record` writes one entry; :meth:`list_entries` reads + every entry for a given run in sequence order. + """ + + def __init__(self, connection: Any) -> None: + self._conn = connection + with connection.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS firefly_audit ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + visit INT NOT NULL, + node_id TEXT NOT NULL, + started_at TIMESTAMPTZ NOT NULL, + completed_at TIMESTAMPTZ NOT NULL, + latency_ms DOUBLE PRECISION NOT NULL, + status TEXT NOT NULL, + inputs_snapshot JSONB NOT NULL, + outputs_snapshot JSONB NOT NULL, + error_message TEXT, + pause_reason TEXT, + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS firefly_audit_run_idx + ON firefly_audit (pipeline_name, run_id); + """ + ) + + def record(self, entry: AuditEntry) -> None: + with self._conn.cursor() as cur: + cur.execute( + "INSERT INTO firefly_audit " + "(pipeline_name, run_id, sequence, visit, node_id, started_at, completed_at, " + " latency_ms, status, inputs_snapshot, outputs_snapshot, error_message, pause_reason) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO NOTHING", + ( + entry.pipeline_name, + entry.run_id, + entry.sequence, + entry.visit, + entry.node_id, + entry.started_at, + entry.completed_at, + entry.latency_ms, + entry.status, + json.dumps(entry.inputs_snapshot), + json.dumps(entry.outputs_snapshot), + entry.error_message, + entry.pause_reason, + ), + ) + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + with self._conn.cursor() as cur: + cur.execute( + "SELECT pipeline_name, run_id, sequence, visit, node_id, started_at, " + " completed_at, latency_ms, status, inputs_snapshot, outputs_snapshot, " + " error_message, pause_reason " + "FROM firefly_audit WHERE pipeline_name = %s AND run_id = %s " + "ORDER BY sequence", + (pipeline_name, run_id), + ) + rows = cur.fetchall() + return [ + AuditEntry( + pipeline_name=row[0], + run_id=row[1], + sequence=row[2], + visit=row[3], + node_id=row[4], + started_at=row[5], + completed_at=row[6], + latency_ms=row[7], + status=row[8], + inputs_snapshot=json.loads(row[9]) if isinstance(row[9], str) else row[9], + outputs_snapshot=json.loads(row[10]) if isinstance(row[10], str) else row[10], + error_message=row[11], + pause_reason=row[12], + ) + for row in rows + ] diff --git a/examples/software_factory/checkpointers/__init__.py b/examples/software_factory/checkpointers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/software_factory/checkpointers/postgres.py b/examples/software_factory/checkpointers/postgres.py new file mode 100644 index 00000000..a9a6d3eb --- /dev/null +++ b/examples/software_factory/checkpointers/postgres.py @@ -0,0 +1,107 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Plug-and-play Postgres :class:`Checkpointer` for fireflyframework-agentic. + +This is **example code**, not framework code. Copy this file into your +project and adapt as needed: + +* Pass your own ``psycopg.Connection``. This template does not own the pool. +* Adapt the table name if ``firefly_checkpoints`` clashes with anything. +* Add retry / instrumentation in a wrapper if your stack needs it — the + framework engine already catches and logs checkpoint failures, so the + pipeline keeps running on transient errors regardless. +""" + +from __future__ import annotations + +import json +from typing import Any + +from fireflyframework_agentic.pipeline import CheckpointRecord + + +class PostgresCheckpointer: + """Stores checkpoints in a single ``firefly_checkpoints`` table. + + Implements the :class:`fireflyframework_agentic.pipeline.Checkpointer` + Protocol — three sync methods over a caller-supplied connection. + """ + + def __init__(self, connection: Any) -> None: + self._conn = connection + with connection.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS firefly_checkpoints ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + node_id TEXT NOT NULL, + state JSONB NOT NULL, + completed_nodes JSONB NOT NULL, + paused BOOLEAN NOT NULL DEFAULT FALSE, + pause_reason TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS firefly_checkpoints_run_idx + ON firefly_checkpoints (pipeline_name, run_id); + """ + ) + + def save(self, record: CheckpointRecord) -> None: + with self._conn.cursor() as cur: + cur.execute( + "INSERT INTO firefly_checkpoints " + "(pipeline_name, run_id, sequence, node_id, state, completed_nodes, paused, pause_reason) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET " + "node_id=EXCLUDED.node_id, state=EXCLUDED.state, " + "completed_nodes=EXCLUDED.completed_nodes, " + "paused=EXCLUDED.paused, pause_reason=EXCLUDED.pause_reason", + ( + record.pipeline_name, + record.run_id, + record.sequence, + record.node_id, + json.dumps(record.state), + json.dumps(record.completed_nodes), + record.paused, + record.pause_reason, + ), + ) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + with self._conn.cursor() as cur: + cur.execute( + "SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes, " + " paused, pause_reason " + "FROM firefly_checkpoints " + "WHERE pipeline_name = %s AND run_id = %s " + "ORDER BY sequence DESC LIMIT 1", + (pipeline_name, run_id), + ) + row = cur.fetchone() + if row is None: + return None + return CheckpointRecord( + pipeline_name=row[0], + run_id=row[1], + sequence=row[2], + node_id=row[3], + state=json.loads(row[4]) if isinstance(row[4], str) else row[4], + completed_nodes=json.loads(row[5]) if isinstance(row[5], str) else row[5], + paused=row[6], + pause_reason=row[7], + ) + + def list_runs(self, pipeline_name: str) -> list[str]: + with self._conn.cursor() as cur: + cur.execute( + "SELECT DISTINCT run_id FROM firefly_checkpoints WHERE pipeline_name = %s ORDER BY run_id", + (pipeline_name,), + ) + return [r[0] for r in cur.fetchall()] diff --git a/examples/software_factory/checkpointers/redis.py b/examples/software_factory/checkpointers/redis.py new file mode 100644 index 00000000..ae43283d --- /dev/null +++ b/examples/software_factory/checkpointers/redis.py @@ -0,0 +1,65 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Plug-and-play Redis :class:`Checkpointer` for fireflyframework-agentic. + +This is **example code**, not framework code. Copy this file into your +project and adapt as needed: + +* Pass your own ``redis.Redis`` client. The template does not own it. +* Tune ``ttl_seconds`` to match your workflow's longest expected wall-clock. +* The ``firefly:ckpt::runs`` ZSET does not expire — it's tiny and + serves as the index for :meth:`list_runs`. +""" + +from __future__ import annotations + +import json +import time +from typing import Any + +from fireflyframework_agentic.pipeline import CheckpointRecord + + +class RedisCheckpointer: + """Stores checkpoints as TTL'd JSON keys, indexed by a per-pipeline ZSET. + + Key layout: + + * ``firefly:ckpt:::_`` → JSON record (TTL). + * ``firefly:ckpt::runs`` → ZSET of run_ids (no TTL). + + Implements the :class:`fireflyframework_agentic.pipeline.Checkpointer` + Protocol — three sync methods over a caller-supplied client. + """ + + _PREFIX = "firefly:ckpt" + + def __init__(self, client: Any, *, ttl_seconds: int = 30 * 24 * 3600) -> None: + self._client = client + self._ttl = ttl_seconds + + def save(self, record: CheckpointRecord) -> None: + key = f"{self._PREFIX}:{record.pipeline_name}:{record.run_id}:{record.sequence:06d}_{record.node_id}" + self._client.set(key, record.model_dump_json(), ex=self._ttl) + self._client.zadd( + f"{self._PREFIX}:{record.pipeline_name}:runs", + {record.run_id: time.time()}, + ) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + pattern = f"{self._PREFIX}:{pipeline_name}:{run_id}:*" + keys = self._client.keys(pattern) + if not keys: + return None + # Keys are zero-padded by sequence — lex-sorted last = numerically-latest. + latest_key = sorted(keys)[-1] + payload = self._client.get(latest_key) + if payload is None: + return None + return CheckpointRecord.model_validate(json.loads(payload)) + + def list_runs(self, pipeline_name: str) -> list[str]: + return list(self._client.zrange(f"{self._PREFIX}:{pipeline_name}:runs", 0, -1)) diff --git a/examples/software_factory/pipeline.py b/examples/software_factory/pipeline.py new file mode 100644 index 00000000..f64f3b0e --- /dev/null +++ b/examples/software_factory/pipeline.py @@ -0,0 +1,61 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Wire the software-factory DAG. + +The pipeline: + + architect → codegen → builder → qa ──(qa_router)──▶ stable_release + │ + └──── qa_status='fail' ──▶ codegen (cycle) + +``qa_router`` is the one piece of routing logic — it implements the QA +feedback loop with a hard cap of ``recursion_limit=3``. +""" + +from __future__ import annotations + +from examples.software_factory.agents import ( + architect, + builder, + codegen, + qa, + stable_release, +) +from examples.software_factory.progress import ProgressHandler +from examples.software_factory.state import BuildState +from fireflyframework_agentic.pipeline import ( + Checkpointer, + PipelineBuilder, + PipelineEngine, +) + + +def qa_router(state: BuildState) -> str: + """Route on QA outcome — pass → release, fail → codegen (cycle).""" + return "stable_release" if state.qa_status == "pass" else "codegen" + + +def build_pipeline(checkpointer: Checkpointer) -> PipelineEngine: + pipeline = ( + PipelineBuilder( + "software-factory", + state=BuildState, + checkpointer=checkpointer, + recursion_limit=3, + event_handler=ProgressHandler(), + ) + .add_node(architect) + .add_node(codegen) + .add_node(builder) + .add_node(qa) + .add_node(stable_release) + .add_edge("architect", "codegen") + .add_edge("codegen", "builder") + .add_edge("builder", "qa") + .branch("qa", qa_router) + .build() + ) + return pipeline diff --git a/examples/software_factory/progress.py b/examples/software_factory/progress.py new file mode 100644 index 00000000..f43892e6 --- /dev/null +++ b/examples/software_factory/progress.py @@ -0,0 +1,34 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Console progress handler. + +Implements (structurally) the framework's :class:`EventHandler` +Protocol. Prints one line per pipeline / node event so the QA loop and +checkpoint+resume flow are visible when running the example by hand. +""" + +from __future__ import annotations + + +class ProgressHandler: + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + print(f"▶ [{pipeline_name}] run {run_id[:8]}… starting") + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + print(f" ▶ {node_id} (visit #{visit})") + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + print(f" ✔ {node_id} ({latency_ms:.0f}ms)") + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + print(f" ✗ {node_id}: {error}") + + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + print(f" ⏸ {node_id}: {reason}") + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + status = "OK" if success else "FAILED" + print(f"═ [{pipeline_name}] {status} in {duration_ms:.0f}ms") diff --git a/examples/software_factory/state.py b/examples/software_factory/state.py new file mode 100644 index 00000000..f9c15bb6 --- /dev/null +++ b/examples/software_factory/state.py @@ -0,0 +1,30 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Shared state for the software factory pipeline. + +One Pydantic model carries every field the agents read or write. The only +non-default reducer is ``extend`` on ``qa_feedback`` so feedback accumulates +across QA-loop iterations instead of being overwritten on each pass. +""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import extend + + +class BuildState(BaseModel): + request: str + iteration: int = 0 + adr: str | None = None + code: str | None = None + build_status: str | None = None + qa_status: str | None = None + qa_feedback: Annotated[list[str], extend] = [] + release_tag: str | None = None diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index b6392aed..c4f69fd0 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -17,12 +17,41 @@ This package provides a Directed Acyclic Graph (DAG) execution engine that wires agents, reasoning patterns, validation, and tools into production pipelines where independent stages execute concurrently. + +Two builder modes exist: + +* **Port-based** (legacy, parallel): :class:`PipelineEngine` executes a DAG + whose nodes communicate via ``output_key``/``input_key`` edge ports. +* **State-based**: configure ``PipelineBuilder(state=SomeModel)`` and nodes + become ``async (state) -> dict`` functions over a typed shared state. + Branching is one ``.branch(source, router)`` call. Optional checkpointing + via :class:`Checkpointer` enables resume after failure and mid-pipeline start. """ +from fireflyframework_agentic.pipeline.audit import ( + AuditEntry, + AuditLog, + FileAuditLog, + LoggingAuditLog, + OtelAuditLog, + QueryableAuditLog, +) from fireflyframework_agentic.pipeline.builder import PipelineBuilder +from fireflyframework_agentic.pipeline.checkpoint import ( + Checkpointer, + CheckpointRecord, + FileCheckpointer, +) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine, PipelineEventHandler +from fireflyframework_agentic.pipeline.engine import ( + EventHandler, + Pause, + PipelineEngine, + PipelineEventHandler, + Send, +) +from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict, replace from fireflyframework_agentic.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult from fireflyframework_agentic.pipeline.steps import ( AgentStep, @@ -38,25 +67,41 @@ ) __all__ = [ + "DAG", "AgentStep", + "AuditEntry", + "AuditLog", "BatchLLMStep", "BranchStep", "CallableStep", - "DAG", + "CheckpointRecord", + "Checkpointer", "DAGEdge", "DAGNode", "EmbeddingStep", "ExecutionTraceEntry", "FailureStrategy", + "EventHandler", "FanInStep", "FanOutStep", + "FileAuditLog", + "FileCheckpointer", + "LoggingAuditLog", "NodeResult", + "OtelAuditLog", + "Pause", "PipelineBuilder", "PipelineContext", "PipelineEngine", "PipelineEventHandler", "PipelineResult", + "QueryableAuditLog", "ReasoningStep", "RetrievalStep", + "Send", "StepExecutor", + "append", + "extend", + "merge_dict", + "replace", ] diff --git a/fireflyframework_agentic/pipeline/audit.py b/fireflyframework_agentic/pipeline/audit.py new file mode 100644 index 00000000..fbb02156 --- /dev/null +++ b/fireflyframework_agentic/pipeline/audit.py @@ -0,0 +1,225 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Append-only audit logs for state-pipeline node visits. + +Distinct from :mod:`fireflyframework_agentic.pipeline.checkpoint` — the +checkpointer stores the *latest* state for crash recovery; the audit log +stores *every* node visit for compliance, debugging, and replay. + +Three backends ship in the framework: + +* :class:`FileAuditLog` — one JSONL file per ``(pipeline_name, run_id)``. + Best for dev / single-host audit trails. Implements :class:`QueryableAuditLog`. +* :class:`LoggingAuditLog` — stdlib ``logging``; pairs with whatever log + aggregation pipeline (Splunk-HEC, Loki, Datadog, OTel-LoggingHandler-bridge) + the host application already runs. Write-only. +* :class:`OtelAuditLog` — direct OTel logs API; attaches trace correlation + (``trace_id``/``span_id``) automatically. Best for OTel-native stacks + (Application Insights, Datadog APM, OTel-Collector). Write-only. + +For a Postgres-backed queryable audit log, see the plug-and-play template at +``examples/software_factory/audit/postgres.py`` — a ~80 LOC class implementing +:class:`QueryableAuditLog` against a caller-supplied ``psycopg.Connection``. +""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import BaseModel + +try: + from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found] + from opentelemetry._logs import SeverityNumber as _OtelSeverityNumber # type: ignore[import-not-found] + from opentelemetry._logs import get_logger as _otel_get_logger # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _otel_get_logger = None # type: ignore[assignment] + _OtelLogRecord = None # type: ignore[assignment,misc] + _OtelSeverityNumber = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + + +AuditStatus = Literal["success", "error", "paused"] + + +class AuditEntry(BaseModel): + """A single audit record — one per node visit (success, error, or pause).""" + + pipeline_name: str + run_id: str + node_id: str + sequence: int + visit: int + started_at: datetime + completed_at: datetime + latency_ms: float + status: AuditStatus + inputs_snapshot: dict[str, Any] + outputs_snapshot: dict[str, Any] + error_message: str | None = None + pause_reason: str | None = None + + +@runtime_checkable +class AuditLog(Protocol): + """Write-only audit log. Every backend implements this method. + + Implementations must be safe to call from async code (called inside the + state pipeline's executor) but the method itself is sync. + """ + + def record(self, entry: AuditEntry) -> None: ... + + +@runtime_checkable +class QueryableAuditLog(AuditLog, Protocol): + """Audit log that also supports reading back recorded entries. + + File and Postgres backends implement this. Logging and OTel backends do + not — query your observability stack (Splunk / Datadog / Loki / etc.) + instead. + """ + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: ... + + +class FileAuditLog: + """Filesystem-backed audit log. Layout:: + + //.jsonl + + Each line is a JSON-serialized :class:`AuditEntry`. Appends are atomic + at the line level (single ``write`` call per entry); concurrent writers + to the same run_id may interleave at line boundaries but never within a + single entry. + """ + + def __init__(self, root: str | Path) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + def _path(self, pipeline_name: str, run_id: str) -> Path: + return self._root / pipeline_name / f"{run_id}.jsonl" + + def record(self, entry: AuditEntry) -> None: + path = self._path(entry.pipeline_name, entry.run_id) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as f: + f.write(entry.model_dump_json() + "\n") + + def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: + path = self._path(pipeline_name, run_id) + if not path.exists(): + return [] + entries: list[AuditEntry] = [] + with path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + entries.append(AuditEntry.model_validate(json.loads(line))) + return entries + + +class LoggingAuditLog: + """Audit log backed by Python's stdlib ``logging`` module. + + Each entry is emitted as a structured log record with the full + :class:`AuditEntry` available under ``record.firefly_audit``. Pairs + naturally with any log-aggregation pipeline the host already configures: + + * OTel collector via ``opentelemetry-sdk._logs.LoggingHandler`` + * Splunk via the official Splunk handler + * Loki via promtail tailing stdout + * Datadog via ``ddtrace`` log injection + * Plain JSON-logging via ``python-json-logger`` + + No new dependency. The user wires their log handler exactly once for the + whole application and audit entries flow there automatically. + """ + + def __init__(self, logger_name: str = "firefly.audit", level: int = logging.INFO) -> None: + self._logger = logging.getLogger(logger_name) + self._level = level + + def record(self, entry: AuditEntry) -> None: + self._logger.log( + self._level, + "firefly_audit: pipeline=%s run=%s node=%s status=%s latency_ms=%.1f", + entry.pipeline_name, + entry.run_id, + entry.node_id, + entry.status, + entry.latency_ms, + extra={"firefly_audit": entry.model_dump(mode="json")}, + ) + + +class OtelAuditLog: + """Audit log backed by the OTel logs API. + + Emits each entry as a structured OTel log record with attributes + matching :class:`AuditEntry` fields. When an OTel trace is active the + record is automatically correlated with the current ``trace_id`` and + ``span_id`` — useful for tying audit history to the spans Phase 3b + emits. + + Requires ``opentelemetry-sdk`` to be installed and a ``LoggerProvider`` + configured by the host application (the framework does not install a + provider). + """ + + def __init__(self, logger_name: str = "fireflyframework_agentic.audit") -> None: + if _otel_get_logger is None: + raise ImportError( + "OtelAuditLog requires the 'opentelemetry-sdk' package " + "(with the logs API). Install: pip install opentelemetry-sdk" + ) + self._otel_logger = _otel_get_logger(logger_name) + + def record(self, entry: AuditEntry) -> None: + # The constructor's guard on _otel_get_logger guarantees the OTel + # logs imports succeeded, so the LogRecord and SeverityNumber are + # not None here. + assert _OtelLogRecord is not None and _OtelSeverityNumber is not None + attrs: dict[str, Any] = { + "firefly.pipeline": entry.pipeline_name, + "firefly.run_id": entry.run_id, + "firefly.node": entry.node_id, + "firefly.sequence": entry.sequence, + "firefly.visit": entry.visit, + "firefly.latency_ms": entry.latency_ms, + "firefly.status": entry.status, + } + if entry.error_message: + attrs["firefly.error"] = entry.error_message + if entry.pause_reason: + attrs["firefly.pause_reason"] = entry.pause_reason + + body = f"[{entry.pipeline_name}] {entry.node_id} {entry.status} ({entry.latency_ms:.0f}ms)" + severity = _OtelSeverityNumber.ERROR if entry.status == "error" else _OtelSeverityNumber.INFO + log_record = _OtelLogRecord( + timestamp=int(entry.completed_at.timestamp() * 1_000_000_000), + severity_number=severity, + severity_text=severity.name, + body=body, + attributes=attrs, + ) + self._otel_logger.emit(log_record) diff --git a/fireflyframework_agentic/pipeline/builder.py b/fireflyframework_agentic/pipeline/builder.py index 9dfaa8e0..d6a6805b 100644 --- a/fireflyframework_agentic/pipeline/builder.py +++ b/fireflyframework_agentic/pipeline/builder.py @@ -14,61 +14,217 @@ """Fluent builder API for constructing pipeline DAGs. -Usage example:: - - pipeline = ( - PipelineBuilder("idp-pipeline") - .add_node("split", splitter_step) - .add_node("classify", classifier_step) - .add_node("extract", extractor_step) - .add_edge("split", "classify") - .add_edge("classify", "extract") - .build() - ) +Two modes, both backed by the same :class:`PipelineEngine`: + +1. **Port-based** (parallel-friendly): nodes are added by string id, data + flows over edge ports:: + + pipeline = ( + PipelineBuilder("idp") + .add_node("split", splitter) + .add_node("classify", classifier) + .add_edge("split", "classify") + .build() + ) + +2. **State-based**: configure ``state=SomeModel`` and nodes become + ``async (state) -> dict | None | Pause | Send | list[Send]``. Branching + is one ``.branch(source, router)`` call; function references work as + node ids. Optional checkpointing supports resume after failure and + mid-pipeline start:: + + pipeline = ( + PipelineBuilder("agent", state=AgentState, checkpointer=FileCheckpointer("./ckpt")) + .add_node(classify) + .add_node(answer) + .add_node(escalate) + .branch(classify, route) + .build() + ) """ from __future__ import annotations import asyncio -from collections.abc import Callable +import inspect +from collections.abc import Awaitable, Callable from typing import Any +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import AuditLog +from fireflyframework_agentic.pipeline.checkpoint import Checkpointer +from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy -from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.engine import ( + EventHandler, + PipelineEngine, + PipelineEventHandler, + RouterFn, + Send, # noqa: F401 re-exported via pipeline/__init__.py +) from fireflyframework_agentic.pipeline.steps import AgentStep, CallableStep, StepExecutor +StateNodeFn = Callable[[Any], Awaitable[Any]] +"""Signature for a state-mode node: ``async (state) -> dict | None | Pause | Send | list[Send]``.""" + + +class _StateStepAdapter: + """Adapts a state-mode node fn into the :class:`StepExecutor` shape so it + can ride through :meth:`PipelineEngine._execute_node`. + + State-mode functions take ``(state)`` and return a state update (or one + of the control sentinels :class:`Pause` / :class:`Send`). The engine + calls ``step.execute(context, inputs)``; this adapter forwards + ``context.state`` to the wrapped fn and returns its value verbatim so + PipelineEngine's existing dict/Pause/Send handling fires. + """ + + def __init__(self, fn: Callable[..., Any]) -> None: + self._fn = _coerce_state_node_fn(fn) + + async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: # noqa: ARG002 + return await self._fn(context.state) + + +def _coerce_state_node_fn(fn: Callable[..., Any]) -> StateNodeFn: + """Turn user-supplied state-mode callables into the standard ``async (state) -> Any`` shape. + + Accepts: + * ``async def f(state)`` — used as-is. + * ``def f(state)`` — wrapped to run on a worker thread. + * Object with ``async run(state)`` (e.g. a FireflyAgent) — adapter calls ``.run(state)``. + """ + if inspect.iscoroutinefunction(fn): + return fn # type: ignore[return-value] + + run = getattr(fn, "run", None) + if not callable(fn) and run is not None and callable(run): + + async def _agent_wrap(state: Any) -> Any: + if inspect.iscoroutinefunction(run): + return await run(state) + return await asyncio.get_running_loop().run_in_executor(None, run, state) + + return _agent_wrap + + if callable(fn): + + async def _async_wrap(state: Any) -> Any: + return await asyncio.get_running_loop().run_in_executor(None, fn, state) + + return _async_wrap + + raise PipelineError(f"Cannot adapt {fn!r} as a state node function") + class PipelineBuilder: - """Fluent builder for constructing a :class:`DAG` and :class:`PipelineEngine`. + """Fluent builder for pipelines. Parameters: name: Human-readable name for the pipeline. + state: Optional Pydantic model class for typed shared state. + When set, the builder produces a state-aware + :class:`PipelineEngine` and nodes are expected to be + ``async (state) -> dict | None | Pause | Send | list[Send]``. + checkpointer: Optional :class:`Checkpointer` for resume. + recursion_limit: Max visits per node in cycle-aware runs. + event_handler: Optional :class:`EventHandler` (or legacy + :class:`PipelineEventHandler`). + audit_log: Optional :class:`AuditLog`. """ - def __init__(self, name: str = "pipeline") -> None: - self._dag = DAG(name=name) + def __init__( + self, + name: str = "pipeline", + *, + state: type[BaseModel] | None = None, + checkpointer: Checkpointer | None = None, + recursion_limit: int = 25, + event_handler: EventHandler | PipelineEventHandler | None = None, + audit_log: AuditLog | None = None, + ) -> None: + # State-aware pipelines may have cycles (ReAct loops, retry-with-critique). + self._dag = DAG(name=name, allow_cycles=state is not None) + self._name = name + self._state_schema = state + self._checkpointer = checkpointer + self._recursion_limit = recursion_limit + self._event_handler = event_handler + self._audit_log = audit_log self._pending_nodes: list[DAGNode] = [] self._pending_edges: list[DAGEdge] = [] + # Routers + mappings drive the cyclic scheduler's next-step pick. + self._routers: dict[str, RouterFn] = {} + self._router_mappings: dict[str, dict[str, str]] = {} def add_node( self, - node_id: str, - step: Any, + node_id_or_fn: str | Callable[..., Any], + step: Any = None, *, condition: Callable[..., bool] | None = None, retry_max: int = 0, timeout_seconds: float = 0, failure_strategy: FailureStrategy = FailureStrategy.SKIP_DOWNSTREAM, ) -> PipelineBuilder: - """Add a node to the pipeline. + """Add a node. - *step* can be: - - A :class:`StepExecutor` (AgentStep, CallableStep, etc.) - - A :class:`FireflyAgent` (auto-wrapped in :class:`AgentStep`) - - An async callable (auto-wrapped in :class:`CallableStep`) + Two signatures: - Returns *self* for chaining. + * ``add_node(fn)`` — state-based mode. ``fn`` is a callable; the node + id is taken from ``fn.__name__``. Requires the builder was constructed + with ``state=...``. + * ``add_node(node_id, step)`` — port-based mode. ``step`` is a + :class:`StepExecutor`, an agent-like, or an async callable. """ + if step is None and callable(node_id_or_fn) and not isinstance(node_id_or_fn, str): + if self._state_schema is None: + raise PipelineError( + "Function-reference add_node(fn) requires PipelineBuilder(state=...). " + "Use add_node('id', step) for port-based pipelines." + ) + fn = node_id_or_fn + node_id = getattr(fn, "__name__", None) or repr(fn) + self._pending_nodes.append( + DAGNode( + node_id=node_id, + step=_StateStepAdapter(fn), + condition=condition, + retry_max=retry_max, + timeout_seconds=timeout_seconds, + failure_strategy=failure_strategy, + ) + ) + return self + + if not isinstance(node_id_or_fn, str): + raise PipelineError("add_node(node_id, step) expects a string node id when a step is provided.") + node_id = node_id_or_fn + + if self._state_schema is not None and step is not None: + run_method = getattr(step, "run", None) + if not callable(step) and not callable(run_method): + raise PipelineError( + f"State pipeline node '{node_id}' must be a callable or expose async run(state); " + f"got {type(step).__name__}" + ) + self._pending_nodes.append( + DAGNode( + node_id=node_id, + step=_StateStepAdapter(step), + condition=condition, + retry_max=retry_max, + timeout_seconds=timeout_seconds, + failure_strategy=failure_strategy, + ) + ) + return self + + if step is None: + raise PipelineError(f"add_node('{node_id}', step=...) requires a step.") + executor = self._resolve_step(step) self._pending_nodes.append( DAGNode( @@ -84,50 +240,84 @@ def add_node( def add_edge( self, - source: str, - target: str, + source: str | Callable[..., Any], + target: str | Callable[..., Any], *, output_key: str = "output", input_key: str = "input", ) -> PipelineBuilder: """Add a directed edge from *source* to *target*. - Returns *self* for chaining. + Both endpoints may be node ids (str) or function references (in which + case ``fn.__name__`` is used). """ self._pending_edges.append( DAGEdge( - source=source, - target=target, + source=_id(source), + target=_id(target), output_key=output_key, input_key=input_key, ) ) return self - def chain(self, *node_ids: str) -> PipelineBuilder: - """Connect nodes in sequence: A -> B -> C -> ... + def chain(self, *nodes: str | Callable[..., Any]) -> PipelineBuilder: + """Connect nodes in sequence: A -> B -> C -> ...""" + ids = [_id(n) for n in nodes] + for i in range(len(ids) - 1): + self.add_edge(ids[i], ids[i + 1]) + return self - All referenced nodes must already have been added via :meth:`add_node`. - Returns *self* for chaining. + def branch( + self, + source: str | Callable[..., Any], + router: RouterFn, + mapping: dict[str, str | Callable[..., Any]] | None = None, + ) -> PipelineBuilder: + """Register a runtime router on ``source``. + + ``router`` is a synchronous ``(state) -> str | Send | list[Send]`` + callable. Behaviour: + + * If ``mapping`` is None, the router must return the **id of an + existing node** that will run next. + * If ``mapping`` is provided, the router returns an abstract label + that is looked up in ``mapping`` to find the target node id. + + State-aware pipelines only. """ - for i in range(len(node_ids) - 1): - self.add_edge(node_ids[i], node_ids[i + 1]) + if self._state_schema is None: + raise PipelineError(".branch(...) requires PipelineBuilder(state=...)") + source_id = _id(source) + if mapping is not None: + resolved_mapping = {label: _id(target) for label, target in mapping.items()} + self._router_mappings[source_id] = resolved_mapping + # Materialize each label's edge so topology stays inspectable. + for target_id in resolved_mapping.values(): + self._pending_edges.append(DAGEdge(source=source_id, target=target_id)) + self._routers[source_id] = router return self def build(self) -> PipelineEngine: - """Build the DAG, validate it, and return a :class:`PipelineEngine`. - - Raises: - PipelineError: If the graph is invalid (cycles, missing nodes). - """ + """Build the DAG and return a :class:`PipelineEngine`.""" for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: self._dag.add_edge(edge) - return PipelineEngine(self._dag) + + return PipelineEngine( + self._dag, + event_handler=self._event_handler, + checkpointer=self._checkpointer, + audit_log=self._audit_log, + state_schema=self._state_schema, + recursion_limit=self._recursion_limit, + routers=self._routers, + router_mappings=self._router_mappings, + ) def build_dag(self) -> DAG: - """Build and return just the :class:`DAG` (for inspection or custom engines).""" + """Build and return just the :class:`DAG` (for inspection).""" for node in self._pending_nodes: self._dag.add_node(node) for edge in self._pending_edges: @@ -139,12 +329,21 @@ def _resolve_step(step: Any) -> Any: """Wrap non-executor objects in the appropriate step type.""" if isinstance(step, StepExecutor): return step - # Duck-type check for agent-like objects if hasattr(step, "run") and callable(step.run): return AgentStep(step) - # Async callable - if callable(step) and asyncio.iscoroutinefunction(step): + if callable(step) and inspect.iscoroutinefunction(step): return CallableStep(step) raise TypeError( - f"Cannot resolve {type(step).__name__} as a pipeline step. Must be StepExecutor, agent-like, or async callable." + f"Cannot resolve {type(step).__name__} as a pipeline step. " + f"Must be StepExecutor, agent-like, or async callable." ) + + +def _id(ref: str | Callable[..., Any]) -> str: + """Coerce a string id or function reference into a node id string.""" + if isinstance(ref, str): + return ref + name = getattr(ref, "__name__", None) + if not name: + raise PipelineError(f"Cannot derive node id from {ref!r}") + return name diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py new file mode 100644 index 00000000..b3cf6b94 --- /dev/null +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -0,0 +1,109 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline state checkpointing for failure recovery and resumable runs. + +A :class:`Checkpointer` persists state after each successful node, keyed by +``(pipeline_name, run_id, node_id)``. On resume the engine loads the latest +checkpoint and skips nodes that already completed in that run. + +The framework ships :class:`FileCheckpointer` for dev / single-host work. +For Postgres- or Redis-backed checkpointing, see the plug-and-play templates +under ``examples/software_factory/checkpointers/``: each is a ~50 LOC class +that implements the :class:`Checkpointer` Protocol against a caller-supplied +connection. Copy whichever you need into your project and adapt. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +from pydantic import BaseModel + + +class CheckpointRecord(BaseModel): + """One saved checkpoint. + + ``paused`` and ``pause_reason`` are set when a node returns + :class:`fireflyframework_agentic.pipeline.engine.Pause`. Default + to ``False`` / ``None`` so existing records from earlier phases load + cleanly under the new schema. + """ + + pipeline_name: str + run_id: str + node_id: str + sequence: int + state: dict[str, Any] + completed_nodes: list[str] + paused: bool = False + pause_reason: str | None = None + + +@runtime_checkable +class Checkpointer(Protocol): + """Persists pipeline state after each successful node. + + Implementations must be safe to call from async code (the engine awaits + save() inside its task loop) but the methods themselves may be sync. + """ + + def save(self, record: CheckpointRecord) -> None: + """Persist a checkpoint. Overwrites if (pipeline, run_id, node_id) exists.""" + ... + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + """Return the most recent checkpoint for ``run_id`` or ``None`` if no run exists.""" + ... + + def list_runs(self, pipeline_name: str) -> list[str]: + """Return all known run IDs for ``pipeline_name``.""" + ... + + +class FileCheckpointer: + """Filesystem-backed checkpointer. Layout:: + + ///_.json + + The ``sequence`` prefix gives a natural sort order for ``load_latest``. + """ + + def __init__(self, root: str | Path) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + def save(self, record: CheckpointRecord) -> None: + run_dir = self._root / record.pipeline_name / record.run_id + run_dir.mkdir(parents=True, exist_ok=True) + path = run_dir / f"{record.sequence:06d}_{record.node_id}.json" + path.write_text(record.model_dump_json(indent=2)) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + run_dir = self._root / pipeline_name / run_id + if not run_dir.exists(): + return None + files = sorted(run_dir.glob("*.json")) + if not files: + return None + latest = files[-1] + return CheckpointRecord.model_validate(json.loads(latest.read_text())) + + def list_runs(self, pipeline_name: str) -> list[str]: + pipeline_dir = self._root / pipeline_name + if not pipeline_dir.exists(): + return [] + return sorted(d.name for d in pipeline_dir.iterdir() if d.is_dir()) diff --git a/fireflyframework_agentic/pipeline/context.py b/fireflyframework_agentic/pipeline/context.py index c79e10d9..8af0c998 100644 --- a/fireflyframework_agentic/pipeline/context.py +++ b/fireflyframework_agentic/pipeline/context.py @@ -42,11 +42,17 @@ def __init__( metadata: dict[str, Any] | None = None, correlation_id: str | None = None, memory: MemoryManager | None = None, + state: Any = None, ) -> None: self.inputs = inputs self.metadata: dict[str, Any] = metadata or {} self.correlation_id = correlation_id or uuid.uuid4().hex self.memory: MemoryManager | None = memory + # Shared typed state for state-aware pipelines. None for legacy + # port-based runs. Engine reassigns after each node's reducer-merged + # update — readers within a single in-flight task see the snapshot + # they were scheduled with. + self.state: Any = state self._results: dict[str, Any] = {} # node_id -> NodeResult def set_node_result(self, node_id: str, result: Any) -> None: diff --git a/fireflyframework_agentic/pipeline/dag.py b/fireflyframework_agentic/pipeline/dag.py index 2e4fef9b..1656cc42 100644 --- a/fireflyframework_agentic/pipeline/dag.py +++ b/fireflyframework_agentic/pipeline/dag.py @@ -15,12 +15,14 @@ """Directed Acyclic Graph (DAG) model for pipeline topology. :class:`DAG` holds :class:`DAGNode` and :class:`DAGEdge` objects, validates -acyclicity, computes topological sort, and identifies independent execution -levels for parallel scheduling. +acyclicity (unless ``allow_cycles=True``), computes topological sort, and +identifies independent execution levels for parallel scheduling. Also renders +itself as Mermaid or JSON for inspection / docs / Studio. """ from __future__ import annotations +import json from collections import defaultdict, deque from collections.abc import Callable from enum import StrEnum @@ -56,12 +58,28 @@ class DAGEdge(BaseModel): target: ID of the downstream node. output_key: Which output from the source to pass (default ``"output"``). input_key: Which input key on the target receives the value (default ``"input"``). + condition: Optional predicate ``(PipelineContext) -> bool`` that gates + edge traversal. When False (or when the callable raises), the + edge is treated as inactive: it neither delivers a signal to + the target nor contributes to scheduling readiness. If every + incoming edge of a target is inactive (and all of them are + resolved), the target is skipped — the same SKIP_DOWNSTREAM + cascade as an upstream failure. ``None`` (default) means + "always traverse". + + Conditions live on edges rather than on ``DAGNode`` because + branching is a routing decision, not a node-internal predicate. + The legacy :attr:`DAGNode.condition` field is preserved for + backward compatibility but is the wrong layer. """ + model_config = {"arbitrary_types_allowed": True} + source: str target: str output_key: str = "output" input_key: str = "input" + condition: Callable[..., bool] | None = None class DAGNode(BaseModel): @@ -98,8 +116,9 @@ class DAG: name: A human-readable name for the pipeline. """ - def __init__(self, name: str = "pipeline") -> None: + def __init__(self, name: str = "pipeline", *, allow_cycles: bool = False) -> None: self._name = name + self._allow_cycles = allow_cycles self._nodes: dict[str, DAGNode] = {} self._edges: list[DAGEdge] = [] # Adjacency and reverse adjacency for topo-sort @@ -136,9 +155,9 @@ def add_edge(self, edge: DAGEdge) -> None: self._edges.append(edge) self._adj[edge.source].append(edge.target) self._in_degree[edge.target] = self._in_degree.get(edge.target, 0) + 1 - # Incremental cycle check - if self._has_cycle(): - # Rollback + # Cycle check is skipped when the DAG was constructed with allow_cycles=True + # (state-based pipelines opt into this for ReAct-style loops). + if not self._allow_cycles and self._has_cycle(): self._edges.pop() self._adj[edge.source].pop() self._in_degree[edge.target] -= 1 @@ -147,7 +166,16 @@ def add_edge(self, edge: DAGEdge) -> None: # -- Query ------------------------------------------------------------- def topological_sort(self) -> list[str]: - """Return node IDs in topological order (Kahn's algorithm).""" + """Return node IDs in topological order (Kahn's algorithm). + + Raises :class:`PipelineError` if the DAG contains a cycle. Cyclic + graphs have no topological order; the caller should branch on + :meth:`is_cyclic` first (or use the engine's cycle-aware scheduler). + """ + if self._has_cycle(): + raise PipelineError( + "topological_sort() is not defined on cyclic graphs; use is_cyclic() to branch before calling." + ) in_deg = dict(self._in_degree) for nid in self._nodes: in_deg.setdefault(nid, 0) @@ -162,16 +190,19 @@ def topological_sort(self) -> list[str]: if in_deg[neighbour] == 0: queue.append(neighbour) - if len(order) != len(self._nodes): - raise PipelineError("DAG contains a cycle (should not reach here)") return order def execution_levels(self) -> list[list[str]]: """Group nodes into levels for parallel execution. Nodes at the same level have no inter-dependencies and can be - executed concurrently. + executed concurrently. Raises :class:`PipelineError` on cyclic + DAGs — levels are undefined when cycles exist. """ + if self._has_cycle(): + raise PipelineError( + "execution_levels() is not defined on cyclic graphs; use is_cyclic() to branch before calling." + ) in_deg = dict(self._in_degree) for nid in self._nodes: in_deg.setdefault(nid, 0) @@ -236,5 +267,68 @@ def _has_cycle(self) -> bool: queue.append(neighbour) return count != len(self._nodes) + def is_cyclic(self) -> bool: + """True if the graph contains at least one cycle.""" + return self._has_cycle() + + # -- Export ------------------------------------------------------------ + + def to_mermaid(self) -> str: + """Render the topology as a Mermaid flowchart. + + Edges with ``input_key`` other than the default ``"input"`` are + labelled with that key so port wiring is visible. Conditional + edges are prefixed ``if?`` so branches stand out. + """ + lines = ["flowchart TD"] + for node_id in self._nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + for edge in self._edges: + parts: list[str] = [] + if edge.condition is not None: + parts.append("if?") + if edge.input_key and edge.input_key != "input": + parts.append(edge.input_key) + label = " · ".join(parts) if parts else None + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + return "\n".join(lines) + + def to_json(self) -> str: + """Render the topology as a JSON document. + + Schema:: + + {"name": str, "nodes": [str], "edges": [{"source", "target", "output_key", "input_key"}]} + """ + doc = { + "name": self._name, + "nodes": list(self._nodes.keys()), + "edges": [ + { + "source": e.source, + "target": e.target, + "output_key": e.output_key, + "input_key": e.input_key, + } + for e in self._edges + ], + } + return json.dumps(doc, indent=2) + def __repr__(self) -> str: return f"DAG(name={self._name!r}, nodes={len(self._nodes)}, edges={len(self._edges)})" + + +def _mermaid_id(node_id: str) -> str: + """Sanitize a node id for use as a Mermaid identifier.""" + out = [] + for ch in node_id: + if ch.isalnum() or ch == "_": + out.append(ch) + else: + out.append("_") + sanitized = "".join(out) + if sanitized and sanitized[0].isdigit(): + sanitized = "n_" + sanitized + return sanitized or "anon" diff --git a/fireflyframework_agentic/pipeline/engine.py b/fireflyframework_agentic/pipeline/engine.py index 82010a9d..de644ad8 100644 --- a/fireflyframework_agentic/pipeline/engine.py +++ b/fireflyframework_agentic/pipeline/engine.py @@ -18,21 +18,31 @@ import asyncio import contextlib +import inspect import logging import random import time +import uuid +from collections.abc import Callable +from dataclasses import dataclass from datetime import UTC, datetime -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, cast, runtime_checkable try: from opentelemetry import trace as otel_trace except ImportError: # pragma: no cover - optional dep otel_trace = None # type: ignore[assignment] +from pydantic import BaseModel + from fireflyframework_agentic.config import get_config +from fireflyframework_agentic.exceptions import PipelineError from fireflyframework_agentic.observability.usage import default_usage_tracker +from fireflyframework_agentic.pipeline.audit import AuditEntry, AuditLog, AuditStatus +from fireflyframework_agentic.pipeline.checkpoint import Checkpointer, CheckpointRecord from fireflyframework_agentic.pipeline.context import PipelineContext -from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy +from fireflyframework_agentic.pipeline.dag import DAG, FailureStrategy, _mermaid_id +from fireflyframework_agentic.pipeline.reducers import Reducer, apply_update, discover_reducers from fireflyframework_agentic.pipeline.result import ( ExecutionTraceEntry, NodeResult, @@ -42,33 +52,177 @@ logger = logging.getLogger(__name__) +@runtime_checkable +class EventHandler(Protocol): + """Pipeline event handler used by :class:`PipelineEngine`. + + Implement any subset of these methods; missing ones are no-ops. Exceptions + raised in callbacks are swallowed by the engine so observability never + breaks business logic. + + The engine dispatches events by parameter name. If your method signature + omits a parameter — e.g. legacy implementations that don't accept + ``run_id`` or ``visit`` — the engine simply drops it from the call. + That keeps the legacy :class:`PipelineEventHandler` shape working + transparently alongside this unified one. + + Parameter conventions: + + * ``pipeline_name`` — DAG name, always present. + * ``run_id`` — opaque identifier for a single invocation; lets ops + correlate events across resumes and across multiple parallel runs. + * ``visit`` — re-entry counter on cyclic graphs and fan-out. Starts at + 1 and increments each time a node is re-entered. + * ``latency_ms`` — node wall-clock time, captured at the engine level. + * ``reason`` — human-readable string; for skips and pauses. + """ + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: ... + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: ... + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: ... + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: ... + + async def on_node_skip(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... + + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: ... + + async def on_pipeline_complete( + self, pipeline_name: str, run_id: str, success: bool, duration_ms: float + ) -> None: ... + + @runtime_checkable class PipelineEventHandler(Protocol): - """Protocol for pipeline progress callbacks. + """Legacy port-based event handler protocol. Use :class:`EventHandler`. + + Kept for backward compatibility. The engine inspects each callback's + signature and only passes parameters the method declares — so existing + implementations of this protocol continue to work unchanged. New code + should implement :class:`EventHandler` so it receives ``run_id`` and + ``visit`` too. + """ + + async def on_node_start(self, node_id: str, pipeline_name: str) -> None: ... + async def on_node_complete(self, node_id: str, pipeline_name: str, latency_ms: float) -> None: ... + async def on_node_error(self, node_id: str, pipeline_name: str, error: str) -> None: ... + async def on_node_skip(self, node_id: str, pipeline_name: str, reason: str) -> None: ... + async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: ... + + +RouterFn = Callable[[Any], "str | Send | list[Send]"] +"""Signature for a runtime branch router: receives the current state, returns +the next-step instruction — either a target node id, a single Send, or a +list of Sends for fan-out. +""" + - Implement any subset of these methods to receive notifications - when pipeline nodes start, complete, or fail. +@dataclass +class Pause: + """Human-in-the-loop sentinel returned by a node to halt the pipeline. + + A node returns ``Pause(reason="...")`` when external approval is required + before the pipeline may continue. The engine then: + + 1. Writes a checkpoint with ``paused=True`` and the reason set. + 2. Emits ``on_node_pause`` on the configured event handler. + 3. Returns a :class:`PipelineResult` with ``paused=True`` and + ``success=False`` — the run is not finished, but it did not fail + either. + + Resume after approval:: + + result = await engine.run(run_id=paused_run_id, approve_pause=True) + + The successor of the paused node runs next — the pause node itself is + not re-executed. Without ``approve_pause=True``, resuming a paused run + raises :class:`PipelineError`. """ - async def on_node_start(self, node_id: str, pipeline_name: str) -> None: - """Called when a node begins execution.""" - ... + reason: str + - async def on_node_complete(self, node_id: str, pipeline_name: str, latency_ms: float) -> None: - """Called when a node completes successfully.""" - ... +@dataclass +class Send: + """Runtime fan-out dispatch: run ``target`` with ``payload`` merged into state. - async def on_node_error(self, node_id: str, pipeline_name: str, error: str) -> None: - """Called when a node fails (after all retries exhausted).""" - ... + A node may return a single ``Send`` or ``list[Send]`` to dispatch one or + more targets concurrently. Each Send's payload is applied to a *copy* of + the current state before its target runs; the target's return is then + merged back into shared state via reducers. - async def on_node_skip(self, node_id: str, pipeline_name: str, reason: str) -> None: - """Called when a node is skipped.""" - ... + Replaces the legacy ``FanOutStep`` pattern with a first-class primitive. + """ - async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: - """Called when the entire pipeline finishes.""" - ... + target: str + payload: dict[str, Any] + + +def _resolve_node_id(ref: Any) -> str: + """Turn either a string node id or a function reference into a node id. + + Function references use ``fn.__name__``. Anything else raises + :class:`PipelineError`. + """ + if isinstance(ref, str): + return ref + name = getattr(ref, "__name__", None) + if not name: + raise PipelineError(f"Cannot derive node id from {ref!r}") + return name + + +def _is_send_payload(value: Any) -> bool: + """True when a node's return value is a single :class:`Send` or a + non-empty ``list[Send]``. Drives the runtime fan-out branch in + :meth:`PipelineEngine.run`. + """ + if isinstance(value, Send): + return True + return isinstance(value, list) and bool(value) and all(isinstance(s, Send) for s in value) + + +def _serialize_value(value: Any) -> Any: + """Best-effort conversion of arbitrary values into JSON-safe form. + + Pydantic models go through ``model_dump(mode="json")``. Primitives, + lists, and dicts pass through. Anything else falls back to ``str()`` + so the serialization layer (checkpoint, audit) doesn't blow up on + exotic objects. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_serialize_value(v) for v in value] + if hasattr(value, "model_dump"): + try: + return value.model_dump(mode="json") + except Exception: + return str(value) + return str(value) + + +def start_otel_span(name: str, **attributes: Any) -> Any: + """Start an OTel span if observability is enabled, else return ``None``. + + Module-level helper shared by :class:`PipelineEngine` and + :class:`PipelineEngine`. + """ + try: + if not get_config().observability_enabled: + return None + if otel_trace is None: + return None + return otel_trace.get_tracer("fireflyframework_agentic").start_span( + name, + attributes={f"firefly.{k}": str(v) for k, v in attributes.items()}, + ) + except Exception: # noqa: BLE001 + return None class PipelineEngine: @@ -83,41 +237,229 @@ def __init__( self, dag: DAG, *, - event_handler: PipelineEventHandler | None = None, + event_handler: EventHandler | PipelineEventHandler | None = None, + checkpointer: Checkpointer | None = None, + audit_log: AuditLog | None = None, + state_schema: type[BaseModel] | None = None, + recursion_limit: int = 25, + routers: dict[str, RouterFn] | None = None, + router_mappings: dict[str, dict[str, str]] | None = None, ) -> None: self._dag = dag self._event_handler = event_handler + self._checkpointer = checkpointer + self._audit_log = audit_log + # Optional shared-state overlay. When set, nodes returning a dict + # have it merged into the state via reducers; non-dict returns + # continue to flow through edges as port outputs. Both can coexist. + self._state_schema = state_schema + self._reducers: dict[str, Reducer] = discover_reducers(state_schema) if state_schema is not None else {} + # Max visits per node for cycle-aware runs. Matches StatePipeline's default. + self._recursion_limit = recursion_limit + # Optional runtime routers: source_id -> router(state) -> str | Send | list[Send]. + # When a source node has a router, the cyclic scheduler consults it + # instead of (or in addition to) the source's outgoing edges. With + # an accompanying mapping the router returns an abstract label that + # is looked up in the mapping. + self._routers: dict[str, RouterFn] = dict(routers or {}) + self._router_mappings: dict[str, dict[str, str]] = dict(router_mappings or {}) + # Per-method signature cache for legacy-vs-unified dispatch. + self._handler_params: dict[str, set[str]] = {} + + def to_mermaid(self) -> str: + """Render the pipeline as a Mermaid flowchart, labelling branch edges. + + When the builder called ``.branch(source, router, mapping={...})`` + the resulting edges carry abstract labels (``yes``/``no``/etc). + This view threads those labels back into the diagram so the routing + is visible alongside the topology. + """ + lines = ["flowchart TD"] + for node_id in self._dag.nodes: + lines.append(f" {_mermaid_id(node_id)}[{node_id}]") + for edge in self._dag.edges: + label: str | None = None + mapping = self._router_mappings.get(edge.source) + if mapping: + for lbl, tgt in mapping.items(): + if tgt == edge.target: + label = lbl + break + if label is None and edge.condition is not None: + label = "if?" + arrow = f"-->|{label}|" if label else "-->" + lines.append(f" {_mermaid_id(edge.source)} {arrow} {_mermaid_id(edge.target)}") + return "\n".join(lines) + + async def invoke( + self, + state: Any = None, + *, + run_id: str | None = None, + start_at: Any = None, + approve_pause: bool = False, + ) -> PipelineResult: + """Shorthand for state-aware runs: ``await pipeline.invoke(state)``. + + Mirrors the legacy ``StatePipeline.invoke`` signature so callers that + treat the first positional as the state object keep working. New code + should call :meth:`run` directly with explicit kwargs. + """ + return await self.run( + state=state, + run_id=run_id, + start_at=start_at, + approve_pause=approve_pause, + ) + + async def _dispatch(self, method_name: str, /, **kwargs: Any) -> None: + """Invoke ``event_handler.method_name`` with the subset of ``kwargs`` + the method's signature actually declares. + + Lets the engine emit events using the unified :class:`EventHandler` + convention while still supporting legacy + :class:`PipelineEventHandler` implementations whose methods don't + accept ``run_id`` or ``visit``. Missing methods and raised + exceptions are silently swallowed — observability never breaks the + pipeline. + """ + if self._event_handler is None: + return + method = getattr(self._event_handler, method_name, None) + if method is None: + return + if method_name not in self._handler_params: + try: + params = inspect.signature(method).parameters + self._handler_params[method_name] = { + name + for name, p in params.items() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + except (TypeError, ValueError): + self._handler_params[method_name] = set(kwargs) + accepted = self._handler_params[method_name] + call_kwargs = {k: v for k, v in kwargs.items() if k in accepted} + with contextlib.suppress(Exception): + await method(**call_kwargs) async def run( self, context: PipelineContext | None = None, *, inputs: Any = None, + state: BaseModel | None = None, + run_id: str | None = None, + approve_pause: bool = False, + start_at: str | Any = None, ) -> PipelineResult: """Execute the pipeline. Parameters: context: Pre-built context, or *None* to create one automatically. inputs: Initial inputs (used if *context* is not provided). + state: Optional shared state object for engines configured with + ``state_schema=``. When omitted, the engine instantiates the + schema with its defaults. + run_id: Identifier for this run. When given alone (no ``context`` + and no ``inputs``), the engine loads the latest checkpoint for + that run and resumes from after the last completed node. + Requires a checkpointer to be configured. Returns: - A :class:`PipelineResult` with all node outputs and trace. + A :class:`PipelineResult` with all node outputs, trace, ``run_id`` + (use to resume later), and ``final_state`` for state-aware runs. """ - if context is None: - context = PipelineContext(inputs=inputs) + if run_id is not None and context is None and inputs is None and state is None: + resume_run_id: str = run_id + context, pre_completed_list, sequence_start = self._load_for_resume( + resume_run_id, approve_pause=approve_pause + ) + pre_completed = set(pre_completed_list) + # Preserve original completion order so PipelineResult.completed_nodes + # reflects the run's actual sequence after resume. + all_results: dict[str, NodeResult] = {} + for nid in pre_completed_list: + nr = context.get_node_result(nid) + if isinstance(nr, NodeResult): + all_results[nid] = nr + # Synthesize trace entries for the pre-completed nodes so the + # resumed result's completed_nodes reflects the full history. + now = datetime.now(UTC) + resume_trace_seed: list[ExecutionTraceEntry] = [ + ExecutionTraceEntry(node_id=nid, started_at=now, completed_at=now, status="success") + for nid in pre_completed_list + ] + else: + if context is None: + context = PipelineContext(inputs=inputs) + # Initialize shared state if configured. + if self._state_schema is not None and context.state is None: + if state is not None: + if not isinstance(state, self._state_schema): + state = self._state_schema.model_validate(state) + context.state = state + else: + try: + context.state = self._state_schema() + except Exception as exc: + raise PipelineError( + f"state required for pipeline with state_schema {self._state_schema.__name__}: {exc}" + ) from exc + pre_completed = set() + sequence_start = 0 + all_results = {} + resume_trace_seed: list[ExecutionTraceEntry] = [] + # Mid-pipeline start: pretend everything not reachable from + # `start_at` already ran. The scheduler then starts at start_at + # because its upstream nodes appear "completed". + if start_at is not None: + start_id = _resolve_node_id(start_at) + if start_id not in self._dag.nodes: + raise PipelineError(f"start_at='{start_id}' not in DAG") + forward = {start_id} | self._dag.transitive_successors(start_id) + pre_completed = {nid for nid in self._dag.nodes if nid not in forward} - # Observability: pipeline-level span + if run_id is None: + run_id = uuid.uuid4().hex[:12] + + # Observability: pipeline-level span + start event + # State-aware runs use the "pipeline.state.*" span prefix to match + # the legacy StatePipeline taxonomy that observability dashboards + # already key on. + span_prefix = "pipeline.state" if self._state_schema is not None else "pipeline" _pipeline_span = self._start_otel_span( - f"pipeline.{self._dag.name}", + f"{span_prefix}.{self._dag.name}", pipeline=self._dag.name, + run_id=run_id, ) + await self._dispatch("on_pipeline_start", pipeline_name=self._dag.name, run_id=run_id) + + # Cycle-aware mode: a separate sequential frontier-following scheduler + # that respects ``recursion_limit``. The topological scheduler below + # cannot run cyclic graphs because execution_levels()/topological_sort() + # are undefined on them. Runtime routers also force this mode because + # they make routing a function of state, not topology. + if self._dag.is_cyclic() or self._routers: + return await self._run_cyclic( + context=context, + run_id=run_id, + all_results=all_results, + pre_completed=pre_completed, + sequence_start=sequence_start, + pipeline_span=_pipeline_span, + resume_trace_seed=resume_trace_seed, + ) # Topological levels ensure that all upstream dependencies of a node # complete before the node itself executes. Nodes within the same # level are independent and run concurrently via asyncio.gather. levels = self._dag.execution_levels() - trace_entries: list[ExecutionTraceEntry] = [] - all_results: dict[str, NodeResult] = {} + trace_entries: list[ExecutionTraceEntry] = list(resume_trace_seed) pipeline_start = time.perf_counter() failed_nodes: set[str] = set() @@ -129,23 +471,105 @@ async def run( pending: set[str] = set() for level in levels: pending.update(level) + pending -= pre_completed # resume: don't re-run nodes already completed - completed: set[str] = set() + completed: set[str] = set(pre_completed) running: dict[str, asyncio.Task[NodeResult]] = {} + inputs_by_node: dict[str, dict[str, Any]] = {} + sequence = sequence_start abort = False + pending_pause: tuple[str, str] | None = None # (node_id, reason) if Pause + + def _edge_alive(edge: Any) -> bool: + """An edge is alive if it has no condition, or its condition returns True. + + Raises in the condition itself are treated as False — fail + closed so a broken predicate kills the branch instead of + silently waking up the wrong target. + """ + if edge.condition is None: + return True + try: + return bool(edge.condition(context)) + except Exception: + return False def _ready(nid: str) -> bool: - """A node is ready when all its upstream deps have completed.""" + """A node is ready when: + (1) every incoming edge's source has completed, AND + (2) at least one of those edges is alive (or it has no edges). + + Entry nodes (no incoming) are always ready once scheduled. + """ + edges = self._dag.incoming_edges(nid) + if not edges: + return True + if not all(e.source in completed for e in edges): + return False + return any(_edge_alive(e) for e in edges) + + def _is_dead(nid: str) -> bool: + """A node is dead when every incoming edge has resolved but + none of them is alive. Cascades via the SKIP_DOWNSTREAM + mechanism so transitive successors are skipped without + being scheduled. + """ edges = self._dag.incoming_edges(nid) - return all(e.source in completed for e in edges) + if not edges: + return False + if not all(e.source in completed for e in edges): + return False + return not any(_edge_alive(e) for e in edges) + + async def _record_skip(nid: str) -> None: + """Mark a node as skipped without scheduling it. Mirrors the + handling of in-flight skips returned by ``_execute_node``. + """ + nonlocal sequence + nr = NodeResult(node_id=nid, skipped=True, error="No alive incoming edge") + all_results[nid] = nr + context.set_node_result(nid, nr) + completed.add(nid) + failed_nodes.add(nid) + failed_nodes.update(self._dag.transitive_successors(nid)) + await self._emit_node_result(nr, run_id) + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=nid, + sequence=sequence, + nr=nr, + inputs_snapshot={}, + trace_entries=trace_entries, + ) while pending or running: # Schedule all ready nodes that aren't already running. if not abort: for nid in list(pending): if _ready(nid) and nid not in running: + # Gather inputs outside _execute_node so we can stash + # them for the audit snapshot. + gathered = self._gather_inputs(nid, context) + inputs_by_node[nid] = gathered + # Emit start event here (visit=1 in the acyclic scheduler; + # the cyclic scheduler and Send fan-out emit their own). + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=nid, + visit=1, + ) task = asyncio.create_task( - self._execute_node(nid, context, trace_entries, failed_nodes), + self._execute_node( + nid, + context, + trace_entries, + failed_nodes, + inputs=gathered, + run_id=run_id, + ), ) running[nid] = task pending.discard(nid) @@ -178,7 +602,79 @@ def _ready(nid: str) -> bool: context.set_node_result(node_id, nr) # Emit event callbacks - await self._emit_node_result(nr) + await self._emit_node_result(nr, run_id) + + # Persist lifecycle: audit every executed visit; checkpoint only + # successful completions (failed nodes must re-run on resume). + sequence += 1 + paused_now = nr.success and isinstance(nr.output, Pause) + self._record_audit( + run_id=run_id, + node_id=node_id, + sequence=sequence, + nr=nr, + inputs_snapshot=inputs_by_node.get(node_id, {}), + trace_entries=trace_entries, + status_override="paused" if paused_now else None, + pause_reason=nr.output.reason if paused_now else None, + ) + # HITL: a node returned Pause(reason=...). Halt cleanly, save + # a paused checkpoint, and surface the pause in the result. + if paused_now: + pause_reason = nr.output.reason + await self._dispatch( + "on_node_pause", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + reason=pause_reason, + ) + self._save_checkpoint( + run_id=run_id, + node_id=node_id, + sequence=sequence, + context=context, + all_results=all_results, + paused=True, + pause_reason=pause_reason, + ) + pending_pause = (node_id, pause_reason) + abort = True + continue + + # Runtime fan-out: a node returned Send / list[Send]. + if nr.success and _is_send_payload(nr.output): + sends: list[Send] = cast( + "list[Send]", + list(nr.output) if isinstance(nr.output, list) else [nr.output], + ) + ok = await self._run_sends( + sends=sends, + context=context, + run_id=run_id, + all_results=all_results, + trace_entries=trace_entries, + completed=completed, + pending=pending, + ) + if not ok: + abort = True + # Successors of the worker targets are picked up by the + # normal readiness sweep on the next loop iteration. + continue + + if nr.success and not nr.skipped: + # State overlay: a dict return from the node is a state + # update; non-dict returns flow through edges as ports. + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) + self._save_checkpoint( + run_id=run_id, + node_id=node_id, + sequence=sequence, + context=context, + all_results=all_results, + ) # Handle failure strategies if not nr.success and not nr.skipped: @@ -190,6 +686,14 @@ def _ready(nid: str) -> bool: failed_nodes.add(node_id) failed_nodes.update(self._dag.transitive_successors(node_id)) + # Sweep pending for nodes whose incoming edges have resolved + # but none is alive. Mark them skipped and cascade — this is + # what makes DAGEdge.condition a usable branching primitive. + for nid in list(pending): + if _is_dead(nid): + await _record_skip(nid) + pending.discard(nid) + if abort: # Cancel remaining tasks for t in running.values(): @@ -209,13 +713,13 @@ def _ready(nid: str) -> bool: success = all(r.success or r.skipped for r in all_results.values()) # Emit pipeline complete event - if self._event_handler is not None and hasattr(self._event_handler, "on_pipeline_complete"): - with contextlib.suppress(Exception): - await self._event_handler.on_pipeline_complete( - self._dag.name, - success, - pipeline_elapsed, - ) + await self._dispatch( + "on_pipeline_complete", + pipeline_name=self._dag.name, + run_id=run_id, + success=success, + duration_ms=pipeline_elapsed, + ) # Aggregate usage across all nodes for this pipeline run usage_summary = self._aggregate_usage(context.correlation_id) @@ -223,6 +727,12 @@ def _ready(nid: str) -> bool: if _pipeline_span is not None: _pipeline_span.end() + paused_node = pending_pause[0] if pending_pause else None + pause_reason_final = pending_pause[1] if pending_pause else None + if pending_pause is not None: + # A paused run is not "successful" — it didn't finish. + success = False + return PipelineResult( pipeline_name=self._dag.name, outputs=all_results, @@ -231,6 +741,401 @@ def _ready(nid: str) -> bool: total_duration_ms=pipeline_elapsed, success=success, usage=usage_summary, + run_id=run_id, + final_state=context.state, + paused=pending_pause is not None, + paused_node=paused_node, + pause_reason=pause_reason_final, + ) + + async def _run_sends( + self, + *, + sends: list[Send], + context: PipelineContext, + run_id: str, + all_results: dict[str, NodeResult], + trace_entries: list[ExecutionTraceEntry], + completed: set[str], + pending: set[str], + visit_counts: dict[str, int] | None = None, + ) -> bool: + """Dispatch a list of :class:`Send` workers concurrently. + + Each Send's payload is applied to a copy of the current state before + its target runs. Results merge back into shared state via reducers. + Targets are added to ``completed`` and removed from ``pending`` so + the main scheduler does not re-execute them. + + Returns ``True`` on success, ``False`` if any worker failed (the + caller treats this as an abort signal). + """ + # Validate targets up front so unknown ones fail loud, not after gather(). + for send in sends: + if send.target not in self._dag.nodes: + nr = NodeResult( + node_id=send.target, + success=False, + error=f"Send dispatches to unknown target '{send.target}'", + ) + all_results[send.target] = nr + return False + + async def _run_one(send: Send, visit_n: int) -> tuple[Send, NodeResult]: + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=send.target, + visit=visit_n, + ) + # Per-worker context: own state copy with payload applied so + # workers don't race on the shared state object. + worker_context = PipelineContext(inputs=context.inputs) + if self._state_schema is not None and context.state is not None: + worker_context.state = apply_update(context.state, send.payload, self._reducers) + for nid, prev in context.results.items(): + worker_context.set_node_result(nid, prev) + nr = await self._execute_node( + send.target, + worker_context, + trace_entries, + None, + inputs={"input": send.payload}, + run_id=run_id, + ) + return send, nr + + # Per-Send visit numbers: increment per dispatched target. The + # caller may seed counts (cyclic scheduler tracks them globally); + # otherwise each fan-out batch starts at 1. + send_visits: list[int] = [] + running_counts = dict(visit_counts) if visit_counts else {} + for send in sends: + running_counts[send.target] = running_counts.get(send.target, 0) + 1 + send_visits.append(running_counts[send.target]) + if visit_counts is not None: + visit_counts.update(running_counts) + + try: + results = await asyncio.gather(*(_run_one(s, v) for s, v in zip(sends, send_visits, strict=True))) + except Exception as exc: + logger.exception("Fan-out worker crashed") + for send in sends: + if send.target not in all_results: + all_results[send.target] = NodeResult(node_id=send.target, success=False, error=str(exc)) + return False + + all_ok = True + for send, nr in results: + all_results[send.target] = nr + context.set_node_result(send.target, nr) + completed.add(send.target) + pending.discard(send.target) + await self._emit_node_result(nr, run_id) + if not nr.success: + all_ok = False + continue + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) + return all_ok + + async def _run_cyclic( + self, + *, + context: PipelineContext, + run_id: str, + all_results: dict[str, NodeResult], + pre_completed: set[str], + sequence_start: int, + pipeline_span: Any, + resume_trace_seed: list[ExecutionTraceEntry] | None = None, + ) -> PipelineResult: + """Sequential frontier-following scheduler for cyclic DAGs. + + Walks the graph one node at a time, picking the next node from each + completed node's alive outgoing edges. Visit counts are tracked per + node and bounded by ``self._recursion_limit``. Within this mode, + having multiple alive outgoing edges is currently an error — parallel + cyclic fan-out is the job of :class:`Send` in a later layer. + """ + trace_entries: list[ExecutionTraceEntry] = list(resume_trace_seed or []) + pipeline_start = time.perf_counter() + visit_counts: dict[str, int] = dict.fromkeys(pre_completed, 1) + sequence = sequence_start + + # Entry node: insertion order, matching StatePipeline. + nodes_in_order = list(self._dag.nodes) + if not nodes_in_order: + raise PipelineError("Pipeline has no nodes") + next_step: str | list[Send] | None = nodes_in_order[0] + # Skip past anything already completed during this resumed run. + while isinstance(next_step, str) and next_step in pre_completed: + next_step = self._cyclic_next(next_step, context) + + pending_pause: tuple[str, str] | None = None + try: + while next_step is not None: + # --- Fan-out (list[Send]) --------------------------------- + if isinstance(next_step, list): + sends = next_step + # Preview the per-target visit numbers to enforce the + # recursion limit; _run_sends does the real increment. + over_limit: str | None = None + preview = dict(visit_counts) + for send in sends: + preview[send.target] = preview.get(send.target, 0) + 1 + if preview[send.target] > self._recursion_limit: + over_limit = send.target + break + if over_limit is not None: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node '{over_limit}' during fan-out." + ) + logger.error(msg) + all_results[over_limit] = NodeResult(node_id=over_limit, success=False, error=msg) + break + completed_set: set[str] = set(all_results) + pending_set: set[str] = set() + ok = await self._run_sends( + sends=sends, + context=context, + run_id=run_id, + all_results=all_results, + trace_entries=trace_entries, + completed=completed_set, + pending=pending_set, + visit_counts=visit_counts, + ) + if not ok: + break + # Continue from the common successor of all workers, if any. + next_step = self._common_successor([s.target for s in sends]) + continue + + # --- Single-node step ------------------------------------- + current = next_step + visit_counts[current] = visit_counts.get(current, 0) + 1 + visit_n = visit_counts[current] + if visit_n > self._recursion_limit: + msg = ( + f"Recursion limit ({self._recursion_limit}) exceeded at node " + f"'{current}'. Raise recursion_limit= or fix the routing logic." + ) + logger.error(msg) + nr_over = NodeResult(node_id=current, success=False, error=msg) + all_results[current] = nr_over + sequence += 1 + self._record_audit( + run_id=run_id, + node_id=current, + sequence=sequence, + nr=nr_over, + inputs_snapshot={}, + trace_entries=trace_entries, + visit=visit_n, + ) + break + + gathered = self._gather_inputs(current, context) + await self._dispatch( + "on_node_start", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=current, + visit=visit_n, + ) + nr = await self._execute_node( + current, + context, + trace_entries, + None, + inputs=gathered, + run_id=run_id, + ) + all_results[current] = nr + context.set_node_result(current, nr) + await self._emit_node_result(nr, run_id) + + sequence += 1 + paused_now = nr.success and isinstance(nr.output, Pause) + self._record_audit( + run_id=run_id, + node_id=current, + sequence=sequence, + nr=nr, + inputs_snapshot=gathered, + trace_entries=trace_entries, + visit=visit_n, + status_override="paused" if paused_now else None, + pause_reason=nr.output.reason if paused_now else None, + ) + + if not nr.success and not nr.skipped: + break + + # HITL: node returned Pause — checkpoint paused and halt. + if paused_now: + pause_reason = nr.output.reason + await self._dispatch( + "on_node_pause", + pipeline_name=self._dag.name, + run_id=run_id, + node_id=current, + reason=pause_reason, + ) + self._save_checkpoint( + run_id=run_id, + node_id=current, + sequence=sequence, + context=context, + all_results=all_results, + paused=True, + pause_reason=pause_reason, + ) + pending_pause = (current, pause_reason) + break + + # Fan-out: node returned Send / list[Send]. + if nr.success and _is_send_payload(nr.output): + next_step = cast( + "list[Send]", + list(nr.output) if isinstance(nr.output, list) else [nr.output], + ) + continue + + if not nr.skipped: + if self._state_schema is not None and context.state is not None and isinstance(nr.output, dict): + context.state = apply_update(context.state, nr.output, self._reducers) + self._save_checkpoint( + run_id=run_id, + node_id=current, + sequence=sequence, + context=context, + all_results=all_results, + ) + + try: + next_step = self._cyclic_next(current, context) + except PipelineError as exc: + all_results[current] = NodeResult(node_id=current, success=False, error=str(exc), output=nr.output) + break + finally: + elapsed = (time.perf_counter() - pipeline_start) * 1000 + success = False if pending_pause is not None else all(r.success or r.skipped for r in all_results.values()) + await self._dispatch( + "on_pipeline_complete", + pipeline_name=self._dag.name, + run_id=run_id, + success=success, + duration_ms=elapsed, + ) + if pipeline_span is not None: + with contextlib.suppress(Exception): + pipeline_span.end() + + paused_node = pending_pause[0] if pending_pause else None + pause_reason_final = pending_pause[1] if pending_pause else None + return PipelineResult( + pipeline_name=self._dag.name, + outputs=all_results, + final_output=None, + execution_trace=trace_entries, + total_duration_ms=elapsed, + success=success, + usage=None, + run_id=run_id, + final_state=context.state, + paused=pending_pause is not None, + paused_node=paused_node, + pause_reason=pause_reason_final, + ) + + def _common_successor(self, node_ids: list[str]) -> str | None: + """Return the node all ``node_ids`` share as their unique successor, or None.""" + successors = [self._dag.successors(nid) for nid in node_ids] + if not successors or any(len(s) != 1 for s in successors): + return None + first = successors[0][0] + return first if all(s[0] == first for s in successors[1:]) else None + + def _cyclic_next(self, current: str, context: PipelineContext) -> str | list[Send] | None: + """Pick what runs next from ``current``. + + Priority: a registered router (.branch(...)) wins. Its return value + (str, Send, list[Send], or None) is resolved to a concrete target. + Otherwise fall back to the unique alive outgoing edge; multiple + alive edges with no router raise. + """ + # Runtime router takes precedence. + if current in self._routers: + decision = self._routers[current](context.state) + return self._resolve_router_decision(current, decision) + + def _alive(edge: Any) -> bool: + if edge.condition is None: + return True + try: + return bool(edge.condition(context)) + except Exception: + return False + + outgoing = [e for e in self._dag.edges if e.source == current] + alive = [e for e in outgoing if _alive(e)] + if not alive: + return None + if len(alive) > 1: + raise PipelineError( + f"Cyclic node '{current}' has multiple alive outgoing edges " + f"({[e.target for e in alive]}). Register a .branch(...) " + f"router or make the edge conditions mutually exclusive." + ) + return alive[0].target + + def _resolve_router_decision(self, source: str, decision: Any) -> str | list[Send] | None: + """Translate a router's return value into a concrete next-step. + + Accepts: + * a string node id (looked up in ``router_mappings`` if registered), + * a single :class:`Send` (wrapped into a one-element list), + * a ``list[Send]`` (returned as-is after validation), + * ``None`` or an empty list (terminus). + """ + if decision is None: + return None + if isinstance(decision, list): + if not decision: + return None + for s in decision: + if not isinstance(s, Send): + raise PipelineError( + f"Router for '{source}' returned a list containing non-Send element {s!r}; expected list[Send]." + ) + if s.target not in self._dag.nodes: + raise PipelineError(f"Router for '{source}' fans out to unknown target '{s.target}'") + return decision + if isinstance(decision, Send): + if decision.target not in self._dag.nodes: + raise PipelineError(f"Router for '{source}' dispatched to unknown target '{decision.target}'") + return [decision] + if isinstance(decision, str): + mapping = self._router_mappings.get(source) + if mapping is not None: + if decision not in mapping: + raise PipelineError( + f"Router for '{source}' returned label '{decision}' not in mapping {list(mapping)}" + ) + return mapping[decision] + if decision not in self._dag.nodes: + raise PipelineError( + f"Router for '{source}' returned '{decision}' " + f"which is not a registered node id; pass an explicit " + f"mapping if you want abstract labels." + ) + return decision + raise PipelineError( + f"Router for '{source}' returned unsupported type {type(decision).__name__}; " + f"expected str, Send, list[Send], or None." ) async def _execute_node( @@ -239,8 +1144,15 @@ async def _execute_node( context: PipelineContext, trace_entries: list[ExecutionTraceEntry], failed_nodes: set[str] | None = None, + *, + inputs: dict[str, Any] | None = None, + run_id: str = "", ) -> NodeResult: - """Execute a single node with retries and condition gating.""" + """Execute a single node with retries and condition gating. + + ``inputs`` may be pre-gathered by the caller so the same dict can be + used both for execution and for the audit log's inputs snapshot. + """ # Skip if an upstream node failed with SKIP_DOWNSTREAM strategy if failed_nodes and node_id in failed_nodes: logger.debug("Node '%s' skipped (upstream failure)", node_id) @@ -259,18 +1171,20 @@ async def _execute_node( logger.debug("Node '%s' skipped (condition not met)", node_id) return NodeResult(node_id=node_id, skipped=True) - # Gather inputs from upstream edges - inputs = self._gather_inputs(node_id, context) + # Gather inputs from upstream edges (unless caller already did) + if inputs is None: + inputs = self._gather_inputs(node_id, context) + node_prefix = "pipeline.state.node" if self._state_schema is not None else "pipeline.node" _node_span = self._start_otel_span( - f"pipeline.node.{node_id}", + f"{node_prefix}.{node_id}", node=node_id, + visit=1, ) - # Emit node start event - if self._event_handler is not None and hasattr(self._event_handler, "on_node_start"): - with contextlib.suppress(Exception): - await self._event_handler.on_node_start(node_id, self._dag.name) + # Note: on_node_start is now emitted by the caller (run / _run_cyclic / + # _run_sends) so the cyclic and fan-out paths can supply the right + # ``visit`` number. Emitting here would duplicate the event. max_retries = node.retry_max backoff_factor = node.backoff_factor @@ -347,19 +1261,8 @@ async def _execute_node( @staticmethod def _start_otel_span(name: str, **attributes: Any) -> Any: - """Start an OTel span if observability is enabled, else return *None*.""" - try: - if not get_config().observability_enabled: - return None - if otel_trace is None: - return None - - return otel_trace.get_tracer("fireflyframework_agentic").start_span( - name, - attributes={f"firefly.{k}": str(v) for k, v in attributes.items()}, - ) - except Exception: # noqa: BLE001 - return None + """Backwards-compatible wrapper around the module-level :func:`start_otel_span`.""" + return start_otel_span(name, **attributes) @staticmethod def _aggregate_usage(correlation_id: str) -> Any: @@ -373,31 +1276,147 @@ def _aggregate_usage(correlation_id: str) -> Any: except Exception: # noqa: BLE001 return None - async def _emit_node_result(self, nr: NodeResult) -> None: - """Emit event handler callbacks for a completed node.""" + async def _emit_node_result(self, nr: NodeResult, run_id: str) -> None: + """Emit handler callbacks for a completed node via :meth:`_dispatch`.""" if self._event_handler is None: return + common = { + "pipeline_name": self._dag.name, + "run_id": run_id, + "node_id": nr.node_id, + } + if nr.skipped: + await self._dispatch("on_node_skip", reason=nr.error or "skipped", **common) + elif nr.success: + await self._dispatch("on_node_complete", latency_ms=nr.latency_ms or 0.0, **common) + else: + await self._dispatch("on_node_error", error=nr.error or "unknown", **common) + + def _load_for_resume(self, run_id: str, *, approve_pause: bool = False) -> tuple[PipelineContext, list[str], int]: + """Rebuild context + completed-set from the latest checkpoint. + + Resuming a paused run (checkpoint.paused=True) requires + ``approve_pause=True``; otherwise a :class:`PipelineError` halts the + attempt and surfaces the pause reason. + """ + if self._checkpointer is None: + raise PipelineError("Cannot resume: pipeline has no checkpointer configured") + record = self._checkpointer.load_latest(self._dag.name, run_id) + if record is None: + raise PipelineError(f"No checkpoint found for run_id='{run_id}'") + if record.paused and not approve_pause: + raise PipelineError( + f"Run '{run_id}' is paused at node '{record.node_id}' " + f"(reason: {record.pause_reason!r}). Pass approve_pause=True to resume." + ) + context = PipelineContext(inputs=record.state.get("inputs")) + for nid, nr_dict in record.state.get("results", {}).items(): + try: + context.set_node_result(nid, NodeResult.model_validate(nr_dict)) + except Exception: + logger.warning("Could not restore NodeResult for '%s' on resume", nid) + # Restore shared state if the run was state-aware. + saved_state = record.state.get("shared_state") + if self._state_schema is not None and isinstance(saved_state, dict): + try: + context.state = self._state_schema.model_validate(saved_state) + except Exception: + logger.warning("Could not restore shared state on resume for run '%s'", run_id) + return context, list(record.completed_nodes), record.sequence + + def _save_checkpoint( + self, + *, + run_id: str, + node_id: str, + sequence: int, + context: PipelineContext, + all_results: dict[str, NodeResult], + paused: bool = False, + pause_reason: str | None = None, + ) -> None: + """Persist state after a successful node. No-op if no checkpointer. + + Only successful (non-skipped) nodes go into ``completed_nodes`` so + that resume re-attempts the failures. + """ + if self._checkpointer is None: + return + completed_successful = [nid for nid, nr in all_results.items() if nr.success and not nr.skipped] + state = { + "inputs": _serialize_value(context.inputs), + "results": {nid: all_results[nid].model_dump(mode="json") for nid in completed_successful}, + "shared_state": _serialize_value(context.state), + } try: - if nr.skipped and hasattr(self._event_handler, "on_node_skip"): - await self._event_handler.on_node_skip( - nr.node_id, - self._dag.name, - nr.error or "skipped", - ) - elif nr.success and hasattr(self._event_handler, "on_node_complete"): - await self._event_handler.on_node_complete( - nr.node_id, - self._dag.name, - nr.latency_ms or 0.0, - ) - elif not nr.success and hasattr(self._event_handler, "on_node_error"): - await self._event_handler.on_node_error( - nr.node_id, - self._dag.name, - nr.error or "unknown", + self._checkpointer.save( + CheckpointRecord( + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + state=state, + completed_nodes=completed_successful, + paused=paused, + pause_reason=pause_reason, ) - except Exception: # noqa: BLE001 - pass + ) + except Exception: + logger.exception("Checkpoint save failed for run '%s' at '%s'", run_id, node_id) + + def _record_audit( + self, + *, + run_id: str, + node_id: str, + sequence: int, + nr: NodeResult, + inputs_snapshot: dict[str, Any], + trace_entries: list[ExecutionTraceEntry], + visit: int = 1, + status_override: AuditStatus | None = None, + pause_reason: str | None = None, + ) -> None: + """Write an audit entry for a node visit. No-op if no audit log. + + Skipped nodes are not recorded — they represent work that did NOT + happen and would clutter the trail. ``status_override`` lets the + cyclic scheduler tag a Pause-returning node with ``"paused"`` + instead of the default ``"success"`` derived from ``nr.success``. + """ + if self._audit_log is None or nr.skipped: + return + # Pull timing from the trace entry the node just wrote. + started_at = completed_at = datetime.now(UTC) + for te in reversed(trace_entries): + if te.node_id == node_id: + started_at = te.started_at + completed_at = te.completed_at + break + if status_override is not None: + status: AuditStatus = status_override + else: + status = "success" if nr.success else "error" + outputs: dict[str, Any] = {"output": _serialize_value(nr.output)} if nr.success else {} + entry = AuditEntry( + pipeline_name=self._dag.name, + run_id=run_id, + node_id=node_id, + sequence=sequence, + visit=visit, + started_at=started_at, + completed_at=completed_at, + latency_ms=nr.latency_ms or 0.0, + status=status, + inputs_snapshot={k: _serialize_value(v) for k, v in inputs_snapshot.items()}, + outputs_snapshot=outputs, + error_message=nr.error if not nr.success else None, + pause_reason=pause_reason, + ) + try: + self._audit_log.record(entry) + except Exception: + logger.exception("Audit log write failed for run '%s' at '%s'", run_id, node_id) def _gather_inputs(self, node_id: str, context: PipelineContext) -> dict[str, Any]: """Collect inputs for a node from its upstream edges.""" diff --git a/fireflyframework_agentic/pipeline/reducers.py b/fireflyframework_agentic/pipeline/reducers.py new file mode 100644 index 00000000..989e1697 --- /dev/null +++ b/fireflyframework_agentic/pipeline/reducers.py @@ -0,0 +1,108 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-merge reducers for typed pipeline state. + +Reducers are functions ``(current, update) -> merged`` declared on a state +field via :class:`typing.Annotated`. The pipeline engine inspects +``typing.get_type_hints(state_schema, include_extras=True)`` for each field +and applies the relevant reducer when a node returns a partial state dict. + +Fields without an annotated reducer use :func:`replace` (last-write-wins). + +Example:: + + class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] + intent: str | None = None # uses replace by default +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Any, get_type_hints + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +Reducer = Callable[[Any, Any], Any] + + +def replace(current: Any, update: Any) -> Any: # noqa: ARG001 + """Last-write-wins: the update value replaces the current value.""" + return update + + +def append(current: Any, update: Any) -> list[Any]: + """Append a single item to a list. ``current`` is treated as ``[]`` if ``None``.""" + base = list(current) if current else [] + base.append(update) + return base + + +def extend(current: Any, update: Any) -> list[Any]: + """Concatenate two iterables. ``update`` must be iterable.""" + base = list(current) if current else [] + base.extend(update) + return base + + +def merge_dict(current: Any, update: Any) -> dict[Any, Any]: + """Shallow-merge two dicts; keys in ``update`` win.""" + base = dict(current) if current else {} + base.update(update or {}) + return base + + +def discover_reducers(state_schema: type) -> dict[str, Reducer]: + """Inspect ``Annotated[T, reducer_fn]`` annotations on the schema. + + Only ``Annotated[...]`` metadata is consulted — not generic origins like + ``list[...]`` or unions. Fields without an annotated reducer are absent + from the returned dict; callers should treat absence as :func:`replace`. + """ + out: dict[str, Reducer] = {} + try: + hints = get_type_hints(state_schema, include_extras=True) + except Exception: + return out + for field_name, hint in hints.items(): + metadata = getattr(hint, "__metadata__", None) + if not metadata: + continue + for meta in metadata: + if callable(meta): + out[field_name] = meta + break + return out + + +def apply_update(state: BaseModel, update: dict[str, Any], reducers: dict[str, Reducer]) -> BaseModel: + """Return a new state object with ``update`` merged into ``state`` via reducers. + + Keys present in ``update`` but missing from the schema are logged and + ignored — incremental schema evolution stays painless. + """ + if not update: + return state + new_values = state.model_dump() + for key, value in update.items(): + if key not in new_values: + logger.warning("State update key '%s' not in schema %s; ignored.", key, type(state).__name__) + continue + reducer = reducers.get(key, replace) + new_values[key] = reducer(new_values[key], value) + return type(state).model_validate(new_values) diff --git a/fireflyframework_agentic/pipeline/result.py b/fireflyframework_agentic/pipeline/result.py index dff1ffff..ade1ccf8 100644 --- a/fireflyframework_agentic/pipeline/result.py +++ b/fireflyframework_agentic/pipeline/result.py @@ -68,6 +68,7 @@ class PipelineResult(BaseModel): total_duration_ms: End-to-end pipeline execution time. success: Whether all nodes completed successfully. usage: Aggregated token usage across all pipeline nodes. + run_id: Identifier for this run; resume with ``engine.run(run_id=...)``. """ pipeline_name: str = "" @@ -77,7 +78,49 @@ class PipelineResult(BaseModel): total_duration_ms: float = 0.0 success: bool = True usage: UsageSummary | None = None + run_id: str = "" + # Final shared state for pipelines configured with state_schema. None + # when the engine had no state overlay. + final_state: Any = None + # HITL: a node returned :class:`Pause` and the run halted cleanly. + # Resume via ``engine.run(run_id=..., approve_pause=True)``. + paused: bool = False + paused_node: str | None = None + pause_reason: str | None = None @property def failed_nodes(self) -> list[str]: return [nid for nid, r in self.outputs.items() if not r.success and not r.skipped] + + # -- State-mode convenience aliases --------------------------------- + + @property + def state(self) -> Any: + """Final shared state. Alias of :attr:`final_state` for state-aware + pipelines built via ``PipelineBuilder(state=...)``.""" + return self.final_state + + @property + def completed_nodes(self) -> list[str]: + """IDs of every successful node visit, in completion order. + + Derived from :attr:`execution_trace` so each cyclic re-entry of a + node appears as its own entry (matches StatePipeline's semantics). + """ + return [e.node_id for e in self.execution_trace if e.status == "success"] + + @property + def failed_node(self) -> str | None: + """First node that failed, if any. ``None`` when the run succeeded.""" + for nid, r in self.outputs.items(): + if not r.success and not r.skipped: + return nid + return None + + @property + def error(self) -> str | None: + """Error message from the first failed node, if any.""" + for r in self.outputs.values(): + if not r.success and not r.skipped and r.error: + return r.error + return None diff --git a/fireflyframework_agentic/pipeline/steps.py b/fireflyframework_agentic/pipeline/steps.py index 47d12220..68e2bb29 100644 --- a/fireflyframework_agentic/pipeline/steps.py +++ b/fireflyframework_agentic/pipeline/steps.py @@ -20,6 +20,7 @@ import asyncio import logging +import warnings from collections.abc import Callable, Coroutine from typing import Any, Protocol, runtime_checkable @@ -148,6 +149,12 @@ def classify(inputs): """ def __init__(self, router: Callable[[dict[str, Any]], str]) -> None: + warnings.warn( + "BranchStep is deprecated; use PipelineBuilder(state=...).branch(source, router) " + "for first-class declarative branching.", + DeprecationWarning, + stacklevel=2, + ) self._router = router async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: @@ -162,6 +169,12 @@ class FanOutStep: """ def __init__(self, split_fn: Callable[[Any], list[Any]]) -> None: + warnings.warn( + "FanOutStep is deprecated; use PipelineBuilder(state=...) with a router returning " + "list[Send(target, payload)] for first-class runtime fan-out.", + DeprecationWarning, + stacklevel=2, + ) self._split_fn = split_fn async def execute(self, context: PipelineContext, inputs: dict[str, Any]) -> Any: diff --git a/tests/examples/software_factory/__init__.py b/tests/examples/software_factory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/examples/software_factory/test_pipeline.py b/tests/examples/software_factory/test_pipeline.py new file mode 100644 index 00000000..ff3c8d01 --- /dev/null +++ b/tests/examples/software_factory/test_pipeline.py @@ -0,0 +1,48 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""End-to-end smoke test for the software_factory example. + +Runs the pipeline against a tmp-dir FileCheckpointer; asserts that the +transient builder failure triggers a checkpointed failure on the first +invoke, and that resuming via ``run_id`` walks through the QA loop and +finishes with a stable release. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from examples.software_factory import agents +from examples.software_factory.pipeline import build_pipeline +from examples.software_factory.state import BuildState +from fireflyframework_agentic.pipeline import FileCheckpointer + + +@pytest.fixture(autouse=True) +def _reset_builder_attempts() -> None: + """The builder stub uses a process-wide counter to simulate a one-shot + transient failure. Reset it so each test sees the same starting state. + """ + agents._BUILDER_ATTEMPTS.clear() + + +async def test_factory_end_to_end(tmp_path: Path) -> None: + pipeline = build_pipeline(FileCheckpointer(tmp_path)) + + first = await pipeline.invoke(BuildState(request="payments microservice")) + assert first.success is False + assert first.failed_node == "builder" + assert first.state.iteration == 1 + assert first.state.build_status is None + + resumed = await pipeline.invoke(run_id=first.run_id) + assert resumed.success is True + assert resumed.state.release_tag == "v2026.05.28" + assert resumed.state.qa_status == "pass" + assert resumed.state.iteration == 2 # QA fail on iter 1 → loop → iter 2 passes + assert resumed.state.qa_feedback == ["missing PSD2 strong-auth flow"] diff --git a/tests/unit/pipeline/test_audit_log.py b/tests/unit/pipeline/test_audit_log.py new file mode 100644 index 00000000..6ccb1fde --- /dev/null +++ b/tests/unit/pipeline/test_audit_log.py @@ -0,0 +1,226 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Audit-log tests — File / Logging / OTel backends + pipeline wiring. + +PostgresAuditLog used to live in the framework and was tested here with mocks; +it moved to ``examples/software_factory/audit/postgres.py`` as a plug-and-play +template. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.audit as audit_module +from fireflyframework_agentic.pipeline import ( + AuditEntry, + FileAuditLog, + LoggingAuditLog, + OtelAuditLog, + Pause, + PipelineBuilder, +) + + +def _entry(**overrides: Any) -> AuditEntry: + defaults = { + "pipeline_name": "p", + "run_id": "r", + "node_id": "n", + "sequence": 1, + "visit": 1, + "started_at": datetime(2026, 5, 27, tzinfo=UTC), + "completed_at": datetime(2026, 5, 27, 0, 0, 1, tzinfo=UTC), + "latency_ms": 100.0, + "status": "success", + "inputs_snapshot": {"x": 1}, + "outputs_snapshot": {"y": 2}, + } + defaults.update(overrides) + return AuditEntry(**defaults) # type: ignore[arg-type] + + +# ============================================================================= +# FileAuditLog +# ============================================================================= + + +def test_file_audit_log_writes_jsonl_per_run(tmp_path: Path) -> None: + log = FileAuditLog(tmp_path) + log.record(_entry(sequence=1, node_id="a")) + log.record(_entry(sequence=2, node_id="b")) + + path = tmp_path / "p" / "r.jsonl" + assert path.exists() + lines = path.read_text().strip().splitlines() + assert len(lines) == 2 + assert json.loads(lines[0])["node_id"] == "a" + assert json.loads(lines[1])["node_id"] == "b" + + +def test_file_audit_log_list_entries_round_trips(tmp_path: Path) -> None: + log = FileAuditLog(tmp_path) + for seq, node in [(1, "a"), (2, "b"), (3, "c")]: + log.record(_entry(sequence=seq, node_id=node)) + entries = log.list_entries("p", "r") + assert [e.node_id for e in entries] == ["a", "b", "c"] + + +def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None: + assert FileAuditLog(tmp_path).list_entries("p", "missing") == [] + + +# ============================================================================= +# Optional-dep stubs for OTel +# ============================================================================= + + +@pytest.fixture(autouse=True) +def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: + """Stub OTel symbols so OtelAuditLog can be constructed with mocks.""" + if audit_module._otel_get_logger is None: + monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory")) + monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord")) + sev = MagicMock(name="SeverityNumber") + sev.ERROR = MagicMock(name="ERROR") + sev.ERROR.name = "ERROR" + sev.INFO = MagicMock(name="INFO") + sev.INFO.name = "INFO" + monkeypatch.setattr(audit_module, "_OtelSeverityNumber", sev) + + +# ============================================================================= +# LoggingAuditLog +# ============================================================================= + + +def test_logging_audit_emits_record_with_firefly_audit_extra( + caplog: pytest.LogCaptureFixture, +) -> None: + log = LoggingAuditLog(logger_name="firefly.test_audit") + with caplog.at_level(logging.INFO, logger="firefly.test_audit"): + log.record(_entry(node_id="z", status="success")) + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert "firefly_audit" in rec.__dict__ + assert rec.__dict__["firefly_audit"]["node_id"] == "z" + assert rec.__dict__["firefly_audit"]["status"] == "success" + + +# ============================================================================= +# OtelAuditLog +# ============================================================================= + + +def test_otel_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(audit_module, "_otel_get_logger", None) + with pytest.raises(ImportError, match="opentelemetry-sdk"): + OtelAuditLog() + + +def test_otel_audit_emits_log_record_via_otel_logger(monkeypatch: pytest.MonkeyPatch) -> None: + mock_logger = MagicMock(name="otel_logger") + factory = MagicMock(name="get_logger", return_value=mock_logger) + monkeypatch.setattr(audit_module, "_otel_get_logger", factory) + + log = OtelAuditLog() + log.record(_entry(node_id="a", status="success")) + + factory.assert_called_once() + assert mock_logger.emit.called, "OtelAuditLog should call logger.emit() with a LogRecord" + + +# ============================================================================= +# Pipeline wiring — audit fires for every node visit +# ============================================================================= + + +class S(BaseModel): + log: str = "" + + +@pytest.mark.asyncio +async def test_pipeline_writes_one_audit_entry_per_node_visit(tmp_path: Path) -> None: + async def a(state: S) -> dict: + return {"log": "a"} + + async def b(state: S) -> dict: + return {"log": "b"} + + audit = FileAuditLog(tmp_path) + pipeline = PipelineBuilder("audit-test", state=S, audit_log=audit).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-test", result.run_id) + assert [e.node_id for e in entries] == ["a", "b"] + assert all(e.status == "success" for e in entries) + + +@pytest.mark.asyncio +async def test_pipeline_audit_captures_error_status(tmp_path: Path) -> None: + async def boom(state: S) -> dict: + raise RuntimeError("nope") + + audit = FileAuditLog(tmp_path) + pipeline = PipelineBuilder("audit-err", state=S, audit_log=audit).add_node(boom).build() + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-err", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "error" + assert "nope" in (entries[0].error_message or "") + + +@pytest.mark.asyncio +async def test_pipeline_audit_captures_paused_status(tmp_path: Path) -> None: + async def gate(state: S) -> Pause: + return Pause(reason="approval please") + + audit = FileAuditLog(tmp_path / "audit") + from fireflyframework_agentic.pipeline import FileCheckpointer + + pipeline = ( + PipelineBuilder( + "audit-pause", + state=S, + audit_log=audit, + checkpointer=FileCheckpointer(tmp_path / "ckpt"), + ) + .add_node(gate) + .build() + ) + result = await pipeline.invoke(S()) + entries = audit.list_entries("audit-pause", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "paused" + assert entries[0].pause_reason == "approval please" + + +@pytest.mark.asyncio +async def test_audit_write_failure_does_not_abort_pipeline(tmp_path: Path) -> None: + """A broken audit log shouldn't kill business logic.""" + + class CrashyAudit: + def record(self, entry: AuditEntry) -> None: + raise RuntimeError("audit storage offline") + + async def step(state: S) -> dict: + return {"log": "ran"} + + pipeline = ( + PipelineBuilder("crashy", state=S, audit_log=CrashyAudit()) # type: ignore[arg-type] + .add_node(step) + .build() + ) + result = await pipeline.invoke(S()) + assert result.success is True + assert result.state.log == "ran" diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py new file mode 100644 index 00000000..51e42906 --- /dev/null +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -0,0 +1,160 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Tests for the framework's File Checkpointer. + +The Postgres and Redis backends used to live in the framework and were +exercised here with mocks; both moved to plug-and-play templates under +``examples/software_factory/checkpointers/`` (apps that need them copy the +file into their repo and test it against their own infra). +""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + CheckpointRecord, + FileCheckpointer, + PipelineBuilder, +) + +# ============================================================================= +# FileCheckpointer +# ============================================================================= + + +def test_file_checkpointer_save_and_load_latest(tmp_path) -> None: + ckpt = FileCheckpointer(tmp_path / "ckpt") + + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=1, + node_id="a", + state={"k": 1}, + completed_nodes=["a"], + ) + ) + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=2, + node_id="b", + state={"k": 2}, + completed_nodes=["a", "b"], + ) + ) + + latest = ckpt.load_latest("p", "r") + assert latest is not None + assert latest.node_id == "b" + assert latest.state == {"k": 2} + assert latest.completed_nodes == ["a", "b"] + + +def test_file_checkpointer_load_latest_unknown_run_returns_none(tmp_path) -> None: + assert FileCheckpointer(tmp_path).load_latest("p", "missing") is None + + +def test_file_checkpointer_list_runs(tmp_path) -> None: + ckpt = FileCheckpointer(tmp_path / "ckpt") + for run_id in ("rA", "rB"): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id=run_id, + sequence=1, + node_id="a", + state={}, + completed_nodes=["a"], + ) + ) + assert ckpt.list_runs("p") == ["rA", "rB"] + assert ckpt.list_runs("missing") == [] + + +def test_file_checkpointer_paused_record_round_trips(tmp_path) -> None: + ckpt = FileCheckpointer(tmp_path) + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=1, + node_id="await_approval", + state={"x": 1}, + completed_nodes=["a", "await_approval"], + paused=True, + pause_reason="waiting on human", + ) + ) + latest = ckpt.load_latest("p", "r") + assert latest is not None + assert latest.paused is True + assert latest.pause_reason == "waiting on human" + + +# ============================================================================= +# Protocol conformance — software-factory scenario against File backend +# ============================================================================= + + +class FactoryState(BaseModel): + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +def _build_factory(checkpointer): + """Construct the canonical 4-step agent pipeline that fails on first deploy.""" + state_flag = {"failed_once": False} + + async def architect(state: FactoryState) -> dict: + return {"spec": f"spec for {state.requirements}"} + + async def python_dev(state: FactoryState) -> dict: + return {"code": f"# code for {state.spec}"} + + async def deployer(state: FactoryState) -> dict: + if not state_flag["failed_once"]: + state_flag["failed_once"] = True + raise RuntimeError("blip") + return {"deploy_url": "https://app"} + + async def evaluator(state: FactoryState) -> dict: + return {"evaluation": f"PASS {state.deploy_url}"} + + pipeline = ( + PipelineBuilder("factory", state=FactoryState, checkpointer=checkpointer) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + return pipeline + + +@pytest.mark.asyncio +async def test_file_backend_supports_fail_and_resume(tmp_path) -> None: + """Deployer fails on its first call → run is checkpointed → resume completes.""" + backend = FileCheckpointer(tmp_path / "ckpt") + pipeline = _build_factory(backend) + + first = await pipeline.invoke(FactoryState(requirements="users service")) + assert not first.success + assert first.failed_node == "deployer" + assert first.completed_nodes == ["architect", "python_dev"] + + second = await pipeline.invoke(run_id=first.run_id) + assert second.success + assert second.completed_nodes == ["architect", "python_dev", "deployer", "evaluator"] + assert second.state.evaluation == "PASS https://app" diff --git a/tests/unit/pipeline/test_pipeline_engine_cycles.py b/tests/unit/pipeline/test_pipeline_engine_cycles.py new file mode 100644 index 00000000..d45b1c16 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_cycles.py @@ -0,0 +1,167 @@ +"""Layer 4 of the unification (#245): cycle-aware scheduler. + +PipelineEngine accepts ``recursion_limit=`` and, when the DAG is cyclic +(allow_cycles=True and a cycle is actually present), switches to a +sequential frontier-following scheduler. Each node visit increments a +per-node counter; exceeding ``recursion_limit`` halts the run with an +explanatory failure. + +This also patches the silent-corruption hazard in :meth:`DAG.topological_sort` +and :meth:`DAG.execution_levels` — both now raise on cyclic DAGs instead +of producing partial / wrong output. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.reducers import append + +# ---- topology-API safety --------------------------------------------------- + + +def test_topological_sort_raises_on_cyclic_dag(): + dag = DAG("cyclic", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=None)) + dag.add_node(DAGNode(node_id="b", step=None)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="a")) + with pytest.raises(PipelineError, match="cyclic"): + dag.topological_sort() + + +def test_execution_levels_raises_on_cyclic_dag(): + dag = DAG("cyclic-lev", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=None)) + dag.add_node(DAGNode(node_id="b", step=None)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="a")) + with pytest.raises(PipelineError, match="cyclic"): + dag.execution_levels() + + +# ---- cyclic execution ------------------------------------------------------ + + +class _CounterState(BaseModel): + counter: int = 0 + log: Annotated[list[str], append] = [] + + +def _bump(label: str, by: int = 1): + """Return a step that records its label and bumps counter by `by`.""" + + class _Step: + def __init__(self): + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return {"counter": ctx.state.counter + by, "log": label} + + return _Step() + + +async def test_cyclic_dag_loops_until_condition_fails(): + """Loop: incrementer -> guard. Guard's outgoing edge back to incrementer + is alive while counter < 3. Loop exits when guard's continue edge dies.""" + inc = _bump("inc", by=1) + # guard is a no-op pass-through. + + class _Pass: + calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return None + + guard = _Pass() + dag = DAG("loop", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=guard)) + dag.add_edge(DAGEdge(source="inc", target="guard")) + # Continue edge: re-enter inc while counter < 3. + dag.add_edge(DAGEdge(source="guard", target="inc", condition=lambda ctx: ctx.state.counter < 3)) + engine = PipelineEngine(dag, state_schema=_CounterState, recursion_limit=10) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.counter == 3 + assert inc.calls == 3 + # guard runs after each inc. + assert guard.calls == 3 + + +async def test_recursion_limit_halts_runaway_cycle(): + inc = _bump("inc") + + class _Pass: + async def execute(self, ctx, inputs): + return None + + dag = DAG("infinite", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=_Pass())) + dag.add_edge(DAGEdge(source="inc", target="guard")) + dag.add_edge(DAGEdge(source="guard", target="inc")) # always alive — runaway + engine = PipelineEngine(dag, state_schema=_CounterState, recursion_limit=5) + result = await engine.run(inputs="") + assert not result.success + assert ( + "recursion" in (result.outputs.get("inc") and result.outputs["inc"].error or "").lower() + or "recursion" in (result.outputs.get("guard") and result.outputs["guard"].error or "").lower() + ) + + +async def test_recursion_limit_default_is_25(): + """The engine's default recursion_limit matches StatePipeline's (25).""" + engine = PipelineEngine(DAG("x")) + assert engine._recursion_limit == 25 # noqa: SLF001 + + +async def test_audit_records_visit_per_iteration(tmp_path): + """Each iteration of a cycle gets its own audit entry with incrementing visit.""" + from fireflyframework_agentic.pipeline.audit import FileAuditLog + + inc = _bump("inc") + + class _Pass: + async def execute(self, ctx, inputs): + return None + + dag = DAG("audited-loop", allow_cycles=True) + dag.add_node(DAGNode(node_id="inc", step=inc)) + dag.add_node(DAGNode(node_id="guard", step=_Pass())) + dag.add_edge(DAGEdge(source="inc", target="guard")) + dag.add_edge(DAGEdge(source="guard", target="inc", condition=lambda ctx: ctx.state.counter < 2)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, state_schema=_CounterState, audit_log=al, recursion_limit=10) + result = await engine.run(inputs="") + assert result.success + entries = al.list_entries("audited-loop", result.run_id) + inc_visits = sorted([e.visit for e in entries if e.node_id == "inc"]) + assert inc_visits == [1, 2] + + +# ---- acyclic still works --------------------------------------------------- + + +async def test_acyclic_dag_with_allow_cycles_true_runs_normally(): + """allow_cycles=True doesn't force cyclic mode if there are no cycles.""" + a = _bump("a") + b = _bump("b") + dag = DAG("ac", allow_cycles=True) + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_CounterState) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.counter == 2 + assert a.calls == 1 + assert b.calls == 1 diff --git a/tests/unit/pipeline/test_pipeline_engine_edge_condition.py b/tests/unit/pipeline/test_pipeline_engine_edge_condition.py new file mode 100644 index 00000000..39fe7878 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_edge_condition.py @@ -0,0 +1,219 @@ +"""Layer 2 of the unification (#245): branching as DAGEdge.condition. + +DAGEdge now carries an optional predicate that gates traversal. When a +source completes, each outgoing edge's condition is evaluated against the +current PipelineContext. Targets whose incoming edges all evaluate False +are marked skipped (no execution, no result, transitive downstream cascade +via SKIP_DOWNSTREAM). + +This unifies the legacy ``BranchStep`` + ``DAGNode.condition`` machinery +into a single property of the DAG. ``.branch(source, router, mapping)`` — +which today lives in StatePipeline — will be reframed as sugar that adds +conditional edges in a later layer. +""" + +from __future__ import annotations + +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _Echo: + """Step that returns its input verbatim, tagged with a node prefix.""" + + def __init__(self, prefix: str = "") -> None: + self.prefix = prefix + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return f"{self.prefix}{inputs.get('input', '')}" + + +# ---- baseline: edge without condition is unchanged ------------------------ + + +async def test_edge_without_condition_is_unchanged(): + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("plain") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert a.calls == 1 and b.calls == 1 + + +# ---- single conditional edge ---------------------------------------------- + + +async def test_true_condition_lets_target_run(): + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("true-cond") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b", condition=lambda ctx: True)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert b.calls == 1 + + +async def test_false_condition_skips_target(): + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("false-cond") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b", condition=lambda ctx: False)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success # the pipeline as a whole still succeeds + assert a.calls == 1 + assert b.calls == 0 + assert result.outputs["b"].skipped + + +# ---- branching: one source, two conditional targets ----------------------- + + +async def test_branch_chooses_one_of_two_targets(): + """Classic if/else branching via two conditional edges from the same source.""" + a = _Echo("a:") + yes, no = _Echo("yes:"), _Echo("no:") + dag = DAG("if-else") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="yes", step=yes)) + dag.add_node(DAGNode(node_id="no", step=no)) + dag.add_edge( + DAGEdge( + source="a", + target="yes", + condition=lambda ctx: "good" in str(ctx.get_node_result("a").output), + ) + ) + dag.add_edge( + DAGEdge( + source="a", + target="no", + condition=lambda ctx: "good" not in str(ctx.get_node_result("a").output), + ) + ) + result = await PipelineEngine(dag).run(inputs="good run") + assert result.success + assert yes.calls == 1 + assert no.calls == 0 + assert result.outputs["no"].skipped + + +# ---- cascading skip -------------------------------------------------------- + + +async def test_skipped_target_cascades_to_its_downstream(): + a, b, c = _Echo("a:"), _Echo("b:"), _Echo("c:") + dag = DAG("cascade") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="b", condition=lambda ctx: False)) + dag.add_edge(DAGEdge(source="b", target="c")) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert a.calls == 1 + assert b.calls == 0 + assert c.calls == 0 + assert result.outputs["b"].skipped + assert result.outputs["c"].skipped + + +# ---- fan-in with mixed conditions: OR semantics --------------------------- + + +async def test_fanin_runs_if_any_incoming_edge_alive(): + """Two upstreams, one edge False, one edge True → target runs.""" + a, b, c = _Echo("a:"), _Echo("b:"), _Echo("c:") + dag = DAG("fanin") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="c", condition=lambda ctx: False)) + dag.add_edge(DAGEdge(source="b", target="c", condition=lambda ctx: True)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert c.calls == 1 + assert not result.outputs["c"].skipped + + +async def test_fanin_skipped_when_all_incoming_edges_dead(): + a, b, c = _Echo("a:"), _Echo("b:"), _Echo("c:") + dag = DAG("fanin-dead") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="c", condition=lambda ctx: False)) + dag.add_edge(DAGEdge(source="b", target="c", condition=lambda ctx: False)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert c.calls == 0 + assert result.outputs["c"].skipped + + +# ---- condition can read upstream output ----------------------------------- + + +async def test_condition_sees_completed_upstream_output(): + """The condition gets a PipelineContext and can inspect prior node results.""" + + class _Number: + async def execute(self, ctx, inputs): + return 42 + + a = _Number() + b = _Echo("big:") + dag = DAG("cond-reads-upstream") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge( + DAGEdge( + source="a", + target="b", + condition=lambda ctx: ctx.get_node_result("a").output > 10, + ) + ) + result = await PipelineEngine(dag).run(inputs="") + assert result.success + assert b.calls == 1 + + +# ---- condition exception is treated as False ------------------------------ + + +async def test_raising_condition_treated_as_false(): + """If the condition itself raises, the edge is dead — fail closed.""" + a, b = _Echo("a:"), _Echo("b:") + dag = DAG("raising-cond") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + + def raiser(ctx): + raise RuntimeError("oops") + + dag.add_edge(DAGEdge(source="a", target="b", condition=raiser)) + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert b.calls == 0 + assert result.outputs["b"].skipped + + +# ---- to_mermaid renders conditional edges --------------------------------- + + +def test_mermaid_marks_conditional_edges(): + dag = DAG("viz") + dag.add_node(DAGNode(node_id="a", step=_Echo())) + dag.add_node(DAGNode(node_id="b", step=_Echo())) + dag.add_node(DAGNode(node_id="c", step=_Echo())) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="a", target="c", condition=lambda ctx: True)) + mermaid = dag.to_mermaid() + # Unconditional edge: plain arrow. + assert "a --> b" in mermaid + # Conditional edge: labelled distinctively (we use "if?"). + assert "a -->|if?| c" in mermaid or "a -.->|if?| c" in mermaid diff --git a/tests/unit/pipeline/test_pipeline_engine_event_dispatch.py b/tests/unit/pipeline/test_pipeline_engine_event_dispatch.py new file mode 100644 index 00000000..86125515 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_event_dispatch.py @@ -0,0 +1,163 @@ +"""Layer 1B of the unification (#245): unified EventHandler protocol. + +PipelineEngine now uses a single :class:`EventHandler` protocol that +includes ``run_id`` and ``visit`` on every callback, plus +``on_pipeline_start`` and ``on_node_pause``. Dispatch is by parameter name +via signature inspection, so legacy :class:`PipelineEventHandler` +implementations (port-based, run_id-unaware) still receive the events +they declared. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _Echo: + async def execute(self, ctx, inputs): + return inputs.get("input", "") + + +def _two_node_dag() -> DAG: + dag = DAG("dispatch") + dag.add_node(DAGNode(node_id="a", step=_Echo())) + dag.add_node(DAGNode(node_id="b", step=_Echo())) + dag.add_edge(DAGEdge(source="a", target="b")) + return dag + + +# ---- Unified (rich) handler ------------------------------------------------ + + +@dataclass +class _UnifiedHandler: + """Implements the full EventHandler shape (run_id + visit aware).""" + + started: list[tuple[str, str]] = field(default_factory=list) # (pipeline, run_id) + node_starts: list[tuple[str, int]] = field(default_factory=list) # (node_id, visit) + node_completes: list[str] = field(default_factory=list) + completed: list[tuple[str, bool]] = field(default_factory=list) # (run_id, success) + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.started.append((pipeline_name, run_id)) + + async def on_node_start(self, pipeline_name, run_id, node_id, visit): + self.node_starts.append((node_id, visit)) + + async def on_node_complete(self, pipeline_name, run_id, node_id, latency_ms): + self.node_completes.append(node_id) + + async def on_pipeline_complete(self, pipeline_name, run_id, success, duration_ms): + self.completed.append((run_id, success)) + + +async def test_unified_handler_receives_pipeline_start_with_run_id(): + handler = _UnifiedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert handler.started == [("dispatch", result.run_id)] + + +async def test_unified_handler_receives_visit_on_node_start(): + handler = _UnifiedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + await engine.run(inputs="x") + # Port-based pipelines always emit visit=1 until cycles arrive in a later layer. + assert handler.node_starts == [("a", 1), ("b", 1)] + + +async def test_unified_handler_receives_pipeline_complete_with_run_id(): + handler = _UnifiedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert handler.completed == [(result.run_id, True)] + + +# ---- Legacy PipelineEventHandler (run_id-unaware) -------------------------- + + +@dataclass +class _LegacyHandler: + """Implements the legacy PipelineEventHandler signatures (no run_id).""" + + starts: list[str] = field(default_factory=list) + completes: list[str] = field(default_factory=list) + pipeline_done: list[tuple[str, bool]] = field(default_factory=list) + + async def on_node_start(self, node_id: str, pipeline_name: str) -> None: + self.starts.append(node_id) + + async def on_node_complete(self, node_id: str, pipeline_name: str, latency_ms: float) -> None: + self.completes.append(node_id) + + async def on_pipeline_complete(self, pipeline_name: str, success: bool, duration_ms: float) -> None: + self.pipeline_done.append((pipeline_name, success)) + + +async def test_legacy_handler_still_works_without_run_id(): + """The engine drops run_id/visit when the handler doesn't declare them.""" + handler = _LegacyHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert result.success + assert handler.starts == ["a", "b"] + assert handler.completes == ["a", "b"] + assert handler.pipeline_done == [("dispatch", True)] + + +async def test_legacy_handler_without_on_pipeline_start_is_fine(): + """Legacy handlers don't have on_pipeline_start; engine just skips it.""" + handler = _LegacyHandler() + assert not hasattr(handler, "on_pipeline_start") + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + # Should not raise — missing methods are no-ops. + await engine.run(inputs="x") + + +# ---- Mixed handler (some legacy methods, some new) ------------------------ + + +@dataclass +class _MixedHandler: + """Some methods unified-signature, some legacy. Both should fire.""" + + pipeline_starts_with_run_id: list[str] = field(default_factory=list) + legacy_node_starts: list[str] = field(default_factory=list) + + # New (rich) signature + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.pipeline_starts_with_run_id.append(run_id) + + # Legacy signature — engine should still call it without run_id/visit + async def on_node_start(self, node_id: str, pipeline_name: str) -> None: + self.legacy_node_starts.append(node_id) + + +async def test_mixed_handler_dispatches_correctly(): + handler = _MixedHandler() + engine = PipelineEngine(_two_node_dag(), event_handler=handler) + result = await engine.run(inputs="x") + assert handler.pipeline_starts_with_run_id == [result.run_id] + assert handler.legacy_node_starts == ["a", "b"] + + +# ---- Exception safety ------------------------------------------------------ + + +async def test_handler_exception_does_not_break_pipeline(): + class _Broken: + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + raise RuntimeError("boom in start") + + async def on_node_start(self, pipeline_name, run_id, node_id, visit): + raise RuntimeError("boom in node start") + + async def on_pipeline_complete(self, pipeline_name, run_id, success, duration_ms): + raise RuntimeError("boom in complete") + + engine = PipelineEngine(_two_node_dag(), event_handler=_Broken()) + result = await engine.run(inputs="x") + assert result.success diff --git a/tests/unit/pipeline/test_pipeline_engine_lifecycle.py b/tests/unit/pipeline/test_pipeline_engine_lifecycle.py new file mode 100644 index 00000000..157dc431 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_lifecycle.py @@ -0,0 +1,252 @@ +"""Layer 1 of the unification (#245): PipelineEngine gains checkpoint, audit, and resume. + +These tests pin the contract for port-based pipelines to opt into the same +checkpointing + audit machinery that StatePipeline already has — without +becoming state-based. Resume via ``run(run_id=...)`` is the headline feature. +""" + +from __future__ import annotations + +import pytest + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.audit import FileAuditLog +from fireflyframework_agentic.pipeline.checkpoint import FileCheckpointer +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _CountingStep: + """Step that records how many times its .execute() was called.""" + + def __init__(self, prefix: str = "") -> None: + self._prefix = prefix + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + val = inputs.get("input", "") + return f"{self._prefix}{val}" + + +class _FailOnceStep: + """Step that raises on the first call and succeeds afterward.""" + + def __init__(self) -> None: + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + if self.calls == 1: + raise RuntimeError("flake") + return "b:done" + + +def _chain_dag(*node_ids: str) -> tuple[DAG, dict[str, _CountingStep]]: + dag = DAG("chain") + steps: dict[str, _CountingStep] = {} + for nid in node_ids: + step = _CountingStep(f"{nid}:") + steps[nid] = step + dag.add_node(DAGNode(node_id=nid, step=step)) + for i in range(len(node_ids) - 1): + dag.add_edge(DAGEdge(source=node_ids[i], target=node_ids[i + 1])) + return dag, steps + + +# ---- run_id ----------------------------------------------------------------- + + +async def test_run_returns_non_empty_run_id(): + dag, _ = _chain_dag("a", "b") + engine = PipelineEngine(dag) + result = await engine.run(inputs="x") + assert result.success + assert result.run_id # non-empty + + +async def test_explicit_run_id_is_preserved(): + dag, _ = _chain_dag("a") + engine = PipelineEngine(dag) + result = await engine.run(inputs="x", run_id="manual-id") + assert result.run_id == "manual-id" + + +# ---- checkpointing --------------------------------------------------------- + + +async def test_checkpoint_written_per_successful_node(tmp_path): + dag, _ = _chain_dag("a", "b", "c") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + result = await engine.run(inputs="x") + assert result.success + files = sorted((tmp_path / "chain" / result.run_id).glob("*.json")) + assert len(files) == 3 + # Sequence prefix preserves completion order. + assert files[0].name.endswith("_a.json") + assert files[1].name.endswith("_b.json") + assert files[2].name.endswith("_c.json") + + +async def test_checkpoint_omitted_when_no_checkpointer(tmp_path): + dag, _ = _chain_dag("a", "b") + engine = PipelineEngine(dag) # no checkpointer + result = await engine.run(inputs="x") + assert result.success + # tmp_path should still be empty since no checkpointer was wired. + assert not any(tmp_path.iterdir()) + + +async def test_checkpoint_records_completed_nodes(tmp_path): + dag, _ = _chain_dag("a", "b") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + result = await engine.run(inputs="x") + record = cp.load_latest("chain", result.run_id) + assert record is not None + assert record.completed_nodes == ["a", "b"] + assert record.node_id == "b" + + +# ---- resume ---------------------------------------------------------------- + + +async def test_resume_completed_run_is_a_noop(tmp_path): + dag, steps = _chain_dag("a", "b", "c") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + result = await engine.run(inputs="x") + assert all(s.calls == 1 for s in steps.values()) + # All nodes are completed; resume should not re-execute anything. + result2 = await engine.run(run_id=result.run_id) + assert result2.success + assert all(s.calls == 1 for s in steps.values()) + + +async def test_resume_after_failure_skips_completed_and_finishes(tmp_path): + a_step = _CountingStep("a:") + b_step = _FailOnceStep() + c_step = _CountingStep("c:") + dag = DAG("recoverable") + dag.add_node(DAGNode(node_id="a", step=a_step)) + dag.add_node(DAGNode(node_id="b", step=b_step, failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_node(DAGNode(node_id="c", step=c_step)) + dag.add_edge(DAGEdge(source="a", target="b")) + dag.add_edge(DAGEdge(source="b", target="c")) + + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + + result1 = await engine.run(inputs="x") + assert not result1.success + assert a_step.calls == 1 + assert b_step.calls == 1 + assert c_step.calls == 0 + + result2 = await engine.run(run_id=result1.run_id) + assert result2.success + # 'a' was already done — must not be re-executed on resume. + assert a_step.calls == 1 + # 'b' is re-executed (its second attempt succeeds via _FailOnceStep). + assert b_step.calls == 2 + # 'c' runs once, after b succeeds on resume. + assert c_step.calls == 1 + + +async def test_resume_without_checkpointer_raises(): + dag, _ = _chain_dag("a") + engine = PipelineEngine(dag) + with pytest.raises(PipelineError, match="checkpoint"): + await engine.run(run_id="anything") + + +async def test_resume_unknown_run_id_raises(tmp_path): + dag, _ = _chain_dag("a") + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp) + with pytest.raises(PipelineError, match="No checkpoint"): + await engine.run(run_id="missing") + + +# ---- audit log ------------------------------------------------------------- + + +async def test_audit_log_writes_entry_per_node(tmp_path): + dag, _ = _chain_dag("a", "b") + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al) + result = await engine.run(inputs="x") + entries = al.list_entries("chain", result.run_id) + assert len(entries) == 2 + assert [e.node_id for e in entries] == ["a", "b"] + assert all(e.status == "success" for e in entries) + assert all(e.visit == 1 for e in entries) + assert all(e.latency_ms >= 0 for e in entries) + + +async def test_audit_log_captures_failure(tmp_path): + class _Bad: + async def execute(self, ctx, inputs): + raise RuntimeError("boom") + + dag = DAG("fail") + dag.add_node(DAGNode(node_id="bad", step=_Bad(), failure_strategy=FailureStrategy.FAIL_PIPELINE)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al) + result = await engine.run(inputs="x") + assert not result.success + entries = al.list_entries("fail", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "error" + assert entries[0].error_message is not None + assert "boom" in entries[0].error_message + + +async def test_audit_skipped_nodes_not_recorded_as_success(tmp_path): + """Skipped nodes (condition gate) shouldn't show up as successful audits.""" + step = _CountingStep("a:") + dag = DAG("skipping") + dag.add_node(DAGNode(node_id="skipped", step=step, condition=lambda ctx: False)) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al) + await engine.run(inputs="x") + entries = al.list_entries("skipping", _last_run_id(al, "skipping")) + # Skipped nodes are not work that happened — leave them out. + assert entries == [] or all(e.status != "success" for e in entries) + + +def _last_run_id(al: FileAuditLog, pipeline: str) -> str: + pipeline_dir = al._root / pipeline + if not pipeline_dir.exists(): + return "" + files = list(pipeline_dir.glob("*.jsonl")) + return files[0].stem if files else "" + + +# ---- combined checkpoint + audit + resume ---------------------------------- + + +async def test_full_stack_resume_with_audit(tmp_path): + cp_dir = tmp_path / "cp" + al_dir = tmp_path / "al" + a_step = _CountingStep("a:") + b_step = _FailOnceStep() + dag = DAG("full") + dag.add_node(DAGNode(node_id="a", step=a_step)) + dag.add_node(DAGNode(node_id="b", step=b_step, failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b")) + cp = FileCheckpointer(cp_dir) + al = FileAuditLog(al_dir) + engine = PipelineEngine(dag, checkpointer=cp, audit_log=al) + + r1 = await engine.run(inputs="x") + assert not r1.success + r2 = await engine.run(run_id=r1.run_id) + assert r2.success + entries = al.list_entries("full", r1.run_id) + # Three entries: a-success, b-error (first attempt), b-success (resume). + assert len(entries) == 3 + assert entries[0].node_id == "a" and entries[0].status == "success" + assert entries[1].node_id == "b" and entries[1].status == "error" + assert entries[2].node_id == "b" and entries[2].status == "success" diff --git a/tests/unit/pipeline/test_pipeline_engine_pause_send.py b/tests/unit/pipeline/test_pipeline_engine_pause_send.py new file mode 100644 index 00000000..32e56c85 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_pause_send.py @@ -0,0 +1,201 @@ +"""Layer 5 of the unification (#245): Pause and Send in PipelineEngine. + +PipelineEngine recognizes the same control sentinels that StatePipeline +uses today: + +- A node returning :class:`Pause` halts the pipeline cleanly; the run + resumes with ``engine.run(run_id=..., approve_pause=True)``. +- A node returning :class:`Send` or ``list[Send]`` triggers a parallel + fan-out where each Send's target runs concurrently with the supplied + payload merged into a per-worker state copy. Reducers merge worker + outputs back into shared state. + +Both sentinels work in the acyclic and cyclic schedulers. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.checkpoint import FileCheckpointer +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode + +# Pause and Send live in pipeline.engine now (moved from state_pipeline in +# this layer); the public re-export from pipeline/__init__.py is unchanged. +from fireflyframework_agentic.pipeline.engine import Pause, PipelineEngine, Send +from fireflyframework_agentic.pipeline.reducers import extend + +# ---- shared state --------------------------------------------------------- + + +class _LoopState(BaseModel): + items: Annotated[list[str], extend] = [] + approved: bool = False + deployed_to: str = "" + + +# ---- Pause ---------------------------------------------------------------- + + +def _step_pause(reason: str = "human gate"): + class _Step: + async def execute(self, ctx, inputs): + return Pause(reason=reason) + + return _Step() + + +def _step_record(label: str): + class _Step: + async def execute(self, ctx, inputs): + return {"items": [label]} + + return _Step() + + +async def test_pause_halts_pipeline_and_records_state(tmp_path): + """A node returning Pause halts: result.paused=True, success=False, checkpoint with paused=True.""" + build = _step_record("build") + gate = _step_pause("awaiting approval") + deploy = _step_record("deploy") + dag = DAG("hitl") + dag.add_node(DAGNode(node_id="build", step=build)) + dag.add_node(DAGNode(node_id="gate", step=gate)) + dag.add_node(DAGNode(node_id="deploy", step=deploy)) + dag.add_edge(DAGEdge(source="build", target="gate")) + dag.add_edge(DAGEdge(source="gate", target="deploy")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, state_schema=_LoopState, checkpointer=cp) + result = await engine.run(inputs="") + assert result.paused is True + assert result.paused_node == "gate" + assert result.pause_reason == "awaiting approval" + assert not result.success + # State so far: only 'build' contributed. + assert result.final_state.items == ["build"] + # Checkpoint reflects the paused state. + record = cp.load_latest("hitl", result.run_id) + assert record is not None + assert record.paused is True + assert record.pause_reason == "awaiting approval" + + +async def test_resume_paused_run_requires_approve_pause(tmp_path): + build = _step_record("build") + gate = _step_pause("awaiting approval") + dag = DAG("needs-approve") + dag.add_node(DAGNode(node_id="build", step=build)) + dag.add_node(DAGNode(node_id="gate", step=gate)) + dag.add_edge(DAGEdge(source="build", target="gate")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, state_schema=_LoopState, checkpointer=cp) + paused = await engine.run(inputs="") + assert paused.paused is True + # Without approve_pause: error. + with pytest.raises(PipelineError, match="paused"): + await engine.run(run_id=paused.run_id) + + +async def test_resume_with_approve_pause_continues_from_successor(tmp_path): + build = _step_record("build") + gate = _step_pause("awaiting approval") + deploy = _step_record("deploy") + dag = DAG("approved") + dag.add_node(DAGNode(node_id="build", step=build)) + dag.add_node(DAGNode(node_id="gate", step=gate)) + dag.add_node(DAGNode(node_id="deploy", step=deploy)) + dag.add_edge(DAGEdge(source="build", target="gate")) + dag.add_edge(DAGEdge(source="gate", target="deploy")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, state_schema=_LoopState, checkpointer=cp) + paused = await engine.run(inputs="") + assert paused.paused is True + resumed = await engine.run(run_id=paused.run_id, approve_pause=True) + assert resumed.success + # 'gate' is NOT re-executed; only 'deploy' adds to items. + assert resumed.final_state.items == ["build", "deploy"] + + +# ---- Send ------------------------------------------------------------------ + + +def _step_emit_sends(targets: list[str]): + class _Step: + async def execute(self, ctx, inputs): + return [Send(target=t, payload={"items": [f"sent-{t}"]}) for t in targets] + + return _Step() + + +def _step_consume_payload(suffix: str): + """A worker that turns its inbound items into a state update.""" + + class _Step: + async def execute(self, ctx, inputs): + seen = list(ctx.state.items) + return {"items": [f"{s}+{suffix}" for s in seen]} + + return _Step() + + +async def test_send_dispatches_workers_concurrently(): + """One node returns list[Send]; targets run concurrently and their outputs merge.""" + planner = _step_emit_sends(["a", "b"]) + worker_a = _step_consume_payload("A") + worker_b = _step_consume_payload("B") + dag = DAG("fanout") + dag.add_node(DAGNode(node_id="planner", step=planner)) + dag.add_node(DAGNode(node_id="a", step=worker_a)) + dag.add_node(DAGNode(node_id="b", step=worker_b)) + dag.add_edge(DAGEdge(source="planner", target="a")) + dag.add_edge(DAGEdge(source="planner", target="b")) + engine = PipelineEngine(dag, state_schema=_LoopState) + result = await engine.run(inputs="") + assert result.success + # Each worker sees its own payload (a sees "sent-a", b sees "sent-b"). + assert sorted(result.final_state.items) == sorted(["sent-a+A", "sent-b+B"]) + + +async def test_single_send_is_treated_as_list_of_one(): + # A planner step that emits one Send directly (not wrapped in a list). + class _Solo: + async def execute(self, ctx, inputs): + return Send(target="a", payload={"items": ["just-a"]}) + + worker_a = _step_consume_payload("X") + dag = DAG("solo") + dag.add_node(DAGNode(node_id="planner", step=_Solo())) + dag.add_node(DAGNode(node_id="a", step=worker_a)) + dag.add_edge(DAGEdge(source="planner", target="a")) + engine = PipelineEngine(dag, state_schema=_LoopState) + result = await engine.run(inputs="") + assert result.success + assert result.final_state.items == ["just-a+X"] + + +async def test_send_to_unknown_target_raises(): + class _Bad: + async def execute(self, ctx, inputs): + return [Send(target="ghost", payload={})] + + dag = DAG("unknown-send") + dag.add_node(DAGNode(node_id="planner", step=_Bad())) + engine = PipelineEngine(dag, state_schema=_LoopState) + result = await engine.run(inputs="") + # The fan-out fails; the pipeline reports failure. + assert not result.success + + +# ---- Pause exports --------------------------------------------------------- + + +def test_pause_and_send_reexported_from_pipeline_package(): + from fireflyframework_agentic.pipeline import Pause as PausePkg + from fireflyframework_agentic.pipeline import Send as SendPkg + + assert PausePkg is Pause + assert SendPkg is Send diff --git a/tests/unit/pipeline/test_pipeline_engine_start_at.py b/tests/unit/pipeline/test_pipeline_engine_start_at.py new file mode 100644 index 00000000..81003ec3 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_start_at.py @@ -0,0 +1,123 @@ +"""Layer 6 of the unification (#245): start_at kwarg for mid-pipeline entry. + +PipelineEngine.run() accepts ``start_at=`` (string node id or callable +reference). Execution begins at the named node; everything not reachable +from it is treated as pre-completed and skipped. This is the unified +equivalent of StatePipeline.invoke(state=..., start_at=...). +""" + +from __future__ import annotations + +import pytest + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine + + +class _Counting: + def __init__(self, name: str): + self.name = name + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + return self.name + + +def _chain(*ids: str) -> tuple[DAG, dict[str, _Counting]]: + dag = DAG("chain") + steps: dict[str, _Counting] = {} + for nid in ids: + s = _Counting(nid) + steps[nid] = s + dag.add_node(DAGNode(node_id=nid, step=s)) + for i in range(len(ids) - 1): + dag.add_edge(DAGEdge(source=ids[i], target=ids[i + 1])) + return dag, steps + + +# ---- baseline ------------------------------------------------------------- + + +async def test_no_start_at_runs_every_node(): + dag, steps = _chain("a", "b", "c") + result = await PipelineEngine(dag).run(inputs="x") + assert result.success + assert all(s.calls == 1 for s in steps.values()) + + +# ---- start_at: skip upstream ---------------------------------------------- + + +async def test_start_at_skips_upstream_nodes(): + dag, steps = _chain("a", "b", "c", "d") + result = await PipelineEngine(dag).run(inputs="x", start_at="c") + assert result.success + assert steps["a"].calls == 0 + assert steps["b"].calls == 0 + assert steps["c"].calls == 1 + assert steps["d"].calls == 1 + + +async def test_start_at_first_node_is_like_no_start_at(): + dag, steps = _chain("a", "b", "c") + result = await PipelineEngine(dag).run(inputs="x", start_at="a") + assert result.success + assert all(s.calls == 1 for s in steps.values()) + + +async def test_start_at_terminal_runs_only_that_node(): + dag, steps = _chain("a", "b", "c") + result = await PipelineEngine(dag).run(inputs="x", start_at="c") + assert result.success + assert steps["a"].calls == 0 + assert steps["b"].calls == 0 + assert steps["c"].calls == 1 + + +# ---- start_at via callable ------------------------------------------------ + + +async def test_start_at_accepts_callable_reference(): + async def deploy(ctx, inputs): + return "deployed" + + from fireflyframework_agentic.pipeline.steps import CallableStep + + dag = DAG("callable") + dag.add_node(DAGNode(node_id="build", step=_Counting("build"))) + dag.add_node(DAGNode(node_id="deploy", step=CallableStep(deploy))) + dag.add_edge(DAGEdge(source="build", target="deploy")) + result = await PipelineEngine(dag).run(inputs="x", start_at=deploy) + assert result.success + # Resolves deploy.__name__ -> 'deploy' -> only deploy ran. + assert result.outputs["deploy"].output == "deployed" + + +# ---- invalid start_at ----------------------------------------------------- + + +async def test_unknown_start_at_raises(): + dag, _ = _chain("a", "b") + with pytest.raises(PipelineError, match="start_at"): + await PipelineEngine(dag).run(inputs="x", start_at="ghost") + + +# ---- branching dag -------------------------------------------------------- + + +async def test_start_at_in_branching_dag(): + """In a branching DAG, start_at picks one branch; the other is skipped entirely.""" + dag = DAG("branchy") + for nid in ("root", "left", "right", "leftchild"): + dag.add_node(DAGNode(node_id=nid, step=_Counting(nid))) + dag.add_edge(DAGEdge(source="root", target="left")) + dag.add_edge(DAGEdge(source="root", target="right")) + dag.add_edge(DAGEdge(source="left", target="leftchild")) + result = await PipelineEngine(dag).run(inputs="x", start_at="left") + assert result.success + # 'right' is not downstream of 'left' — it should not run. + assert "right" not in result.outputs or result.outputs["right"].skipped + # 'leftchild' is reachable from 'left'. + assert result.outputs["leftchild"].success diff --git a/tests/unit/pipeline/test_pipeline_engine_state_overlay.py b/tests/unit/pipeline/test_pipeline_engine_state_overlay.py new file mode 100644 index 00000000..f503a2cd --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_engine_state_overlay.py @@ -0,0 +1,264 @@ +"""Layer 3 of the unification (#245): state as optional overlay on PipelineEngine. + +PipelineEngine now accepts ``state_schema=`` and ``state=`` arguments. When +configured, nodes that return a dict have it merged into a shared Pydantic +state object via reducers (replace, append, extend, merge_dict). Non-dict +returns continue to flow as port outputs — both modes coexist on the same +node. + +This reclaims parallelism for state-aware pipelines: nodes that write +disjoint state fields can run concurrently via the existing topological +scheduler. Concurrent writes to the same field are merged by the reducer +declared on that field (commutative reducers like ``append`` are safe; +``replace`` is last-write-wins). +""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline.audit import FileAuditLog +from fireflyframework_agentic.pipeline.checkpoint import FileCheckpointer +from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode +from fireflyframework_agentic.pipeline.engine import PipelineEngine +from fireflyframework_agentic.pipeline.reducers import append, extend, merge_dict + +# ---- Schemas --------------------------------------------------------------- + + +class _SimpleState(BaseModel): + counter: int = 0 + note: str = "" + + +class _ListState(BaseModel): + items: Annotated[list[str], append] = [] + batch: Annotated[list[str], extend] = [] + + +class _MergeState(BaseModel): + bag: Annotated[dict[str, int], merge_dict] = {} + + +# ---- Step helpers ---------------------------------------------------------- + + +def _step_returns(value): + """Build a step whose execute() always returns `value`.""" + + class _Step: + async def execute(self, ctx, inputs): + return value + + return _Step() + + +def _step_reads_state(field: str): + """Build a step that returns the current state's `field` as a port output.""" + + class _Step: + async def execute(self, ctx, inputs): + return getattr(ctx.state, field) + + return _Step() + + +# ---- baseline: no state_schema = unchanged -------------------------------- + + +async def test_engine_without_state_schema_is_unchanged(): + dag = DAG("plain") + dag.add_node(DAGNode(node_id="a", step=_step_returns("port-value"))) + engine = PipelineEngine(dag) # no state_schema + result = await engine.run(inputs="x") + assert result.success + assert result.final_state is None + assert result.outputs["a"].output == "port-value" + + +# ---- engine instantiates state from defaults when none passed ------------- + + +async def test_state_schema_with_defaults_is_auto_instantiated(): + dag = DAG("auto-state") + dag.add_node(DAGNode(node_id="a", step=_step_returns(None))) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert isinstance(result.final_state, _SimpleState) + assert result.final_state.counter == 0 + + +# ---- explicit state passed via run() -------------------------------------- + + +async def test_state_arg_is_used_when_passed(): + dag = DAG("explicit-state") + dag.add_node(DAGNode(node_id="a", step=_step_returns(None))) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x", state=_SimpleState(counter=42, note="hi")) + assert result.success + assert result.final_state.counter == 42 + assert result.final_state.note == "hi" + + +# ---- node returning dict merges into state via reducer -------------------- + + +async def test_dict_return_is_state_update_under_replace(): + dag = DAG("dict-replace") + dag.add_node(DAGNode(node_id="a", step=_step_returns({"counter": 7}))) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.counter == 7 + + +async def test_append_reducer_accumulates_across_nodes(): + a, b = _step_returns({"items": "first"}), _step_returns({"items": "second"}) + dag = DAG("appender") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_ListState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.items == ["first", "second"] + + +async def test_extend_reducer_concatenates(): + a, b = _step_returns({"batch": ["x", "y"]}), _step_returns({"batch": ["z"]}) + dag = DAG("extender") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_ListState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.batch == ["x", "y", "z"] + + +async def test_merge_dict_reducer_merges(): + a, b = _step_returns({"bag": {"k1": 1}}), _step_returns({"bag": {"k2": 2}}) + dag = DAG("merger") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_MergeState) + result = await engine.run(inputs="x") + assert result.success + assert result.final_state.bag == {"k1": 1, "k2": 2} + + +# ---- non-dict return still flows as a port output ------------------------- + + +async def test_non_dict_return_is_still_a_port_output(): + """A node can write state OR emit a port value — its return type decides.""" + a = _step_returns("port-value") # str, not dict → port output + b = _step_reads_state("note") + dag = DAG("mixed") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge(DAGEdge(source="a", target="b")) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert result.outputs["a"].output == "port-value" # port preserved + assert result.final_state.note == "" # state untouched by 'a' + + +# ---- conditions can read state -------------------------------------------- + + +async def test_edge_condition_reads_ctx_state(): + a = _step_returns({"counter": 5}) + b = _step_returns(None) + dag = DAG("cond-on-state") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_edge( + DAGEdge( + source="a", + target="b", + condition=lambda ctx: ctx.state.counter > 3, + ) + ) + engine = PipelineEngine(dag, state_schema=_SimpleState) + result = await engine.run(inputs="x") + assert result.success + assert not result.outputs["b"].skipped + + +# ---- parallelism: disjoint fields, commutative reducer --------------------- + + +async def test_parallel_nodes_with_commutative_reducer_accumulate(): + """Two nodes at the same level both append to items; both contributions land.""" + a = _step_returns({"items": "from-a"}) + b = _step_returns({"items": "from-b"}) + c = _step_returns(None) + dag = DAG("parallel-append") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b)) + dag.add_node(DAGNode(node_id="c", step=c)) + dag.add_edge(DAGEdge(source="a", target="c")) + dag.add_edge(DAGEdge(source="b", target="c")) + engine = PipelineEngine(dag, state_schema=_ListState) + result = await engine.run(inputs="x") + assert result.success + assert sorted(result.final_state.items) == ["from-a", "from-b"] + + +# ---- checkpoint + resume restores state ----------------------------------- + + +async def test_resume_restores_shared_state(tmp_path): + """Run with state, fail mid-pipeline, resume — state survives.""" + + class _FailOnce: + def __init__(self): + self.calls = 0 + + async def execute(self, ctx, inputs): + self.calls += 1 + if self.calls == 1: + raise RuntimeError("flake") + return {"counter": ctx.state.counter + 100} + + from fireflyframework_agentic.pipeline.dag import FailureStrategy + + a = _step_returns({"counter": 1, "note": "from-a"}) + b = _FailOnce() + dag = DAG("resume-state") + dag.add_node(DAGNode(node_id="a", step=a)) + dag.add_node(DAGNode(node_id="b", step=b, failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b")) + cp = FileCheckpointer(tmp_path) + engine = PipelineEngine(dag, checkpointer=cp, state_schema=_SimpleState) + r1 = await engine.run(inputs="x") + assert not r1.success + # After 'a' succeeds, the checkpoint should contain state.counter=1. + assert r1.final_state.counter == 1 + + r2 = await engine.run(run_id=r1.run_id) + assert r2.success + # On resume: state.counter restored to 1, then b adds 100. + assert r2.final_state.counter == 101 + assert r2.final_state.note == "from-a" # state preserved + + +# ---- audit log works alongside state -------------------------------------- + + +async def test_audit_records_under_state_overlay(tmp_path): + dag = DAG("audit-state") + dag.add_node(DAGNode(node_id="a", step=_step_returns({"counter": 3}))) + al = FileAuditLog(tmp_path) + engine = PipelineEngine(dag, audit_log=al, state_schema=_SimpleState) + result = await engine.run(inputs="x") + entries = al.list_entries("audit-state", result.run_id) + assert len(entries) == 1 + assert entries[0].status == "success" diff --git a/tests/unit/pipeline/test_state_pipeline.py b/tests/unit/pipeline/test_state_pipeline.py new file mode 100644 index 00000000..f0d1095d --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline.py @@ -0,0 +1,350 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Tests for the state-based pipeline API (issue #147 phase 1). + +Covers the canonical agentic-pipeline shape: + * Typed shared state via a Pydantic model. + * Reducers via ``Annotated[T, reducer_fn]``. + * Function references as node ids. + * Auto-entry detection. + * ``.branch(source, router)`` with and without an explicit mapping. + * Checkpoint + resume after failure (the software-factory scenario). + * ``start_at`` to jump into the middle of a pipeline with explicit state. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + append, +) + + +class AgentState(BaseModel): + messages: Annotated[list[str], append] = [] + intent: str | None = None + answer: str | None = None + + +# --- linear pipeline ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_linear_pipeline_runs_all_nodes(): + """Three nodes in sequence; each writes to state; final state has all updates.""" + + async def step_a(state: AgentState) -> dict: + return {"messages": "a"} + + async def step_b(state: AgentState) -> dict: + return {"messages": "b"} + + async def step_c(state: AgentState) -> dict: + return {"messages": "c", "answer": "done"} + + pipeline = ( + PipelineBuilder("linear", state=AgentState) + .add_node(step_a) + .add_node(step_b) + .add_node(step_c) + .chain(step_a, step_b, step_c) + .build() + ) + result = await pipeline.invoke(AgentState(messages=["start"])) + assert result.success + assert result.completed_nodes == ["step_a", "step_b", "step_c"] + assert result.state.messages == ["start", "a", "b", "c"] + assert result.state.answer == "done" + + +@pytest.mark.asyncio +async def test_returning_none_or_empty_dict_keeps_state(): + """A node that returns None or {} should leave state unchanged.""" + + async def noop(state: AgentState) -> None: + return None + + async def writer(state: AgentState) -> dict: + return {"answer": "ok"} + + pipeline = PipelineBuilder("noop", state=AgentState).add_node(noop).add_node(writer).chain(noop, writer).build() + result = await pipeline.invoke(AgentState(messages=["x"])) + assert result.success + assert result.state.messages == ["x"] # unchanged by noop + assert result.state.answer == "ok" + + +# --- branching ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_branch_without_mapping_router_returns_node_id(): + """Router returns the target node id directly; no mapping needed.""" + + async def classify(state: AgentState) -> dict: + return {"intent": "complaint" if "refund" in " ".join(state.messages) else "general"} + + async def answer(state: AgentState) -> dict: + return {"answer": "Here is your answer."} + + async def escalate(state: AgentState) -> dict: + return {"answer": "Escalated."} + + def route(state: AgentState) -> str: + return "escalate" if state.intent == "complaint" else "answer" + + pipeline = ( + PipelineBuilder("branch", state=AgentState) + .add_node(classify) + .add_node(answer) + .add_node(escalate) + .branch(classify, route) + .build() + ) + complaint = await pipeline.invoke(AgentState(messages=["I want a refund"])) + assert complaint.state.answer == "Escalated." + assert complaint.completed_nodes == ["classify", "escalate"] + + general = await pipeline.invoke(AgentState(messages=["hello"])) + assert general.state.answer == "Here is your answer." + assert general.completed_nodes == ["classify", "answer"] + + +@pytest.mark.asyncio +async def test_branch_with_explicit_mapping_uses_abstract_labels(): + async def start(state: AgentState) -> dict: + return {"intent": "x"} + + async def left(state: AgentState) -> dict: + return {"answer": "L"} + + async def right(state: AgentState) -> dict: + return {"answer": "R"} + + def route(state: AgentState) -> str: + return "go_left" if state.intent == "x" else "go_right" + + pipeline = ( + PipelineBuilder("mapped", state=AgentState) + .add_node(start) + .add_node(left) + .add_node(right) + .branch(start, route, {"go_left": left, "go_right": right}) + .build() + ) + result = await pipeline.invoke(AgentState()) + assert result.state.answer == "L" + + +@pytest.mark.asyncio +async def test_router_returning_unknown_label_raises(): + async def start(state: AgentState) -> dict: + return {} + + async def target(state: AgentState) -> dict: + return {"answer": "ok"} + + def bad_router(state: AgentState) -> str: + return "nonexistent_node" + + pipeline = ( + PipelineBuilder("bad", state=AgentState).add_node(start).add_node(target).branch(start, bad_router).build() + ) + result = await pipeline.invoke(AgentState()) + assert not result.success + assert "nonexistent_node" in (result.error or "") + + +# --- reducers -------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_append_reducer_accumulates_across_nodes(): + """The default test schema uses append on messages; each node adds one.""" + + async def a(state: AgentState) -> dict: + return {"messages": "from_a"} + + async def b(state: AgentState) -> dict: + return {"messages": "from_b"} + + pipeline = PipelineBuilder("acc", state=AgentState).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(AgentState(messages=["initial"])) + assert result.state.messages == ["initial", "from_a", "from_b"] + + +@pytest.mark.asyncio +async def test_replace_reducer_is_default_for_unannotated_field(): + async def a(state: AgentState) -> dict: + return {"answer": "first"} + + async def b(state: AgentState) -> dict: + return {"answer": "second"} + + pipeline = PipelineBuilder("rep", state=AgentState).add_node(a).add_node(b).chain(a, b).build() + result = await pipeline.invoke(AgentState()) + assert result.state.answer == "second" + + +# --- checkpoint + resume --------------------------------------------------- + + +class BuildState(BaseModel): + """Software-factory scenario state.""" + + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +@pytest.mark.asyncio +async def test_checkpoint_resume_after_failure(tmp_path: Path): + """Run a 4-step agent factory; deployer fails the first time; resume succeeds.""" + + failed_once = {"deploy": False} + + async def architect(state: BuildState) -> dict: + return {"spec": "architecture spec for: " + state.requirements} + + async def python_dev(state: BuildState) -> dict: + return {"code": f"# code implementing {state.spec}"} + + async def deployer(state: BuildState) -> dict: + if not failed_once["deploy"]: + failed_once["deploy"] = True + raise RuntimeError("network glitch") + return {"deploy_url": "https://app.example.com"} + + async def evaluator(state: BuildState) -> dict: + return {"evaluation": f"PASS: {state.deploy_url}"} + + ckpt = FileCheckpointer(tmp_path / "ckpt") + pipeline = ( + PipelineBuilder("software-factory", state=BuildState, checkpointer=ckpt) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + + # First run: deployer fails. + first = await pipeline.invoke(BuildState(requirements="user-mgmt service")) + assert not first.success + assert first.failed_node == "deployer" + assert first.completed_nodes == ["architect", "python_dev"] + assert first.state.code is not None # python_dev did persist + + # Resume: should skip architect/python_dev, retry deployer, then evaluator. + second = await pipeline.invoke(run_id=first.run_id) + assert second.success + assert second.completed_nodes == ["architect", "python_dev", "deployer", "evaluator"] + assert second.state.evaluation == "PASS: https://app.example.com" + + +@pytest.mark.asyncio +async def test_start_at_jumps_to_middle_with_explicit_state(tmp_path: Path): + """Caller supplies state + start_at to run only from deployer onwards.""" + + async def architect(state: BuildState) -> dict: + raise AssertionError("should not run") + + async def python_dev(state: BuildState) -> dict: + raise AssertionError("should not run") + + async def deployer(state: BuildState) -> dict: + return {"deploy_url": "https://app.example.com"} + + async def evaluator(state: BuildState) -> dict: + return {"evaluation": "PASS"} + + pipeline = ( + PipelineBuilder("factory", state=BuildState) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + pre_built = BuildState(requirements="x", spec="precomputed", code="precomputed code") + result = await pipeline.invoke(pre_built, start_at=deployer) + assert result.success + assert result.completed_nodes == ["deployer", "evaluator"] + assert result.state.deploy_url == "https://app.example.com" + + +@pytest.mark.asyncio +async def test_resume_without_checkpointer_raises(): + async def a(state: AgentState) -> dict: + return {} + + pipeline = PipelineBuilder("nockpt", state=AgentState).add_node(a).build() + with pytest.raises(PipelineError, match="no checkpointer"): + await pipeline.invoke(run_id="anything") + + +# --- validation / errors --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_default_entry_is_first_node_added(): + """When no inbound edges disambiguate, the first add_node call is the entry.""" + + async def first_one(state: AgentState) -> dict: + return {"answer": "first ran"} + + async def second_one(state: AgentState) -> dict: + raise AssertionError("not reached without an edge") + + pipeline = PipelineBuilder("order", state=AgentState).add_node(first_one).add_node(second_one).build() + result = await pipeline.invoke(AgentState()) + assert result.completed_nodes == ["first_one"] + assert result.state.answer == "first ran" + + +def test_function_ref_without_state_raises(): + async def step(state): + return {} + + with pytest.raises(PipelineError, match="state=..."): + PipelineBuilder("nostate").add_node(step) + + +def test_branch_without_state_raises(): + builder = PipelineBuilder("nostate") + with pytest.raises(PipelineError, match="state=..."): + builder.branch("x", lambda s: "y") + + +# --- agent-shape adapter --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_agent_like_object_adapts_via_run_method(): + """Object exposing async run(state) is accepted as a node.""" + + class MockAgent: + __name__ = "mock_agent" # required for function-ref node id derivation + + async def run(self, state: AgentState) -> dict: + return {"answer": "from mock agent"} + + pipeline = PipelineBuilder("agent", state=AgentState).add_node("mock_agent", MockAgent()).build() + result = await pipeline.invoke(AgentState()) + assert result.success + assert result.state.answer == "from mock agent" diff --git a/tests/unit/pipeline/test_state_pipeline_hitl.py b/tests/unit/pipeline/test_state_pipeline_hitl.py new file mode 100644 index 00000000..f880c459 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_hitl.py @@ -0,0 +1,206 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3c HITL tests: Pause + approve_pause resume + on_node_pause event.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.exceptions import PipelineError +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + Pause, + PipelineBuilder, + extend, +) + + +class DeployState(BaseModel): + requirements: str = "" + spec: str | None = None + approved: Annotated[list[str], extend] = [] + deployed: bool = False + + +@dataclass +class PauseRecorder: + events: list[tuple] = field(default_factory=list) + + async def on_node_pause(self, pipeline_name: str, run_id: str, node_id: str, reason: str) -> None: + self.events.append(("pause", node_id, reason)) + + +# --- core pause/resume ------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_node_returning_pause_halts_pipeline(tmp_path: Path) -> None: + async def architect(state: DeployState) -> dict: + return {"spec": "v1"} + + async def gate(state: DeployState) -> Pause: + return Pause(reason="awaiting deploy approval") + + async def deploy(state: DeployState) -> dict: + return {"deployed": True} + + ckpt = FileCheckpointer(tmp_path) + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=ckpt) + .add_node(architect) + .add_node(gate) + .add_node(deploy) + .chain(architect, gate, deploy) + .build() + ) + + result = await pipeline.invoke(DeployState(requirements="user-mgmt")) + assert result.paused is True + assert result.paused_node == "gate" + assert result.pause_reason == "awaiting deploy approval" + assert result.success is False # paused != success + assert result.state.deployed is False # deploy did NOT run + assert result.completed_nodes == ["architect", "gate"] + + +@pytest.mark.asyncio +async def test_resume_without_approve_pause_raises(tmp_path: Path) -> None: + async def gate(state: DeployState) -> Pause: + return Pause(reason="block here") + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)).add_node(gate).build() + ) + first = await pipeline.invoke(DeployState()) + assert first.paused is True + + with pytest.raises(PipelineError, match="approve_pause=True"): + await pipeline.invoke(run_id=first.run_id) + + +@pytest.mark.asyncio +async def test_resume_with_approve_pause_continues_from_successor(tmp_path: Path) -> None: + fail_once = {"flag": False} + + async def architect(state: DeployState) -> dict: + if fail_once["flag"]: + raise AssertionError("architect should NOT re-run on resume") + return {"spec": "v1"} + + async def gate(state: DeployState) -> Pause: + if fail_once["flag"]: + raise AssertionError("gate should NOT re-run on resume") + return Pause(reason="approve please") + + async def deploy(state: DeployState) -> dict: + return {"deployed": True} + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)) + .add_node(architect) + .add_node(gate) + .add_node(deploy) + .chain(architect, gate, deploy) + .build() + ) + first = await pipeline.invoke(DeployState(requirements="x")) + assert first.paused is True + fail_once["flag"] = True # ensure neither architect nor gate re-runs + + second = await pipeline.invoke(run_id=first.run_id, approve_pause=True) + assert second.success is True + assert second.state.deployed is True + assert second.completed_nodes == ["architect", "gate", "deploy"] + + +@pytest.mark.asyncio +async def test_on_node_pause_event_fires(tmp_path: Path) -> None: + async def gate(state: DeployState) -> Pause: + return Pause(reason="hold") + + handler = PauseRecorder() + pipeline = ( + PipelineBuilder( + "hitl", + state=DeployState, + checkpointer=FileCheckpointer(tmp_path), + event_handler=handler, # type: ignore[arg-type] + ) + .add_node(gate) + .build() + ) + await pipeline.invoke(DeployState()) + assert handler.events == [("pause", "gate", "hold")] + + +# --- backward compat ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_paused_checkpoint_loads_when_pause_fields_missing(tmp_path: Path) -> None: + """An existing checkpoint without paused/pause_reason fields still loads.""" + from fireflyframework_agentic.pipeline.checkpoint import CheckpointRecord + + # Round-trip a record produced from a dict that omits the new fields — + # mirrors what an existing on-disk checkpoint from a pre-3c version looks like. + raw = { + "pipeline_name": "old", + "run_id": "legacy", + "node_id": "n", + "sequence": 1, + "state": {"requirements": "x"}, + "completed_nodes": ["n"], + } + record = CheckpointRecord.model_validate(raw) + assert record.paused is False + assert record.pause_reason is None + + +# --- a paused pipeline can still resume after non-pause failures ----------- + + +@pytest.mark.asyncio +async def test_pause_then_resume_then_error_then_resume(tmp_path: Path) -> None: + """End-to-end: pause, approve, then a subsequent failure, then retry.""" + counters = {"deploy_fail": False} + + async def gate(state: DeployState) -> Pause: + return Pause(reason="approve") + + async def deploy(state: DeployState) -> dict: + if not counters["deploy_fail"]: + counters["deploy_fail"] = True + raise RuntimeError("flaky") + return {"deployed": True} + + pipeline = ( + PipelineBuilder("hitl", state=DeployState, checkpointer=FileCheckpointer(tmp_path)) + .add_node(gate) + .add_node(deploy) + .chain(gate, deploy) + .build() + ) + paused = await pipeline.invoke(DeployState()) + assert paused.paused + + failed = await pipeline.invoke(run_id=paused.run_id, approve_pause=True) + assert not failed.success + assert failed.failed_node == "deploy" + + succeeded = await pipeline.invoke(run_id=paused.run_id, approve_pause=True) + # The deploy checkpoint at this point isn't marked paused, but the gate + # checkpoint still is. approve_pause is needed because load_latest may + # still return the older paused checkpoint OR the newer one depending on + # backend sort order — both backends sort by sequence so the latest is + # the failed-deploy record. The check passes either way: if the latest is + # paused, approve_pause is required; if not, approve_pause is ignored. + assert succeeded.success is True + assert succeeded.state.deployed is True diff --git a/tests/unit/pipeline/test_state_pipeline_observability.py b/tests/unit/pipeline/test_state_pipeline_observability.py new file mode 100644 index 00000000..ac79233a --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_observability.py @@ -0,0 +1,324 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-3b tests: StatePipelineEventHandler callbacks + OTel spans for state pipelines.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Annotated, Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.engine as engine_module +from fireflyframework_agentic.pipeline import ( + FileCheckpointer, + PipelineBuilder, + Send, + extend, +) + + +@dataclass +class RecordingHandler: + """A test handler that captures every callback in order.""" + + events: list[tuple] = field(default_factory=list) + + async def on_pipeline_start(self, pipeline_name: str, run_id: str) -> None: + self.events.append(("pipeline_start", pipeline_name, run_id)) + + async def on_node_start(self, pipeline_name: str, run_id: str, node_id: str, visit: int) -> None: + self.events.append(("node_start", node_id, visit)) + + async def on_node_complete(self, pipeline_name: str, run_id: str, node_id: str, latency_ms: float) -> None: + self.events.append(("node_complete", node_id)) + + async def on_node_error(self, pipeline_name: str, run_id: str, node_id: str, error: str) -> None: + self.events.append(("node_error", node_id, error)) + + async def on_pipeline_complete(self, pipeline_name: str, run_id: str, success: bool, duration_ms: float) -> None: + self.events.append(("pipeline_complete", success)) + + +class LinearState(BaseModel): + log: Annotated[list[str], extend] = [] + + +class LoopState(BaseModel): + counter: int = 0 + + +# --- linear pipeline event ordering ----------------------------------------- + + +@pytest.mark.asyncio +async def test_linear_pipeline_emits_events_in_order() -> None: + async def a(state: LinearState) -> dict: + return {"log": ["a"]} + + async def b(state: LinearState) -> dict: + return {"log": ["b"]} + + async def c(state: LinearState) -> dict: + return {"log": ["c"]} + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("linear", state=LinearState, event_handler=handler) + .add_node(a) + .add_node(b) + .add_node(c) + .chain(a, b, c) + .build() + ) + await pipeline.invoke(LinearState()) + + event_kinds = [e[0] for e in handler.events] + assert event_kinds == [ + "pipeline_start", + "node_start", + "node_complete", + "node_start", + "node_complete", + "node_start", + "node_complete", + "pipeline_complete", + ] + assert handler.events[-1] == ("pipeline_complete", True) + + +# --- failure path ---------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_failure_emits_node_error_and_pipeline_complete_false() -> None: + async def boom(state: LinearState) -> dict: + raise RuntimeError("nope") + + handler = RecordingHandler() + pipeline = PipelineBuilder("fail", state=LinearState, event_handler=handler).add_node(boom).build() + result = await pipeline.invoke(LinearState()) + assert not result.success + + assert ("node_error", "boom", "nope") in handler.events + assert ("pipeline_complete", False) in handler.events + # node_error fires BEFORE pipeline_complete + err_idx = handler.events.index(("node_error", "boom", "nope")) + done_idx = handler.events.index(("pipeline_complete", False)) + assert err_idx < done_idx + + +# --- cyclic graph visit count ---------------------------------------------- + + +@pytest.mark.asyncio +async def test_cyclic_graph_increments_visit_count() -> None: + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1} + + async def done(state: LoopState) -> dict: + return {} + + def route(state: LoopState) -> str: + return "done" if state.counter >= 3 else "step" + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("loop", state=LoopState, event_handler=handler) + .add_node(step) + .add_node(done) + .branch(step, route) + .build() + ) + await pipeline.invoke(LoopState()) + + step_starts = [e for e in handler.events if e[0] == "node_start" and e[1] == "step"] + assert [e[2] for e in step_starts] == [1, 2, 3] + + +# --- fan-out via Send ------------------------------------------------------- + + +class FanOutState(BaseModel): + items: list[str] = [] + results: Annotated[list[str], extend] = [] + item: str | None = None + + +@pytest.mark.asyncio +async def test_fanout_emits_per_send_node_events() -> None: + async def planner(state: FanOutState) -> dict: + return {} + + async def worker(state: FanOutState) -> dict: + return {"results": [f"r:{state.item}"]} + + async def collect(state: FanOutState) -> dict: + return {} + + def dispatch(state: FanOutState) -> list[Send]: + return [Send("worker", {"item": x}) for x in state.items] + + handler = RecordingHandler() + pipeline = ( + PipelineBuilder("fanout", state=FanOutState, event_handler=handler) + .add_node(planner) + .add_node(worker) + .add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() + ) + await pipeline.invoke(FanOutState(items=["a", "b", "c"])) + + worker_starts = [e for e in handler.events if e[0] == "node_start" and e[1] == "worker"] + worker_completes = [e for e in handler.events if e[0] == "node_complete" and e[1] == "worker"] + assert len(worker_starts) == 3 + assert len(worker_completes) == 3 + # Visits are 1, 2, 3 across the three Sends. + assert sorted(e[2] for e in worker_starts) == [1, 2, 3] + + +# --- resume from a checkpoint ---------------------------------------------- + + +class BuildState(BaseModel): + requirements: str + spec: str | None = None + code: str | None = None + deploy: str | None = None + + +@pytest.mark.asyncio +async def test_resume_emits_events_only_for_remaining_nodes(tmp_path: Path) -> None: + fail_once = {"flag": False} + + async def arch(state: BuildState) -> dict: + return {"spec": "s"} + + async def dev(state: BuildState) -> dict: + return {"code": "c"} + + async def deploy(state: BuildState) -> dict: + if not fail_once["flag"]: + fail_once["flag"] = True + raise RuntimeError("blip") + return {"deploy": "ok"} + + handler1 = RecordingHandler() + handler2 = RecordingHandler() + ckpt = FileCheckpointer(tmp_path) + + # First run uses handler1; deploy fails. + p1 = ( + PipelineBuilder("factory", state=BuildState, checkpointer=ckpt, event_handler=handler1) + .add_node(arch) + .add_node(dev) + .add_node(deploy) + .chain(arch, dev, deploy) + .build() + ) + first = await p1.invoke(BuildState(requirements="x")) + assert not first.success + + # Second run uses handler2 and resumes; only deploy should run. + p2 = ( + PipelineBuilder("factory", state=BuildState, checkpointer=ckpt, event_handler=handler2) + .add_node(arch) + .add_node(dev) + .add_node(deploy) + .chain(arch, dev, deploy) + .build() + ) + second = await p2.invoke(run_id=first.run_id) + assert second.success + + nodes_started_on_resume = [e[1] for e in handler2.events if e[0] == "node_start"] + assert nodes_started_on_resume == ["deploy"] + + +# --- partial handler -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_partial_handler_only_implementing_node_error_works() -> None: + captured: list[str] = [] + + class JustErrors: + async def on_node_error(self, pipeline_name, run_id, node_id, error): + captured.append(f"{node_id}:{error}") + + async def boom(state: LinearState) -> dict: + raise RuntimeError("kaboom") + + pipeline = PipelineBuilder("partial", state=LinearState, event_handler=JustErrors()).add_node(boom).build() + result = await pipeline.invoke(LinearState()) + assert not result.success + assert captured == ["boom:kaboom"] + + +# --- handler exception is swallowed ---------------------------------------- + + +@pytest.mark.asyncio +async def test_handler_exception_does_not_break_pipeline() -> None: + class CrashyHandler: + async def on_node_complete(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError("handler crashed") + + async def step(state: LinearState) -> dict: + return {"log": ["ran"]} + + pipeline = PipelineBuilder("crash", state=LinearState, event_handler=CrashyHandler()).add_node(step).build() + result = await pipeline.invoke(LinearState()) + assert result.success + assert result.state.log == ["ran"] + + +# --- OTel spans ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_otel_spans_emitted_per_pipeline_and_per_node(monkeypatch: pytest.MonkeyPatch) -> None: + # Force the observability path to fire even if the config disables it by default. + mock_config = MagicMock(observability_enabled=True) + monkeypatch.setattr(engine_module, "get_config", lambda: mock_config) + + # Stub the OTel tracer so we can record span starts. + started: list[tuple[str, dict]] = [] + mock_tracer = MagicMock() + + def fake_start_span(name: str, attributes: dict | None = None) -> MagicMock: + started.append((name, attributes or {})) + return MagicMock() + + mock_tracer.start_span.side_effect = fake_start_span + mock_trace = MagicMock() + mock_trace.get_tracer.return_value = mock_tracer + monkeypatch.setattr(engine_module, "otel_trace", mock_trace) + + async def a(state: LinearState) -> dict: + return {} + + async def b(state: LinearState) -> dict: + return {} + + pipeline = PipelineBuilder("otel", state=LinearState).add_node(a).add_node(b).chain(a, b).build() + await pipeline.invoke(LinearState()) + + names = [n for n, _ in started] + assert "pipeline.state.otel" in names + assert "pipeline.state.node.a" in names + assert "pipeline.state.node.b" in names + + # Spot-check attributes: pipeline span carries run_id, node span carries visit. + pipeline_attrs = next(attrs for name, attrs in started if name == "pipeline.state.otel") + assert "firefly.run_id" in pipeline_attrs + node_a_attrs = next(attrs for name, attrs in started if name == "pipeline.state.node.a") + assert node_a_attrs.get("firefly.visit") == "1" diff --git a/tests/unit/pipeline/test_state_pipeline_phase2.py b/tests/unit/pipeline/test_state_pipeline_phase2.py new file mode 100644 index 00000000..c3862fa4 --- /dev/null +++ b/tests/unit/pipeline/test_state_pipeline_phase2.py @@ -0,0 +1,228 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Phase-2 tests: cycles + recursion_limit, Send fan-out, Mermaid export, +soft-deprecation of BranchStep / FanOutStep. +""" + +from __future__ import annotations + +import json +import warnings +from typing import Annotated + +import pytest +from pydantic import BaseModel + +from fireflyframework_agentic.pipeline import ( + DAG, + BranchStep, + DAGEdge, + DAGNode, + FailureStrategy, + FanOutStep, + PipelineBuilder, + Send, + extend, +) +from fireflyframework_agentic.pipeline.steps import CallableStep + + +class LoopState(BaseModel): + counter: int = 0 + log: Annotated[list[str], extend] = [] + + +# --- cycles + recursion_limit ---------------------------------------------- + + +@pytest.mark.asyncio +async def test_simple_cycle_with_exit_router(): + """A node loops back to itself N times, then a router exits to END.""" + + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1, "log": [f"step#{state.counter + 1}"]} + + async def done(state: LoopState) -> dict: + return {"log": ["done"]} + + def route(state: LoopState) -> str: + return "done" if state.counter >= 3 else "step" + + pipeline = PipelineBuilder("loop", state=LoopState).add_node(step).add_node(done).branch(step, route).build() + result = await pipeline.invoke(LoopState()) + assert result.success + assert result.state.counter == 3 + assert "done" in result.state.log + # 3 step visits + 1 done = 4 entries before done's own log entry. + assert result.completed_nodes == ["step", "step", "step", "done"] + + +@pytest.mark.asyncio +async def test_recursion_limit_aborts_infinite_loop(): + """A router that never exits triggers the recursion_limit safety net.""" + + async def step(state: LoopState) -> dict: + return {"counter": state.counter + 1} + + def never_exits(state: LoopState) -> str: + return "step" + + pipeline = ( + PipelineBuilder("inf", state=LoopState, recursion_limit=5).add_node(step).branch(step, never_exits).build() + ) + result = await pipeline.invoke(LoopState()) + assert not result.success + assert "Recursion limit" in (result.error or "") + assert result.failed_node == "step" + # The node ran exactly recursion_limit times before the guard fired. + assert result.state.counter == 5 + + +# --- Send fan-out ----------------------------------------------------------- + + +class FanOutState(BaseModel): + items: list[str] = [] + results: Annotated[list[str], extend] = [] + item: str | None = None # filled per-Send via payload + + +@pytest.mark.asyncio +async def test_send_fans_out_to_multiple_workers_and_merges_results(): + """Router returns list[Send]; workers run concurrently; reducer merges.""" + + async def planner(state: FanOutState) -> dict: + return {} # passthrough; could populate items if not preset + + async def worker(state: FanOutState) -> dict: + # Each worker sees its own copy of state with the Send payload applied. + assert state.item is not None + return {"results": [f"processed:{state.item}"]} + + async def collect(state: FanOutState) -> dict: + return {"results": ["collected"]} + + def dispatch(state: FanOutState) -> list[Send]: + return [Send("worker", {"item": x}) for x in state.items] + + pipeline = ( + PipelineBuilder("mapreduce", state=FanOutState) + .add_node(planner) + .add_node(worker) + .add_node(collect) + .add_edge(worker, collect) + .branch(planner, dispatch) + .build() + ) + result = await pipeline.invoke(FanOutState(items=["a", "b", "c"])) + assert result.success + processed = sorted(r for r in result.state.results if r.startswith("processed:")) + assert processed == ["processed:a", "processed:b", "processed:c"] + assert "collected" in result.state.results + # Each worker counts as a completed node visit; planner once, three workers, then collect. + assert result.completed_nodes.count("worker") == 3 + assert result.completed_nodes[-1] == "collect" + + +@pytest.mark.asyncio +async def test_send_to_unknown_target_fails_cleanly(): + async def planner(state: FanOutState) -> dict: + return {} + + async def worker(state: FanOutState) -> dict: + return {} + + def bad_dispatch(state: FanOutState) -> list[Send]: + return [Send("ghost", {})] + + pipeline = ( + PipelineBuilder("bad", state=FanOutState) + .add_node(planner) + .add_node(worker) + .branch(planner, bad_dispatch) + .build() + ) + result = await pipeline.invoke(FanOutState()) + assert not result.success + assert "ghost" in (result.error or "") + + +# --- Mermaid + JSON export -------------------------------------------------- + + +def test_dag_to_mermaid_renders_topology(): + dag = DAG(name="example") + dag.add_node(DAGNode(node_id="a", step=CallableStep(_noop_async))) + dag.add_node(DAGNode(node_id="b", step=CallableStep(_noop_async))) + dag.add_edge(DAGEdge(source="a", target="b")) + out = dag.to_mermaid() + assert out.startswith("flowchart TD") + assert "a[a]" in out + assert "b[b]" in out + assert "a --> b" in out + + +def test_dag_to_json_round_trips_via_pydantic(): + dag = DAG(name="example") + dag.add_node(DAGNode(node_id="a", step=CallableStep(_noop_async))) + dag.add_node(DAGNode(node_id="b", step=CallableStep(_noop_async), failure_strategy=FailureStrategy.FAIL_PIPELINE)) + dag.add_edge(DAGEdge(source="a", target="b", input_key="payload")) + doc = json.loads(dag.to_json()) + assert doc["name"] == "example" + assert doc["nodes"] == ["a", "b"] + assert doc["edges"] == [{"source": "a", "target": "b", "output_key": "output", "input_key": "payload"}] + + +def test_state_pipeline_to_mermaid_labels_branch_edges(): + async def start(state: LoopState) -> dict: + return {} + + async def left(state: LoopState) -> dict: + return {} + + async def right(state: LoopState) -> dict: + return {} + + def route(state: LoopState) -> str: + return "left_path" + + pipeline = ( + PipelineBuilder("branched", state=LoopState) + .add_node(start) + .add_node(left) + .add_node(right) + .branch(start, route, {"left_path": left, "right_path": right}) + .build() + ) + mermaid = pipeline.to_mermaid() + assert "start -->|left_path| left" in mermaid + assert "start -->|right_path| right" in mermaid + + +# --- soft-deprecation ------------------------------------------------------ + + +def test_branch_step_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + BranchStep(router=lambda _: "x") + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert any("branch(" in str(w.message) for w in caught) + + +def test_fan_out_step_emits_deprecation_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + FanOutStep(split_fn=lambda x: [x]) + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert any("Send" in str(w.message) for w in caught) + + +# --- helpers --------------------------------------------------------------- + + +async def _noop_async(ctx, inputs): + return None