In [None]:
from graphviz import Digraph

def plot_workflow(graph_data):
    """
    Renders a directed workflow graph from a dict of nodes & edges using Graphviz.
    """
    
    # 1) Create the Digraph and set global attrs
    dot = Digraph(
        name='workflow',
        format='png',
        engine='dot'
    )
    dot.attr(
        splines='true',       # smooth curved edges
        overlap='false',      # try to avoid any overlaps
        rankdir='TB',         # Top → Bottom flow
        nodesep='0.6',        # separation between nodes
        ranksep='1.5'         # separation between ranks/layers
    )
    
    # 2) Default node style
    dot.attr('node',
        shape='box',
        style='rounded,filled',
        fontname='Helvetica',
        fontsize='12',
        color='navy'
    )
    
    # 3) Add nodes
    for node_id, node_info in graph_data['nodes'].items():
        label = node_info.get('name', node_id)
        
        # highlight the root instruction node in green
        if label == 'instruction':
            fillcolor = 'lightgreen'
            border_color = 'darkgreen'
        else:
            fillcolor = 'lightblue'
            border_color = 'navy'
        
        dot.node(
            node_id,
            label=label,
            fillcolor=fillcolor,
            color=border_color
        )
    
    # 4) Add edges (with labels)
    for e in graph_data['edges']:
        src, dst = e['from'], e['to']
        lbl = e.get('label', '')
        dot.edge(src, dst, label=lbl, color='blue')

    return dot


In [None]:
import json
from collections import defaultdict

def detect_spec_version(spec):
    """
    Returns 'old' if no action has an 'id' key AND no edge.connection
    has 'from_id' or 'to_id'. Otherwise returns 'new'.
    """
    task = spec.get("task", {})
    actions = task.get("actions", [])
    edges   = task.get("edges", [])

    # Check for any action.id
    has_action_id = any("id" in act for act in actions)

    # Check for any from_id/to_id in edge.connection
    has_edge_ids = any(
        ("from_id" in edge.get("connection", {})) or
        ("to_id"   in edge.get("connection", {}))
        for edge in edges
    )

    return "new" if (has_action_id or has_edge_ids) else "old"

def ensure_list(x):
    if isinstance(x, list):
        return x
    if isinstance(x, str):
        return [p.strip() for p in x.split(",")]
    return [x]

def assign_ids_to_actions(spec):
    """
    Assigns an 'id' to each action only if not already present.
    Existing 'id's are honored, and counts are bumped accordingly.
    """
    # First gather max assigned id per name
    max_id = defaultdict(lambda: -1)
    for act in spec["task"]["actions"]:
        if "id" in act:
            name = act["name"]
            max_id[name] = max(max_id[name], act["id"])

    # Now assign missing ids
    next_id = { name: max_id[name] + 1 for name in max_id }
    for act in spec["task"]["actions"]:
        name = act["name"]
        if "id" not in act:
            # if we haven't seen this name at all, next_id[name] will default to 0
            act["id"] = next_id.setdefault(name, 0)
            next_id[name] += 1

    return spec

