In [1]:

import hashlib
from graphviz import Digraph

def hash_children(children):
    """Returns a hashable representation of the sorted children symbols."""
    labels = tuple(child['symbol'] for child in children)
    return labels  # No need for actual hash function; tuples are sufficient keys

def relabel_internal_nodes(trees):
    hash_to_label = {}  # Map from child structure to common label
    label_counter = [1]  # Mutable counter so it can be updated inside nested func

    def relabel_node(node, is_root=False):
        if 'children' not in node or not node['children']:
            return  # leaf

        for child in node['children']:
            relabel_node(child, is_root=False)

        if not is_root:
            key = hash_children(node['children'])
            if key not in hash_to_label:
                hash_to_label[key] = f"N{label_counter[0]}"
                label_counter[0] += 1
            node['symbol'] = hash_to_label[key]

    for tree in trees:
        relabel_node(tree, is_root=True)

def count_leaf_nodes(node):
    if 'children' not in node or not node['children']:
        return 1  # This is a leaf
    return sum(count_leaf_nodes(child) for child in node['children'])

def find_rightmost_node(node):
    # Base case: if no children, this is the rightmost node
    if 'children' not in node or not node['children']:
        return node
    # Recursively follow the last child
    return find_rightmost_node(node['children'][-1])

def relabel_leaf_nodes(node, subtask_seq):
    it = iter(subtask_seq)  # Create an iterator over the labels

    def _relabel(n):
        if 'children' not in n or not n['children']:
            # It's a leaf node; replace its symbol
            n['symbol'] = str(next(it))
        else:
            for child in n['children']:
                _relabel(child)

    _relabel(node)

def remove_rightmost_node(node):
    parent = None
    current = node
    index = None

    # Traverse down the last child at each level
    while 'children' in current and current['children']:
        parent = current
        index = len(current['children']) - 1
        current = current['children'][index]

    # Remove the rightmost node from its parent's children list
    if parent is not None and index is not None:
        del parent['children'][index]

def parse_tree_string(s, subtask_seq, ep_idx):
    s = s.strip()
    idx = 0
    counter = {'n': 0}

    def skip_ws():
        nonlocal idx
        while idx < len(s) and s[idx].isspace():
            idx += 1

    def parse_node(is_root=False):
        nonlocal idx
        assert s[idx] == '(', f"Expected '(', got {s[idx:]}"
        idx += 1
        vals = []
        kids = []
        while True:
            skip_ws()
            if idx >= len(s):
                raise ValueError("Unmatched '('")
            if s[idx] == '(':
                kids.append(parse_node(is_root=False))
            elif s[idx] == ')':
                idx += 1
                break
            else:
                # parse an integer
                m = re.match(r'[+-]?\d+', s[idx:])
                if not m:
                    raise ValueError(f"Invalid token at {s[idx:]}")
                token = m.group(0)
                vals.append(int(token))
                idx += len(token)

        # choose a label
        if kids and not vals:
            if is_root:
                label = f'H{ep_idx}'
            else:
                counter['n'] += 1
                label = f"N{counter['n']}"
        elif vals:
            label = " ".join(str(v) for v in vals)
        else:
            # empty node?  Treat like a pure‑children node.
            if is_root:
                label = f'H{ep_idx}'
            else:
                counter['n'] += 1
                label = f"N{counter['n']}"

        node = {'symbol': label}
        if kids:
            node['children'] = kids
        return node

    # kick off the parse, telling the top level it's the root
    skip_ws()
    if not s or s[0] != '(':
        raise ValueError("Input must start with '('")
    tree = parse_node(is_root=True)
    skip_ws()
    if idx != len(s):
        raise ValueError(f"Extra text after tree: {s[idx:]}")
    
    #Check if the number of leaf nodes = len(subtask_seq)
    leaf_count = count_leaf_nodes(tree)
    if leaf_count != len(subtask_seq):
        #Check the right most node has 1 child

        right_node = find_rightmost_node(tree)
        length_right = len(right_node['symbol'])

        if length_right == 1:
            remove_rightmost_node(tree)
            # print("Removed rightmost node with single child.")
        else:
            # print("Rightmost node has more than one child, no removal performed.")
            raise ValueError(f"Leaf count {leaf_count} does not match subtask sequence length {len(subtask_seq)}")

            
    relabel_leaf_nodes(tree, subtask_seq)
    # print(f"Relabeled leaf nodes with subtasks: {subtask_seq}")
    return tree

def visualize_tree(tree_dict, filename_base, fmt="png", root_label=None, label_mapping=None):
    """
    tree_dict: nested dict with keys
      - 'production' → int
      - OR 'symbol' → str
      and optional 'children': [ … ]
    filename_base: e.g. "seq_tree_0"  (no extension)
    fmt: "png" or "pdf"
    root_label: if given, use this string as the label for the very top node
    label_mapping: optional dict mapping grammar labels to readable strings
    """
    dot = Digraph(format=fmt)
    dot.attr(rankdir="TB")     # top→bottom
    dot.attr("node", shape="oval", fontsize="12")

    def recurse(node, parent_id=None, is_root=False):
        nid = str(id(node))
        # if this is the root of the entire tree and a custom label was supplied, use it:
        if is_root and root_label is not None:
            label = root_label
        else:
            if "production" in node:
                label = f"R{node['production']}"
            else:
                # Use mapping if available, otherwise use original symbol
                label = label_mapping.get(node["symbol"], node["symbol"]) if label_mapping else node["symbol"]

        dot.node(nid, label)
        if parent_id is not None:
            dot.edge(parent_id, nid)

        for child in node.get("children", []):
            recurse(child, nid)

    # kick off recursion marking the first call as the root
    recurse(tree_dict, parent_id=None, is_root=True)

    outpath = dot.render(filename_base, cleanup=True)
    # print(f"Written {outpath}")



In [6]:
import json
import re
from tree.structure_metrics import compute_average_structure_metrics
from tree.structure_metrics import compute_structure_metrics
from tree.structure_metrics import count_unique_trees

data_folder = 'ompn_paper_runs/wsws_static_pixels_big'
exp_name = data_folder.split('/')[-1]

episode_data = json.load(open(f'{data_folder}/episode_data/data_details_{exp_name}.json')) 

trees = []
for ep in episode_data['episode_details']:
    tree_str = ep['predicted_tree']
    subtask_seq = ep['subtask_order']
    ep_idx = ep['ep_idx']
    
    # Parse the tree string
    tree = parse_tree_string(tree_str, subtask_seq, ep_idx)
    
    # Relabel internal nodes
    trees.append(tree)
    

relabel_internal_nodes(trees)

for i, tree in enumerate(trees):
    visualize_tree(tree, f"{data_folder}/trees/seq_tree_{i}", fmt="png")


per_tree_metrics = {}
for i, tree in enumerate(trees):
    per_tree_metrics[f"tree_{i}"] = compute_structure_metrics(tree)

# Save per-tree metrics
with open(f'{data_folder}/trees/structure_metrics.json', 'w') as f:
    json.dump(per_tree_metrics, f, indent=2)


avg_metrics =  compute_average_structure_metrics(trees)

print("Average Structure Metrics:")
for key, value in avg_metrics.items():
    print(f"{key}: {value}")

print("unique trees:", (count_unique_trees(trees)))

Average Structure Metrics:
depth: 2.0
size: 4.986
avg_branching: 3.986
max_branching: 3.986
reuse: 0.6023999999999998
modularity: 1.0
unique trees: 3
