# DAG Visualization with daft-func and anywidget

This notebook demonstrates how to build a daft-func pipeline and visualize it as an interactive DAG using anywidget.

## 1. Build a Simple Pipeline with daft-func

In [1]:
from daft_func import Pipeline, func
from typing import Dict, List
from pydantic import BaseModel

In [2]:
# Define data models
class Query(BaseModel):
    id: str
    text: str


class Document(BaseModel):
    id: str
    content: str
    score: float = 0.0


class Result(BaseModel):
    query_id: str
    documents: List[Document]
    final_score: float

In [3]:
# Define pipeline functions
@func(output="corpus_indexed", cache=True)
def index_corpus(corpus: Dict[str, str]) -> bool:
    """Index the document corpus."""
    print(f"Indexing {len(corpus)} documents...")
    return True


@func(output="embeddings", cache=True)
def compute_embeddings(
    corpus: Dict[str, str], corpus_indexed: bool
) -> Dict[str, List[float]]:
    """Compute embeddings for documents."""
    print("Computing embeddings...")
    return {doc_id: [0.1, 0.2, 0.3] for doc_id in corpus.keys()}


@func(output="query_embedding", map_axis="query", key_attr="id", cache=True)
def embed_query(query: Query) -> List[float]:
    """Embed a query."""
    return [0.15, 0.25, 0.35]


@func(output="retrieved_docs", map_axis="query", key_attr="id")
def retrieve(
    query: Query,
    embeddings: Dict[str, List[float]],
    query_embedding: List[float],
    top_k: int,
) -> List[Document]:
    """Retrieve top-k documents for a query."""
    docs = [
        Document(id=doc_id, content=f"Content {doc_id}", score=0.8)
        for doc_id in list(embeddings.keys())[:top_k]
    ]
    return docs


@func(output="reranked_docs", map_axis="query", key_attr="id")
def rerank(query: Query, retrieved_docs: List[Document]) -> List[Document]:
    """Rerank retrieved documents."""
    for doc in retrieved_docs:
        doc.score *= 1.2
    return sorted(retrieved_docs, key=lambda d: d.score, reverse=True)


@func(output="final_result", map_axis="query", key_attr="id")
def generate_result(query: Query, reranked_docs: List[Document]) -> Result:
    """Generate final result."""
    avg_score = (
        sum(d.score for d in reranked_docs) / len(reranked_docs)
        if reranked_docs
        else 0.0
    )
    return Result(query_id=query.id, documents=reranked_docs, final_score=avg_score)

In [4]:
# Create the pipeline
pipeline = Pipeline(
    functions=[
        index_corpus,
        compute_embeddings,
        embed_query,
        retrieve,
        rerank,
        generate_result,
    ]
)

## 2. Extract DAG Structure from Pipeline

We'll extract nodes and edges from the daft-func pipeline to visualize.

In [5]:
def extract_dag_structure(pipeline: Pipeline):
    """Extract nodes and edges from a daft-func pipeline."""
    nodes = []
    edges = []

    for node_def in pipeline.nodes:
        # Create node
        node_id = node_def.meta.output_name
        label = node_def.fn.__name__

        # Add metadata
        metadata = {
            "cached": node_def.meta.cache,
            "map_axis": node_def.meta.map_axis,
        }

        nodes.append({"id": node_id, "label": label, "metadata": metadata})

        # Create edges based on dependencies
        for param in node_def.params:
            if param in pipeline.by_output:
                # This parameter is an output from another node
                edges.append({"from": param, "to": node_id})

    return {"nodes": nodes, "edges": edges}

In [6]:
# Extract the DAG
dag_data = extract_dag_structure(pipeline)
print(f"Nodes: {len(dag_data['nodes'])}")
print(f"Edges: {len(dag_data['edges'])}")
print("\nNodes:", [n["id"] for n in dag_data["nodes"]])
print("\nEdges:", [(e["from"], e["to"]) for e in dag_data["edges"]])

Nodes: 6
Edges: 5

Nodes: ['corpus_indexed', 'embeddings', 'query_embedding', 'retrieved_docs', 'reranked_docs', 'final_result']

