In [157]:
import ast
import json
from pathlib import Path

import networkx as nx


In [158]:
graphs = Path('eval/graphs').iterdir()

G = nx.read_gexf(next(g for g in graphs if g.suffix == '.gexf'))
G

<networkx.classes.digraph.DiGraph at 0x7c69e17d5130>

In [159]:
tools = {}
with open('frankenstein/tools/tool_schema.jsonl') as f:
    for line in f:
        obj = json.loads(line)
        fn = obj['function']
        name = fn['name']
        params = fn['parameters']
        required = set(params.get('required', []))
        properties = params.get('properties', {})
        types = {k: v.get('type', 'string') for k, v in properties.items()}
        tools[name] = {'required': required, 'types': types}


In [160]:
for node, data in G.nodes(data=True):
    # data.pop('call_order', None)
    # data.pop('id', None)
    short_node_name = node[-5:]
    print(f'Node {short_node_name} has attributes {data}.')
    if 'label' in data:
        action = data['label']
        if action in tools:
            tool_info = tools[action]
            required = list(tool_info['required'])

            node_tool_called = data.get('tool_called', None)
            node_args_used = {k.removeprefix('arg_'): value for k, value in data.items() if k.startswith('arg_')}
            node_tool_result = data.get('result', None)

            # Gather all incoming edge labels for this node
            incoming_edges = list(G.in_edges(node, data=True))
            incoming_labels = set()
            for src, dst, edge_data in incoming_edges:
                edge_label = edge_data.get('label', None)
                if edge_label and '=' in edge_label:
                    arg_name = edge_label.split('=')[0]
                    incoming_labels.add(arg_name)
                elif edge_label:
                    incoming_labels.add(edge_label)

            # If any arg value is a stringified list/dict, recover it with ast.literal_eval

            for arg, value in node_args_used.items():
                if isinstance(value, str):
                    try:
                        # Only try to parse if it looks like a list/dict
                        if (value.startswith('[') and value.endswith(']')) or (value.startswith('{') and value.endswith('}')):
                            node_args_used[arg] = ast.literal_eval(value)
                    except Exception:
                        pass  # Leave as string if parsing fails

            # Now check for each arg if there is an incoming edge with that label
            for arg in node_args_used:
                if arg not in incoming_labels:
                    print(f'⚠️ Node {short_node_name} ({action}) is missing incoming edge for argument: {arg}')

