In [8]:
# dagflow_minimal_auto_daft.py
from __future__ import annotations

import inspect
import re
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, get_type_hints

from pydantic import BaseModel

# --------------------------
# Small domain (schemas + ops)
# --------------------------


def normalize(text: str) -> set[str]:
    toks = re.findall(r"[a-zA-Z]+", text.lower())
    return {t[:-1] if t.endswith("s") else t for t in toks}


class Query(BaseModel):
    query_uuid: str
    text: str


class RetrievalHit(BaseModel):
    doc_id: str
    score: float


class RetrievalResult(BaseModel):
    query_uuid: str
    hits: List[RetrievalHit]


class RerankedHit(BaseModel):
    query_uuid: str
    doc_id: str
    score: float


class Retriever(Protocol):
    def retrieve(self, query: Query, top_k: int) -> RetrievalResult: ...


class Reranker(Protocol):
    def rerank(
        self, query: Query, hits: RetrievalResult, top_k: int
    ) -> List[RerankedHit]: ...


class ToyRetriever(Retriever):
    def __init__(self, corpus: Dict[str, str]):
        self._doc_tokens = {doc_id: normalize(txt) for doc_id, txt in corpus.items()}

    def retrieve(self, query: Query, top_k: int) -> RetrievalResult:
        q = normalize(query.text)
        scored = [
            RetrievalHit(doc_id=d, score=float(len(q & toks)))
            for d, toks in self._doc_tokens.items()
        ]
        scored.sort(key=lambda h: h.score, reverse=True)
        return RetrievalResult(query_uuid=query.query_uuid, hits=scored[:top_k])


class IdentityReranker(Reranker):
    def rerank(
        self, query: Query, hits: RetrievalResult, top_k: int
    ) -> List[RerankedHit]:
        return [
            RerankedHit(query_uuid=query.query_uuid, doc_id=h.doc_id, score=h.score)
            for h in hits.hits[:top_k]
        ]


# --------------------------
# Generic DAG decorator + registry
# --------------------------


@dataclass(frozen=True)
class NodeMeta:
    output_name: str
    # For auto-batching: which parameter is the per-item "map axis"?
    map_axis: Optional[str] = None  # e.g., "query"
    # If nodes align by a key (same item across nodes), what attribute on the map_axis carries it?
    key_attr: Optional[str] = None  # e.g., "query_uuid"


@dataclass(frozen=True)
class NodeDef:
    fn: Callable
    meta: NodeMeta
    params: Tuple[str, ...]  # ordered parameter names (from signature)


class DagRegistry:
    def __init__(self):
        self.nodes: List[NodeDef] = []
        self.by_output: Dict[str, NodeDef] = {}

    def add(self, fn: Callable, meta: NodeMeta):
        sig = inspect.signature(fn)
        params = tuple(sig.parameters.keys())
        node = NodeDef(fn=fn, meta=meta, params=params)
        self.nodes.append(node)
        self.by_output[meta.output_name] = node

    def topo(self, initial_inputs: Dict[str, Any]) -> List[NodeDef]:
        """Very small topo-sort by parameter availability."""
        available = set(initial_inputs.keys())
        ordered: List[NodeDef] = []
        remaining = set(self.nodes)
        while remaining:
            progress = False
            for node in list(remaining):
                needed = {p for p in node.params if p not in ("self",)}
                if needed.issubset(available):
                    ordered.append(node)
                    available.add(node.meta.output_name)
                    remaining.remove(node)
                    progress = True
            if not progress:
                raise RuntimeError(
                    f"Cannot resolve dependencies; remaining: {[n.fn.__name__ for n in remaining]}"
                )
        return ordered


_REG = DagRegistry()


def dagflow(
    *, output: str, map_axis: Optional[str] = None, key_attr: Optional[str] = None
):
    """
    Register a DAG node.
    - output: name of the produced value (binds it into the DAG namespace)
    - map_axis: name of the parameter that carries the per-item object (for multi-input runs)
    - key_attr: attribute on the map_axis object that uniquely identifies items (alignment)
    """
    meta = NodeMeta(output_name=output, map_axis=map_axis, key_attr=key_attr)

    def deco(fn: Callable):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            return fn(*args, **kwargs)

        wrapper._dagflow_meta = meta
        _REG.add(wrapper, meta)
        return wrapper

    return deco