Edges: [('corpus_indexed', 'embeddings'), ('embeddings', 'retrieved_docs'), ('query_embedding', 'retrieved_docs'), ('retrieved_docs', 'reranked_docs'), ('reranked_docs', 'final_result')]


## 3. Create Interactive DAG Visualization Widget

We'll use anywidget to create an interactive DAG visualization with:
- Automatic layered layout (topological sorting)
- Color-coded nodes (cached vs non-cached, mapped vs single)
- SVG-based rendering with smooth animations
- Interactive hover effects

In [7]:
import anywidget
import traitlets

In [None]:
class DAGVisualizationWidget(anywidget.AnyWidget):
    _esm = """
    function render({ model, el }) {
        el.innerHTML = `
            <style>
                .dag-container {
                    width: 100%;
                    background: #f9fafb;
                    border-radius: 8px;
                    border: 1px solid #e5e7eb;
                    padding: 20px;
                    overflow-x: auto;
                }
                .dag-controls {
                    margin-bottom: 15px;
                    display: flex;
                    gap: 10px;
                    align-items: center;
                }
                .run-button {
                    padding: 8px 16px;
                    background: #3b82f6;
                    color: white;
                    border: none;
                    border-radius: 6px;
                    font-weight: 600;
                    cursor: pointer;
                    transition: background 0.2s;
                }
                .run-button:hover {
                    background: #2563eb;
                }
                .run-button:disabled {
                    background: #9ca3af;
                    cursor: not-allowed;
                }
                .reset-button {
                    padding: 8px 16px;
                    background: #6b7280;
                    color: white;
                    border: none;
                    border-radius: 6px;
                    font-weight: 600;
                    cursor: pointer;
                    transition: background 0.2s;
                }
                .reset-button:hover {
                    background: #4b5563;
                }
                .status-text {
                    font-size: 14px;
                    color: #4b5563;
                    font-weight: 500;
                }
                .dag-svg {
                    display: block;
                    margin: 0 auto;
                }
                .node-group:hover .node-rect {
                    stroke-width: 3;
                    filter: brightness(1.1);
                }
                .node-rect {
                    transition: all 0.2s ease;
                    cursor: pointer;
                }
                .node-rect-running {
                    animation: pulse-node 1.5s ease-in-out infinite;
                }
                @keyframes pulse-node {
                    0%, 100% { opacity: 1; }
                    50% { opacity: 0.6; }
                }
                .node-text {
                    pointer-events: none;
                    user-select: none;
                }
                .edge-path {
                    fill: none;
                    stroke: #6b7280;
                    stroke-width: 2;
                    marker-end: url(#arrowhead);
                }
                .legend {
                    margin-top: 20px;
                    padding: 15px;
                    background: white;
                    border-radius: 6px;
                    border: 1px solid #e5e7eb;
                }
                .legend-title {
                    font-weight: 600;
                    margin-bottom: 10px;
                    color: #1f2937;
                }
                .legend-items {
                    display: flex;
                    gap: 20px;
                    flex-wrap: wrap;
                }
                .legend-item {
                    display: flex;
                    align-items: center;
                    gap: 8px;
                    font-size: 14px;
                    color: #4b5563;
                }
                .legend-box {
                    width: 30px;
                    height: 20px;
                    border-radius: 4px;
                    border: 2px solid #374151;
                }
            </style>
            <div class="dag-container">
                <div class="dag-controls">
                    <button class="run-button" id="run-button">▶ Run Simulation</button>
                    <button class="reset-button" id="reset-button">↻ Reset</button>
                    <span class="status-text" id="status-text"></span>
                </div>
                <svg class="dag-svg" id="dag-svg"></svg>
                <div class="legend">
                    <div class="legend-title">Legend</div>
                    <div class="legend-items">
                        <div class="legend-item">
                            <div class="legend-box" style="background: #93c5fd;"></div>
                            <span>Cached Node</span>
                        </div>
                        <div class="legend-item">
                            <div class="legend-box" style="background: #d1d5db;"></div>
                            <span>Regular Node</span>
                        </div>
                        <div class="legend-item">
                            <div class="legend-box" style="background: #fde68a; border-style: dashed;"></div>
                            <span>Mapped Node</span>
                        </div>
                        <div class="legend-item">
                            <div class="legend-box" style="background: #fbbf24;"></div>
                            <span>Running</span>
                        </div>
                        <div class="legend-item">
                            <div class="legend-box" style="background: #86efac;"></div>
                            <span>Completed</span>
                        </div>
                    </div>
                </div>
            </div>
        `;
        
        const runButton = el.querySelector('#run-button');
        const resetButton = el.querySelector('#reset-button');
        const statusText = el.querySelector('#status-text');
        
        // Handle run button click
        runButton.addEventListener('click', () => {
            model.set('trigger_run', model.get('trigger_run') + 1);
            model.save_changes();
        });
        
        // Handle reset button click
        resetButton.addEventListener('click', () => {
            model.set('node_status', {});
            model.save_changes();
            statusText.textContent = '';
        });
        
        function computeLayout(nodes, edges) {
            // Build adjacency list and in-degree map
            const adjList = {};
            const inDegree = {};
            const nodeMap = {};
            
            nodes.forEach(node => {
                nodeMap[node.id] = node;
                adjList[node.id] = [];
                inDegree[node.id] = 0;
            });
            
            edges.forEach(edge => {
                if (adjList[edge.from] && inDegree[edge.to] !== undefined) {
                    adjList[edge.from].push(edge.to);
                    inDegree[edge.to]++;
                }
            });
            
            // Topological sort to determine layers
            const layers = [];
            const queue = [];
            const nodeLayer = {};
            
            // Start with nodes that have no dependencies
            Object.keys(inDegree).forEach(nodeId => {
                if (inDegree[nodeId] === 0) {
                    queue.push(nodeId);
                    nodeLayer[nodeId] = 0;
                }
            });
            
            while (queue.length > 0) {
                const nodeId = queue.shift();
                const layer = nodeLayer[nodeId];
                
                if (!layers[layer]) layers[layer] = [];
                layers[layer].push(nodeId);
                
                adjList[nodeId].forEach(nextId => {
                    inDegree[nextId]--;
                    if (inDegree[nextId] === 0) {
                        queue.push(nextId);
                        nodeLayer[nextId] = layer + 1;
                    }
                });
            }
            
            // Compute positions
            const positions = {};
            const nodeWidth = 150;
            const nodeHeight = 60;
            const layerSpacing = 200;
            const nodeSpacing = 100;
            
            layers.forEach((layer, layerIdx) => {
                const layerHeight = layer.length * (nodeHeight + nodeSpacing);
                layer.forEach((nodeId, idx) => {
                    positions[nodeId] = {
                        x: layerIdx * layerSpacing,
                        y: idx * (nodeHeight + nodeSpacing) - layerHeight / 2 + nodeHeight / 2
                    };
                });
            });
            
            return { positions, nodeWidth, nodeHeight };
        }
        
        function getNodeColor(node, status) {
            // Status takes precedence
            if (status === 'running') {
                return '#fbbf24'; // amber for running
            } else if (status === 'completed') {
                return '#86efac'; // green for completed
            }
            
            // Otherwise use metadata-based colors
            if (node.metadata.cached && node.metadata.map_axis) {
                return '#a5f3fc'; // cyan for cached + mapped
            } else if (node.metadata.cached) {
                return '#93c5fd'; // blue for cached
            } else if (node.metadata.map_axis) {
                return '#fde68a'; // yellow for mapped
            }
            return '#d1d5db'; // gray for regular
        }
        
        function getBorderStyle(node) {
            return node.metadata.map_axis ? '4, 4' : '0';
        }
        
        function updateStatusText() {
            const nodeStatus = model.get('node_status') || {};
            const statusValues = Object.values(nodeStatus);
            const runningCount = statusValues.filter(s => s === 'running').length;
            const completedCount = statusValues.filter(s => s === 'completed').length;
            
            if (runningCount > 0) {
                statusText.textContent = `Running... (${completedCount} completed, ${runningCount} in progress)`;
                runButton.disabled = true;
            } else if (completedCount > 0) {
                statusText.textContent = `✓ Completed (${completedCount} nodes)`;
                runButton.disabled = false;
            } else {
                statusText.textContent = '';
                runButton.disabled = false;
            }
        }
        
        function renderDAG() {
            const data = model.get('dag_data');
            if (!data || !data.nodes || !data.edges) return;
            
            const { nodes, edges } = data;
            const nodeStatus = model.get('node_status') || {};
            const { positions, nodeWidth, nodeHeight } = computeLayout(nodes, edges);
            
            updateStatusText();
            
            // Calculate SVG dimensions
            const margin = 50;
            let minX = Infinity, maxX = -Infinity;
            let minY = Infinity, maxY = -Infinity;
            
            Object.values(positions).forEach(pos => {
                minX = Math.min(minX, pos.x);
                maxX = Math.max(maxX, pos.x);
                minY = Math.min(minY, pos.y);
                maxY = Math.max(maxY, pos.y);
            });
            
            const width = maxX - minX + nodeWidth + 2 * margin;
            const height = maxY - minY + nodeHeight + 2 * margin;
            const offsetX = -minX + margin;
            const offsetY = -minY + margin;
            
            const svg = document.getElementById('dag-svg');
            svg.setAttribute('width', width);
            svg.setAttribute('height', height);
            
            // Clear previous content
            svg.innerHTML = '';
            
            // Add arrow marker definition
            const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs');
            defs.innerHTML = `
                <marker id="arrowhead" markerWidth="10" markerHeight="10" 
                        refX="9" refY="3" orient="auto">
                    <polygon points="0 0, 10 3, 0 6" fill="#6b7280" />
                </marker>
            `;
            svg.appendChild(defs);
            
            // Draw edges first (so they're behind nodes)
            edges.forEach(edge => {
                const fromPos = positions[edge.from];
                const toPos = positions[edge.to];
                
                if (!fromPos || !toPos) return;
                
                const x1 = fromPos.x + offsetX + nodeWidth;
                const y1 = fromPos.y + offsetY + nodeHeight / 2;
                const x2 = toPos.x + offsetX;
                const y2 = toPos.y + offsetY + nodeHeight / 2;
                
                // Create curved path
                const midX = (x1 + x2) / 2;
                const path = `M ${x1} ${y1} Q ${midX} ${y1}, ${midX} ${(y1 + y2) / 2} Q ${midX} ${y2}, ${x2} ${y2}`;
                
                const pathEl = document.createElementNS('http://www.w3.org/2000/svg', 'path');
                pathEl.setAttribute('d', path);
                pathEl.setAttribute('class', 'edge-path');
                svg.appendChild(pathEl);
            });
            
            // Draw nodes
            nodes.forEach(node => {
                const pos = positions[node.id];
                if (!pos) return;
                
                const x = pos.x + offsetX;
                const y = pos.y + offsetY;
                const status = nodeStatus[node.id] || 'idle';
                
                const group = document.createElementNS('http://www.w3.org/2000/svg', 'g');
                group.setAttribute('class', 'node-group');
                
                // Node rectangle
                const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
                rect.setAttribute('x', x);
                rect.setAttribute('y', y);
                rect.setAttribute('width', nodeWidth);
                rect.setAttribute('height', nodeHeight);
                rect.setAttribute('rx', 8);
                rect.setAttribute('fill', getNodeColor(node, status));
                rect.setAttribute('stroke', '#374151');
                rect.setAttribute('stroke-width', 2);
                rect.setAttribute('stroke-dasharray', getBorderStyle(node));
                rect.setAttribute('class', status === 'running' ? 'node-rect node-rect-running' : 'node-rect');
                group.appendChild(rect);
                
                // Node label (function name)
                const text = document.createElementNS('http://www.w3.org/2000/svg', 'text');
                text.setAttribute('x', x + nodeWidth / 2);
                text.setAttribute('y', y + nodeHeight / 2 - 5);
                text.setAttribute('text-anchor', 'middle');
                text.setAttribute('dominant-baseline', 'middle');
                text.setAttribute('class', 'node-text');
                text.setAttribute('font-weight', '600');
                text.setAttribute('font-size', '14');
                text.setAttribute('fill', '#1f2937');
                text.textContent = node.label;
                group.appendChild(text);
                
                // Output name (smaller, below label)
                const outputText = document.createElementNS('http://www.w3.org/2000/svg', 'text');
                outputText.setAttribute('x', x + nodeWidth / 2);
                outputText.setAttribute('y', y + nodeHeight / 2 + 12);
                outputText.setAttribute('text-anchor', 'middle');
                outputText.setAttribute('dominant-baseline', 'middle');
                outputText.setAttribute('class', 'node-text');
                outputText.setAttribute('font-size', '11');
                outputText.setAttribute('fill', '#6b7280');
                outputText.textContent = node.id;
                group.appendChild(outputText);
                
                // Add tooltip on hover
                const title = document.createElementNS('http://www.w3.org/2000/svg', 'title');
                const tooltipText = [
                    `Function: ${node.label}`,
                    `Output: ${node.id}`,
                    node.metadata.cached ? 'Cached: Yes' : '',
                    node.metadata.map_axis ? `Map Axis: ${node.metadata.map_axis}` : '',
                    status !== 'idle' ? `Status: ${status}` : ''
                ].filter(Boolean).join('\\n');
                title.textContent = tooltipText;
                group.appendChild(title);
                
                svg.appendChild(group);
            });
        }
        
        renderDAG();
        model.on('change:dag_data', renderDAG);
        model.on('change:node_status', renderDAG);
    }
    export default { render };
    """

    dag_data = traitlets.Dict({}).tag(sync=True)
    node_status = traitlets.Dict({}).tag(sync=True)
    trigger_run = traitlets.Int(0).tag(sync=True)

    def __init__(self, pipeline=None, **kwargs):
        super().__init__(**kwargs)
        self._pipeline = pipeline
        self._running = False
        self.observe(self._on_trigger_run, names=["trigger_run"])

    async def _simulate_execution(self):
        """Simulate pipeline execution with real-time status updates."""
        if self._running or not self._pipeline:
            return

        self._running = True

        # Get execution order via topological sort
        dummy_inputs = {"corpus": {}, "query": [], "top_k": 2, "source": "dummy"}

        try:
            ordered_nodes = self._pipeline.topo(dummy_inputs)
        except RuntimeError:
            # If topo fails, just use the order they were added
            ordered_nodes = self._pipeline.nodes

        # Reset all nodes to idle
        self.node_status = {}
        await asyncio.sleep(0.5)

        # Execute each node in order
        for node in ordered_nodes:
            node_id = node.meta.output_name

            # Mark as running
            self.node_status = {**self.node_status, node_id: "running"}
            await asyncio.sleep(0.8)  # Simulate execution time

            # Mark as completed
            self.node_status = {**self.node_status, node_id: "completed"}
            await asyncio.sleep(0.2)  # Brief pause between nodes

        self._running = False

    def _on_trigger_run(self, change):
        """Handle run button click."""
        import asyncio

        asyncio.create_task(self._simulate_execution())

