In [9]:

import hashlib
from graphviz import Digraph

def parse_tree_string(s, ep_idx):
    import re
    s = s.strip()
    idx = 0
    counter = {'n': 0}

    def skip_ws():
        nonlocal idx
        while idx < len(s) and s[idx].isspace():
            idx += 1

    def parse_node(is_root=False):
        nonlocal idx
        assert s[idx] == '(', f"Expected '(', got {s[idx:]}"
        idx += 1
        skip_ws()

        tokens = []
        children = []

        while idx < len(s) and s[idx] != ')':
            skip_ws()
            if s[idx] == '(':
                children.append(parse_node(False))
            else:
                # Try to parse a number
                m = re.match(r'[+-]?\d+', s[idx:])
                if not m:
                    raise ValueError(f"Invalid token at {s[idx:]}")
                token = m.group(0)
                tokens.append(token)
                idx += len(token)
            skip_ws()

        idx += 1  # consume ')'

        if children:
            # This is an internal node
            counter['n'] += 1
            label = f"H{ep_idx}" if is_root else f"N{counter['n']}"
            return {'symbol': label, 'children': children}
        else:
            # This is a leaf node
            label = " ".join(tokens)
            return {'symbol': label}

    skip_ws()
    if not s or s[0] != '(':
        raise ValueError("Input must start with '('")
    tree = parse_node(is_root=True)
    skip_ws()
    if idx != len(s):
        raise ValueError(f"Extra text after tree: {s[idx:]}")
    
    return tree


def visualize_tree(tree_dict, filename_base, fmt="png", root_label=None, label_mapping=None):
    """
    tree_dict: nested dict with keys
      - 'production' → int
      - OR 'symbol' → str
      and optional 'children': [ … ]
    filename_base: e.g. "seq_tree_0"  (no extension)
    fmt: "png" or "pdf"
    root_label: if given, use this string as the label for the very top node
    label_mapping: optional dict mapping grammar labels to readable strings
    """
    dot = Digraph(format=fmt)
    dot.attr(rankdir="TB")     # top→bottom
    dot.attr("node", shape="oval", fontsize="12")

    def recurse(node, parent_id=None, is_root=False):
        nid = str(id(node))
        # if this is the root of the entire tree and a custom label was supplied, use it:
        if is_root and root_label is not None:
            label = root_label
        else:
            if "production" in node:
                label = f"R{node['production']}"
            else:
                # Use mapping if available, otherwise use original symbol
                label = label_mapping.get(node["symbol"], node["symbol"]) if label_mapping else node["symbol"]

        dot.node(nid, label)
        if parent_id is not None:
            dot.edge(parent_id, nid)

        for child in node.get("children", []):
            recurse(child, nid)

    # kick off recursion marking the first call as the root
    recurse(tree_dict, parent_id=None, is_root=True)

    outpath = dot.render(filename_base, cleanup=True)
    # print(f"Written {outpath}")



In [10]:
import json
import re
from tree.structure_metrics import compute_average_structure_metrics
from tree.structure_metrics import compute_structure_metrics
from tree.structure_metrics import count_unique_trees

data_folders = [
    'experiment/wsws_static_pixels_big',
    # 'ompn_paper_runs/stone_pick_static_pixels_big',
    # 'ompn_paper_runs/wsws_random_pixels_big',
    # 'ompn_paper_runs/wsws_static_pixels_big'
]

for data_folder in data_folders:
    exp_name = data_folder.split('/')[-1]

    episode_data = json.load(open(f'{data_folder}/episode_data/data_details_{exp_name}.json')) 

    trees = []
    for ep in episode_data['episode_details']:
        tree_str = ep['predicted_tree']
        subtask_seq = ep['subtask_order']
        ep_idx = ep['ep_idx']
        
        # Parse the tree string
        tree = parse_tree_string(tree_str, ep_idx)

        
        # Relabel internal nodes
        trees.append(tree)
        

    # relabel_internal_nodes(trees)

    for i, tree in enumerate(trees):
        visualize_tree(tree, f"{data_folder}/trees/seq_tree_{i}", fmt="png")


    per_tree_metrics = {}
    for i, tree in enumerate(trees):
        per_tree_metrics[f"tree_{i}"] = compute_structure_metrics(tree)

    # Save per-tree metrics
    with open(f'{data_folder}/trees/structure_metrics.json', 'w') as f:
        json.dump(per_tree_metrics, f, indent=2)


    avg_metrics =  compute_average_structure_metrics(trees)


    print(f"Average metrics for {exp_name}:")

    print("Average Structure Metrics:")
    for key, value in avg_metrics.items():
        print(f"{key}: {value}")

    print("unique trees:", (count_unique_trees(trees)))
    print("============================================")

Average metrics for wsws_static_pixels_big:
Average Structure Metrics:
depth: 2.0
size: 4.136
avg_branching: 3.136
max_branching: 3.136
reuse: 0.9943000000000002
modularity: 1.0
unique trees: 493
