In [1]:
import re
from typing import Tuple, Dict, Any, List
import os
import json 
import numpy as np
# =========
# Parsing
# =========

def parse_tree_string(s: str) -> Dict[str, Any]:
    """
    Parses a single bracketed tree string like:
      "(((1 1) (5 3)) (1 3))"
    into an AST where:
      - internal nodes are dicts with {"_children": [child, ...]}
      - token leaves are dicts with {"symbol": "<token>"}
    Every (...) becomes a node; tokens inside are leaves.
    """
    s = s.strip()
    if not s or s[0] != '(':
        raise ValueError("Input must start with '('")

    i = 0
    n = len(s)

    def skip_ws():
        nonlocal i
        while i < n and s[i].isspace():
            i += 1

    def parse_node() -> Dict[str, Any]:
        nonlocal i
        if i >= n or s[i] != '(':
            raise ValueError(f"Expected '(', got {s[i:i+10]!r}")
        i += 1  # consume '('
        skip_ws()

        children: List[Dict[str, Any]] = []

        while i < n and s[i] != ')':
            skip_ws()
            if i < n and s[i] == '(':
                # nested node
                children.append(parse_node())
            else:
                # token (number)
                m = re.match(r'[+-]?\d+', s[i:])
                if not m:
                    raise ValueError(f"Invalid token at: {s[i:i+20]!r}")
                tok = m.group(0)
                i += len(tok)
                children.append({"symbol": tok})
            skip_ws()

        if i >= n or s[i] != ')':
            raise ValueError("Unbalanced parentheses")
        i += 1  # consume ')'
        return {"_children": children}

    root = parse_node()
    skip_ws()
    if i != n:
        raise ValueError(f"Extra text after tree: {s[i:]!r}")
    return root

# =======================================
# Production interning & tree conversion
# =======================================

Signature = Tuple  # (tag, ...), fully immutable/tuplized

def node_signature(node: Dict[str, Any]) -> Signature:
    """
    Builds a canonical, immutable signature for a node based on the ordered
    sequence of its children. Leaves: ('T', <symbol>).
    Internals: ('N', (child_signatures...)).
    """
    if "symbol" in node:
        return ("T", node["symbol"])
    # internal node
    child_sigs = tuple(node_signature(ch) for ch in node["_children"])
    return ("N", child_sigs)

def to_production_tree(
    node: Dict[str, Any],
    intern_table: Dict[Signature, int]
) -> Tuple[Dict[str, Any], Signature]:
    """
    Converts a parsed AST to the requested JSON form, assigning a
    global production id to each unique internal node signature.

    Returns:
      (converted_json, signature)
    """
    if "symbol" in node:
        # leaf
        sig = ("T", node["symbol"])
        return {"symbol": node["symbol"]}, sig

    # internal node: convert children first (post-order)
    out_children = []
    child_sigs = []
    for ch in node["_children"]:
        out_ch, ch_sig = to_production_tree(ch, intern_table)
        out_children.append(out_ch)
        child_sigs.append(ch_sig)

    sig = ("N", tuple(child_sigs))
    if sig not in intern_table:
        intern_table[sig] = len(intern_table)  # next id
    prod_id = intern_table[sig]

    return {"production": prod_id, "children": out_children}, sig

# ==========================
# Public API for many trees
# ==========================

def trees_to_productions(tree_strings: List[str]) -> List[Dict[str, Any]]:
    """
    Given multiple tree strings, returns their converted JSON forms, sharing
    a single interning table so identical subtrees across the forest get the
    same production id.
    """
    intern_table: Dict[Signature, int] = {}
    outputs: List[Dict[str, Any]] = []
    for s in tree_strings:
        ast = parse_tree_string(s)
        out, _ = to_production_tree(ast, intern_table)
        outputs.append(out)
    return outputs

from graphviz import Digraph

def visualize_tree(tree_dict, filename_base, fmt="pdf", root_label=None, label_mapping=None):
    """
    tree_dict: nested dict with keys
      - 'production' → int
      - OR 'symbol' → str
      and optional 'children': [ … ]
    filename_base: e.g. "seq_tree_0"  (no extension)
    fmt: "pdf" (default)
    root_label: if given, use this string as the label for the very top node
    label_mapping: optional dict mapping grammar labels to readable strings
    """
    dot = Digraph(format=fmt)
    dot.attr(rankdir="TB")     # top→bottom
    dot.attr("node", shape="oval", fontsize="12")

    def pretty_map(sym: str) -> str:
        if label_mapping and sym in label_mapping:
            mapped = label_mapping[sym]
            # ensure string and do the prettifying that you wanted
            return str(mapped).replace("_", " ").title()
        return sym  # leave leaves untouched by default

    def recurse(node, parent_id=None, is_root=False):
        nid = str(id(node))

        # label
        if is_root and root_label is not None:
            label = root_label
        else:
            if "production" in node:
                label = f"R{node['production']}"
            else:
                label = pretty_map(node["symbol"])

        dot.node(nid, label)

        if parent_id is not None:
            dot.edge(parent_id, nid)

        for child in node.get("children", []):
            recurse(child, parent_id=nid, is_root=False)

    # kick off recursion, marking this call as the root
    recurse(tree_dict, parent_id=None, is_root=True)

    # write the file; returns the full path with extension
    outpath = dot.render(filename_base, cleanup=True)
    return outpath