def assign_ids_to_edges(spec):
    """
    Assigns 'from_id' and 'to_id' on each edge only if not provided.
    If provided, consumes arguments/outputs from the matching instance.
    """
    actions = spec["task"]["actions"]
    edges   = spec["task"]["edges"]

    # Collect all names mentioned
    action_names = { act["name"] for act in actions }
    edge_from    = { e["from"] for e in edges }
    edge_to      = { e["to"]   for e in edges }
    all_names    = action_names | edge_from | edge_to

    # Build per-instance input-argument trackers
    input_instances = defaultdict(list)
    for act in actions:
        name, aid = act["name"], act["id"]
        args = set(act["arguments"].keys())
        input_instances[name].append({"id": aid, "args_left": set(args)})
    # Ensure every name has at least one instance
    for name in all_names:
        if name not in input_instances:
            input_instances[name].append({"id": 0, "args_left": set()})
    # Sort by id
    for name in input_instances:
        input_instances[name].sort(key=lambda x: x["id"])

    # Build per-instance output trackers
    output_keys = defaultdict(set)
    for e in edges:
        frm  = e["from"]
        outs = ensure_list(e["connection"].get("output", []))
        output_keys[frm].update(outs)

    output_instances = defaultdict(list)
    for name, insts in input_instances.items():
        keys = output_keys.get(name, set())
        for inst in insts:
            output_instances[name].append({
                "id": inst["id"],
                "outs_left": set(keys)
            })
        output_instances[name].sort(key=lambda x: x["id"])

    # Now fill in edge ids, consuming from instances
    for e in edges:
        frm, to = e["from"], e["to"]
        conn     = e["connection"]
        need_out = set(ensure_list(conn.get("output", [])))
        need_in  = set(ensure_list(conn.get("input",  [])))

        # FROM_ID
        if "from_id" in conn:
            # honor provided from_id
            provided = conn["from_id"]
            inst = next((i for i in output_instances[frm] if i["id"] == provided), None)
            if inst is None:
                raise ValueError(f"Edge {e}: invalid from_id={provided} for action '{frm}'")
            inst["outs_left"] -= need_out
        else:
            # assign the earliest instance that still has these outputs
            for inst in output_instances[frm]:
                if inst["outs_left"] & need_out:
                    conn["from_id"]     = inst["id"]
                    inst["outs_left"]  -= need_out
                    break
            else:
                # fallback to last instance
                conn["from_id"] = output_instances[frm][-1]["id"]

        # TO_ID
        if "to_id" in conn:
            provided = conn["to_id"]
            inst = next((i for i in input_instances[to] if i["id"] == provided), None)
            if inst is None:
                raise ValueError(f"Edge {e}: invalid to_id={provided} for action '{to}'")
            inst["args_left"] -= need_in
        else:
            for inst in input_instances[to]:
                if inst["args_left"] & need_in:
                    conn["to_id"]      = inst["id"]
                    inst["args_left"] -= need_in
                    break
            else:
                conn["to_id"] = input_instances[to][-1]["id"]

    return spec


In [None]:
import json
from collections import defaultdict, deque
from IPython.display import display

def merge_dicts(d1, d2):
    """Merge two dicts, error on key collision."""
    overlap = set(d1) & set(d2)
    if overlap:
        raise ValueError(f"Colliding keys: {overlap}")
    new = d1.copy()
    new.update(d2)
    return new

def ensure_list(x):
    """Normalize to a list of strings."""
    if isinstance(x, list):
        return x
    if isinstance(x, str):
        return [p.strip() for p in x.split(",")]
    return [x]