# --------------------------
# DAG nodes (generic; tiny)
# --------------------------


@dagflow(output="hits", map_axis="query", key_attr="query_uuid")
def retrieve(retriever: Retriever, query: Query, top_k: int) -> RetrievalResult:
    return retriever.retrieve(query, top_k=top_k)


@dagflow(output="reranked_hits", map_axis="query", key_attr="query_uuid")
def rerank(
    reranker: Reranker, query: Query, hits: RetrievalResult, top_k: int
) -> List[RerankedHit]:
    return reranker.rerank(query, hits, top_k=top_k)


# --------------------------
# Generic Runner
# --------------------------


class Runner:
    """
    - Single item: executes locally (pure Python), returns [[RerankedHit]] for uniformity.
    - Multi items: automatically uses Daft for vectorized execution (falls back to Python if Daft not available).
    """

    def __init__(self, mode: str = "auto", batch_threshold: int = 2):
        self.mode = mode
        self.batch_threshold = batch_threshold

    def run(self, *, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Execute the DAG given initial inputs.
        The only special key is the map_axis object (e.g., "query"); it may be a single item or a list of items.
        """
        # Determine if we are batching based on any node's map_axis param presence + list input
        map_axes = {n.meta.map_axis for n in _REG.nodes if n.meta.map_axis}
        if len(map_axes) > 1:
            raise ValueError(
                f"This minimal runner supports one map axis; found: {map_axes}"
            )
        map_axis = next(iter(map_axes)) if map_axes else None

        batching_requested = False
        if map_axis and map_axis in inputs and isinstance(inputs[map_axis], list):
            batching_requested = True

        if self.mode == "local":
            batching = False
        elif self.mode == "daft":
            batching = True
        else:
            batching = (
                batching_requested and len(inputs[map_axis]) >= self.batch_threshold
                if map_axis
                else False
            )

        return (
            self._run_batch(inputs, map_axis) if batching else self._run_single(inputs)
        )

    # ---- single item path (pure Python) ----
    def _run_single(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        outputs = dict(inputs)
        order = _REG.topo(inputs)
        for node in order:
            kwargs = {p: outputs[p] for p in node.params if p in outputs}
            res = node.fn(**kwargs)
            outputs[node.meta.output_name] = res
        # normalize to batched shape if final output is present
        if "reranked_hits" in outputs and isinstance(outputs["reranked_hits"], list):
            outputs["reranked_hits"] = [outputs["reranked_hits"]]
        return outputs

    # ---- multi item path (Daft if available; else Python loop) ----
    def _run_batch(
        self, inputs: Dict[str, Any], map_axis: Optional[str]
    ) -> Dict[str, Any]:
        # Safety: require a single map axis for this minimal demo
        assert map_axis, "No map axis configured but batching was requested."

        items: List[BaseModel] = inputs[map_axis]  # e.g., List[Query]
        constants = {k: v for k, v in inputs.items() if k != map_axis}

        try:
            import daft
        except Exception:
            # Fallback to Python loop using the same single-run logic per item
            aggregated: List[Dict[str, Any]] = []
            for it in items:
                per_inputs = {**constants, map_axis: it}
                per_out = self._run_single(per_inputs)
                aggregated.append(per_out)
            # Merge structure: final outputs to lists
            merged: Dict[str, Any] = dict(constants)
            # For final output example, ensure list-of-lists
            merged["reranked_hits"] = [o["reranked_hits"][0] for o in aggregated]
            return merged

        # Build initial Daft DF with one column for the map_axis (dictified Pydantic)
        df = daft.from_pylist([{map_axis: it.model_dump()} for it in items])

        order = _REG.topo(
            {**constants, map_axis: items[0]}
        )  # topo over a "sample" shape

        # We'll iteratively add columns to df for each node's output.
        for node in order:
            # Collect columns for this node call: mapped param -> series; constants -> literals; deps -> series
            arg_names = node.params

            # Build a list of expressions to feed into the batch UDF in the right order
            series_args: List[Any] = []
            serializers: List[Callable[[Any], Dict]] = []
            deserializers: List[Callable[[Dict], Any]] = []

            # Figure out types from hints (for Pydantic reconstruction)
            hints = get_type_hints(node.fn)

            def _mk_deser(py_type):
                # Minimal: if subclass of BaseModel, rehydrate; else, passthrough
                try:
                    if isinstance(py_type, type) and issubclass(py_type, BaseModel):
                        return lambda d: py_type.model_validate(d)
                except Exception:
                    pass
                return lambda x: x

            for name in arg_names:
                if name == node.meta.map_axis:
                    series_args.append(df[map_axis])
                    deserializers.append(_mk_deser(hints.get(name, Any)))
                    serializers.append(
                        lambda x: x.model_dump() if isinstance(x, BaseModel) else x
                    )
                elif name in constants:
                    # constants captured via closure (not as series)
                    series_args.append(None)
                    deserializers.append(lambda x: x)  # not used
                    serializers.append(lambda x: x)  # not used
                elif name in df.column_names:
                    series_args.append(df[name])
                    deserializers.append(_mk_deser(hints.get(name, Any)))
                    serializers.append(
                        lambda x: x.model_dump() if isinstance(x, BaseModel) else x
                    )
                else:
                    raise RuntimeError(
                        f"Parameter '{name}' not found among inputs/columns for node {node.fn.__name__}"
                    )

            # Define a batch UDF that rebuilds per-row kwargs, calls node.fn, returns per-row result (dict)
            @daft.func.batch(return_dtype=daft.DataType.python())
            def _apply_batch(*cols):
                # cols only includes series (constants were filtered out)
                # Convert Series to pylist per column (Daft passes daft.Series)
                ser_lists: List[List[Any]] = []
                for c in cols:
                    ser_lists.append(c.to_pylist())

                out_list: List[Any] = []
                rows = len(ser_lists[0]) if ser_lists else len(items)
                for i in range(rows):
                    kwargs: Dict[str, Any] = {}
                    arg_idx = 0
                    for idx, name in enumerate(arg_names):
                        if name == node.meta.map_axis:
                            raw = ser_lists[arg_idx][i]
                            kwargs[name] = deserializers[idx](raw)
                            arg_idx += 1
                        elif name in constants:
                            kwargs[name] = constants[name]
                            # Don't increment arg_idx for constants
                        else:
                            # This is a dependent column from df
                            raw = ser_lists[arg_idx][i]
                            kwargs[name] = deserializers[idx](raw)
                            arg_idx += 1
                    # call the original node function
                    res = node.fn(**kwargs)
                    # store as dict (Pydantic -> dict; list[Pydantic] -> list[dict])
                    if isinstance(res, BaseModel):
                        out_list.append(res.model_dump())
                    elif (
                        isinstance(res, list) and res and isinstance(res[0], BaseModel)
                    ):
                        out_list.append([r.model_dump() for r in res])
                    else:
                        out_list.append(res)
                return out_list

            # Build a new DataFrame with the new column appended
            # We must pass only the Series columns to _apply_batch
            call_series = [c for c in series_args if c is not None]
            new_col_expr = _apply_batch(*call_series)
            # Keep existing columns + add the new one
            keep_cols = [df[c].alias(c) for c in df.column_names]
            df = df.select(*keep_cols, new_col_expr.alias(node.meta.output_name))

        # Finalize: collect wanted outputs
        out_py = df.to_pylist()  # list of dicts per row, includes all columns
        merged: Dict[str, Any] = dict(constants)
        # normalize final output to List[List[RerankedHit]]
        final_name = order[-1].meta.output_name
        merged[final_name] = [
            [RerankedHit.model_validate(d) for d in row[final_name]] for row in out_py
        ]
        return merged


# --------------------------
# Demo
# --------------------------

if __name__ == "__main__":
    corpus = {
        "d1": "a quick brown fox jumps",
        "d2": "brown dog sleeps",
        "d3": "five boxing wizards jump quickly",
    }
    retriever = ToyRetriever(corpus)
    reranker = IdentityReranker()
    runner = Runner(mode="auto", batch_threshold=2)

    # Single input (local)
    single_inputs = {
        "retriever": retriever,
        "reranker": reranker,
        "query": Query(query_uuid="q1", text="quick brown"),
        "top_k": 2,  # shared by both nodes
    }
    single_out = runner.run(inputs=single_inputs)
    print("SINGLE (uniform batched shape):", single_out["reranked_hits"])

    # Multi input (auto → Daft if available, else Python fallback)
    multi_inputs = {
        "retriever": retriever,
        "reranker": reranker,
        "query": [
            Query(query_uuid="q1", text="quick brown"),
            Query(query_uuid="q2", text="wizards jump"),
            Query(query_uuid="q3", text="brown dog"),
        ],
        "top_k": 2,
    }
    multi_out = runner.run(inputs=multi_inputs)
    print("MULTI:", multi_out["reranked_hits"])


SINGLE (uniform batched shape): [[RerankedHit(query_uuid='q1', doc_id='d1', score=2.0), RerankedHit(query_uuid='q1', doc_id='d2', score=1.0)]]
                                                      d
                                                      d
[A
[A
[A
[A
[A

Error when running pipeline node UDF _apply_batch-05e1bacb-3579-49c0-ac0c-cd769e92815a


IndexError: list index out of range

In [9]:
# Debug: Let's trace the issue
corpus = {
    "d1": "a quick brown fox jumps",
    "d2": "brown dog sleeps",
    "d3": "five boxing wizards jump quickly",
}
retriever = ToyRetriever(corpus)
reranker = IdentityReranker()

# Check what parameters each node has
for node in _REG.nodes:
    print(f"Node: {node.fn.__name__}")
    print(f"  Output: {node.meta.output_name}")
    print(f"  Map axis: {node.meta.map_axis}")
    print(f"  Params: {node.params}")
    print()

Node: retrieve
  Output: hits
  Map axis: query
  Params: ('retriever', 'query', 'top_k')

Node: rerank
  Output: reranked_hits
  Map axis: query
  Params: ('reranker', 'query', 'hits', 'top_k')



In [10]:
# More detailed debugging - let's manually trace what should happen
import daft

# Simulate first node (retrieve)
items_test = [
    Query(query_uuid="q1", text="quick brown"),
    Query(query_uuid="q2", text="wizards jump"),
]
df_test = daft.from_pylist([{"query": it.model_dump()} for it in items_test])
print("Initial df columns:", df_test.column_names)
print()

# After first node, we'd have query + hits columns
# Let's check indexing for second node (rerank)
node = _REG.by_output["reranked_hits"]
constants_test = {"retriever": retriever, "reranker": reranker, "top_k": 2}

print(f"Node {node.fn.__name__} params: {node.params}")
print(f"Map axis: {node.meta.map_axis}")
print()

# Simulate what series_args would look like
for idx, name in enumerate(node.params):
    if name == node.meta.map_axis:
        print(f"{idx}. {name} -> map_axis (SERIES)")
    elif name in constants_test:
        print(f"{idx}. {name} -> constant (NO SERIES)")
    else:
        print(f"{idx}. {name} -> dependent column (SERIES)")

print()
print("So call_series should have 2 items: [query_series, hits_series]")
print("And when iterating with enumerate:")
print("  idx=0, name='reranker' -> constant, don't use arg_idx")
print("  idx=1, name='query' -> map_axis, use arg_idx=0, then increment to 1")
print("  idx=2, name='hits' -> dependent, use arg_idx=1, then increment to 2")
print("  idx=3, name='top_k' -> constant, don't use arg_idx")
print()
print("But arg_idx=2 is out of range for a 2-item list!")

Initial df columns: ['query']

Node rerank params: ('reranker', 'query', 'hits', 'top_k')
Map axis: query

0. reranker -> constant (NO SERIES)
1. query -> map_axis (SERIES)
2. hits -> dependent column (SERIES)
3. top_k -> constant (NO SERIES)

So call_series should have 2 items: [query_series, hits_series]
And when iterating with enumerate:
  idx=0, name='reranker' -> constant, don't use arg_idx
  idx=1, name='query' -> map_axis, use arg_idx=0, then increment to 1
  idx=2, name='hits' -> dependent, use arg_idx=1, then increment to 2
  idx=3, name='top_k' -> constant, don't use arg_idx

But arg_idx=2 is out of range for a 2-item list!
