diff --git a/cli/pyproject.toml b/cli/pyproject.toml index 66f8147..bc2cb05 100644 --- a/cli/pyproject.toml +++ b/cli/pyproject.toml @@ -11,6 +11,7 @@ requires-python = ">=3.12" dependencies = [ "typer>=0.12.5", "pyyaml>=6.0.2", + "rich>=14.2.0", "websockets>=15.0", "flowmesh-sdk>=0.1.0", ] diff --git a/cli/src/flowmesh_cli/commands/__init__.py b/cli/src/flowmesh_cli/commands/__init__.py index 0146a55..900f8dd 100644 --- a/cli/src/flowmesh_cli/commands/__init__.py +++ b/cli/src/flowmesh_cli/commands/__init__.py @@ -8,6 +8,7 @@ from .ssh import app as ssh_app from .system import app as system_app from .task import app as task_app +from .trace import app as trace_app from .worker import app as worker_app from .workflow import app as workflow_app @@ -22,3 +23,4 @@ def register(app: typer.Typer) -> None: app.add_typer(node_app, name="node") app.add_typer(worker_app, name="worker") app.add_typer(ssh_app, name="ssh") + app.add_typer(trace_app, name="trace") diff --git a/cli/src/flowmesh_cli/commands/trace.py b/cli/src/flowmesh_cli/commands/trace.py new file mode 100644 index 0000000..36c2f07 --- /dev/null +++ b/cli/src/flowmesh_cli/commands/trace.py @@ -0,0 +1,382 @@ +"""``flowmesh trace`` — fetch raw rows or run the analyzer.""" + +import json +from collections import defaultdict +from enum import StrEnum +from pathlib import Path + +import typer +from flowmesh.exceptions import FlowMeshError +from flowmesh.models.traces import ( + EventSummary, + ProfileSummary, + TaskTiming, +) +from flowmesh.resources.traces import TraceType +from pydantic import BaseModel +from rich.box import SIMPLE +from rich.console import Console +from rich.json import JSON as RichJSON +from rich.panel import Panel +from rich.rule import Rule +from rich.table import Table +from rich.tree import Tree + +from ..core import logging +from ..core.runtime import flowmesh_client_from_config +from ..core.typer import get_typer + +app = get_typer(help="Workflow trace: fetch raw rows or run the analyzer.") +console = Console() + + +class _AnalyzeView(StrEnum): + RICH = "rich" + CRITICAL_PATH = "critical-path" + CP = "cp" + END_TO_END = "end-to-end" + E2E = "e2e" + QUEUING = "queuing" + LINEAGE = "lineage" + JSON = "json" + + +_ANALYZE_VIEW_ALIAS: dict[_AnalyzeView, _AnalyzeView] = { + _AnalyzeView.CP: _AnalyzeView.CRITICAL_PATH, + _AnalyzeView.E2E: _AnalyzeView.END_TO_END, +} + + +class _EventRow(BaseModel): + """One row of an event-summary table — a per-event-type breakdown.""" + + event_type: str + count: int + total_seconds: float + avg_seconds: float + min_seconds: float + max_seconds: float + + +def _event_rows(summary: EventSummary) -> list[_EventRow]: + """Reshape an ``EventSummary``'s parallel lists into validated rows.""" + return [ + _EventRow( + event_type=event_type, + count=count, + total_seconds=total, + avg_seconds=avg, + min_seconds=min, + max_seconds=max, + ) + for event_type, count, total, avg, min, max in zip( + summary.event_type, + summary.count, + summary.total_seconds, + summary.avg_seconds, + summary.min_seconds, + summary.max_seconds, + strict=True, + ) + ] + + +def _compute_table(hw: EventSummary, title: str) -> Table: + """Per-event-type compute-time table; total cells colored by share of max.""" + table = Table(title=title, box=SIMPLE, header_style="bold cyan", title_style="bold") + table.add_column("event_type", style="cyan", no_wrap=True) + table.add_column("n", justify="right") + table.add_column("total_sec", justify="right", style="bold") + table.add_column("avg", justify="right", style="dim") + table.add_column("min", justify="right", style="dim") + table.add_column("max", justify="right", style="dim") + rows = sorted(_event_rows(hw), key=lambda r: r.total_seconds, reverse=True) + max_total = max((r.total_seconds for r in rows), default=0.0) + for row in rows: + total_str = f"{row.total_seconds:.3f}" + if row.total_seconds > 0 and max_total > 0: + ratio = row.total_seconds / max_total + color = "red" if ratio > 0.5 else "yellow" if ratio > 0.1 else "green" + total_str = f"[{color}]{total_str}[/{color}]" + table.add_row( + row.event_type, + str(row.count), + total_str, + f"{row.avg_seconds:.3f}", + f"{row.min_seconds:.3f}", + f"{row.max_seconds:.3f}", + ) + return table + + +def _network_table(net: EventSummary, title: str) -> Table: + """Per-event-type network-time table.""" + table = Table( + title=title, box=SIMPLE, header_style="bold magenta", title_style="bold" + ) + table.add_column("event_type", style="magenta", no_wrap=True) + table.add_column("n", justify="right") + table.add_column("active_sec", justify="right", style="bold") + table.add_column("avg", justify="right", style="dim") + table.add_column("min", justify="right", style="dim") + table.add_column("max", justify="right", style="dim") + rows = sorted(_event_rows(net), key=lambda r: r.total_seconds, reverse=True) + for row in rows: + table.add_row( + row.event_type, + str(row.count), + f"{row.total_seconds:.3f}", + f"{row.avg_seconds:.3f}", + f"{row.min_seconds:.3f}", + f"{row.max_seconds:.3f}", + ) + return table + + +def _short_data_id(value: str) -> str: + """Compact a ``tsk-`` for narrow tables: keep the first 8 hex chars.""" + if not value: + return value + body = value[len("tsk-") :] if value.startswith("tsk-") else value + return f"tsk-{body[:8]}" if body else value + + +def _queuing_delay_table( + timings: list[TaskTiming], cp_set: set[str], title: str +) -> Table: + """Per-data_id queuing delay, sorted by wait time descending.""" + table = Table( + title=title, box=SIMPLE, header_style="bold yellow", title_style="bold" + ) + table.add_column("data_id", style="cyan", no_wrap=True) + table.add_column("duration_sec", justify="right", style="green") + table.add_column("wait_sec", justify="right", style="bold yellow") + table.add_column("blocked_by", style="cyan", no_wrap=True) + table.add_column("cp", justify="center", no_wrap=True) + rows = sorted( + timings, + key=lambda t: (t.queuing_delay_seconds, t.duration_seconds), + reverse=True, + ) + for t in rows: + blocker = ( + _short_data_id(t.blocking_parent_data_id) + if t.blocking_parent_data_id + else "—" + ) + cp_marker = "[bold red]◆[/bold red]" if t.data_id in cp_set else "" + table.add_row( + _short_data_id(t.data_id), + f"{t.duration_seconds:.3f}", + f"{t.queuing_delay_seconds:.3f}", + blocker, + cp_marker, + ) + return table + + +def _lineage_tree(summary: ProfileSummary) -> Tree: + """Lineage DAG as a Rich tree rooted at sinks.""" + children_of: dict[str, list[str]] = defaultdict(list) + parents_of: dict[str, list[str]] = defaultdict(list) + for edge in summary.lineage: + children_of[edge.source_data_id].append(edge.data_id) + parents_of[edge.data_id].append(edge.source_data_id) + + cp_set: set[str] = ( + set(summary.critical_path.path) if summary.critical_path else set() + ) + sinks = [d for d in summary.data_ids if d not in children_of] + + root = Tree("[bold]lineage DAG[/bold]") + + def _label(data_id: str) -> str: + marker = " [bold red]◆ critical path[/bold red]" if data_id in cp_set else "" + return f"[cyan]{data_id}[/cyan]{marker}" + + def _walk(parent_node: Tree, data_id: str, branch: frozenset[str]) -> None: + node = parent_node.add(_label(data_id)) + for upstream in parents_of.get(data_id, []): + if upstream in branch: + node.add(f"[dim]↺ {upstream} (cycle skipped)[/dim]") + continue + _walk(node, upstream, branch | {upstream}) + + if not sinks: + return root.add("[dim](no events with timestamps)[/dim]") + for sink in sinks: + _walk(root, sink, frozenset({sink})) + in_any_edge: set[str] = set(children_of) | set(parents_of) + for orphan in summary.data_ids: + if orphan not in in_any_edge: + root.add(_label(orphan)) + return root + + +def _critical_path_tree(summary: ProfileSummary) -> Tree: + cp = summary.critical_path + if cp is None: + tree = Tree("[bold]critical path[/bold]") + tree.add("[dim](no path: workflow has no events with timestamps)[/dim]") + return tree + + tree = Tree( + f"[bold]critical path[/bold] " + f"[green]{cp.critical_path_seconds:.3f}s[/green]" + f" network=[magenta]{cp.total_network_seconds:.3f}s[/magenta]" + f" length=[cyan]{len(cp.path)}[/cyan]" + ) + awb = cp.active_wait_breakdown + cursor = tree + for data_id, active, wait in zip( + awb.data_id, awb.active_seconds, awb.wait_seconds, strict=True + ): + wait_part = f" [yellow]wait {wait:.3f}s[/yellow]" if wait > 0 else "" + cursor = cursor.add( + f"[bold red]◆[/bold red] [cyan]{data_id}[/cyan]" + f" [green]active {active:.3f}s[/green]{wait_part}" + ) + return tree + + +def _print_header(summary: ProfileSummary) -> None: + e2e = summary.e2e_breakdown + cp = summary.critical_path + lines = [ + f"[bold]workflow:[/bold] {summary.workflow_id or '(unnamed)'}", + ( + f"e2e=[bold green]{e2e.workflow_duration_seconds:.3f}s[/bold green]" + f" network=[bold magenta]{e2e.total_network_seconds:.3f}s[/bold magenta]" + f" data_ids=[bold cyan]{len(summary.data_ids)}[/bold cyan]" + f" events=[bold cyan]{summary.event_count}[/bold cyan]" + f" assets=[bold cyan]{len(summary.assets)}[/bold cyan]" + ), + ] + if cp is not None: + lines.append( + f"cp=[bold green]{cp.critical_path_seconds:.3f}s[/bold green]" + f" network=[bold magenta]{cp.total_network_seconds:.3f}s[/bold magenta]" + f" length=[bold cyan]{len(cp.path)}[/bold cyan]" + ) + console.print(Panel("\n".join(lines), title="trace", border_style="cyan")) + + +def _print_critical_path(summary: ProfileSummary) -> None: + if summary.critical_path is None: + console.print("[dim](no critical path: no events with timestamps)[/dim]") + return + console.print(_critical_path_tree(summary)) + console.print( + _compute_table( + summary.critical_path.hardware_summary, "Compute time (critical path)" + ) + ) + console.print( + _network_table( + summary.critical_path.network_summary, "Network time (critical path)" + ) + ) + + +def _print_e2e(summary: ProfileSummary) -> None: + e2e = summary.e2e_breakdown + console.print(_compute_table(e2e.hardware_summary, "Compute time (end-to-end)")) + console.print(_network_table(e2e.network_summary, "Network time (end-to-end)")) + + +def _print_queuing(summary: ProfileSummary) -> None: + if not summary.per_data_id: + return + cp_set: set[str] = ( + set(summary.critical_path.path) if summary.critical_path else set() + ) + console.print( + _queuing_delay_table(summary.per_data_id, cp_set, "Queuing delay (per data_id)") + ) + + +def _print_lineage(summary: ProfileSummary) -> None: + console.print(_lineage_tree(summary)) + + +@app.command("fetch") +def fetch( + trace_type: TraceType = typer.Argument( + ..., help="One of: spans, assets, lineage", metavar="TYPE" + ), + workflow_id: str = typer.Argument(..., help="Workflow identifier"), + output: Path | None = typer.Option( + None, "--out", "-o", help="Write rows to this JSONL file (default: stdout)" + ), +) -> None: + """Fetch JSONL rows for a workflow's spans / assets / lineage.""" + client = flowmesh_client_from_config() + try: + rows = client.traces.fetch(workflow_id, trace_type) + except FlowMeshError as exc: + logging.error(str(exc)) + raise typer.Exit(code=1) + + if output is None: + for row in rows: + logging.log(json.dumps(row, ensure_ascii=False)) + return + + output.parent.mkdir(parents=True, exist_ok=True) + count = 0 + with output.open("w", encoding="utf-8") as fh: + for row in rows: + fh.write(json.dumps(row, ensure_ascii=False) + "\n") + count += 1 + logging.log(f"Wrote {count} {trace_type} rows to {output}") + + +@app.command("analyze") +def analyze( + workflow_id: str = typer.Argument(..., help="Workflow identifier"), + fmt: _AnalyzeView = typer.Option( + _AnalyzeView.RICH, + "--format", + "-f", + help=( + "Output view: rich, critical-path (cp), end-to-end (e2e), " + "queuing, lineage, json." + ), + case_sensitive=True, + ), +) -> None: + """Run the trace analyzer on a workflow and render the result.""" + fmt = _ANALYZE_VIEW_ALIAS.get(fmt, fmt) + + client = flowmesh_client_from_config() + try: + summary = client.traces.analyze(workflow_id) + except FlowMeshError as exc: + logging.error(str(exc)) + raise typer.Exit(code=1) + + if fmt is _AnalyzeView.JSON: + console.print(RichJSON.from_data(summary.model_dump(mode="json"))) + return + + _print_header(summary) + + match fmt: + case _AnalyzeView.RICH: + if summary.critical_path is not None: + _print_critical_path(summary) + console.print(Rule(style="dim")) + _print_e2e(summary) + if summary.per_data_id: + console.print(Rule(style="dim")) + _print_queuing(summary) + console.print(Rule(style="dim")) + _print_lineage(summary) + case _AnalyzeView.CRITICAL_PATH: + _print_critical_path(summary) + case _AnalyzeView.END_TO_END: + _print_e2e(summary) + case _AnalyzeView.QUEUING: + _print_queuing(summary) + case _AnalyzeView.LINEAGE: + _print_lineage(summary) diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 2aa6828..5812478 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -10,6 +10,7 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "httpx>=0.27.0", + "pandas>=2.3.3", "pydantic>=2.0.0", "pyyaml>=6.0.0", ] diff --git a/sdk/src/flowmesh/async_client.py b/sdk/src/flowmesh/async_client.py index c00aa6f..4c270d1 100644 --- a/sdk/src/flowmesh/async_client.py +++ b/sdk/src/flowmesh/async_client.py @@ -9,6 +9,7 @@ from .resources.ssh import AsyncSSH from .resources.system import AsyncSystem from .resources.tasks import AsyncTasks +from .resources.traces import AsyncTraces from .resources.workers import AsyncWorkers from .resources.workflows import AsyncWorkflows @@ -39,6 +40,7 @@ class AsyncFlowMesh(BaseAsyncClient): nodes: AsyncNodes ssh: AsyncSSH system: AsyncSystem + traces: AsyncTraces def __init__( self, @@ -61,3 +63,4 @@ def __init__( self.nodes = AsyncNodes(self) self.ssh = AsyncSSH(self) self.system = AsyncSystem(self) + self.traces = AsyncTraces(self) diff --git a/sdk/src/flowmesh/client.py b/sdk/src/flowmesh/client.py index 39e3f01..d871e1d 100644 --- a/sdk/src/flowmesh/client.py +++ b/sdk/src/flowmesh/client.py @@ -9,6 +9,7 @@ from .resources.ssh import SSH from .resources.system import System from .resources.tasks import Tasks +from .resources.traces import Traces from .resources.workers import Workers from .resources.workflows import Workflows @@ -49,6 +50,7 @@ class FlowMesh(BaseClient): nodes: Nodes ssh: SSH system: System + traces: Traces def __init__( self, @@ -71,3 +73,4 @@ def __init__( self.nodes = Nodes(self) self.ssh = SSH(self) self.system = System(self) + self.traces = Traces(self) diff --git a/sdk/src/flowmesh/models/__init__.py b/sdk/src/flowmesh/models/__init__.py index 09eccf9..0111ead 100644 --- a/sdk/src/flowmesh/models/__init__.py +++ b/sdk/src/flowmesh/models/__init__.py @@ -20,6 +20,16 @@ ) from .results import PathResponse from .tasks import HardwareUsage, TaskInfo, TaskUsage +from .traces import ( + ActiveWaitBreakdown, + AssetSummary, + CriticalPathSummary, + E2EBreakdown, + EventSummary, + LineageEdge, + ProfileSummary, + TaskTiming, +) from .workers import ( CPUInfo, GpuInfo, @@ -41,11 +51,17 @@ ) __all__ = [ + "ActiveWaitBreakdown", + "AssetSummary", "CPUInfo", + "CriticalPathSummary", + "E2EBreakdown", + "EventSummary", "GpuInfo", "GpuPlatformInfo", "HardwareUsage", "HostInfo", + "LineageEdge", "LogEntry", "LogEvent", "LogLevel", @@ -58,9 +74,11 @@ "Node", "NodeRegisterResponse", "NodeWorkerInfo", + "ProfileSummary", "StorageInfo", "TaskInfo", "TaskStatus", + "TaskTiming", "TaskType", "TaskUsage", "Worker", diff --git a/sdk/src/flowmesh/models/traces.py b/sdk/src/flowmesh/models/traces.py new file mode 100644 index 0000000..9fef596 --- /dev/null +++ b/sdk/src/flowmesh/models/traces.py @@ -0,0 +1,82 @@ +"""Trace analyzer response payload types as seen by the SDK. + +These describe the wire shape returned by +``GET /traces/workflows/analyze/{workflow_id}``. +""" + +from datetime import datetime + +from pydantic import BaseModel, ConfigDict + + +class _ProfileBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class AssetSummary(_ProfileBase): + asset_guid: str + latest_data_id: str + latest_version: int + user_id: str + versions: int + created_at: str | None = None + + +class LineageEdge(_ProfileBase): + data_id: str + source_data_id: str + created_at: str | None = None + + +class EventSummary(_ProfileBase): + """Per-event-type duration aggregates as parallel lists.""" + + event_type: list[str] + count: list[int] + total_seconds: list[float] + avg_seconds: list[float] + min_seconds: list[float] + max_seconds: list[float] + + +class E2EBreakdown(_ProfileBase): + hardware_summary: EventSummary + network_summary: EventSummary + workflow_duration_seconds: float + total_network_seconds: float + + +class ActiveWaitBreakdown(_ProfileBase): + data_id: list[str] + active_seconds: list[float] + wait_seconds: list[float] + + +class TaskTiming(_ProfileBase): + data_id: str + start_time: datetime + end_time: datetime + duration_seconds: float + queuing_delay_seconds: float + parent_data_ids: list[str] + blocking_parent_data_id: str | None = None + + +class CriticalPathSummary(_ProfileBase): + path: list[str] + critical_path_seconds: float + active_wait_breakdown: ActiveWaitBreakdown + hardware_summary: EventSummary + network_summary: EventSummary + total_network_seconds: float + + +class ProfileSummary(_ProfileBase): + workflow_id: str | None = None + event_count: int + data_ids: list[str] + assets: list[AssetSummary] + lineage: list[LineageEdge] + e2e_breakdown: E2EBreakdown + per_data_id: list[TaskTiming] + critical_path: CriticalPathSummary | None = None diff --git a/sdk/src/flowmesh/profile_views.py b/sdk/src/flowmesh/profile_views.py new file mode 100644 index 0000000..bdcfd9f --- /dev/null +++ b/sdk/src/flowmesh/profile_views.py @@ -0,0 +1,96 @@ +"""Stringification / dataframe helpers for `ProfileSummary`. + +The SDK leaves rendering to the caller, but exposes a few convenience +adapters so downstream Python tools (notebooks, lumilake, ad-hoc scripts) +don't have to reimplement them. +""" + +import re +from typing import Any + +import pandas as pd + +from .models.traces import EventSummary, ProfileSummary + +_MERMAID_SAFE = re.compile(r"[^A-Za-z0-9_]+") + + +def _mermaid_node_id(value: str) -> str: + cleaned = _MERMAID_SAFE.sub("_", value).strip("_") + return cleaned or "node" + + +def _event_summary_dataframe(summary: EventSummary) -> pd.DataFrame: + df = pd.DataFrame( + { + "event_type": summary.event_type, + "count": summary.count, + "total_seconds": summary.total_seconds, + "avg_seconds": summary.avg_seconds, + "min_seconds": summary.min_seconds, + "max_seconds": summary.max_seconds, + } + ) + return df.sort_values("total_seconds", ascending=False).reset_index(drop=True) + + +def hardware_dataframe( + summary: ProfileSummary, *, on_critical_path: bool = False +) -> pd.DataFrame: + """Compute-time breakdown as a DataFrame, restricted to the CP if requested.""" + hw = ( + summary.critical_path.hardware_summary + if on_critical_path and summary.critical_path is not None + else summary.e2e_breakdown.hardware_summary + ) + return _event_summary_dataframe(hw) + + +def network_dataframe( + summary: ProfileSummary, *, on_critical_path: bool = False +) -> pd.DataFrame: + """Network-active-time breakdown as a DataFrame; restrict to CP if requested.""" + net = ( + summary.critical_path.network_summary + if on_critical_path and summary.critical_path is not None + else summary.e2e_breakdown.network_summary + ) + return _event_summary_dataframe(net) + + +def critical_path_dataframe(summary: ProfileSummary) -> pd.DataFrame: + """Per-node active vs wait on the critical path.""" + if summary.critical_path is None: + return pd.DataFrame(columns=["data_id", "active_seconds", "wait_seconds"]) + awb = summary.critical_path.active_wait_breakdown + return pd.DataFrame( + { + "data_id": awb.data_id, + "active_seconds": awb.active_seconds, + "wait_seconds": awb.wait_seconds, + } + ) + + +def to_mermaid(summary: ProfileSummary | dict[str, Any]) -> str: + """Lineage DAG as Mermaid ``graph TD`` source.""" + if isinstance(summary, dict): + summary = ProfileSummary.model_validate(summary) + lines = ["graph TD"] + seen: set[str] = set() + for edge in summary.lineage: + src = _mermaid_node_id(edge.source_data_id) + dst = _mermaid_node_id(edge.data_id) + if src not in seen: + lines.append(f' {src}["{edge.source_data_id}"]') + seen.add(src) + if dst not in seen: + lines.append(f' {dst}["{edge.data_id}"]') + seen.add(dst) + lines.append(f" {src} --> {dst}") + for data_id in summary.data_ids: + node = _mermaid_node_id(data_id) + if node not in seen: + lines.append(f' {node}["{data_id}"]') + seen.add(node) + return "\n".join(lines) diff --git a/sdk/src/flowmesh/resources/traces.py b/sdk/src/flowmesh/resources/traces.py new file mode 100644 index 0000000..6eb1699 --- /dev/null +++ b/sdk/src/flowmesh/resources/traces.py @@ -0,0 +1,84 @@ +"""Workflow trace resource — fetch raw rows or run the analyzer.""" + +import json +from collections.abc import AsyncIterator, Iterator +from enum import StrEnum + +import httpx + +from .._base_client import ( + _make_url, + _raise_for_stream_status, + _raise_for_stream_status_async, +) +from ..exceptions import FlowMeshConnectionError +from ..models.traces import ProfileSummary +from ._base import AsyncResource, SyncResource + + +class TraceType(StrEnum): + """Trace row type. Members serialize as their values.""" + + SPANS = "spans" + ASSETS = "assets" + LINEAGE = "lineage" + + +class Traces(SyncResource): + """Synchronous workflow trace operations.""" + + def fetch(self, workflow_id: str, trace_type: TraceType) -> Iterator[dict]: + """Yield JSONL rows for `spans`, `assets`, or `lineage`.""" + url = _make_url( + self._client.base_url, f"/traces/workflows/{workflow_id}/{trace_type}" + ) + try: + with self._client._http.stream("GET", url) as response: + _raise_for_stream_status(response, "GET") + for line in response.iter_lines(): + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + except httpx.ConnectError as exc: + raise FlowMeshConnectionError(f"Failed to connect to {url}: {exc}") + + def analyze(self, workflow_id: str) -> ProfileSummary: + """Run the trace analyzer and return a parsed `ProfileSummary`.""" + return ProfileSummary.model_validate( + self._client._request("GET", f"/traces/workflows/analyze/{workflow_id}") + ) + + +class AsyncTraces(AsyncResource): + """Asynchronous workflow trace operations.""" + + async def fetch( + self, workflow_id: str, trace_type: TraceType + ) -> AsyncIterator[dict]: + url = _make_url( + self._client.base_url, f"/traces/workflows/{workflow_id}/{trace_type}" + ) + try: + async with self._client._http.stream("GET", url) as response: + await _raise_for_stream_status_async(response, "GET") + async for line in response.aiter_lines(): + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + except httpx.ConnectError as exc: + raise FlowMeshConnectionError(f"Failed to connect to {url}: {exc}") + + async def analyze(self, workflow_id: str) -> ProfileSummary: + return ProfileSummary.model_validate( + await self._client._request( + "GET", f"/traces/workflows/analyze/{workflow_id}" + ) + ) diff --git a/src/server/governance/__init__.py b/src/server/governance/__init__.py new file mode 100644 index 0000000..da78de6 --- /dev/null +++ b/src/server/governance/__init__.py @@ -0,0 +1,28 @@ +from .analyzer import ( + ActiveWaitBreakdown, + AssetSummary, + CriticalPathSummary, + E2EBreakdown, + EventSummary, + LineageEdge, + ProfileSummary, + TaskTiming, + analyze, +) +from .spans import Span, SpanAttributes, SpanContext, SpanStatus + +__all__ = [ + "ActiveWaitBreakdown", + "AssetSummary", + "CriticalPathSummary", + "E2EBreakdown", + "EventSummary", + "LineageEdge", + "ProfileSummary", + "Span", + "SpanAttributes", + "SpanContext", + "SpanStatus", + "TaskTiming", + "analyze", +] diff --git a/src/server/governance/analyzer.py b/src/server/governance/analyzer.py new file mode 100644 index 0000000..5c7f6eb --- /dev/null +++ b/src/server/governance/analyzer.py @@ -0,0 +1,497 @@ +"""Trace analyzer over (spans, assets, lineage) JSONL rows. + +Per-data_id timing comes from the ``"task"`` root span; ``"dump to storage"`` +end_time is the data-ready timestamp; ``queuing_delay`` = +task.start - max(parent.dump_to_storage.end). +""" + +from collections import defaultdict +from collections.abc import Iterable +from datetime import datetime, timedelta +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from shared.schemas.governance import ( + READY_SPAN_NAME, + TASK_SPAN_NAME, + SpanType, +) + +from .spans import Span + + +class _ProfileBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class AssetSummary(_ProfileBase): + asset_guid: str + latest_data_id: str + latest_version: int + user_id: str + versions: int + created_at: str | None = None + + +class LineageEdge(_ProfileBase): + data_id: str + source_data_id: str + created_at: str | None = None + + +class EventSummary(_ProfileBase): + """Per-event-type duration aggregates as parallel lists. + + For ``compute`` spans, ``total_seconds[i]`` is the per-batch sum (so + parallel spans within a batch collapse). For ``network`` spans it is the + merged-interval wall-clock time (so overlapping reads/writes collapse). + """ + + event_type: list[str] + count: list[int] + total_seconds: list[float] + avg_seconds: list[float] + min_seconds: list[float] + max_seconds: list[float] + + +class E2EBreakdown(_ProfileBase): + hardware_summary: EventSummary + network_summary: EventSummary + workflow_duration_seconds: float + total_network_seconds: float + + +class ActiveWaitBreakdown(_ProfileBase): + data_id: list[str] + active_seconds: list[float] + wait_seconds: list[float] + + +class TaskTiming(_ProfileBase): + data_id: str + start_time: datetime + end_time: datetime + duration_seconds: float + queuing_delay_seconds: float + parent_data_ids: list[str] + blocking_parent_data_id: str | None = None + + +class CriticalPathSummary(_ProfileBase): + path: list[str] + critical_path_seconds: float + active_wait_breakdown: ActiveWaitBreakdown + hardware_summary: EventSummary + network_summary: EventSummary + total_network_seconds: float + + +class ProfileSummary(_ProfileBase): + workflow_id: str | None = None + event_count: int + data_ids: list[str] + assets: list[AssetSummary] + lineage: list[LineageEdge] + e2e_breakdown: E2EBreakdown + per_data_id: list[TaskTiming] + critical_path: CriticalPathSummary | None = None + + +def analyze( + spans: Iterable[dict[str, Any]], + assets: Iterable[dict[str, Any]], + lineage: Iterable[dict[str, Any]], + workflow_id: str | None = None, +) -> ProfileSummary: + """Build a :class:`ProfileSummary` from raw JSONL rows for a single workflow. + + ``spans`` rows are parsed via :class:`Span`; malformed entries are dropped. + ``assets`` and ``lineage`` rows are passed straight through as dicts. The + returned summary contains the asset rollup, full DAG edges, an end-to-end + breakdown, per-data_id timings (with queuing delay + blocking parent), + and a critical-path subset. + """ + parsed: list[Span] = [] + for raw in spans: + if not isinstance(raw, dict): + continue + try: + parsed.append(Span.parse_otel_json(raw)) + except (ValueError, TypeError): + continue + + asset_rows = [a for a in assets if isinstance(a, dict)] + lineage_rows = [le for le in lineage if isinstance(le, dict)] + + asset_summaries = _asset_summaries(asset_rows) + data_ids = sorted({did for s in parsed if (did := s.attributes.data_id)}) + dep_map = _dep_map(lineage_rows) + lineage_edges = [ + LineageEdge( + data_id=str(le.get("data_id") or ""), + source_data_id=str(le.get("source_data_id") or ""), + created_at=str(le.get("created_at") or "") or None, + ) + for le in lineage_rows + if le.get("data_id") and le.get("source_data_id") + ] + + grouped = _group_spans(parsed) + e2e_breakdown = _obtain_breakdown(grouped) + per_data_id = _per_data_id_timings(grouped, dep_map, data_ids) + critical_path = ( + _compute_critical_path(grouped, dep_map, per_data_id) if data_ids else None + ) + + return ProfileSummary( + workflow_id=workflow_id, + event_count=len(parsed), + data_ids=data_ids, + assets=asset_summaries, + lineage=lineage_edges, + e2e_breakdown=E2EBreakdown.model_validate(e2e_breakdown), + per_data_id=per_data_id, + critical_path=( + CriticalPathSummary.model_validate(critical_path) + if critical_path is not None + else None + ), + ) + + +def _asset_summaries(rows: list[dict[str, Any]]) -> list[AssetSummary]: + """Group asset rows by ``asset_guid`` and emit one summary per asset. + + Each summary points at the highest-version row (``latest_*``) and reports + the total version count. + """ + asset_versions: dict[str, list[dict[str, Any]]] = defaultdict(list) + for row in rows: + guid = str(row.get("asset_guid") or "") + if not guid: + continue + asset_versions[guid].append(row) + summaries: list[AssetSummary] = [] + for guid, items in asset_versions.items(): + items_sorted = sorted(items, key=lambda r: int(r.get("version") or 0)) + latest = items_sorted[-1] + summaries.append( + AssetSummary( + asset_guid=guid, + latest_data_id=str(latest.get("data_id") or ""), + latest_version=int(latest.get("version") or 0), + user_id=str(latest.get("user_id") or ""), + versions=len(items_sorted), + created_at=str(latest.get("created_at") or "") or None, + ) + ) + return summaries + + +def _dep_map(lineage_rows: list[dict[str, Any]]) -> dict[str, list[str]]: + """Build a ``data_id -> [source_data_id, ...]`` adjacency from lineage edges.""" + dep_map: dict[str, list[str]] = defaultdict(list) + for row in lineage_rows: + target = str(row.get("data_id") or "") + source = str(row.get("source_data_id") or "") + if target and source: + dep_map[target].append(source) + return dict(dep_map) + + +def _group_spans(spans: list[Span]) -> dict[str, list[Span]]: + """Bucket spans by ``attributes.data_id`` and sort each bucket by start time. + + Spans without a ``data_id`` (e.g. third-party telemetry) are dropped here. + """ + grouped: dict[str, list[Span]] = defaultdict(list) + for span in spans: + if data_id := span.attributes.data_id: + grouped[data_id].append(span) + for data_id, items in grouped.items(): + items.sort(key=lambda s: s.start_time) + return dict(grouped) + + +def _ready_finish(spans: list[Span]) -> datetime | None: + """Latest ``"dump to storage"`` span ``end_time`` — the data-ready boundary + for a task. ``None`` if no such span exists (e.g. failed task).""" + ready_times = [s.end_time for s in spans if s.name == READY_SPAN_NAME] + return max(ready_times) if ready_times else None + + +def _task_span(spans: list[Span]) -> Span | None: + """Pick the root ``"task"`` span, preferring one without a parent.""" + for span in spans: + if span.name == TASK_SPAN_NAME and span.parent_id is None: + return span + for span in spans: + if span.name == TASK_SPAN_NAME: + return span + return None + + +def _per_data_id_timings( + grouped: dict[str, list[Span]], + dep_map: dict[str, list[str]], + data_ids: list[str], +) -> list[TaskTiming]: + """Per-data_id start/end timestamps + queuing delay against parents. + + ``queuing_delay = task.start - max(parent.dump_to_storage.end)``. Falls + back to ``min(span.start) / max(span.end)`` when a data_id has no root + ``"task"`` span (e.g. merged children that only emit per-task markers). + """ + finish_ts: dict[str, datetime] = {} + for data_id, spans in grouped.items(): + ready = _ready_finish(spans) + if ready is not None: + finish_ts[data_id] = ready + + timings: list[TaskTiming] = [] + for data_id in data_ids: + spans = grouped.get(data_id) or [] + if not spans: + continue + task = _task_span(spans) + if task is not None: + start = task.start_time + end = task.end_time + else: + start = min(s.start_time for s in spans) + end = max(s.end_time for s in spans) + + parents = dep_map.get(data_id) or [] + eligible = [(p, finish_ts[p]) for p in parents if p in finish_ts] + if eligible: + blocking_parent, blocking_finish = max(eligible, key=lambda x: x[1]) + wait = max((start - blocking_finish).total_seconds(), 0.0) + else: + blocking_parent = None + wait = 0.0 + + timings.append( + TaskTiming( + data_id=data_id, + start_time=start, + end_time=end, + duration_seconds=(end - start).total_seconds(), + queuing_delay_seconds=wait, + parent_data_ids=parents.copy(), + blocking_parent_data_id=blocking_parent, + ) + ) + return timings + + +def _merge_intervals( + intervals: list[tuple[datetime, datetime]], +) -> list[tuple[datetime, datetime]]: + """Collapse overlapping ``(start, end)`` intervals so concurrent network + spans count once toward total active time.""" + if not intervals: + return [] + sorted_ivl = sorted(intervals, key=lambda x: x[0]) + merged: list[tuple[datetime, datetime]] = [] + cur_start, cur_end = sorted_ivl[0] + for start, end in sorted_ivl[1:]: + if start <= cur_end: + cur_end = max(cur_end, end) + else: + merged.append((cur_start, cur_end)) + cur_start, cur_end = start, end + merged.append((cur_start, cur_end)) + return merged + + +def _avg_min_max(values: list[float]) -> tuple[float, float, float]: + """``(avg, min, max)`` over ``values``; zeros when empty.""" + if not values: + return 0.0, 0.0, 0.0 + return sum(values) / len(values), min(values), max(values) + + +def _obtain_breakdown( + grouped: dict[str, list[Span]], +) -> dict[str, Any]: + """Aggregate per-event-type compute / network / wall stats over a span set. + + Compute totals are summed per ``batch_id`` then across batches (parallel + spans within a batch collapse). Network totals use merged-interval + wall-clock so concurrent reads/writes count once. The ``"task"`` root + span and ``MARKER`` kind spans are excluded from event-type aggregates; + ``"task"`` start/end bound ``workflow_duration_seconds``. + """ + by_type: dict[str, list[float]] = defaultdict(list) + by_type_batch: dict[str, dict[str, timedelta]] = defaultdict( + lambda: defaultdict(timedelta) + ) + network_intervals: list[tuple[datetime, datetime]] = [] + network_intervals_by_type: dict[str, list[tuple[datetime, datetime]]] = defaultdict( + list + ) + network_active_seconds: dict[str, list[float]] = defaultdict(list) + all_starts: list[datetime] = [] + all_ends: list[datetime] = [] + + for spans in grouped.values(): + for span in spans: + span_type = span.attributes.flowmesh_type + if span_type == SpanType.MARKER: + continue + if span.name == TASK_SPAN_NAME: + all_starts.append(span.start_time) + all_ends.append(span.end_time) + continue + + duration = span.duration_seconds + + if span_type == SpanType.NETWORK: + interval = (span.start_time, span.end_time) + network_intervals.append(interval) + network_intervals_by_type[span.name].append(interval) + network_active_seconds[span.name].append(duration) + elif span_type == SpanType.COMPUTE: + by_type[span.name].append(duration) + if batch_id := span.attributes.batch_id: + by_type_batch[span.name][batch_id] += span.end_time - ( + span.start_time + ) + + if all_starts and all_ends: + workflow_duration = max(all_ends) - min(all_starts) + else: + workflow_duration = timedelta(0) + total_network = sum( + (end - start for start, end in _merge_intervals(network_intervals)), + timedelta(0), + ) + + hw_event_types = list(by_type.keys()) + hardware_summary = { + "event_type": hw_event_types, + "count": [len(by_type[t]) for t in hw_event_types], + "total_seconds": [ + sum((d for d in by_type_batch[t].values()), timedelta(0)).total_seconds() + for t in hw_event_types + ], + "avg_seconds": [_avg_min_max(by_type[t])[0] for t in hw_event_types], + "min_seconds": [_avg_min_max(by_type[t])[1] for t in hw_event_types], + "max_seconds": [_avg_min_max(by_type[t])[2] for t in hw_event_types], + } + + net_event_types = list(network_intervals_by_type.keys()) + network_summary = { + "event_type": net_event_types, + "count": [len(network_active_seconds[t]) for t in net_event_types], + "total_seconds": [ + sum( + ( + end - start + for start, end in _merge_intervals(network_intervals_by_type[t]) + ), + timedelta(0), + ).total_seconds() + for t in net_event_types + ], + "avg_seconds": [ + _avg_min_max(network_active_seconds[t])[0] for t in net_event_types + ], + "min_seconds": [ + _avg_min_max(network_active_seconds[t])[1] for t in net_event_types + ], + "max_seconds": [ + _avg_min_max(network_active_seconds[t])[2] for t in net_event_types + ], + } + + return { + "hardware_summary": hardware_summary, + "network_summary": network_summary, + "workflow_duration_seconds": workflow_duration.total_seconds(), + "total_network_seconds": total_network.total_seconds(), + } + + +def _compute_critical_path( + grouped: dict[str, list[Span]], + dep_map: dict[str, list[str]], + per_data_id: list[TaskTiming], +) -> dict[str, Any] | None: + """Walk back from the latest-finishing data_id, picking the slowest parent + at each hop, to surface the bottleneck chain. + + ``critical_path_seconds`` is the sum of active + wait along the chain. + The CP-restricted breakdown also pulls in spans from any merge-parent + ``batch_id`` referenced by CP nodes — otherwise shared work (model load, + generation) emitted under the merge parent's data_id would be missed + when a merged-child branch is on the path. + """ + by_id: dict[str, TaskTiming] = {t.data_id: t for t in per_data_id} + finish_times: dict[str, datetime] = {} + for data_id, spans in grouped.items(): + ready = _ready_finish(spans) + if ready is not None: + finish_times[data_id] = ready + elif data_id in by_id: + finish_times[data_id] = by_id[data_id].end_time + + if not finish_times: + return None + + sink = max(finish_times, key=lambda k: finish_times[k]) + path_rev: list[str] = [sink] + cursor = sink + while parents := dep_map.get(cursor): + eligible = [(p, finish_times[p]) for p in parents if p in finish_times] + if not eligible: + break + latest_parent = max(eligible, key=lambda x: x[1])[0] + path_rev.append(latest_parent) + cursor = latest_parent + critical_path = list(reversed(path_rev)) + + actives: list[float] = [] + waits: list[float] = [] + cp_duration = timedelta(0) + for nid in critical_path: + timing = by_id.get(nid) + if timing is None: + actives.append(0.0) + waits.append(0.0) + continue + active = timing.duration_seconds + wait = timing.queuing_delay_seconds + cp_duration += timedelta(seconds=active + wait) + actives.append(active) + waits.append(wait) + + # Expand CP membership through merged-execution batch_ids so the breakdown + # captures shared work (model load, generation) emitted under the merge + # parent's data_id when a merged-child branch lands on the path. + cp_data_ids: set[str] = set(critical_path) + for nid in critical_path: + for span in grouped.get(nid, []): + batch_id = span.attributes.batch_id + if batch_id and batch_id != nid and batch_id in grouped: + cp_data_ids.add(batch_id) + + cp_breakdown = _obtain_breakdown( + {nid: grouped[nid] for nid in cp_data_ids if nid in grouped} + ) + cp_breakdown.pop("workflow_duration_seconds", None) + + return { + "path": critical_path, + "critical_path_seconds": cp_duration.total_seconds(), + "active_wait_breakdown": { + "data_id": critical_path, + "active_seconds": actives, + "wait_seconds": waits, + }, + "hardware_summary": cp_breakdown["hardware_summary"], + "network_summary": cp_breakdown["network_summary"], + "total_network_seconds": cp_breakdown["total_network_seconds"], + } diff --git a/src/server/governance/spans.py b/src/server/governance/spans.py new file mode 100644 index 0000000..efcf50d --- /dev/null +++ b/src/server/governance/spans.py @@ -0,0 +1,74 @@ +"""Server-side parser for OTel-shape span rows in ``spans.jsonl``. + +Producers (workers) emit one ``ReadableSpan.to_json()`` row per line; the +analyzer reads them as :class:`Span` instances. The shared wire-contract enum +:class:`shared.schemas.governance.SpanType` lives in +``src/shared/schemas/governance.py`` because workers also need it. +""" + +from datetime import datetime +from typing import Annotated, Any + +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field + +from shared.schemas.governance import SpanType + + +def _strip_hex_prefix(value: Any) -> Any: + """Trim the leading ``0x`` from OTel hex ids; pass non-strings through.""" + if isinstance(value, str) and value.startswith("0x"): + return value[2:] + return value + + +HexId = Annotated[str, BeforeValidator(_strip_hex_prefix)] +OptionalHexId = Annotated[str | None, BeforeValidator(_strip_hex_prefix)] + + +class SpanContext(BaseModel): + """The ``context`` sub-object of ``ReadableSpan.to_json()``.""" + + model_config = ConfigDict(extra="allow") + trace_id: HexId + span_id: HexId + + +class SpanStatus(BaseModel): + """The ``status`` sub-object of ``ReadableSpan.to_json()``.""" + + model_config = ConfigDict(extra="allow") + status_code: str = "UNSET" + description: str | None = None + + +class SpanAttributes(BaseModel): + """FlowMesh-required span attributes; arbitrary extras are preserved.""" + + model_config = ConfigDict(extra="allow", populate_by_name=True) + data_id: str | None = None + batch_id: str | None = None + flowmesh_type: SpanType | None = Field(default=None, alias="flowmesh.type") + + +class Span(BaseModel): + """Parsed OTel JSON span row; ids stripped of ``0x``, times as ``datetime``.""" + + model_config = ConfigDict(extra="allow") + + name: str + context: SpanContext + parent_id: OptionalHexId = None + start_time: datetime + end_time: datetime + status: SpanStatus = Field(default_factory=SpanStatus) + attributes: SpanAttributes = Field(default_factory=SpanAttributes) + + @property + def duration_seconds(self) -> float: + return (self.end_time - self.start_time).total_seconds() + + @classmethod + def parse_otel_json(cls, raw: str | dict[str, Any]) -> "Span": + if isinstance(raw, str): + return cls.model_validate_json(raw) + return cls.model_validate(raw) diff --git a/src/server/main.py b/src/server/main.py index 1b5a306..6c0f4d3 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -372,6 +372,7 @@ async def _lifespan(_: FastAPI): app.include_router(v1.results.router, prefix=v1_prefix) app.include_router(v1.ssh.router, prefix=v1_prefix) app.include_router(v1.system.router, prefix=v1_prefix) + app.include_router(v1.traces.router, prefix=v1_prefix) # Routers — supervisor (any node with worker management) if config.worker_management.enabled: diff --git a/src/server/routers/v1/__init__.py b/src/server/routers/v1/__init__.py index f36bde0..e597051 100644 --- a/src/server/routers/v1/__init__.py +++ b/src/server/routers/v1/__init__.py @@ -5,6 +5,7 @@ stack, system, tasks, + traces, workers, workflows, ) @@ -16,6 +17,7 @@ "stack", "system", "tasks", + "traces", "workers", "workflows", ] diff --git a/src/server/routers/v1/traces.py b/src/server/routers/v1/traces.py new file mode 100644 index 0000000..db5330d --- /dev/null +++ b/src/server/routers/v1/traces.py @@ -0,0 +1,115 @@ +"""Trace endpoints — per-task upload, workflow-level read + analyzer.""" + +from collections.abc import Iterable, Iterator +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi.responses import StreamingResponse + +from shared.utils.json import encode_jsonl_bytes, read_jsonl + +from ...app_state import get_results_dir, get_workflow_registry +from ...governance import ProfileSummary, analyze +from ...registries.workflow import WorkflowRegistry +from ...schemas.common import PathResponse +from ...schemas.result import result_file_path + +router = APIRouter(prefix="/traces", tags=["Traces"]) + +_TYPE_TO_FILENAME: dict[str, str] = { + "spans": "spans.jsonl", + "assets": "assets.jsonl", + "lineage": "lineage.jsonl", +} + + +def _logs_dir_for_task(results_dir: Path, task_id: str) -> Path: + """Per-task ``logs/`` directory holding the trace JSONL artifacts.""" + return result_file_path(results_dir, task_id).parent / "logs" + + +def _iter_workflow_jsonl( + results_dir: Path, task_ids: Iterable[str], filename: str +) -> Iterator[dict[str, Any]]: + for task_id in task_ids: + yield from read_jsonl(_logs_dir_for_task(results_dir, task_id) / filename) + + +async def _resolve_task_ids(workflow_id: str, registry: WorkflowRegistry) -> list[str]: + workflow = await registry.get_workflow_async(workflow_id) + if not workflow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workflow '{workflow_id}' not found", + ) + return workflow.task_ids + + +@router.get( + "/workflows/analyze/{workflow_id}", + summary="Run the trace analyzer; return ProfileSummary", + response_model=ProfileSummary, +) +async def analyze_workflow_trace( + workflow_id: str, + registry: WorkflowRegistry = Depends(get_workflow_registry), + results_dir: Path = Depends(get_results_dir), +) -> ProfileSummary: + task_ids = await _resolve_task_ids(workflow_id, registry) + spans = list(_iter_workflow_jsonl(results_dir, task_ids, "spans.jsonl")) + assets = list(_iter_workflow_jsonl(results_dir, task_ids, "assets.jsonl")) + lineage = list(_iter_workflow_jsonl(results_dir, task_ids, "lineage.jsonl")) + return analyze(spans, assets, lineage, workflow_id=workflow_id) + + +@router.get( + "/workflows/{workflow_id}/{trace_type}", + summary="Stream JSONL rows (spans / assets / lineage)", +) +async def get_workflow_trace( + workflow_id: str, + trace_type: str, + registry: WorkflowRegistry = Depends(get_workflow_registry), + results_dir: Path = Depends(get_results_dir), +) -> StreamingResponse: + filename = _TYPE_TO_FILENAME.get(trace_type) + if filename is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"unknown type '{trace_type}'; expected spans, assets, or lineage", + ) + task_ids = await _resolve_task_ids(workflow_id, registry) + return StreamingResponse( + encode_jsonl_bytes(_iter_workflow_jsonl(results_dir, task_ids, filename)), + media_type="application/x-ndjson", + ) + + +@router.post( + "/tasks/{task_id}/{trace_type}", + summary="Upload a per-task trace JSONL file (spans / assets / lineage)", +) +async def upload_task_trace( + task_id: str, + trace_type: str, + file: UploadFile = File(...), + results_dir: Path = Depends(get_results_dir), +) -> PathResponse: + filename = _TYPE_TO_FILENAME.get(trace_type) + if filename is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"unknown type '{trace_type}'; expected spans, assets, or lineage", + ) + target_path = _logs_dir_for_task(results_dir, task_id) / filename + target_path.parent.mkdir(parents=True, exist_ok=True) + try: + with target_path.open("wb") as out: + out.write(await file.read()) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to store trace: {exc}", + ) from exc + return PathResponse(ok=True, path=target_path.as_posix()) diff --git a/src/server/utils/time.py b/src/server/utils/time.py index a8ffc29..8f530f8 100644 --- a/src/server/utils/time.py +++ b/src/server/utils/time.py @@ -1,19 +1,15 @@ -import datetime import time -from shared.utils.time import now_iso +from shared.utils.time import now_iso, parse_iso_datetime def parse_iso_ts(value: str | None) -> float: - if not value: - return time.time() + """ISO 8601 → Unix timestamp; ``time.time()`` on missing / malformed.""" try: - v = value - if v.endswith("Z"): - v = v[:-1] + "+00:00" - return datetime.datetime.fromisoformat(v).timestamp() - except Exception: + dt = parse_iso_datetime(value) + except ValueError: return time.time() + return dt.timestamp() if dt else time.time() __all__ = ["now_iso", "parse_iso_ts"] diff --git a/src/shared/schemas/governance.py b/src/shared/schemas/governance.py new file mode 100644 index 0000000..7a4d853 --- /dev/null +++ b/src/shared/schemas/governance.py @@ -0,0 +1,24 @@ +"""Wire-contract names + type enum for the OTel-shape span rows in ``spans.jsonl``. + +Both the worker (producer, ``src/worker/executors/mixins/governance.py``) and +the server analyzer (consumer, ``src/server/governance/analyzer.py``) read +these. Pydantic parsing models for span rows live server-side at +``src/server/governance/spans.py``. +""" + +from enum import StrEnum + + +class SpanType(StrEnum): + """Producer-side type in ``attributes["flowmesh.type"]``.""" + + COMPUTE = "compute" + NETWORK = "network" + MARKER = "marker" + + +# Wire-contract span names. The worker emits these (`_task_span` opens a +# ``"task"`` span; ``_record_output`` opens a ``"dump to storage"`` span); +# the analyzer reads them to identify the root span and the data-ready boundary. +TASK_SPAN_NAME = "task" +READY_SPAN_NAME = "dump to storage" diff --git a/src/shared/tasks/specs/common.py b/src/shared/tasks/specs/common.py index 825b710..d6cb61d 100644 --- a/src/shared/tasks/specs/common.py +++ b/src/shared/tasks/specs/common.py @@ -69,7 +69,6 @@ class TaskSpecStrictBase(StrictBaseModel): output: OutputSpec | None = None dependsOn: list[str] | None = None condition: ConditionSpec | None = None - governance: dict[str, Any] | None = None shard: ShardSpec | None = None # Server-injected stage context (reserve the user-facing key `_upstreamResults`) @@ -97,7 +96,6 @@ class TaskSpecTemplateBase(TemplateBaseModel): output: OutputSpecTemplate | None = None dependsOn: list[str] | None = None condition: ConditionSpec | None = None - governance: dict[str, Any] | None = None shard: ShardSpecTemplate | None = None upstreamResults: dict[str, Any] | None = Field( diff --git a/src/shared/utils/json.py b/src/shared/utils/json.py index 799a6f7..56cb357 100644 --- a/src/shared/utils/json.py +++ b/src/shared/utils/json.py @@ -1,5 +1,8 @@ import datetime +import json import uuid +from collections.abc import AsyncIterator, Iterable, Iterator +from pathlib import Path from typing import Any @@ -116,3 +119,43 @@ def normalize_numbers(value: Any) -> Any: if isinstance(value, float) and value.is_integer(): return int(value) return value + + +def parse_jsonl_lines(lines: Iterable[str]) -> Iterator[dict[str, Any]]: + """Yield decoded dict rows from text lines; skip empty / malformed.""" + for line in lines: + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + + +async def aparse_jsonl_lines( + lines: AsyncIterator[str], +) -> AsyncIterator[dict[str, Any]]: + """Async variant of :func:`parse_jsonl_lines` for streamed responses.""" + async for line in lines: + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + + +def read_jsonl(path: Path) -> Iterator[dict[str, Any]]: + """Read a JSONL file and yield decoded dict rows. Missing file → empty.""" + if not path.exists() or not path.is_file(): + return + with path.open(encoding="utf-8") as fh: + yield from parse_jsonl_lines(fh) + + +def encode_jsonl_bytes(rows: Iterable[dict[str, Any]]) -> Iterator[bytes]: + """Encode dict rows as JSONL bytes: one row per line, trailing newline.""" + for row in rows: + yield (json.dumps(row, ensure_ascii=False) + "\n").encode("utf-8") diff --git a/src/shared/utils/time.py b/src/shared/utils/time.py index 430de62..37447d6 100644 --- a/src/shared/utils/time.py +++ b/src/shared/utils/time.py @@ -3,3 +3,10 @@ def now_iso() -> str: return datetime.datetime.now(datetime.UTC).isoformat() + + +def parse_iso_datetime(value: str | None) -> datetime.datetime | None: + """Parse ISO 8601 → ``datetime``; ``None`` if empty, raises on malformed.""" + if not value: + return None + return datetime.datetime.fromisoformat(value.replace("Z", "+00:00")) diff --git a/src/worker/executors/data_retrieval_executor.py b/src/worker/executors/data_retrieval_executor.py index 62a7296..6fed564 100644 --- a/src/worker/executors/data_retrieval_executor.py +++ b/src/worker/executors/data_retrieval_executor.py @@ -16,7 +16,11 @@ from ..utils.serialization import serialize_dataframe from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import DataMixin -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import ( + artifact_ref, + maybe_upload_artifacts, + maybe_upload_traces, +) from .utils.graph_templates import _render_template, _resolve_columns logger = logging.getLogger(__name__) @@ -28,37 +32,38 @@ class DataRetrievalExecutor(DataMixin, Executor): def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec = self.require_spec(task, DataRetrievalSpecStrict) task_id = task.task_id - data_cfg = spec.data - if not isinstance(data_cfg, dict): - raise ExecutionError("spec.data must be a mapping for data_retrieval.") - retrieval_type = data_cfg.get("type") - if retrieval_type not in {"sql", "s3"}: - raise ExecutionError( - "spec.data.type must be either 'sql' or 's3' for data_retrieval." - ) - context = spec.upstreamResults or {} - - if retrieval_type == "sql": - result = self._run_sql(data_cfg, context) - elif retrieval_type == "s3": - result = self._run_s3(data_cfg, context, out_dir) - else: - raise ExecutionError( - f"Unsupported data_retrieval type: {retrieval_type!r}." - ) - - maybe_upload_artifacts(task, out_dir, logger=logger) + with self._task_span( + task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + data_cfg = spec.data + if not isinstance(data_cfg, dict): + raise ExecutionError("spec.data must be a mapping for data_retrieval.") + retrieval_type = data_cfg.get("type") + if retrieval_type not in {"sql", "s3"}: + raise ExecutionError( + "spec.data.type must be either 'sql' or 's3' for data_retrieval." + ) + context = spec.upstreamResults or {} + + if retrieval_type == "sql": + result = self._run_sql(data_cfg, context) + elif retrieval_type == "s3": + result = self._run_s3(data_cfg, context, out_dir) + else: + raise ExecutionError( + f"Unsupported data_retrieval type: {retrieval_type!r}." + ) - if governance_spec := spec.governance: deps = self._extract_source_data_ids(spec) dependencies_by_task = {task_id: deps} self._dump_to_governance( - governance_spec=governance_spec, task_id=task_id, result=result, dependencies_by_task=dependencies_by_task, ) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) return result def _run_sql( diff --git a/src/worker/executors/diffusers_executor.py b/src/worker/executors/diffusers_executor.py index fa2b070..1216acb 100644 --- a/src/worker/executors/diffusers_executor.py +++ b/src/worker/executors/diffusers_executor.py @@ -21,7 +21,11 @@ from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import DataMixin -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import ( + artifact_ref, + maybe_upload_artifacts, + maybe_upload_traces, +) try: import torch @@ -244,6 +248,20 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: configure_hf_library_logging() spec = self.require_spec(task, DiffusionSpecStrict) task_id = task.task_id.strip() + with self._task_span( + task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + response = self._run_inner(spec, task_id, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return response + + def _run_inner( + self, + spec: DiffusionSpecStrict, + task_id: str, + out_dir: Path, + ) -> dict[str, Any]: self._ensure_pipeline(spec) assert self._pipe is not None @@ -333,15 +351,11 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: "images": generated_images, } - maybe_upload_artifacts(task, out_dir, logger=logger) - - if governance_spec := spec.governance: - self._dump_to_governance( - governance_spec=governance_spec, - task_id=task_id, - result=response, - dependencies_by_task=dependencies_by_task, - ) + self._dump_to_governance( + task_id=task_id, + result=response, + dependencies_by_task=dependencies_by_task, + ) return response diff --git a/src/worker/executors/echo_executor.py b/src/worker/executors/echo_executor.py index 1126fc1..480938a 100644 --- a/src/worker/executors/echo_executor.py +++ b/src/worker/executors/echo_executor.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Any @@ -5,8 +6,11 @@ from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import DataMixin +from .utils.checkpoints import maybe_upload_traces from .utils.graph_templates import _evaluate_expr +logger = logging.getLogger(__name__) + type EchoItem = str | dict[str, str] @@ -54,37 +58,41 @@ def _resolve_item(self, item: EchoItem, context: dict[str, Any]) -> Any: def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec = self.require_spec(task, EchoSpecStrict) task_id = task.task_id.strip() - data_cfg = spec.data - context = spec.upstreamResults or {} + with self._task_span( + task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + data_cfg = spec.data + context = spec.upstreamResults or {} - if not isinstance(data_cfg, dict): - raise ExecutionError("echo executor requires spec.data to be a mapping") - items_cfg = data_cfg.get("items") - if not isinstance(items_cfg, list): - raise ExecutionError("echo executor requires spec.data.items to be a list") - if not isinstance(context, dict): - raise ExecutionError( - "echo executor requires spec._upstreamResults to be a mapping" - ) + if not isinstance(data_cfg, dict): + raise ExecutionError("echo executor requires spec.data to be a mapping") + items_cfg = data_cfg.get("items") + if not isinstance(items_cfg, list): + raise ExecutionError( + "echo executor requires spec.data.items to be a list" + ) + if not isinstance(context, dict): + raise ExecutionError( + "echo executor requires spec._upstreamResults to be a mapping" + ) - merged_items: list[dict[str, Any]] = [] - for item in items_cfg: - resolved = self._resolve_item(item, context) - self._append_outputs(merged_items, resolved) + merged_items: list[dict[str, Any]] = [] + for item in items_cfg: + resolved = self._resolve_item(item, context) + self._append_outputs(merged_items, resolved) - payload: dict[str, Any] = { - "ok": True, - "items": merged_items, - "count": len(merged_items), - } - deps = self._extract_source_data_ids(spec) - dependencies_by_task = {task_id: deps} + payload: dict[str, Any] = { + "ok": True, + "items": merged_items, + "count": len(merged_items), + } + deps = self._extract_source_data_ids(spec) + dependencies_by_task = {task_id: deps} - if governance_spec := spec.governance: self._dump_to_governance( - governance_spec=governance_spec, task_id=task_id, result=payload, dependencies_by_task=dependencies_by_task, ) + maybe_upload_traces(task, out_dir, logger=logger) return payload diff --git a/src/worker/executors/mixins/_otel.py b/src/worker/executors/mixins/_otel.py new file mode 100644 index 0000000..b3ddb0d --- /dev/null +++ b/src/worker/executors/mixins/_otel.py @@ -0,0 +1,164 @@ +"""OpenTelemetry tracing wiring for worker executors. + +Sets up a single process-wide ``TracerProvider`` with a JSONL exporter that +appends ``ReadableSpan.to_json()`` to ``/logs/spans.jsonl`` +for whichever task is currently executing. The current path is held in a +module-level slot updated by ``_task_span`` on enter / exit; the worker is +single-threaded for executor work so there's no contention. + +The ``trace_id`` is pinned to the workflow id via a custom ``IdGenerator`` +that reads the active workflow id from a ``ContextVar`` populated by +``_task_span``. Sub-spans inherit the trace id from the OTel parent context +automatically. +""" + +import re +import threading +from collections.abc import Callable, Iterator, Sequence +from contextlib import contextmanager +from contextvars import ContextVar +from pathlib import Path +from typing import Any + +from opentelemetry import trace +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, + SpanExporter, + SpanExportResult, +) +from opentelemetry.sdk.trace.id_generator import IdGenerator, RandomIdGenerator + +from shared.schemas.governance import SpanType +from shared.utils.ids import PREFIX_WORKFLOW + +_HEX_ONLY = re.compile(r"[^0-9a-f]") +_TRACER_NAME = "flowmesh.worker" +_SERVICE_NAME = "flowmesh-worker" + +_workflow_id_var: ContextVar[str | None] = ContextVar( + "flowmesh_workflow_id", default=None +) +_lock = threading.Lock() +_current_spans_path: Path | None = None + + +def workflow_to_trace_id_int(workflow_id: str) -> int: + """Stable 128-bit trace id derived from the workflow id. + + Strips the ``wfl-`` prefix before hex extraction so the prefix's ``f`` + doesn't shift the bit pattern. + """ + body = workflow_id.lower().removeprefix(f"{PREFIX_WORKFLOW}-") + hex_only = _HEX_ONLY.sub("", body) + if not hex_only: + return 0 + return int(hex_only.zfill(32)[:32], 16) + + +class _FlowMeshIdGenerator(IdGenerator): + """Pin trace_id to the active workflow id; random span_ids.""" + + def __init__(self) -> None: + self._fallback = RandomIdGenerator() + + def generate_span_id(self) -> int: + return self._fallback.generate_span_id() + + def generate_trace_id(self) -> int: + workflow_id = _workflow_id_var.get() + if workflow_id: + value = workflow_to_trace_id_int(workflow_id) + if value != 0: + return value + return self._fallback.generate_trace_id() + + +class _JSONLSpanExporter(SpanExporter): + """Append each completed span to the active task's spans.jsonl file.""" + + def __init__(self, path_provider: Callable[[], Path | None]) -> None: + self._path_provider = path_provider + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + path = self._path_provider() + if path is None: + return SpanExportResult.SUCCESS + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as fh: + for span in spans: + fh.write(span.to_json(indent=None) + "\n") + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + return None + + +def _resolve_path() -> Path | None: + return _current_spans_path + + +_PROVIDER_INITIALIZED = False + + +def _ensure_tracer_provider() -> None: + global _PROVIDER_INITIALIZED + if _PROVIDER_INITIALIZED: + return + with _lock: + if _PROVIDER_INITIALIZED: + return + provider = TracerProvider( + resource=Resource.create({"service.name": _SERVICE_NAME}), + id_generator=_FlowMeshIdGenerator(), + ) + provider.add_span_processor( + SimpleSpanProcessor(_JSONLSpanExporter(_resolve_path)) + ) + trace.set_tracer_provider(provider) + _PROVIDER_INITIALIZED = True + + +def get_tracer(): + _ensure_tracer_provider() + return trace.get_tracer(_TRACER_NAME) + + +def _set_active_spans_path(path: Path | None) -> None: + global _current_spans_path + _current_spans_path = path + + +@contextmanager +def task_trace_context(workflow_id: str, spans_path: Path) -> Iterator[None]: + """Bind trace_id and span exporter destination for the duration of a task. + + Pins trace_id derivation to ``workflow_id`` and routes the JSONL exporter + to ``spans_path`` while the block is active. Restores the previous state + on exit. + """ + token = _workflow_id_var.set(workflow_id) + _set_active_spans_path(spans_path) + try: + yield + finally: + _set_active_spans_path(None) + _workflow_id_var.reset(token) + + +def attributes_with_type( + span_type: SpanType, + *, + data_id: str | None, + extra: dict[str, Any] | None = None, +) -> dict[str, Any]: + attrs: dict[str, Any] = {"flowmesh.type": span_type.value} + if data_id is not None: + attrs["data_id"] = data_id + if extra: + for key, value in extra.items(): + if value is None: + continue + attrs[key] = value + return attrs diff --git a/src/worker/executors/mixins/data.py b/src/worker/executors/mixins/data.py index 49264e9..cf6c694 100644 --- a/src/worker/executors/mixins/data.py +++ b/src/worker/executors/mixins/data.py @@ -1,19 +1,10 @@ -"""Data mixin helpers used by executors. - -This module contains small utilities encapsulated in the ``DataMixin`` -class that help executors interact with data sources (via connectors), -normalize results into Pandas DataFrames. Only docstrings and explanatory comments -live here — the implementation delegates actual I/O to connector objects obtained via -``get_connector_from_spec`` which provide the runtime behavior (``execute``, -``get_schema``, etc.). -""" +"""Worker mixin: data prep helpers (prompts, images, params, dataset shards).""" import copy import datetime import io import logging from collections.abc import Mapping, Sequence -from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path from typing import Any, cast @@ -59,6 +50,12 @@ class InferenceEntry: class DataMixin(GovernanceMixin): + """Data prep helpers (prompts, images, params, dataset shards). + + Inherits :class:`GovernanceMixin` so every data-prep executor also gets + the trace + lineage emission surface (``_task_span`` / ``_span`` / + ``_log_event`` / ``_record_output`` / ``_dump_to_governance``). + """ _TEMPLATE_TYPE_MAP: dict[str, type] = { "str": str, @@ -77,11 +74,6 @@ class DataMixin(GovernanceMixin): "timestamp": datetime.datetime, } - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._upstream_deps_cache: dict[str, dict[str, Any]] = {} - self.io_executor = ThreadPoolExecutor(max_workers=32) - @classmethod def _resolve_param_type(cls, type_spec: str) -> type | None: normalized = type_spec.strip().lower() @@ -314,20 +306,6 @@ def _spec_inference_cfg(spec: TaskSpecStrictBase) -> dict[str, Any]: raise ExecutionError("spec.inference must be a mapping.") return inference - @staticmethod - def _spec_upstream_results(spec: TaskSpecStrictBase) -> dict[str, Any]: - context = spec.upstreamResults or {} - if not isinstance(context, dict): - raise ExecutionError("spec._upstreamResults must be a mapping.") - return context - - @staticmethod - def _spec_governance_cfg(spec: TaskSpecStrictBase) -> dict[str, Any]: - governance = spec.governance or {} - if not isinstance(governance, dict): - raise ExecutionError("spec.governance must be a mapping.") - return governance - def _collect_prompts_for_spec( self, spec: TaskSpecStrictBase, task_id: str, fetch_images: bool = False ) -> InferenceEntry: @@ -540,29 +518,23 @@ def _collect_prompts_for_spec( ) metadata_raw = list(raw_meta) elif dtype == "graph_template": - upstream_results = self._fetch_upstream_results_from_storage(spec) + upstream_results = self._spec_upstream_results(spec) logger.debug( "Task %s graph_template upstream keys: %s", task_id, list(upstream_results.keys()), ) - self._log_event(data_id=task_id, event_type="upstream fetch") - prompts = build_prompts_from_graph_template(data, spec) - self._log_event( - data_id=task_id, - event_type="build prompt from graph template", - event_data=f"Building prompts from graph_template for task {task_id}", - ) + with self._span("build prompt from graph template", data_id=task_id): + prompts = build_prompts_from_graph_template(data, spec) template_cfg = data.get("template") or {} append_system_prompt = bool(template_cfg.get("append_system_prompt", False)) elif dtype == "dataframe": - upstream_results = self._fetch_upstream_results_from_storage(spec) + upstream_results = self._spec_upstream_results(spec) logger.debug( "Task %s dataframe upstream keys: %s", task_id, list(upstream_results.keys()), ) - self._log_event(data_id=task_id, event_type="upstream fetch") df_columns_cfg = data.get("columns") if df_columns_cfg is None: raise ExecutionError( @@ -750,67 +722,6 @@ def _populate_table( payload["items"] = grouped_items return payload - def _fetch_upstream_results_from_storage( - self, spec: TaskSpecStrictBase - ) -> dict[str, Any]: - """ - Fetch upstream results from GovernanceRelay. - - For each upstream result reference, fetches from GovernanceRelay, - verifies integrity against the server copy. Appends retrieval timestamp - to the event list as metadata. - - Args: - spec: The task specification that may contain upstream results - - Returns: - Dict of upstream results (empty if none found) - """ - upstream_refs = self._spec_upstream_results(spec) - if not upstream_refs: - return {} - - governance_spec = self._spec_governance_cfg(spec) - if not governance_spec: - logger.info("Governance not configured; returning upstream results as-is") - return upstream_refs - - task_ids_to_fetch: set[str] = set( - upstream_spec["task_id"] for upstream_spec in upstream_refs.values() - ) - if len(task_ids_to_fetch) == 1: - task_id = task_ids_to_fetch.pop() - self._upstream_deps_cache[task_id] = self._fetch_data( - task_id, governance_spec - ) - else: - logger.info( - "Fetching %d upstream results in parallel", - len(task_ids_to_fetch), - ) - future_map = { - self.io_executor.submit( - self._fetch_data, task_id, governance_spec - ): task_id - for task_id in task_ids_to_fetch - } - for future in as_completed(future_map): - task_id = future_map[future] - try: - self._upstream_deps_cache[task_id] = future.result() - except Exception as exc: - raise ExecutionError( - f"Failed to fetch upstream result for task {task_id}: {exc}" - ) from exc - - fetched_results = { - graph_node_name: upstream_spec - | {"result": self._upstream_deps_cache[upstream_spec["task_id"]]} - for graph_node_name, upstream_spec in upstream_refs.items() - } - - return fetched_results - def _maybe_apply_dataset_shard(self, dataset, spec: TaskSpecStrictBase): shard_cfg = spec.shard if shard_cfg is None: @@ -824,105 +735,3 @@ def _maybe_apply_dataset_shard(self, dataset, spec: TaskSpecStrictBase): return dataset.shard(num_shards=total, index=index, contiguous=contiguous) except Exception as exc: raise ExecutionError(f"Failed to shard dataset ({index}/{total}): {exc}") - - def _extract_source_data_ids(self, spec: TaskSpecStrictBase) -> list[str]: - """Extract upstream task/data IDs from _upstreamResults for governance.""" - upstream_refs = self._spec_upstream_results(spec) - ids: list[str] = [] - for upstream in upstream_refs.values(): - if not isinstance(upstream, dict): - continue - candidate = upstream.get("task_id") or upstream.get("data_id") - if candidate: - ids.append(str(candidate)) - # Keep order but drop duplicates - seen: set[str] = set() - deduped: list[str] = [] - for ident in ids: - if ident in seen: - continue - seen.add(ident) - deduped.append(ident) - return deduped - - def _dump_to_governance( - self, - governance_spec: dict[str, Any], - task_id: str, - result: dict[str, Any], - dependencies_by_task: dict[str, list[str]], - ) -> None: - """ - Dump execution result and metadata to GovernanceRelay. - - Keeps the object slim by excluding full nested structures where possible. - Logs the size of the dumped object. - - Args: - governance_spec: GovernanceRelay specification dictionary - task_id: Task identifier - result: Execution result to dump - dependencies_by_task: Mapping of task_id -> list of source data dependencies - """ - - parent_deps = dependencies_by_task.get(task_id, []) - parent_events = self._events_for([task_id, *parent_deps]) - children_payload = result.get("children", {}) - - collection_jobs: list[dict[str, Any]] = [ - { - "task_id": task_id, - "result": result, - "deps": parent_deps, - "events": parent_events, - "is_parent": True, - } - ] - for child_id, child_result in children_payload.items(): - child_deps = dependencies_by_task.get(child_id, []) - child_events = self._events_for([child_id, *child_deps]) - collection_jobs.append( - { - "task_id": child_id, - "result": child_result, - "deps": child_deps, - "events": child_events, - "is_parent": False, - } - ) - - if len(collection_jobs) == 1: - job = collection_jobs[0] - self._write_data( - data_id=job["task_id"], - data=job["result"], - source_data_ids=job["deps"], - governance_spec=governance_spec, - events=job["events"], - ) - else: - logger.info( - "Writing data for %d merged tasks in parallel", - len(collection_jobs), - ) - future_map = { - self.io_executor.submit( - self._write_data, - data_id=job["task_id"], - data=job["result"], - source_data_ids=job["deps"], - governance_spec=governance_spec, - events=job["events"], - ): job - for job in collection_jobs - } - for future in as_completed(future_map): - job = future_map[future] - try: - future.result() - except Exception as exc: - if job["is_parent"]: - raise - raise ExecutionError( - f"Failed to write merged child task {job['task_id']}: {exc}" - ) from exc diff --git a/src/worker/executors/mixins/governance.py b/src/worker/executors/mixins/governance.py index 628f751..30adfcc 100644 --- a/src/worker/executors/mixins/governance.py +++ b/src/worker/executors/mixins/governance.py @@ -1,357 +1,295 @@ +"""Worker mixin: OTel span emission + asset/lineage JSONL row writes.""" + +import contextvars import json import logging -import tempfile import threading -from collections import defaultdict +import uuid +from collections.abc import Iterator, Sequence +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from contextlib import contextmanager from pathlib import Path -from threading import Lock from typing import Any -import requests +from opentelemetry.trace import Span as OTelSpan -from shared.utils.json import dedup_json, restore_json +from shared.schemas.governance import ( + READY_SPAN_NAME, + TASK_SPAN_NAME, + SpanType, +) +from shared.tasks.specs import TaskSpecStrictBase from shared.utils.time import now_iso from ..base_executor import ExecutionError +from ._otel import attributes_with_type, get_tracer, task_trace_context logger = logging.getLogger(__name__) class GovernanceMixin: - """ - Mixin for governance-related operations in FlowMesh. - - This class provides functionality for tracking events, caching governance data, - and interfacing with external governance APIs for data read/write operations. - It maintains an event log for audit trails and implements caching mechanisms - to optimize repeated data fetch operations. - - Attributes: - _event_log: Thread-safe dictionary storing event logs keyed by data_id - _current_batch_id: Current batch identifier for grouping related operations - _event_lock: Threading lock for synchronizing event log access - _cache_dir: Directory path for caching governance API responses - """ + """OTel span emission + asset / lineage JSONL row writes.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._task_id: str | None = None - self._event_log: dict[str, list[dict[str, Any]]] = defaultdict(list) + self._task_out_dir: Path | None = None self._current_batch_id: str | None = None + self._task_owner_id: str = "" self._event_lock = threading.Lock() - self._cache_dir = Path(tempfile.gettempdir()) / "flowmesh_governance_cache" - self._cache_dir.mkdir(parents=True, exist_ok=True) - self._cache_dir_lock = Lock() - - # events related methods - def _clear_events(self) -> None: - """ - Clear all events from the event log. - - Thread-safe operation that removes all logged events from memory. - This is typically called when starting a new batch of operations. - """ - with self._event_lock: - self._event_log.clear() + self.io_executor = ThreadPoolExecutor(max_workers=32) + + def _submit_in_context(self, fn: Any, *args: Any, **kwargs: Any) -> Future[Any]: + """``io_executor.submit`` that carries the caller's ContextVars across.""" + ctx = contextvars.copy_context() + return self.io_executor.submit(ctx.run, fn, *args, **kwargs) + + # ------------------------------------------------------------------ # + # Span emission — context managers driven by the OTel SDK # + # ------------------------------------------------------------------ # + @contextmanager + def _task_span( + self, + task_id: str, + workflow_id: str, + out_dir: Path, + *, + owner_id: str = "", + ) -> Iterator[OTelSpan]: + """Root span for a task — wraps the executor's ``run()`` body.""" + self._task_id = task_id + self._current_batch_id = task_id + self._task_out_dir = Path(out_dir) + self._task_owner_id = owner_id + spans_path = self._lineage_dir() / "spans.jsonl" + spans_path.parent.mkdir(parents=True, exist_ok=True) + with task_trace_context(workflow_id, spans_path): + with get_tracer().start_as_current_span( + TASK_SPAN_NAME, + attributes=attributes_with_type( + SpanType.COMPUTE, + data_id=task_id, + extra={ + "batch_id": task_id, + "workflow_id": workflow_id, + "user_id": owner_id, + "executor.name": getattr(self, "name", None), + }, + ), + ) as span: + yield span + + @contextmanager + def _span( + self, + name: str, + *, + span_type: SpanType = SpanType.COMPUTE, + data_id: str | None = None, + attributes: dict[str, Any] | None = None, + ) -> Iterator[OTelSpan]: + """Child span recording start at __enter__, end at __exit__.""" + attrs = attributes_with_type( + span_type, + data_id=data_id if data_id is not None else self._task_id, + extra={"batch_id": self._current_batch_id, **(attributes or {})}, + ) + with get_tracer().start_as_current_span(name, attributes=attrs) as span: + yield span def _log_event( self, - data_id: str = "", - event_type: str = "", - event_data: str = "", - timestamp: str | None = None, + name: str, + *, + span_type: SpanType = SpanType.MARKER, + data_id: str | None = None, + attributes: dict[str, Any] | None = None, ) -> None: - """ - Log an event to the event log for tracking and audit purposes. - - Args: - data_id: Identifier for the data associated with this event. - If empty, uses the current task_id (must be set). - event_type: Type/category of the event (required). - event_data: Additional data or context for the event. - timestamp: ISO format timestamp for the event. If None, uses current - UTC time. - - Raises: - ExecutionError: If event_type is not provided. - AssertionError: If data_id is not provided and task_id is not set. - - The event is stored thread-safely in the event log with the current batch_id. - """ - if not data_id: - assert ( - self._task_id is not None - ), "data_id must be provided if task_id is not set" - data_id = self._task_id - if not event_type: - raise ExecutionError("event_type must be provided") - ts_value = timestamp or now_iso() - event_entry = { - "event_type": event_type, - "event_data": event_data, - "timestamp": ts_value, - "batch_id": self._current_batch_id, - } - with self._event_lock: - self._event_log[data_id].append(event_entry) - logger.debug("Logged event for data_id=%s: %s", data_id, event_entry) - - def _get_events(self) -> dict[str, list[dict[str, Any]]]: - """ - Get a copy of all events in the event log. - - Returns: - A dictionary containing all logged events, keyed by data_id. - Each value is a list of event entries with event_type, event_data, - timestamp, and batch_id information. - """ - with self._event_lock: - return dict(self._event_log) - - def _events_for( - self, data_ids: list[str] | set[str] - ) -> dict[str, list[dict[str, Any]]]: - """ - Return a filtered view of the event log for the given data ids. - - Args: - data_ids: List or set of data identifiers to filter events by. - If empty or None, returns an empty dictionary. - - Returns: - A dictionary containing only the events for the specified data_ids, - with the same structure as the full event log. - """ - wanted = {str(x) for x in (data_ids or [])} - if not wanted: - return {} - with self._event_lock: - return {k: v for k, v in self._event_log.items() if k in wanted} - - def _parse_spec(self, governance_spec: dict[str, Any]) -> tuple[str, str, str]: - """ - Parse and validate governance specification dictionary. - - Args: - governance_spec: Dictionary containing governance configuration with - required fields: 'url', 'user_id', and 'trace_id'. - - Returns: - A tuple of (governance_url, user_id, trace_id) extracted from the spec. - - Raises: - ExecutionError: If any required field (url, user_id, trace_id) is missing. - """ - governance_url = governance_spec.get("url") - user_id = governance_spec.get("user_id") - trace_id = governance_spec.get("trace_id") - if not governance_url or not user_id or not trace_id: + """Record a moment-in-time checkpoint as a zero-duration span.""" + attrs = attributes_with_type( + span_type, + data_id=data_id if data_id is not None else self._task_id, + extra={"batch_id": self._current_batch_id, **(attributes or {})}, + ) + with get_tracer().start_as_current_span(name, attributes=attrs): + pass + + # ------------------------------------------------------------------ # + # Asset / lineage rows — keep their own JSONL files # + # ------------------------------------------------------------------ # + def _lineage_dir(self) -> Path: + """Per-task ``logs/`` directory; requires an active ``_task_span``.""" + if self._task_out_dir is None: raise ExecutionError( - f"Governance spec missing required fields: {governance_spec}" + "Lineage directory accessed before _task_span entered; " + "wrap executor work in `with self._task_span(...)`." ) - return governance_url, user_id, trace_id - - def _fetch_data( - self, data_id: str, governance_spec: dict[str, Any] - ) -> dict[str, Any]: - """ - Fetch data from governance API with caching support. - - This method attempts to retrieve data from a local cache first, and if not found - or expired, makes an HTTP request to the governance API. It logs various events - throughout the process for audit trails. - - Args: - data_id: Unique identifier for the data to fetch. - governance_spec: Dictionary containing governance configuration with - 'url', 'user_id', and 'trace_id' fields. - - Returns: - The retrieved data as a dictionary, restored from JSON format. - - Raises: - ExecutionError: If the API request fails, response parsing fails, - or trace_id validation fails. - - Notes: - - Caches responses to avoid redundant API calls - - Validates that trace_id matches to prevent cross-workflow data access - - Logs multiple events: request initiation, cache hits, transfers, - decoding, and caching - """ - governance_url, user_id, trace_id = self._parse_spec(governance_spec) - try: - self._log_event( - data_id=data_id, - event_type="read request initiated", - timestamp=now_iso(), - ) - cache_path = self._cache_dir / f"{data_id}-{user_id}-{trace_id}.json" - logger.debug( - "Governance read: data_id=%s user_id=%s trace_id=%s cache=%s", - data_id, - user_id, - trace_id, - cache_path.as_posix(), - ) - with self._cache_dir_lock: - if cache_path.exists(): - with open(cache_path, encoding="utf-8") as f: - cached = json.load(f) - self._log_event( - data_id=data_id, - event_type="read cache hit", - timestamp=now_iso(), - event_data="Using cached upstream result", - ) - return cached - - api_url = governance_url.rstrip("/") + "/api/read" - params = { - "data_id": data_id, - "user_id": user_id, - } - logger.debug( - "Governance read request: url=%s params=%s", - api_url, - params, - ) - response = requests.get(api_url, params=params, timeout=30) - if response.status_code >= 400: - logger.warning( - "Governance read failed (status %s): %s", - response.status_code, - response.text[:200], - ) - response.raise_for_status() - - self._log_event( - data_id=data_id, - event_type="read response transfer", - timestamp=now_iso(), - ) - - read_response = response.json() - # Extract required fields from ReadResponse - retrieved_data = restore_json(json.loads(read_response["data"])) - assert ( - trace_id == read_response["trace_id"] - ), "One workflow should not access data from another workflow" + return self._task_out_dir / "logs" - self._log_event( - data_id=data_id, - event_type="read response decoding", - timestamp=now_iso(), - ) + def _append_jsonl(self, filename: str, row: dict[str, Any]) -> None: + target_dir = self._lineage_dir() + target_dir.mkdir(parents=True, exist_ok=True) + line = json.dumps(row, ensure_ascii=False, default=str) + path = target_dir / filename + with self._event_lock: + with path.open("a", encoding="utf-8") as fh: + fh.write(line + "\n") - with self._cache_dir_lock: - with open(cache_path, "w", encoding="utf-8") as f: - json.dump(retrieved_data, f, ensure_ascii=False) + def _record_asset( + self, + data_id: str, + asset_guid: str, + version: int = 1, + user_id: str = "", + created_at: str | None = None, + ) -> None: + row = { + "data_id": data_id, + "asset_guid": asset_guid, + "version": version, + "user_id": user_id, + "created_at": created_at or now_iso(), + } + self._append_jsonl("assets.jsonl", row) - self._log_event( - data_id=data_id, - event_type="read response cache write", - timestamp=now_iso(), - ) - logger.info( - "Written data %s to cache at %s upon first read", - data_id, - cache_path.as_posix(), + def _record_lineage( + self, + data_id: str, + source_data_ids: Sequence[str], + created_at: str | None = None, + ) -> None: + ts = created_at or now_iso() + for source_data_id in source_data_ids: + self._append_jsonl( + "lineage.jsonl", + { + "data_id": data_id, + "source_data_id": source_data_id, + "created_at": ts, + }, ) - except Exception as exc: - raise ExecutionError( - f"Error fetching upstream result {data_id}: {exc}" - ) from exc - return retrieved_data - - def _write_data( + def _record_output( self, data_id: str, data: Any, source_data_ids: list[str], - governance_spec: dict[str, Any], - events: dict[str, list[dict[str, Any]]] | None = None, ) -> None: - """ - Write data to governance API with event tracking. - - This method prepares and sends data to the governance API, including - associated events and metadata. It handles data deduplication and - provides comprehensive logging for the operation. - - Args: - data_id: Unique identifier for the data being written. - data: The data to write (will be JSON serialized). - source_data_ids: List of data IDs that this data depends on or sources from. - governance_spec: Dictionary containing governance configuration with - 'url', 'user_id', and 'trace_id' fields. - events: Optional dictionary of events to include. If None, automatically - collects events for data_id and source_data_ids. - - Returns: - None. Logs success or warning messages based on API response. - - Notes: - - Automatically deduplicates JSON data before sending - - Includes relevant events for audit trails - - Handles 4xx/5xx responses gracefully with warnings - - Logs request preparation and success with data size metrics - """ - governance_url, user_id, trace_id = self._parse_spec(governance_spec) - cache_path = self._cache_dir / f"{data_id}-{user_id}-{trace_id}.json" - request_data = { - "data_id": data_id, - "user_id": user_id, - "trace_id": trace_id, - "data": json.dumps(dedup_json(data), ensure_ascii=False), - "source_data_ids": source_data_ids or [], - "events": ( - events - if events is not None - else self._events_for([data_id, *(source_data_ids or [])]) - ), - "batch_id": self._current_batch_id, - } - self._log_event( - data_id=data_id, - event_type="write request preparation", - timestamp=now_iso(), - ) - - # Write to cache - with self._cache_dir_lock: - assert not cache_path.exists(), "Cache path should not exist before writing" - with open(cache_path, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False) - self._log_event( - data_id=data_id, - event_type="write request cache write", - timestamp=now_iso(), - ) + """Emit asset + lineage rows; ``data`` is only serialized to size the + ``"dump to storage"`` span (runtime does not upload payloads).""" + with self._span( + READY_SPAN_NAME, span_type=SpanType.NETWORK, data_id=data_id + ) as dump_span: + try: + payload = json.dumps(data, ensure_ascii=False, default=str) + except (TypeError, ValueError) as exc: + raise ExecutionError( + f"Failed to serialize data {data_id}: {exc}" + ) from exc + + payload_bytes = len(payload.encode("utf-8")) + dump_span.set_attribute("payload_bytes", payload_bytes) + + asset_guid = ( + source_data_ids[0] if len(source_data_ids) == 1 else str(uuid.uuid4()) + ) + self._record_asset( + data_id=data_id, + asset_guid=asset_guid, + version=1, + user_id=self._task_owner_id, + ) + if source_data_ids: + self._record_lineage(data_id=data_id, source_data_ids=source_data_ids) logger.info( - "Written data %s to cache at %s during governance write", + "Wrote lineage for %s (size: %d bytes, sources: %d)", data_id, - cache_path.as_posix(), + payload_bytes, + len(source_data_ids), ) - # Send to governance API - api_url = governance_url.rstrip("/") + "/api/write" - response = requests.post(api_url, json=request_data, timeout=300) - - if response.status_code >= 400: - logger.warning( - "Governance dump failed for data %s (status %s): %s", - data_id, - response.status_code, - response.text[:200], + @staticmethod + def _spec_upstream_results(spec: TaskSpecStrictBase) -> dict[str, Any]: + """Validated ``spec._upstreamResults`` (server-injected stage context).""" + context = spec.upstreamResults or {} + if not isinstance(context, dict): + raise ExecutionError("spec._upstreamResults must be a mapping.") + return context + + def _extract_source_data_ids(self, spec: TaskSpecStrictBase) -> list[str]: + """Extract upstream task/data IDs from ``_upstreamResults`` for lineage.""" + seen: set[str] = set() + ids: list[str] = [] + for upstream in self._spec_upstream_results(spec).values(): + if not isinstance(upstream, dict): + continue + candidate = upstream.get("task_id") or upstream.get("data_id") + if candidate is None: + continue + sid = str(candidate) + if sid in seen: + continue + seen.add(sid) + ids.append(sid) + return ids + + def _dump_to_governance( + self, + task_id: str, + result: dict[str, Any], + dependencies_by_task: dict[str, list[str]], + ) -> None: + """Write parent + merged-child results and emit asset/lineage rows.""" + parent_deps = dependencies_by_task.get(task_id, []) + children_payload = result.get("children", {}) + + collection_jobs: list[dict[str, Any]] = [ + { + "task_id": task_id, + "result": result, + "deps": parent_deps, + "is_parent": True, + } + ] + for child_id, child_result in children_payload.items(): + child_deps = dependencies_by_task.get(child_id, []) + collection_jobs.append( + { + "task_id": child_id, + "result": child_result, + "deps": child_deps, + "is_parent": False, + } ) - return - logger.info( - "Dumped execution result for data %s to governance " - "(size: %d bytes, items: %d)", - data_id, - len(json.dumps(request_data, ensure_ascii=False).encode("utf-8")), - len(data.get("items", [])), - ) + if len(collection_jobs) == 1: + job = collection_jobs[0] + self._record_output( + data_id=job["task_id"], + data=job["result"], + source_data_ids=job["deps"], + ) + else: + logger.info( + "Recording lineage for %d merged tasks in parallel", + len(collection_jobs), + ) + future_map = { + self._submit_in_context( + self._record_output, + data_id=job["task_id"], + data=job["result"], + source_data_ids=job["deps"], + ): job + for job in collection_jobs + } + for future in as_completed(future_map): + job = future_map[future] + try: + future.result() + except Exception as exc: + if job["is_parent"]: + raise + raise ExecutionError( + f"Failed to write merged child task {job['task_id']}: {exc}" + ) from exc diff --git a/src/worker/executors/mixins/inference.py b/src/worker/executors/mixins/inference.py index 9d8cf74..fb7ae7a 100644 --- a/src/worker/executors/mixins/inference.py +++ b/src/worker/executors/mixins/inference.py @@ -10,6 +10,7 @@ import torch from PIL import Image +from shared.schemas.governance import SpanType from shared.tasks.specs import InferenceSpecStrict from shared.utils.json import to_json_serializable @@ -330,56 +331,58 @@ def _prepare_inference_entry( self, entry: InferenceEntry, *, has_images: bool = False ) -> PreparedInferenceEntry: task_id = entry.task_id - inference_cfg = entry.inference_cfg - append_system_prompt = entry.append_system_prompt - system_prompt = inference_cfg.get("system_prompt") - metadata_raw = entry.metadata_raw - prompts = entry.prompts - - apply_chat_template = bool( - inference_cfg.get("apply_chat_template", self._should_apply_chat_template()) - ) - chat_template_kwargs = self._extract_chat_template_kwargs(inference_cfg) - metadata_prompts: Sequence[MetadataPrompt] - - if apply_chat_template: - metadata_prompts, rendered_prompts = self._apply_chat_template( - prompts, - system_prompt if append_system_prompt else None, - has_images=has_images, - chat_template_kwargs=chat_template_kwargs, + with self._span( + "prompt postprocessing", + span_type=SpanType.COMPUTE, + data_id=task_id, + ): + inference_cfg = entry.inference_cfg + append_system_prompt = entry.append_system_prompt + system_prompt = inference_cfg.get("system_prompt") + metadata_raw = entry.metadata_raw + prompts = entry.prompts + + apply_chat_template = bool( + inference_cfg.get( + "apply_chat_template", self._should_apply_chat_template() + ) ) - else: - if prompts and not isinstance(prompts[0], str): - raise ExecutionError( - "Chat-style prompts require apply_chat_template=true and a " - "tokenizer with a chat template." + chat_template_kwargs = self._extract_chat_template_kwargs(inference_cfg) + metadata_prompts: Sequence[MetadataPrompt] + + if apply_chat_template: + metadata_prompts, rendered_prompts = self._apply_chat_template( + prompts, + system_prompt if append_system_prompt else None, + has_images=has_images, + chat_template_kwargs=chat_template_kwargs, ) - prompts_as_text = cast(list[str], prompts) - metadata_prompts = prompts_as_text - if system_prompt and append_system_prompt: - rendered_prompts = [ - f"{system_prompt}\n{prompt}" for prompt in prompts_as_text - ] else: - rendered_prompts = prompts_as_text.copy() - - metadata_rows = self._build_metadata_rows(metadata_raw, metadata_prompts) - - self._log_event( - data_id=task_id, - event_type="prompt postprocessing", - event_data=f"Formed metadata for {len(metadata_rows)} prompts", - ) - return PreparedInferenceEntry( - task_id=task_id, - prompts=rendered_prompts, - inference_cfg=inference_cfg, - data_cfg=entry.data_cfg, - metadata=metadata_rows, - images=entry.images, - image_group_sizes=entry.image_group_sizes, - image_embedding_path=entry.image_embedding_path, - tables=entry.tables, - applied_chat_template=apply_chat_template, - ) + if prompts and not isinstance(prompts[0], str): + raise ExecutionError( + "Chat-style prompts require apply_chat_template=true and a " + "tokenizer with a chat template." + ) + prompts_as_text = cast(list[str], prompts) + metadata_prompts = prompts_as_text + if system_prompt and append_system_prompt: + rendered_prompts = [ + f"{system_prompt}\n{prompt}" for prompt in prompts_as_text + ] + else: + rendered_prompts = prompts_as_text.copy() + + metadata_rows = self._build_metadata_rows(metadata_raw, metadata_prompts) + + return PreparedInferenceEntry( + task_id=task_id, + prompts=rendered_prompts, + inference_cfg=inference_cfg, + data_cfg=entry.data_cfg, + metadata=metadata_rows, + images=entry.images, + image_group_sizes=entry.image_group_sizes, + image_embedding_path=entry.image_embedding_path, + tables=entry.tables, + applied_chat_template=apply_chat_template, + ) diff --git a/src/worker/executors/transformers_executor.py b/src/worker/executors/transformers_executor.py index a40d1e0..c968c5c 100644 --- a/src/worker/executors/transformers_executor.py +++ b/src/worker/executors/transformers_executor.py @@ -56,6 +56,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from shared.schemas.governance import SpanType from shared.tasks.specs import ( EmbeddingSpecStrict, InferenceSpecStrict, @@ -67,7 +68,11 @@ from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import InferenceEntry from .mixins.inference import InferenceMixin -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import ( + artifact_ref, + maybe_upload_artifacts, + maybe_upload_traces, +) try: import torch @@ -392,7 +397,22 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ign f"{spec.__class__.__name__}" ) task_id = task.task_id - self._ensure_model(spec) + with self._task_span( + task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + result = self._run_inner(spec, task_id, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result + + def _run_inner( + self, + spec: "InferenceSpecStrict | EmbeddingSpecStrict", + task_id: str, + out_dir: Path, + ) -> dict[str, Any]: + with self._span("model load", span_type=SpanType.COMPUTE): + self._ensure_model(spec) deps = self._extract_source_data_ids(spec) dependencies_by_task = {task_id: deps} @@ -473,15 +493,11 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ign if image_group_sizes is not None: result["image_group_sizes"] = image_group_sizes - maybe_upload_artifacts(task, out_dir, logger=logger) - - if governance_spec := spec.governance: - self._dump_to_governance( - governance_spec=governance_spec, - task_id=task_id, - result=result, - dependencies_by_task=dependencies_by_task, - ) + self._dump_to_governance( + task_id=task_id, + result=result, + dependencies_by_task=dependencies_by_task, + ) return result @@ -513,23 +529,28 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ign enc = {k: v.to(device) for k, v in enc.items()} # type: ignore[arg-type] t0 = time.time() - with torch.no_grad(): - try: - outputs = self._model.generate( # type: ignore - **enc, generation_config=gen_cfg - ) - except ValueError as exc: - if not (stops and "stop" in str(exc).lower()): - raise - logger.warning( - "Falling back to decoded stop-string truncation after native " - "stop configuration failed: %s", - exc, - ) - gen_cfg = self._build_generation_config(self._inf, stop_strings=[]) - outputs = self._model.generate( # type: ignore - **enc, generation_config=gen_cfg - ) + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"prompt_count": len(self._prompts)}, + ): + with torch.no_grad(): + try: + outputs = self._model.generate( # type: ignore + **enc, generation_config=gen_cfg + ) + except ValueError as exc: + if not (stops and "stop" in str(exc).lower()): + raise + logger.warning( + "Falling back to decoded stop-string truncation after " + "native stop configuration failed: %s", + exc, + ) + gen_cfg = self._build_generation_config(self._inf, stop_strings=[]) + outputs = self._model.generate( # type: ignore + **enc, generation_config=gen_cfg + ) latency = time.time() - t0 items: list[dict[str, Any]] = [] @@ -591,15 +612,11 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ign if isinstance(spec, InferenceSpecStrict): self._maybe_export_jsonl(spec, task_id, result, out_dir) - maybe_upload_artifacts(task, out_dir, logger=logger) - # Dump execution result to GovernanceRelay - if governance_spec := spec.governance: - self._dump_to_governance( - governance_spec=governance_spec, - task_id=task_id, - result=result, - dependencies_by_task=dependencies_by_task, - ) + self._dump_to_governance( + task_id=task_id, + result=result, + dependencies_by_task=dependencies_by_task, + ) return result diff --git a/src/worker/executors/utils/checkpoints.py b/src/worker/executors/utils/checkpoints.py index 30af0b9..96188ce 100644 --- a/src/worker/executors/utils/checkpoints.py +++ b/src/worker/executors/utils/checkpoints.py @@ -457,3 +457,56 @@ def maybe_upload_artifacts( uploaded.append(rel_name) return uploaded + + +def maybe_upload_traces( + task: TaskReference, + out_dir: Path, + logger: logging.Logger | None = None, + skip_errors: bool = False, +) -> list[str]: + """Upload trace JSONL files under `out_dir/logs/` to + `/traces/tasks/{task_id}/{trace_type}` when the task has an HTTP + destination; no-op otherwise. The destination's `/results` artifact base + is swapped for `/traces`. Returns uploaded trace types.""" + if not task.task_id: + raise ExecutionError("Task id missing; cannot upload traces") + out_dir = Path(out_dir).resolve() + logs_dir = out_dir / "logs" + destination = get_http_destination(task.spec) + if destination is None or not logs_dir.is_dir(): + return [] + + upload_base = destination.url.rstrip("/").removesuffix("/results") + "/traces" + uploaded: list[str] = [] + + for trace_type in ("spans", "assets", "lineage"): + file_path = logs_dir / f"{trace_type}.jsonl" + if not file_path.is_file(): + continue + upload_url = f"{upload_base}/tasks/{task.task_id}/{trace_type}" + try: + with file_path.open("rb") as fh: + response = requests.request( + destination.method, + upload_url, + files={"file": (file_path.name, fh, "application/octet-stream")}, + headers=destination.headers, + timeout=destination.timeout, + ) + response.raise_for_status() + except Exception as exc: + if not skip_errors: + raise ExecutionError( + f"Trace upload failed for {file_path}: {exc}" + ) from exc + if logger: + logger.warning("Failed to upload trace %s: %s", trace_type, exc) + continue + if logger: + logger.info( + "Uploaded trace %s (%d bytes)", trace_type, file_path.stat().st_size + ) + uploaded.append(trace_type) + + return uploaded diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index 350e770..239cc06 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -66,6 +66,7 @@ _HAS_VLLM = False StructuredOutputsParams = None # type: ignore +from shared.schemas.governance import SpanType from shared.tasks.specs import InferenceSpecStrict from worker.config import WorkerConfig from worker.lifecycle import Lifecycle @@ -73,7 +74,11 @@ from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import InferenceEntry from .mixins.inference import InferenceMixin, PreparedInferenceEntry -from .utils.checkpoints import maybe_upload_artifacts, resolve_checkpoint_load +from .utils.checkpoints import ( + maybe_upload_artifacts, + maybe_upload_traces, + resolve_checkpoint_load, +) logger = logging.getLogger(__name__) @@ -415,12 +420,6 @@ def _ensure_llm( attempt_kwargs = dict(kwargs) attempt_kwargs["gpu_memory_utilization"] = util self._llm_kwargs = dict(attempt_kwargs) - if task_ids: - self._log_event_batch( - task_ids=task_ids, event_type="model pre-initialization setup" - ) - else: - self._log_event(event_type="model pre-initialization setup") logger.info( "Initializing vLLM (TP candidate %d/%d, attempt %d/%d) " "with tensor_parallel_size=%d, gpu_memory_utilization=%.3f", @@ -432,15 +431,18 @@ def _ensure_llm( util, ) try: - self._llm = LLM(**attempt_kwargs) # type: ignore[call-arg] + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={ + "task_ids": list(task_ids or ()), + "tensor_parallel_size": tp_value, + "gpu_memory_utilization": util, + }, + ): + self._llm = LLM(**attempt_kwargs) # type: ignore[call-arg] chosen_kwargs = dict(attempt_kwargs) success = True - if task_ids: - self._log_event_batch( - task_ids=task_ids, event_type="model initialization" - ) - else: - self._log_event(event_type="model initialization") break except TypeError as exc: last_exc = exc @@ -651,16 +653,6 @@ def _build_sampling_params( **optional_sampling_fields, ) - def _log_event_batch( - self, task_ids: Iterable[str], event_type: str, message: str | None = None - ) -> None: - for task_id in task_ids: - self._log_event( - data_id=task_id, - event_type=event_type, - event_data=message or str(message), - ) - def _remap_grouped_outputs( self, *, @@ -866,17 +858,21 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ign if not task_id: raise ExecutionError("task_id is required for inference execution") - # Start a fresh event log per run - self._clear_events() - self._upstream_deps_cache.clear() - self._current_batch_id = task_id - self._task_id = task_id - self._log_event( - data_id=task_id, - event_type="queuing for execution", - event_data="vLLM execution started", - ) + with self._task_span( + task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + result = self._run_inner(task, spec, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result + def _run_inner( + self, + task: ExecutorTask, + spec: InferenceSpecStrict, + out_dir: Path, + ) -> dict[str, Any]: + task_id = task.task_id.strip() merge_children = task.merged_children or [] entries: list[PreparedInferenceEntry] = [] collection_jobs: list[dict[str, Any]] = [ @@ -890,11 +886,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ign raise ExecutionError( "Merged child spec must be inference for merged vLLM execution" ) - self._log_event( - data_id=child_id, - event_type="queuing for execution", - event_data=f"Merged child {child_id} ready for execution", - ) + self._log_event("queuing for execution", data_id=child_id) collection_jobs.append( {"task_id": child_id, "spec": child_spec, "is_parent": False} ) @@ -927,7 +919,7 @@ def _collect( len(collection_jobs), ) future_map = { - self.io_executor.submit(_collect, job): job for job in collection_jobs + self._submit_in_context(_collect, job): job for job in collection_jobs } for future in as_completed(future_map): job = future_map[future] @@ -954,12 +946,6 @@ def _collect( dependencies_by_task[child_id] = child_deps entry_by_task_id[child_id] = child_entry - self._log_event_batch( - task_ids, - "prompt synchronization", - f"Prepared prompts for parent and {len(merge_children)} children", - ) - self._batched_inputs = [] self._prompt_owners = [] self._batched_metadata = [] @@ -1023,191 +1009,195 @@ def _collect( generate_kwargs = self._build_generate_kwargs(spec, out_dir) t0 = time.time() - self._log_event_batch( - task_ids, - "generation-preprocessing", - f"Starting LLM generation for {len(self._batched_inputs)} prompts", - ) - outputs = self._llm.generate( - self._batched_inputs, - sampling_params=sampling_params, - **generate_kwargs, - ) # type: ignore[attr-defined] + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={ + "task_ids": task_ids, + "prompt_count": len(self._batched_inputs), + }, + ): + outputs = self._llm.generate( + self._batched_inputs, + sampling_params=sampling_params, + **generate_kwargs, + ) # type: ignore[attr-defined] latency = time.time() - t0 - self._log_event_batch( - task_ids, "generation", f"LLM generation completed ({latency:.2f}s)" - ) - per_task_items: dict[str, list[dict[str, Any]]] = {} - usage_by_task: dict[str, dict[str, int | float]] = {} - counts_by_task: dict[str, int] = {} + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_ids": task_ids}, + ): + per_task_items: dict[str, list[dict[str, Any]]] = {} + usage_by_task: dict[str, dict[str, int | float]] = {} + counts_by_task: dict[str, int] = {} + + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for idx, out in enumerate(outputs): + owner = ( + self._prompt_owners[idx] + if idx < len(self._prompt_owners) + else task_id + ) + owner_items = per_task_items.setdefault(owner, []) + local_index = len(owner_items) + prompt_text: str = "" + if idx < len(self._batched_inputs): + prompt_payload = self._batched_inputs[idx] + if isinstance(prompt_payload, dict): + prompt_text = prompt_payload["prompt"] + else: + prompt_text = prompt_payload + metadata_entry = ( + self._batched_metadata[idx] + if idx < len(self._batched_metadata) + else {"prompt": prompt_text} + ) + + out_outputs = getattr(out, "outputs", None) + if not out_outputs: + payload = { + "index": local_index, + "prompt": prompt_text, + "output": "", + "finish_reason": None, + } + if metadata_entry: + payload["metadata"] = metadata_entry + owner_items.append(payload) + usage_by_task.setdefault( + owner, {"prompt_tokens": 0, "completion_tokens": 0} + ) + counts_by_task[owner] = counts_by_task.get(owner, 0) + 1 + continue - total_prompt_tokens = 0 - total_completion_tokens = 0 + best = out_outputs[0] + text = getattr(best, "text", "") or "" - for idx, out in enumerate(outputs): - owner = ( - self._prompt_owners[idx] if idx < len(self._prompt_owners) else task_id - ) - owner_items = per_task_items.setdefault(owner, []) - local_index = len(owner_items) - prompt_text: str = "" - if idx < len(self._batched_inputs): - prompt_payload = self._batched_inputs[idx] - if isinstance(prompt_payload, dict): - prompt_text = prompt_payload["prompt"] - else: - prompt_text = prompt_payload - metadata_entry = ( - self._batched_metadata[idx] - if idx < len(self._batched_metadata) - else {"prompt": prompt_text} - ) + output_value: Any = text + if template_param_schema: + try: + output_value = json.loads(text) + except json.JSONDecodeError: + logger.warning( + "Task %s: failed to parse structured output as JSON: %r", + owner, + text, + ) - out_outputs = getattr(out, "outputs", None) - if not out_outputs: + finish_reason = getattr(best, "finish_reason", None) payload = { "index": local_index, "prompt": prompt_text, - "output": "", - "finish_reason": None, + "output": output_value, + "finish_reason": finish_reason, } if metadata_entry: payload["metadata"] = metadata_entry owner_items.append(payload) - usage_by_task.setdefault( + + prompt_token_ids = getattr(out, "prompt_token_ids", None) or [] + best_token_ids = getattr(best, "token_ids", None) or [] + prompt_len = len(prompt_token_ids) + completion_len = len(best_token_ids) + + total_prompt_tokens += prompt_len + total_completion_tokens += completion_len + + usage_entry = usage_by_task.setdefault( owner, {"prompt_tokens": 0, "completion_tokens": 0} ) + usage_entry["prompt_tokens"] += prompt_len + usage_entry["completion_tokens"] += completion_len counts_by_task[owner] = counts_by_task.get(owner, 0) + 1 - continue - - best = out_outputs[0] - text = getattr(best, "text", "") or "" - output_value: Any = text - if template_param_schema: - try: - output_value = json.loads(text) - except json.JSONDecodeError: - logger.warning( - "Task %s: failed to parse structured output as JSON: %r", - owner, - text, + for owner, entry in entry_by_task_id.items(): + group_sizes: list[int] | None = entry.image_group_sizes + if group_sizes is None: + continue + base_prompts = entry.image_group_base_prompts + base_metadata = entry.image_group_base_metadata + if base_prompts is None or base_metadata is None: + raise ExecutionError( + "Grouped image outputs require base prompts and metadata " + f"(task={owner})." ) - - finish_reason = getattr(best, "finish_reason", None) - payload = { - "index": local_index, - "prompt": prompt_text, - "output": output_value, - "finish_reason": finish_reason, - } - if metadata_entry: - payload["metadata"] = metadata_entry - owner_items.append(payload) - - prompt_token_ids = getattr(out, "prompt_token_ids", None) or [] - best_token_ids = getattr(best, "token_ids", None) or [] - prompt_len = len(prompt_token_ids) - completion_len = len(best_token_ids) - - total_prompt_tokens += prompt_len - total_completion_tokens += completion_len - - usage_entry = usage_by_task.setdefault( - owner, {"prompt_tokens": 0, "completion_tokens": 0} - ) - usage_entry["prompt_tokens"] += prompt_len - usage_entry["completion_tokens"] += completion_len - counts_by_task[owner] = counts_by_task.get(owner, 0) + 1 - - for owner, entry in entry_by_task_id.items(): - group_sizes: list[int] | None = entry.image_group_sizes - if group_sizes is None: - continue - base_prompts = entry.image_group_base_prompts - base_metadata = entry.image_group_base_metadata - if base_prompts is None or base_metadata is None: - raise ExecutionError( - "Grouped image outputs require base prompts and metadata " - f"(task={owner})." - ) - owner_items = per_task_items.get(owner, []) - per_task_items[owner] = self._remap_grouped_outputs( - task_id=owner, - items=owner_items, - group_sizes=self._validate_image_group_sizes( - group_sizes, + owner_items = per_task_items.get(owner, []) + per_task_items[owner] = self._remap_grouped_outputs( task_id=owner, - ), - base_prompts=base_prompts, - base_metadata=base_metadata, - ) - - for owner, usage in usage_by_task.items(): - usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"] - usage["latency_sec"] = latency - usage["num_requests"] = counts_by_task.get(owner, 0) - - parent_usage = { - "prompt_tokens": total_prompt_tokens, - "completion_tokens": total_completion_tokens, - "total_tokens": total_prompt_tokens + total_completion_tokens, - "latency_sec": latency, - "num_requests": len(self._batched_inputs), - } + items=owner_items, + group_sizes=self._validate_image_group_sizes( + group_sizes, + task_id=owner, + ), + base_prompts=base_prompts, + base_metadata=base_metadata, + ) - result: dict[str, Any] = { - "ok": True, - "model": self._model_name, - "items": per_task_items.get(task_id, []), - "usage": parent_usage, - } + for owner, usage in usage_by_task.items(): + usage["total_tokens"] = ( + usage["prompt_tokens"] + usage["completion_tokens"] + ) + usage["latency_sec"] = latency + usage["num_requests"] = counts_by_task.get(owner, 0) + + parent_usage = { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_prompt_tokens + total_completion_tokens, + "latency_sec": latency, + "num_requests": len(self._batched_inputs), + } - child_results: dict[str, Any] = {} - for child in merge_children: - child_id = child.task_id.strip() - if not child_id: - continue - child_payload: dict[str, Any] = { - "items": per_task_items.get(child_id, []), + result: dict[str, Any] = { + "ok": True, + "model": self._model_name, + "items": per_task_items.get(task_id, []), + "usage": parent_usage, } - maybe_usage = usage_by_task.get(child_id) - if maybe_usage: - child_payload["usage"] = maybe_usage - child_results[child_id] = child_payload - - self._log_event_batch(task_ids, "output postprocessing") - - if parent_tables := parent_entry.tables: - result = self._populate_table(result, parent_tables) - if child_results: - for child_id, child_payload in list(child_results.items()): - if (child_entry := entry_by_task_id.get(child_id)) and ( - child_tables := child_entry.tables - ): - child_results[child_id] = self._populate_table( - child_payload, child_tables - ) - if child_results: - result["children"] = child_results + child_results: dict[str, Any] = {} + for child in merge_children: + child_id = child.task_id.strip() + if not child_id: + continue + child_payload: dict[str, Any] = { + "items": per_task_items.get(child_id, []), + } + maybe_usage = usage_by_task.get(child_id) + if maybe_usage: + child_payload["usage"] = maybe_usage + child_results[child_id] = child_payload + + if parent_tables := parent_entry.tables: + result = self._populate_table(result, parent_tables) + if child_results: + for child_id, child_payload in child_results.items(): + if (child_entry := entry_by_task_id.get(child_id)) and ( + child_tables := child_entry.tables + ): + child_results[child_id] = self._populate_table( + child_payload, child_tables + ) - self._maybe_export_jsonl(spec, task_id, result, out_dir) - self._log_event_batch( - task_ids, "JSONL export", "vLLM execution completed successfully" - ) + if child_results: + result["children"] = child_results - maybe_upload_artifacts(task, out_dir, logger=logger) + with self._span( + "JSONL export", + span_type=SpanType.COMPUTE, + attributes={"task_ids": task_ids}, + ): + self._maybe_export_jsonl(spec, task_id, result, out_dir) - # Dump execution result to GovernanceRelay - if governance_spec := spec.governance: - self._dump_to_governance( - governance_spec=governance_spec, - task_id=task_id, - result=result, - dependencies_by_task=dependencies_by_task, - ) + self._dump_to_governance( + task_id=task_id, + result=result, + dependencies_by_task=dependencies_by_task, + ) return result @@ -1220,8 +1210,6 @@ def cleanup_after_run(self) -> None: self._batched_metadata = [] self._base_inference = {} - self._clear_events() - logger.debug("Shutting down I/O executor") self.io_executor.shutdown(wait=True) logger.debug("I/O executor shut down successfully") diff --git a/tests/sdk/test_models.py b/tests/sdk/test_models.py index 4edea5c..841c7b4 100644 --- a/tests/sdk/test_models.py +++ b/tests/sdk/test_models.py @@ -1,12 +1,22 @@ """Model validation and round-trip tests for SDK Pydantic models.""" +from datetime import UTC, datetime + import pytest from flowmesh.models import ( + ActiveWaitBreakdown, + AssetSummary, + CriticalPathSummary, + E2EBreakdown, + EventSummary, + LineageEdge, LogQueryResponse, Node, NodeWorkerInfo, OkResponse, + ProfileSummary, TaskInfo, + TaskTiming, TaskUsage, WorkerHardware, WorkerInfo, @@ -17,6 +27,14 @@ from flowmesh.models.ssh import SSHConnectionInfo from pydantic import BaseModel +from server.governance.analyzer import ActiveWaitBreakdown as SrvActiveWaitBreakdown +from server.governance.analyzer import AssetSummary as SrvAssetSummary +from server.governance.analyzer import CriticalPathSummary as SrvCriticalPathSummary +from server.governance.analyzer import E2EBreakdown as SrvE2EBreakdown +from server.governance.analyzer import EventSummary as SrvEventSummary +from server.governance.analyzer import LineageEdge as SrvLineageEdge +from server.governance.analyzer import ProfileSummary as SrvProfileSummary +from server.governance.analyzer import TaskTiming as SrvTaskTiming from server.registries.node import Node as SrvNode from server.registries.worker import Worker as SrvWorker from server.registries.worker import WorkerInfo as SrvWorkerInfo @@ -162,6 +180,82 @@ stale=False, ) +_SRV_EVENT_SUMMARY = SrvEventSummary( + event_type=["model load", "generation"], + count=[1, 2], + total_seconds=[53.39, 1.23], + avg_seconds=[53.39, 0.62], + min_seconds=[53.39, 0.39], + max_seconds=[53.39, 0.84], +) + +_SRV_NETWORK_SUMMARY = SrvEventSummary( + event_type=["dump to storage"], + count=[2], + total_seconds=[0.001], + avg_seconds=[0.001], + min_seconds=[0.000], + max_seconds=[0.001], +) + +_SRV_TASK_TIMING = SrvTaskTiming( + data_id="tsk-a", + start_time=datetime(2026, 4, 30, 14, 0, 1, tzinfo=UTC), + end_time=datetime(2026, 4, 30, 14, 0, 55, tzinfo=UTC), + duration_seconds=54.0, + queuing_delay_seconds=0.5, + parent_data_ids=["tsk-up-a"], + blocking_parent_data_id="tsk-up-a", +) + +_SRV_ACTIVE_WAIT = SrvActiveWaitBreakdown( + data_id=["tsk-a", "tsk-b"], + active_seconds=[54.0, 0.84], + wait_seconds=[0.0, 0.5], +) + +_SRV_E2E = SrvE2EBreakdown( + hardware_summary=_SRV_EVENT_SUMMARY, + network_summary=_SRV_NETWORK_SUMMARY, + workflow_duration_seconds=55.05, + total_network_seconds=0.001, +) + +_SRV_CP = SrvCriticalPathSummary( + path=["tsk-a", "tsk-b"], + critical_path_seconds=55.05, + active_wait_breakdown=_SRV_ACTIVE_WAIT, + hardware_summary=_SRV_EVENT_SUMMARY, + network_summary=_SRV_NETWORK_SUMMARY, + total_network_seconds=0.001, +) + +_SRV_PROFILE = SrvProfileSummary( + workflow_id="wfl-abc", + event_count=18, + data_ids=["tsk-a", "tsk-b"], + assets=[ + SrvAssetSummary( + asset_guid="g-1", + latest_data_id="tsk-a", + latest_version=1, + user_id="alice", + versions=1, + created_at="2026-04-30T14:00:55Z", + ) + ], + lineage=[ + SrvLineageEdge( + data_id="tsk-b", + source_data_id="tsk-a", + created_at="2026-04-30T14:00:55Z", + ) + ], + e2e_breakdown=_SRV_E2E, + per_data_id=[_SRV_TASK_TIMING], + critical_path=_SRV_CP, +) + # ------------------------------------------------------------------ # # Helpers # ------------------------------------------------------------------ # @@ -308,3 +402,84 @@ def test_ssh_connection_info(self) -> None: ) r = SSHConnectionInfo.model_validate(_dump(server)) assert r.access_mode == "proxy" + + +class TestTraceModels: + def test_asset_summary(self) -> None: + server = _SRV_PROFILE.assets[0] + r = AssetSummary.model_validate(_dump(server)) + assert r.asset_guid == "g-1" + assert r.latest_data_id == "tsk-a" + assert r.latest_version == 1 + assert r.user_id == "alice" + assert r.versions == 1 + + def test_lineage_edge(self) -> None: + server = _SRV_PROFILE.lineage[0] + r = LineageEdge.model_validate(_dump(server)) + assert r.data_id == "tsk-b" + assert r.source_data_id == "tsk-a" + + def test_event_summary_parallel_lists_align(self) -> None: + r = EventSummary.model_validate(_dump(_SRV_EVENT_SUMMARY)) + n = len(r.event_type) + assert n == 2 + assert all( + len(field) == n + for field in ( + r.count, + r.total_seconds, + r.avg_seconds, + r.min_seconds, + r.max_seconds, + ) + ) + assert r.event_type[0] == "model load" + assert r.total_seconds[0] == pytest.approx(53.39) + + def test_e2e_breakdown(self) -> None: + r = E2EBreakdown.model_validate(_dump(_SRV_E2E)) + assert r.workflow_duration_seconds == pytest.approx(55.05) + assert "model load" in r.hardware_summary.event_type + assert "dump to storage" in r.network_summary.event_type + + def test_active_wait_breakdown(self) -> None: + r = ActiveWaitBreakdown.model_validate(_dump(_SRV_ACTIVE_WAIT)) + assert r.data_id == ["tsk-a", "tsk-b"] + assert r.wait_seconds[1] == pytest.approx(0.5) + + def test_task_timing_datetime_round_trip(self) -> None: + r = TaskTiming.model_validate(_dump(_SRV_TASK_TIMING)) + assert r.data_id == "tsk-a" + assert r.start_time == _SRV_TASK_TIMING.start_time + assert r.end_time == _SRV_TASK_TIMING.end_time + assert r.queuing_delay_seconds == pytest.approx(0.5) + assert r.blocking_parent_data_id == "tsk-up-a" + + def test_critical_path_summary(self) -> None: + r = CriticalPathSummary.model_validate(_dump(_SRV_CP)) + assert r.path == ["tsk-a", "tsk-b"] + assert r.critical_path_seconds == pytest.approx(55.05) + assert r.active_wait_breakdown.data_id == ["tsk-a", "tsk-b"] + + def test_profile_summary(self) -> None: + r = ProfileSummary.model_validate(_dump(_SRV_PROFILE)) + assert r.workflow_id == "wfl-abc" + assert r.event_count == 18 + assert r.data_ids == ["tsk-a", "tsk-b"] + assert len(r.assets) == 1 + assert len(r.lineage) == 1 + assert r.e2e_breakdown.workflow_duration_seconds == pytest.approx(55.05) + assert r.critical_path is not None + assert r.critical_path.path == ["tsk-a", "tsk-b"] + + def test_profile_summary_critical_path_optional(self) -> None: + server = _SRV_PROFILE.model_copy(update={"critical_path": None}) + r = ProfileSummary.model_validate(_dump(server)) + assert r.critical_path is None + + def test_profile_summary_rejects_extra_fields(self) -> None: + payload = _dump(_SRV_PROFILE) + payload["unexpected"] = 1 + with pytest.raises(Exception): + ProfileSummary.model_validate(payload) diff --git a/tests/sdk/test_profile_views.py b/tests/sdk/test_profile_views.py new file mode 100644 index 0000000..ebeb4ae --- /dev/null +++ b/tests/sdk/test_profile_views.py @@ -0,0 +1,51 @@ +"""SDK profile-view helpers (mermaid renderer).""" + +from flowmesh.profile_views import to_mermaid + + +def test_to_mermaid_renders_lineage_edges() -> None: + summary = { + "workflow_id": "wfl-1", + "event_count": 0, + "data_ids": ["tsk-1", "tsk-2", "tsk-3"], + "assets": [], + "lineage": [ + { + "data_id": "tsk-3", + "source_data_id": "tsk-1", + "created_at": "2026-04-29T00:00:00+00:00", + }, + { + "data_id": "tsk-3", + "source_data_id": "tsk-2", + "created_at": "2026-04-29T00:00:00+00:00", + }, + ], + "e2e_breakdown": { + "hardware_summary": { + "event_type": [], + "count": [], + "total_seconds": [], + "avg_seconds": [], + "min_seconds": [], + "max_seconds": [], + }, + "network_summary": { + "event_type": [], + "count": [], + "total_seconds": [], + "avg_seconds": [], + "min_seconds": [], + "max_seconds": [], + }, + "workflow_duration_seconds": 0.0, + "total_network_seconds": 0.0, + }, + "per_data_id": [], + "critical_path": None, + } + rendered = to_mermaid(summary) + assert rendered.startswith("graph TD") + assert "tsk_1" in rendered + assert "tsk_3" in rendered + assert "-->" in rendered diff --git a/tests/server/test_governance_schemas.py b/tests/server/test_governance_schemas.py new file mode 100644 index 0000000..55575a2 --- /dev/null +++ b/tests/server/test_governance_schemas.py @@ -0,0 +1,45 @@ +from server.governance import Span +from shared.schemas.governance import SpanType + + +def test_span_otel_round_trip() -> None: + raw = { + "name": "model load", + "context": { + "trace_id": "0xfbad6be5c4434181a2d394eac830dea1", + "span_id": "0xa3f1e9d2c5b40678", + }, + "parent_id": "0x1b2c3d4e5f6a7b8c", + "start_time": "2026-04-30T14:00:01.000000Z", + "end_time": "2026-04-30T14:00:55.000000Z", + "status": {"status_code": "OK"}, + "attributes": { + "data_id": "tsk-1", + "batch_id": "tsk-1", + "flowmesh.type": "compute", + }, + } + span = Span.parse_otel_json(raw) + assert span.name == "model load" + assert span.context.trace_id == "fbad6be5c4434181a2d394eac830dea1" + assert span.context.span_id == "a3f1e9d2c5b40678" + assert span.parent_id == "1b2c3d4e5f6a7b8c" + assert span.attributes.data_id == "tsk-1" + assert span.attributes.batch_id == "tsk-1" + assert span.attributes.flowmesh_type == SpanType.COMPUTE + assert span.duration_seconds == 54.0 + + +def test_span_marker_zero_duration() -> None: + raw = { + "name": "dump to storage", + "context": {"trace_id": "0x" + "a" * 32, "span_id": "0x" + "b" * 16}, + "parent_id": "0x" + "c" * 16, + "start_time": "2026-04-30T14:00:01.500000Z", + "end_time": "2026-04-30T14:00:01.500000Z", + "status": {"status_code": "OK"}, + "attributes": {"data_id": "tsk-2", "flowmesh.type": "marker"}, + } + span = Span.parse_otel_json(raw) + assert span.duration_seconds == 0.0 + assert span.attributes.flowmesh_type == SpanType.MARKER diff --git a/tests/server/test_profile_analyzer.py b/tests/server/test_profile_analyzer.py new file mode 100644 index 0000000..bef4286 --- /dev/null +++ b/tests/server/test_profile_analyzer.py @@ -0,0 +1,216 @@ +from typing import Any + +from server.governance import analyze + + +def _span( + name: str, + *, + data_id: str, + start: str, + end: str, + span_type: str, + parent_id: str | None = None, + span_id: str = "0xa3f1e9d2c5b40678", + batch_id: str | None = None, +) -> dict[str, Any]: + attributes: dict[str, Any] = {"data_id": data_id, "flowmesh.type": span_type} + if batch_id: + attributes["batch_id"] = batch_id + return { + "name": name, + "context": { + "trace_id": "0xfbad6be5c4434181a2d394eac830dea1", + "span_id": span_id, + }, + "parent_id": parent_id, + "start_time": start, + "end_time": end, + "status": {"status_code": "OK"}, + "attributes": attributes, + } + + +def _spans() -> list[dict[str, Any]]: + """ + Two parallel branches (tsk-1, tsk-2) feeding a synthesis (tsk-3). + Same shape the original event-fixture exercised, now as spans. + """ + return [ + _span( + "task", + data_id="tsk-1", + start="2026-04-29T00:00:00+00:00", + end="2026-04-29T00:00:02+00:00", + span_type="compute", + span_id="0x1111111111111111", + ), + _span( + "model load", + data_id="tsk-1", + start="2026-04-29T00:00:00+00:00", + end="2026-04-29T00:00:01+00:00", + span_type="compute", + parent_id="0x1111111111111111", + span_id="0x1111000000000001", + batch_id="tsk-1", + ), + _span( + "dump to storage", + data_id="tsk-1", + start="2026-04-29T00:00:01+00:00", + end="2026-04-29T00:00:02+00:00", + span_type="network", + parent_id="0x1111111111111111", + span_id="0x1111000000000002", + ), + _span( + "task", + data_id="tsk-2", + start="2026-04-29T00:00:00+00:00", + end="2026-04-29T00:00:04+00:00", + span_type="compute", + span_id="0x2222222222222222", + ), + _span( + "model load", + data_id="tsk-2", + start="2026-04-29T00:00:00+00:00", + end="2026-04-29T00:00:03+00:00", + span_type="compute", + parent_id="0x2222222222222222", + span_id="0x2222000000000001", + batch_id="tsk-2", + ), + _span( + "dump to storage", + data_id="tsk-2", + start="2026-04-29T00:00:03+00:00", + end="2026-04-29T00:00:04+00:00", + span_type="network", + parent_id="0x2222222222222222", + span_id="0x2222000000000002", + ), + _span( + "task", + data_id="tsk-3", + start="2026-04-29T00:00:05+00:00", + end="2026-04-29T00:00:06+00:00", + span_type="compute", + span_id="0x3333333333333333", + ), + _span( + "read", + data_id="tsk-3", + start="2026-04-29T00:00:05+00:00", + end="2026-04-29T00:00:05.500000+00:00", + span_type="network", + parent_id="0x3333333333333333", + span_id="0x3333000000000001", + ), + _span( + "dump to storage", + data_id="tsk-3", + start="2026-04-29T00:00:05.500000+00:00", + end="2026-04-29T00:00:06+00:00", + span_type="network", + parent_id="0x3333333333333333", + span_id="0x3333000000000002", + ), + ] + + +def _assets() -> list[dict]: + return [ + { + "data_id": "tsk-1", + "asset_guid": "g-1", + "version": 1, + "user_id": "alice", + "created_at": "2026-04-29T00:00:02+00:00", + }, + { + "data_id": "tsk-2", + "asset_guid": "g-2", + "version": 1, + "user_id": "alice", + "created_at": "2026-04-29T00:00:04+00:00", + }, + { + "data_id": "tsk-3", + "asset_guid": "g-3", + "version": 1, + "user_id": "alice", + "created_at": "2026-04-29T00:00:06+00:00", + }, + ] + + +def _lineage() -> list[dict]: + return [ + { + "data_id": "tsk-3", + "source_data_id": "tsk-1", + "created_at": "2026-04-29T00:00:06+00:00", + }, + { + "data_id": "tsk-3", + "source_data_id": "tsk-2", + "created_at": "2026-04-29T00:00:06+00:00", + }, + ] + + +def test_e2e_breakdown_workflow_duration_and_network_union() -> None: + summary = analyze(_spans(), _assets(), _lineage()) + e2e = summary.e2e_breakdown + assert e2e.workflow_duration_seconds == 6.0 + assert e2e.total_network_seconds > 0 + + +def test_e2e_hardware_summary_lists_compute_spans() -> None: + summary = analyze(_spans(), _assets(), _lineage()) + hw = summary.e2e_breakdown.hardware_summary + types = set(hw.event_type) + assert "model load" in types + assert "read" not in types + assert "dump to storage" not in types + assert "task" not in types + + +def test_e2e_network_summary_includes_transfers() -> None: + summary = analyze(_spans(), _assets(), _lineage()) + net = summary.e2e_breakdown.network_summary + types = set(net.event_type) + assert "dump to storage" in types + assert "read" in types + + +def test_critical_path_picks_synthesis_chain() -> None: + summary = analyze(_spans(), _assets(), _lineage()) + cp = summary.critical_path + assert cp is not None + assert cp.path == ["tsk-2", "tsk-3"] + awb = cp.active_wait_breakdown + assert awb.data_id == ["tsk-2", "tsk-3"] + assert awb.active_seconds[1] == 1.0 + assert awb.wait_seconds[1] == 1.0 + + +def test_per_data_id_queuing_delays() -> None: + summary = analyze(_spans(), _assets(), _lineage()) + per_id = {t.data_id: t for t in summary.per_data_id} + assert per_id["tsk-1"].queuing_delay_seconds == 0.0 + assert per_id["tsk-2"].queuing_delay_seconds == 0.0 + assert per_id["tsk-3"].queuing_delay_seconds == 1.0 + assert per_id["tsk-3"].blocking_parent_data_id == "tsk-2" + assert per_id["tsk-3"].duration_seconds == 1.0 + + +def test_analyze_handles_empty_spans() -> None: + summary = analyze([], _assets(), _lineage()) + assert summary.event_count == 0 + assert summary.data_ids == [] + assert summary.per_data_id == [] + assert summary.e2e_breakdown.workflow_duration_seconds == 0.0 + assert summary.critical_path is None diff --git a/tests/server/test_traces_router.py b/tests/server/test_traces_router.py new file mode 100644 index 0000000..291fe28 --- /dev/null +++ b/tests/server/test_traces_router.py @@ -0,0 +1,278 @@ +"""Tests for the workflow traces router.""" + +import json +from io import BytesIO +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from fastapi import HTTPException, UploadFile +from fastapi.responses import StreamingResponse + +from server.routers.v1 import traces as traces_router + + +def _otel_span( + name: str, + *, + data_id: str, + start: str, + end: str, + span_type: str, + span_id: str = "0xa3f1e9d2c5b40678", + parent_id: str | None = None, + batch_id: str | None = None, +) -> dict[str, Any]: + attributes: dict[str, Any] = {"data_id": data_id, "flowmesh.type": span_type} + if batch_id: + attributes["batch_id"] = batch_id + return { + "name": name, + "context": { + "trace_id": "0xfbad6be5c4434181a2d394eac830dea1", + "span_id": span_id, + }, + "parent_id": parent_id, + "start_time": start, + "end_time": end, + "status": {"status_code": "OK"}, + "attributes": attributes, + } + + +def _seed_task_logs( + base: Path, + task_id: str, + spans: list[dict[str, Any]] | None = None, + assets: list[dict[str, Any]] | None = None, + lineage: list[dict[str, Any]] | None = None, +) -> None: + logs_dir = base / task_id / "logs" + logs_dir.mkdir(parents=True) + if spans is not None: + (logs_dir / "spans.jsonl").write_text( + "\n".join(json.dumps(row) for row in spans) + "\n" + ) + if assets is not None: + (logs_dir / "assets.jsonl").write_text( + "\n".join(json.dumps(row) for row in assets) + "\n" + ) + if lineage is not None: + (logs_dir / "lineage.jsonl").write_text( + "\n".join(json.dumps(row) for row in lineage) + "\n" + ) + + +def _registry(task_ids: list[str]): + workflow = type("WF", (), {"task_ids": task_ids})() + registry = AsyncMock() + registry.get_workflow_async.return_value = workflow + return registry + + +async def _collect_streamed_lines(response: StreamingResponse) -> list[str]: + chunks: list[bytes] = [] + async for chunk in response.body_iterator: + if isinstance(chunk, bytes): + chunks.append(chunk) + elif isinstance(chunk, memoryview): + chunks.append(bytes(chunk)) + else: + chunks.append(chunk.encode("utf-8")) + text = b"".join(chunks).decode("utf-8") + return [line for line in text.split("\n") if line] + + +@pytest.mark.anyio +async def test_get_workflow_trace_concats_across_tasks(tmp_path: Path) -> None: + _seed_task_logs( + tmp_path, + "tsk-a", + spans=[ + _otel_span( + "write", + data_id="tsk-a", + start="2026-04-29T00:00:00+00:00", + end="2026-04-29T00:00:01+00:00", + span_type="network", + span_id="0xaaaa000000000001", + ) + ], + assets=[{"data_id": "tsk-a", "asset_guid": "g-a", "version": 1}], + ) + _seed_task_logs( + tmp_path, + "tsk-b", + spans=[ + _otel_span( + "read", + data_id="tsk-a", + start="2026-04-29T00:00:01+00:00", + end="2026-04-29T00:00:02+00:00", + span_type="network", + span_id="0xbbbb000000000001", + ) + ], + lineage=[{"data_id": "tsk-b", "source_data_id": "tsk-a"}], + ) + + response = await traces_router.get_workflow_trace( + workflow_id="wfl-1", + trace_type="spans", + registry=_registry(["tsk-a", "tsk-b"]), + results_dir=tmp_path, + ) + lines = await _collect_streamed_lines(response) + parsed = [json.loads(line) for line in lines] + assert [row["name"] for row in parsed] == ["write", "read"] + + +@pytest.mark.anyio +async def test_get_workflow_trace_skips_missing_files(tmp_path: Path) -> None: + _seed_task_logs( + tmp_path, + "tsk-a", + assets=[{"data_id": "tsk-a", "asset_guid": "g-a", "version": 1}], + ) + response = await traces_router.get_workflow_trace( + workflow_id="wfl-1", + trace_type="assets", + registry=_registry(["tsk-a", "tsk-b-missing"]), + results_dir=tmp_path, + ) + parsed = [json.loads(line) for line in await _collect_streamed_lines(response)] + assert len(parsed) == 1 + assert parsed[0]["data_id"] == "tsk-a" + + +@pytest.mark.anyio +async def test_get_workflow_trace_unknown_type(tmp_path: Path) -> None: + with pytest.raises(HTTPException) as excinfo: + await traces_router.get_workflow_trace( + workflow_id="wfl-1", + trace_type="bogus", + registry=_registry(["tsk-a"]), + results_dir=tmp_path, + ) + assert excinfo.value.status_code == 400 + + +@pytest.mark.anyio +async def test_analyze_workflow_trace_runs_analyzer(tmp_path: Path) -> None: + _seed_task_logs( + tmp_path, + "tsk-a", + spans=[ + _otel_span( + "task", + data_id="tsk-a", + start="2026-04-29T00:00:00+00:00", + end="2026-04-29T00:00:01+00:00", + span_type="compute", + span_id="0xaaaa000000000001", + ), + _otel_span( + "write", + data_id="tsk-a", + start="2026-04-29T00:00:00.500000+00:00", + end="2026-04-29T00:00:01+00:00", + span_type="network", + parent_id="0xaaaa000000000001", + span_id="0xaaaa000000000002", + ), + _otel_span( + "dump to storage", + data_id="tsk-a", + start="2026-04-29T00:00:01+00:00", + end="2026-04-29T00:00:01+00:00", + span_type="marker", + parent_id="0xaaaa000000000001", + span_id="0xaaaa000000000003", + ), + ], + assets=[ + { + "data_id": "tsk-a", + "asset_guid": "g-a", + "version": 1, + "user_id": "alice", + "created_at": "2026-04-29T00:00:00+00:00", + } + ], + ) + + summary = await traces_router.analyze_workflow_trace( + workflow_id="wfl-1", + registry=_registry(["tsk-a"]), + results_dir=tmp_path, + ) + assert summary.event_count == 3 + assert len(summary.assets) == 1 + assert summary.workflow_id == "wfl-1" + assert summary.critical_path is not None + assert summary.critical_path.path == ["tsk-a"] + assert summary.assets[0].asset_guid == "g-a" + assert "write" in summary.e2e_breakdown.network_summary.event_type + + +@pytest.mark.anyio +async def test_workflow_not_found_raises_404(tmp_path: Path) -> None: + registry = AsyncMock() + registry.get_workflow_async.return_value = None + + with pytest.raises(HTTPException) as excinfo: + await traces_router.get_workflow_trace( + workflow_id="wfl-missing", + trace_type="spans", + registry=registry, + results_dir=tmp_path, + ) + assert excinfo.value.status_code == 404 + + +def _upload(content: bytes, filename: str = "ignored.jsonl") -> UploadFile: + return UploadFile(file=BytesIO(content), filename=filename) + + +@pytest.mark.anyio +@pytest.mark.parametrize("trace_type", ["spans", "assets", "lineage"]) +async def test_upload_task_trace_writes_named_file( + tmp_path: Path, trace_type: str +) -> None: + payload = b'{"name":"task"}\n' + response = await traces_router.upload_task_trace( + task_id="tsk-up", + trace_type=trace_type, + file=_upload(payload), + results_dir=tmp_path, + ) + target = tmp_path / "tsk-up" / "logs" / f"{trace_type}.jsonl" + assert target.is_file() + assert target.read_bytes() == payload + assert response.path == target.as_posix() + + +@pytest.mark.anyio +async def test_upload_task_trace_ignores_client_filename(tmp_path: Path) -> None: + """Server filename comes from {trace_type}, never from the multipart name.""" + await traces_router.upload_task_trace( + task_id="tsk-x", + trace_type="spans", + file=_upload(b"row\n", filename="../../escape.jsonl"), + results_dir=tmp_path, + ) + assert (tmp_path / "tsk-x" / "logs" / "spans.jsonl").is_file() + assert not (tmp_path / "escape.jsonl").exists() + + +@pytest.mark.anyio +async def test_upload_task_trace_unknown_type_400(tmp_path: Path) -> None: + with pytest.raises(HTTPException) as excinfo: + await traces_router.upload_task_trace( + task_id="tsk-x", + trace_type="bogus", + file=_upload(b""), + results_dir=tmp_path, + ) + assert excinfo.value.status_code == 400 diff --git a/tests/worker/test_data_mixin_lineage.py b/tests/worker/test_data_mixin_lineage.py new file mode 100644 index 0000000..9a92dc3 --- /dev/null +++ b/tests/worker/test_data_mixin_lineage.py @@ -0,0 +1,130 @@ +"""DataMixin tests: span emission + asset/lineage row JSONL writes.""" + +import json +from pathlib import Path +from typing import Any + +from worker.executors.mixins.data import DataMixin + + +class _Mixin(DataMixin): + """Bare-bones DataMixin instance for unit testing.""" + + +def _read_jsonl(path: Path) -> list[dict[str, Any]]: + return [json.loads(line) for line in path.read_text().splitlines() if line.strip()] + + +def _spans_for_task(out_dir: Path) -> list[dict[str, Any]]: + return _read_jsonl(out_dir / "logs" / "spans.jsonl") + + +def test_task_span_emits_root_with_compute_kind(tmp_path: Path) -> None: + mixin = _Mixin() + + out_dir = tmp_path / "task" + with mixin._task_span("tsk-1", "wfl-fbad6be5c4434181a2d394eac830dea1", out_dir): + mixin._log_event("queuing for execution", data_id="tsk-1") + + spans = _spans_for_task(out_dir) + names = [s["name"] for s in spans] + assert "task" in names + assert "queuing for execution" in names + task_row = next(s for s in spans if s["name"] == "task") + assert task_row["attributes"]["data_id"] == "tsk-1" + assert task_row["attributes"]["flowmesh.type"] == "compute" + assert {s["context"]["trace_id"] for s in spans} == { + "0xfbad6be5c4434181a2d394eac830dea1" + } + + +def test_record_asset_and_lineage(tmp_path: Path) -> None: + mixin = _Mixin() + out_dir = tmp_path / "task" + with mixin._task_span("tsk-1", "wfl-1", out_dir): + mixin._record_asset( + data_id="tsk-1", asset_guid="g-1", version=1, user_id="alice" + ) + mixin._record_lineage("tsk-1", ["upstream-a", "upstream-b"]) + + base = out_dir / "logs" + assets = _read_jsonl(base / "assets.jsonl") + assert len(assets) == 1 + assert assets[0]["asset_guid"] == "g-1" + assert assets[0]["user_id"] == "alice" + + lineage = _read_jsonl(base / "lineage.jsonl") + assert len(lineage) == 2 + assert {row["source_data_id"] for row in lineage} == { + "upstream-a", + "upstream-b", + } + + +def test_record_output_emits_dump_span_and_rows(tmp_path: Path) -> None: + mixin = _Mixin() + out_dir = tmp_path / "task-up" + with mixin._task_span("tsk-up", "wfl-1", out_dir, owner_id="alice"): + mixin._record_output( + data_id="tsk-up", + data={"items": [{"output": "ok"}]}, + source_data_ids=["tsk-source-a"], + ) + + base = out_dir / "logs" + assets = _read_jsonl(base / "assets.jsonl") + assert assets and assets[0]["data_id"] == "tsk-up" + assert assets[0]["user_id"] == "alice" + + lineage = _read_jsonl(base / "lineage.jsonl") + assert len(lineage) == 1 + assert lineage[0]["data_id"] == "tsk-up" + assert lineage[0]["source_data_id"] == "tsk-source-a" + + spans = _spans_for_task(out_dir) + dump = [s for s in spans if s["name"] == "dump to storage"] + assert dump + assert dump[0]["attributes"].get("data_id") == "tsk-up" + assert dump[0]["attributes"]["flowmesh.type"] == "network" + assert dump[0]["attributes"].get("payload_bytes", 0) > 0 + + +def test_dump_to_governance_with_merged_children(tmp_path: Path) -> None: + mixin = _Mixin() + out_dir = tmp_path / "task" + with mixin._task_span("tsk-parent", "wfl-1", out_dir, owner_id="alice"): + result = { + "ok": True, + "items": [{"output": "p"}], + "children": { + "tsk-c1": {"items": [{"output": "c1"}]}, + "tsk-c2": {"items": [{"output": "c2"}]}, + }, + } + deps = { + "tsk-parent": ["tsk-up-a"], + "tsk-c1": ["tsk-up-b"], + "tsk-c2": ["tsk-up-c"], + } + mixin._dump_to_governance( + task_id="tsk-parent", + result=result, + dependencies_by_task=deps, + ) + + base = out_dir / "logs" + assets = _read_jsonl(base / "assets.jsonl") + assert {row["data_id"] for row in assets} == { + "tsk-parent", + "tsk-c1", + "tsk-c2", + } + assert all(row["user_id"] == "alice" for row in assets) + + lineage = _read_jsonl(base / "lineage.jsonl") + edges = {(row["data_id"], row["source_data_id"]) for row in lineage} + assert edges == { + ("tsk-parent", "tsk-up-a"), + ("tsk-c1", "tsk-up-b"), + ("tsk-c2", "tsk-up-c"), + } diff --git a/uv.lock b/uv.lock index a1c678c..88931b0 100644 --- a/uv.lock +++ b/uv.lock @@ -1910,6 +1910,7 @@ source = { editable = "cli" } dependencies = [ { name = "flowmesh-sdk" }, { name = "pyyaml" }, + { name = "rich" }, { name = "typer" }, { name = "websockets" }, ] @@ -1924,6 +1925,7 @@ requires-dist = [ { name = "flowmesh-cli-stack", marker = "extra == 'stack'", editable = "cli/stack" }, { name = "flowmesh-sdk", editable = "sdk" }, { name = "pyyaml", specifier = ">=6.0.2" }, + { name = "rich", specifier = ">=14.2.0" }, { name = "typer", specifier = ">=0.12.5" }, { name = "websockets", specifier = ">=15.0" }, ] @@ -1954,6 +1956,7 @@ version = "0.1.0" source = { editable = "sdk" } dependencies = [ { name = "httpx" }, + { name = "pandas" }, { name = "pydantic" }, { name = "pyyaml" }, ] @@ -1961,6 +1964,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "httpx", specifier = ">=0.27.0" }, + { name = "pandas", specifier = ">=2.3.3" }, { name = "pydantic", specifier = ">=2.0.0" }, { name = "pyyaml", specifier = ">=6.0.0" }, ]