def validate_task_and_build_graph_data(json_path):
    """
    Loads your task.json, applies versioning logic, simulates the BFS data-flow,
    and returns a dict with just:
      - 'nodes': { node_key: { 'name':…, 'arguments':… } }
      - 'edges': [ { 'from':u, 'to':v, 'output':[…], 'input':[…], 'label':… }, … ]
    """
    with open(json_path) as f:
        spec = json.load(f)

    version = detect_spec_version(spec)
    print(f"Detected spec version: {version}")

    if version == "old":
        spec = assign_ids_to_actions(spec)
        spec = assign_ids_to_edges(spec)
    else:
        print("Spec already contains ids; skipping version change.")

    task        = spec["task"]
    actions     = task["actions"]
    edges_input = task["edges"]

    # build function info with duplicate-node ids in the name
    function_info = {}
    for act in actions:
        act_id = act.get("id", 0)
        base  = act["name"]
        name  = f"{base} ({act_id})" if act_id != 0 else base
        key   = f"{base}_{act_id}"
        if key in function_info:
            raise ValueError(
                f"Action with same ID already exists: '{key}'"
            )
        function_info[key] = {
            "name":      name,
            "arguments": act["arguments"].copy(),
            "args_left": len(act["arguments"]),
            "outputs": act.get("outputs", {}),
            "received":  {}
        }
    # pseudo-root
    function_info["instruction_0"] = {
        "name":      "instruction",
        "arguments": {},
        "args_left": 0,
        "received":  {}
    }

    # build adjacency
    function_connections = defaultdict(list)
    for edge in edges_input:
        c   = edge["connection"]
        src = f"{edge['from']}_{c.get('from_id', 0)}"
        tgt = f"{edge['to']}_{c.get('to_id',   0)}"
        outs = ensure_list(c.get("output", []))
        ins  = ensure_list(c.get("input",  []))
        if len(outs) != len(ins):
            raise ValueError(
                f"Edge's input and output length do not match:\n"
                f"Source: {src}\n"
                f"Target: {tgt}\n"
                f"Ouputs: {outs}\n"
                f"Inputs: {ins}\n"
            )
        function_connections[src].append({
            "to":     tgt,
            "output": outs,
            "input":  ins
        })

    # BFS simulation
    queue      = deque([("instruction_0", 0)])
    visited    = {"instruction_0"}
    nodes_dict = {}
    edges_list = []
    levels     = {}

    while queue:
        src, lvl = queue.popleft()
        levels[src] = lvl
        info = function_info[src]
        fn   = info["name"]
        args_dict     = info["arguments"]
        received_args = info["received"]

        if fn != "instruction":
            if set(received_args.keys()) != set(args_dict.keys()):
                print(json.dumps({
                    "function": src,
                    "expected_args": list(args_dict.keys()),
                    "received_args": list(received_args.keys())
                }, indent=2))
                raise ValueError(f"Arg mismatch at {src}: "
                                 f"expected {list(args_dict.keys())}, "
                                 f"got {list(received_args.keys())}")
            print(f"{fn} ({', '.join(args_dict.keys())})")
            
        # record node
        nodes_dict[src] = {
            "name":      fn,
            "arguments": args_dict
        }

        # edges out
        for link in function_connections[src]:
            tgt     = link["to"]
            outputs = link["output"]
            inputs  = link["input"]
            label   = f"{', '.join(outputs)} → {', '.join(inputs)}"

            edges_list.append({
                "from":   src,
                "to":     tgt,
                "output": outputs,
                "input":  inputs,
                "label":  label
            })

            args_map = dict(zip(inputs, outputs))
            try:
                function_info[tgt]["received"] = merge_dicts(
                    function_info[tgt]["received"], args_map
                )
            except ValueError as e:
                print(f"Merge conflict on edge {src} → {tgt}: mapping={args_map}, existing={function_info[tgt]['received']}")
                raise ValueError(f"Argument merge conflict at edge {src}→{tgt}") from e

            function_info[tgt]["args_left"] -= len(outputs)
            if function_info[tgt]["args_left"] < 0:
                print(f"Too many arguments received for {tgt}. Please check the edges.")
            if function_info[tgt]["args_left"] == 0 and tgt not in visited:
                visited.add(tgt)
                queue.append((tgt, lvl + 1))
            

    # edge count check
    actual   = len(edges_list)
    expected = task.get("num_edges", actual)
    if actual != expected:
        print(f"Edge list: {edges_list}")
        raise ValueError(f"Edge count mismatch: expected {expected}, got {actual}")
    
    # verify no actions were left unaccessed by the BFS
    for act in task["actions"]:
        key = f"{act['name']}_{act.get('id', 0)}"
        if key not in levels:
            print(function_info[key]["received"])
            raise ValueError(f"Action {key} was never reached in BFS.")

    # verify that each edge's source action appears before its target action in the original list
    action_positions = {
        f"{act['name']}_{act.get('id', 0)}": idx
        for idx, act in enumerate(actions)
    }

    for edge in edges_list:
        src, tgt = edge['from'], edge['to']
        # skip the pseudo-root, which isn't in actions
        if src == "instruction_0":
            continue
        # if either endpoint isn’t actually an action (just in case), skip
        if src not in action_positions or tgt not in action_positions:
            continue
        if action_positions[src] >= action_positions[tgt]:
            raise ValueError(
                f"Action order violation: '{src}' (pos {action_positions[src]}) "
                f"must come before '{tgt}' (pos {action_positions[tgt]})"
            )

    return {"nodes": nodes_dict, "edges": edges_list}


In [None]:

if __name__ == "__main__":
    task_path = "./test.json"
    graph_data = validate_task_and_build_graph_data(task_path)
    dot = plot_workflow(graph_data)
    display(dot)