## 4. Visualize the Pipeline DAG

**Interactive Controls**: Click the **"▶ Run Simulation"** button to watch the pipeline execute in real-time! The visualization updates in-place showing which nodes are running (amber, pulsing) and which have completed (green).

In [23]:
# Create and display the widget with built-in simulation capability
dag_widget = DAGVisualizationWidget(pipeline=pipeline)
dag_widget.dag_data = dag_data
dag_widget

DAGVisualizationWidget(dag_data={'nodes': [{'id': 'corpus_indexed', 'label': 'index_corpus', 'metadata': {'cac…

## 5. Run the Pipeline

Now let's actually execute the pipeline to see it in action.

In [10]:
from daft_func import Runner, CacheConfig, DiskCache

# Create runner
runner = Runner(
    pipeline=pipeline,
    mode="auto",
    cache_config=CacheConfig(enabled=True, backend=DiskCache(cache_dir=".cache")),
)

In [11]:
# Define inputs
corpus = {
    "d1": "machine learning is fascinating",
    "d2": "deep learning models are powerful",
    "d3": "natural language processing is important",
}

queries = [
    Query(id="q1", text="machine learning"),
    Query(id="q2", text="deep learning"),
]

inputs = {
    "corpus": corpus,
    "query": queries,
    "top_k": 2,
}

In [12]:
# Run the pipeline
result = runner.run(inputs=inputs)
print("\nFinal results:")
for res in result["final_result"]:
    print(
        f"  Query {res.query_id}: {len(res.documents)} docs, score={res.final_score:.2f}"
    )

Indexing 3 documents...
Computing embeddings...

Final results:
  Query q1: 2 docs, score=0.96
  Query q2: 2 docs, score=0.96


## 6. Try a Different Pipeline Structure

Let's create a simpler pipeline to see how the visualization adapts.

In [13]:
@func(output="raw_data")
def load_data(source: str) -> List[int]:
    """Load raw data from source."""
    return [1, 2, 3, 4, 5]


@func(output="cleaned_data")
def clean(raw_data: List[int]) -> List[int]:
    """Clean the data."""
    return [x for x in raw_data if x > 0]


@func(output="features")
def extract_features(cleaned_data: List[int]) -> List[float]:
    """Extract features."""
    return [float(x) * 2.0 for x in cleaned_data]


@func(output="scaled_features")
def scale(features: List[float]) -> List[float]:
    """Scale features."""
    max_val = max(features)
    return [f / max_val for f in features]


@func(output="model")
def train_model(scaled_features: List[float]) -> str:
    """Train a model."""
    return "trained_model"


@func(output="predictions")
def predict(model: str, scaled_features: List[float]) -> List[int]:
    """Make predictions."""
    return [1 if f > 0.5 else 0 for f in scaled_features]

In [24]:
# Create simple pipeline
simple_pipeline = Pipeline(
    functions=[load_data, clean, extract_features, scale, train_model, predict]
)

# Extract and visualize with built-in simulation
simple_dag_data = extract_dag_structure(simple_pipeline)
simple_dag_widget = DAGVisualizationWidget(pipeline=simple_pipeline)
simple_dag_widget.dag_data = simple_dag_data
simple_dag_widget

DAGVisualizationWidget(dag_data={'nodes': [{'id': 'raw_data', 'label': 'load_data', 'metadata': {'cached': Fal…

## Summary

We've successfully created an interactive DAG visualization for daft-func pipelines using anywidget!

### Key Features:
- ✓ **Automatic Layout**: Uses topological sorting for layered layout
- ✓ **Color Coding**: Different colors for cached, mapped, and regular nodes
- ✓ **Built-in Simulation Controls**: Click "▶ Run Simulation" to watch execution in real-time
  - 🟡 Amber = Currently running (with pulsing animation)
  - 🟢 Green = Completed
  - 🔵 Blue = Cached nodes
  - 🟡 Yellow = Mapped nodes (dashed border)
- ✓ **Interactive Controls**: 
  - Run button triggers pipeline simulation
  - Reset button clears execution state
  - Status indicator shows progress
- ✓ **Hover Effects**: Tooltips show node metadata
- ✓ **SVG Rendering**: Smooth curves for edges with proper arrows
- ✓ **Responsive**: Adapts to different DAG structures

### Layout Algorithm:
The visualization uses a simple but effective layered layout algorithm:
1. **Topological Sort**: Assigns each node to a layer based on its dependencies
2. **Layer-wise Positioning**: Nodes in the same layer are vertically distributed
3. **Horizontal Spacing**: Layers are spaced horizontally for clarity

This approach works well for DAGs and doesn't require external libraries like dagre or d3-dag!

### Usage:
Simply click the **"▶ Run Simulation"** button in the widget to see the pipeline execution animated in real-time. The visualization updates in-place, showing which nodes are currently running and which have completed.

You can also manually control execution status via Python:
```python
widget.node_status = {'node_id': 'running'}  # or 'completed'
```