Node _root has attributes {'type': 'question_param', 'call_index': 0, 'slot_property': 'SP.DYN.LE00.FE.IN', 'slot_subject': 'SAU', 'slot_region': 'Western Asia', 'slot_year': '2018', 'slot_property_original': 'Life expectancy at birth, female (years)', 'slot_subject_name': 'Saudi Arabia', 'label': 'What proportion of the total expected lifespan of females at birth was contributed by Saudi Arabia for the countries in Western Asia in 2018?'}.
Node 0b4c6 has attributes {'call_index': 1, 'result': 'SP.DYN.LE00.FE.IN', 'arg_indicator_name': 'Life expectancy at birth, female (years)', 'label': 'get_indicator_code_from_name'}.
Node df9a0 has attributes {'call_index': 2, 'result': 'SAU', 'arg_country_name': 'Saudi Arabia', 'label': 'get_country_code_from_name'}.
Node d3f91 has attributes {'call_index': 3, 'result': '80.543', 'arg_country_code': 'SAU', 'arg_indicator_code': 'SP.DYN.LE00.FE.IN', 'arg_year': '2018', 'label': 'retrieve_value'}.
Node 879f7 has attributes {'call_index': 4, 'result':

In [161]:
def print_graph_markdown_incoming(G, tools):
    lines = ['# Graph Structure', '']

    # Build a mapping from (node, arg_name) -> source node
    arg_sources = {}
    for src, dst, edata in G.edges(data=True):
        slot = edata.get('label', None)
        if slot:
            # If label is like "arg_name=val", extract just arg_name
            if '=' in slot:
                arg = slot.split('=')[0]
            else:
                arg = slot
            arg_sources[(dst, arg)] = src

    for node, data in G.nodes(data=True):
        lines.append(f'## Node `{node}`')
        if data:
            call_index = data.get('call_index', None)
            tool_name = data.get('label', None)
            result = data.get('result', None)
            tool_args = {k.removeprefix('arg_'): v for k, v in data.items() if k.startswith('arg_')}

            # Header
            if call_index is not None:
                lines.append(f'- **call_index:** `{call_index}`')
            if tool_name is not None:
                if node == 'question_root':
                    lines.append(f'- **original question:** {tool_name}')
                else:
                    lines.append(f'- **tool_name:** `{tool_name}`')

            # Arguments
            if tool_args:
                lines.append('- **arguments:**')
                for arg, v in tool_args.items():
                    origin = arg_sources.get((node, arg))
                    if origin:
                        lines.append(f'    - `{arg}` = `{v}` _(from node `{origin}`)_')
                    else:
                        lines.append(
                            f'    - `{arg}` = `{v}` ⚠️ **No incoming edge for `{arg}`: not derived from a previous tool call.**'
                        )

            # Result
            if result is not None:
                lines.append(f'- **result:** `{result}`')
        else:
            lines.append('This node has no attributes.')
        lines.append('')
    print('\n'.join(lines))


print_graph_markdown_incoming(G, tools)

# Graph Structure

## Node `question_root`
- **call_index:** `0`
- **original question:** What proportion of the total expected lifespan of females at birth was contributed by Saudi Arabia for the countries in Western Asia in 2018?

## Node `chatcmpl-tool-446b1ed010ba47ba828047e6fa30b4c6`
- **call_index:** `1`
- **tool_name:** `get_indicator_code_from_name`
- **arguments:**
    - `indicator_name` = `Life expectancy at birth, female (years)` _(from node `question_root`)_
- **result:** `SP.DYN.LE00.FE.IN`

## Node `chatcmpl-tool-cb28c312a2c241e3a4cb4ec6205df9a0`
- **call_index:** `2`
- **tool_name:** `get_country_code_from_name`
- **arguments:**
    - `country_name` = `Saudi Arabia` _(from node `question_root`)_
- **result:** `SAU`

## Node `chatcmpl-tool-013450c450734cc29d8a2f8ea2cd3f91`
- **call_index:** `3`
- **tool_name:** `retrieve_value`
- **arguments:**
    - `country_code` = `SAU` _(from node `chatcmpl-tool-e223a6e9e96a4a9fb6388b42457879f7`)_
    - `indicator_code` = `SP.DYN.LE00

In [169]:
def print_graph_markdown_incoming(G, tools):
    lines = ['# Graph Structure', '', '## Node List', '']

    # Build a mapping from (node, arg_name) -> source node
    arg_sources = {}
    issues = []
    for src, dst, edata in G.edges(data=True):
        slot = edata.get('label', None)
        if slot:
            arg = slot.split('=')[0] if '=' in slot else slot
            arg_sources[(dst, arg)] = src

    # Node descriptions
    for node, data in G.nodes(data=True):
        lines.append(f'### Node `{node}`')
        if data:
            call_index = data.get('call_index')
            tool_name = data.get('label')
            result = data.get('result')
            tool_args = {k.removeprefix('arg_'): v for k, v in data.items() if k.startswith('arg_')}

            # Header
            if call_index is not None:
                lines.append(f'- **call_index:** `{call_index}`')
            if tool_name is not None:
                if node == 'question_root':
                    lines.append(f'- **original question:** {tool_name}')
                else:
                    lines.append(f'- **tool_name:** `{tool_name}`')

            # Arguments
            if tool_args:
                lines.append('- **arguments:**')
                for arg, v in tool_args.items():
                    origin = arg_sources.get((node, arg))
                    if origin:
                        lines.append(f'    - `{arg}` = `{v}` from node `{origin}`')
                    else:
                        warning = f'Node `{node}` arg `{arg}` = `{v}` has no incoming edge.'
                        lines.append(
                            f'    - `{arg}` = `{v}` ⚠️ **No incoming edge. This argument is not derived from any previous node, so its provenance is unclear.**'
                        )
                        issues.append(warning)

            # Result
            if result is not None:
                lines.append(f'- **result:** `{result}`')
        else:
            lines.append('This node has no attributes.')
        lines.append('')

    # Explicit Edges section
    lines.append('---')
    lines.append('## Edges')
    for src, dst, edata in G.edges(data=True):
        label = edata.get('label', '')
        if label:
            lines.append(f'- `{src}` → `{dst}` (arg: `{label}`)')
        else:
            lines.append(f'- `{src}` → `{dst}`')

    # Issues summary
    if issues:
        lines.append('---')
        lines.append('## ⚠️ Issues')
        for w in issues:
            lines.append(f'- {w}')

    print('\n'.join(lines))


print_graph_markdown_incoming(G, tools)


# Graph Structure

## Node List

### Node `question_root`
- **call_index:** `0`
- **original question:** What proportion of the total expected lifespan of females at birth was contributed by Saudi Arabia for the countries in Western Asia in 2018?

### Node `chatcmpl-tool-446b1ed010ba47ba828047e6fa30b4c6`
- **call_index:** `1`
- **tool_name:** `get_indicator_code_from_name`
- **arguments:**
    - `indicator_name` = `Life expectancy at birth, female (years)` from node `question_root`
- **result:** `SP.DYN.LE00.FE.IN`

### Node `chatcmpl-tool-cb28c312a2c241e3a4cb4ec6205df9a0`
- **call_index:** `2`
- **tool_name:** `get_country_code_from_name`
- **arguments:**
    - `country_name` = `Saudi Arabia` from node `question_root`
- **result:** `SAU`

### Node `chatcmpl-tool-013450c450734cc29d8a2f8ea2cd3f91`
- **call_index:** `3`
- **tool_name:** `retrieve_value`
- **arguments:**
    - `country_code` = `SAU` from node `chatcmpl-tool-e223a6e9e96a4a9fb6388b42457879f7`
    - `indicator_code` = `SP.DY

In [None]:
import yaml


def generate_graph_yaml(G, tools):
    """Generate a YAML-style report of a tool-use graph, with explicit argument provenance."""
    report = {}

    # Build mapping from (node, arg_name) -> source node
    arg_sources = {}
    issues = []
    for src, dst, edata in G.edges(data=True):
        slot = edata.get('label', None)
        if slot:
            arg = slot.split('=')[0] if '=' in slot else slot
            arg_sources[(dst, arg)] = src

    # Nodes
    nodes_list = []
    for node, data in G.nodes(data=True):
        node_dict = {'id': node}
        if data:
            call_index = data.get('call_index')
            tool_name = data.get('label')
            result = data.get('result')
            tool_args = {k.removeprefix('arg_'): v for k, v in data.items() if k.startswith('arg_')}

            if call_index is not None:
                node_dict['call_index'] = call_index
            if tool_name is not None:
                node_dict['tool_name'] = tool_name if node != 'question_root' else None
                if node == 'question_root':
                    node_dict['original_question'] = tool_name

            # Arguments with provenance
            if tool_args:
                args_list = []
                for arg, v in tool_args.items():
                    origin = arg_sources.get((node, arg))
                    if not origin:
                        issues.append(
                            f'Node `{node}` arg `{arg}` = `{v}` has no incoming edge, indicating that it is not derived from a previous tool call and so its provenance is unclear.'
                        )
                        origin = None
                    args_list.append({'name': arg, 'value': v, 'source_node': origin})
                node_dict['arguments'] = args_list

            # Result
            if result is not None:
                node_dict['result'] = result
        else:
            node_dict['note'] = 'No attributes'

        nodes_list.append(node_dict)
    report['nodes'] = nodes_list

    # Edges
    edges_list = []
    for src, dst, edata in G.edges(data=True):
        label = edata.get('label', None)
        edge_dict = {'from': src, 'to': dst, 'arg': None}
        if label:
            arg, value = label.split('=')
            edge_dict['arg'] = {arg: value}
        edges_list.append(edge_dict)
    report['edges'] = edges_list

    # Issues
    report['issues'] = issues

    # Output YAML
    yaml_str = yaml.dump(report, sort_keys=False, allow_unicode=True)
    return yaml_str


# Example usage:
yaml_report = generate_graph_yaml(G, tools)
with open('graph_report.yaml', 'w', encoding='utf-8') as f:
    f.write(yaml_report)
print(yaml_report)


nodes:
- id: question_root
  call_index: 0
  tool_name: null
  original_question: What proportion of the total expected lifespan of females at
    birth was contributed by Saudi Arabia for the countries in Western Asia in 2018?
- id: chatcmpl-tool-446b1ed010ba47ba828047e6fa30b4c6
  call_index: 1
  tool_name: get_indicator_code_from_name
  arguments:
  - name: indicator_name
    value: Life expectancy at birth, female (years)
    source_node: question_root
  result: SP.DYN.LE00.FE.IN
- id: chatcmpl-tool-cb28c312a2c241e3a4cb4ec6205df9a0
  call_index: 2
  tool_name: get_country_code_from_name
  arguments:
  - name: country_name
    value: Saudi Arabia
    source_node: question_root
  result: SAU
- id: chatcmpl-tool-013450c450734cc29d8a2f8ea2cd3f91
  call_index: 3
  tool_name: retrieve_value
  arguments:
  - name: country_code
    value: SAU
    source_node: chatcmpl-tool-e223a6e9e96a4a9fb6388b42457879f7
  - name: indicator_code
    value: SP.DYN.LE00.FE.IN
    source_node: chatcmpl-tool-4