In [3]:
import os, json, re
import numpy as np

file_path = "Traces/stone_pick_static/ompn_raw_stone_pick_static_pixels_big.json"

with open(file_path, "r") as f:
    data = json.load(f)

# Make sure action files are in a deterministic order
action_dir = "Traces/stone_pick_static/actions"
action_files = sorted([f for f in os.listdir(action_dir) if f.endswith(".npy")])

# Build a lookup: action_sequence_tuple -> index parsed from filename
# and also determine the maximum index we need to place
action_to_index = {}
max_index = -1
for fname in action_files:
    arr = np.load(os.path.join(action_dir, fname))
    # normalize to a tuple of ints for dict key
    key = tuple(map(int, arr.tolist()))
    # parse index from filename (..._<index>.npy)
    idx = int(re.split(r"[_.]", fname)[-2])  # robust split, picks the number before .npy
    action_to_index[key] = idx
    if idx > max_index:
        max_index = idx

# Pre-size the list so direct indexing works; fill with None for sanity checks
all_trees = [None] * (max_index + 1)

# If your saved .npy actions include the terminal 0, keep +[0]; otherwise remove it.
# Verify once by printing a couple of sequences if needed.
mismatches = 0
unmatched_eps = 0

for ep in data["episode_details"]:
    # If your .npy files contain the final 0 action, keep this line.
    # If not, delete the "+ [0]".
    ep_actions = tuple(map(int, (ep["actions"] + [0])))

    if ep_actions in action_to_index:
        idx = action_to_index[ep_actions]
        if all_trees[idx] is not None:
            # Duplicate match—this would indicate non-unique action sequences
            mismatches += 1
            print(f"[WARN] Duplicate match for index {idx} (ep {ep['ep_idx']}). Overwriting.")
        all_trees[idx] = ep["predicted_tree"]
    else:
        unmatched_eps += 1
        # Helpful debug: show length or the first few items
        print(f"[WARN] No action match for ep {ep['ep_idx']} (len={len(ep['actions'])}).")

print(f"Placed trees: {sum(x is not None for x in all_trees)} / {len(all_trees)}")
if mismatches:
    print(f"Duplicates encountered: {mismatches}")
if unmatched_eps:
    print(f"Episodes with no match: {unmatched_eps}")

# Optional sanity check: ensure no gaps remain
missing_indices = [i for i, t in enumerate(all_trees) if t is None]
if missing_indices:
    print(f"[WARN] Missing tree entries at indices: {missing_indices}")

# Convert and save
converted = trees_to_productions(all_trees)

out_dir = "Traces/stone_pick_static/hierarchy_data/ompn_hierarchy"
os.makedirs(out_dir, exist_ok=True)

for i, tree in enumerate(converted):
    if tree is None:
        continue  # or handle as you like
    with open(os.path.join(out_dir, f"seq_tree_{i}.json"), "w") as f:
        json.dump(tree, f)

Placed trees: 500 / 500


In [4]:
from typing import List, Dict, Any

def get_node_sequence(tree: Dict[str, Any]) -> List[str]:
    """
    Given one converted tree (internal nodes use {"production": int, "children": [...]},
    leaves use {"symbol": "..."}) return the sequence of the *parents of primitive nodes*
    in left-to-right leaf order, formatted as ["N<id>", ...].

    If the root itself directly parents leaves, its production id is used.
    If the input is just a leaf (no parents), returns [].
    """
    seq: List[str] = []

    def walk(node: Dict[str, Any]) -> None:
        # leaf -> contributes nothing (no parent to report at this level)
        if "symbol" in node:
            return

        prod = node["production"]
        for ch in node.get("children", []):
            if "symbol" in ch:
                # ch is a primitive leaf: record *this* node as the parent
                seq.append(f"{prod}")
            else:
                # ch is an internal node: descend to find its leaves' parents
                walk(ch)

    walk(tree)
    return seq

In [5]:
os.makedirs("Traces/stone_pick_static/ompn_hierarchy_skills", exist_ok=True)

for i in range(len(converted)):
    tree = converted[i]
    seq = get_node_sequence(tree)
    with open(f"Traces/stone_pick_static/ompn_hierarchy_skills/craftax_{i}", "w") as f:
        f.write("\n".join(seq))