In [None]:
import json
import os
import numpy as np
# Path to your JSON file
file_path = "Traces/stone_pick_static/ompn_raw_stone_pick_static_pixels_big.json"

# Open and load JSON data
with open(file_path, "r") as f:
    data = json.load(f)

# Print some information about the file
print("Top-level keys:", data.keys() if isinstance(data, dict) else type(data))


In [None]:
print(data['episode_details'][0].keys())

In [None]:

#List all the .npy files in Traces/stone_pick_static/actions
action_files = [f for f in os.listdir("Traces/stone_pick_static/actions") if f.endswith('.npy')]

os.makedirs("Traces/stone_pick_static/ompn_skills", exist_ok=True)

for ep in data['episode_details']:
    ep_id = ep['ep_idx']
    ep_skills = [str(x) for x in ep['decoded_subtask']]
    actions = ep['actions'] + [0]

    for gt_a in action_files:
        loaded_a = np.load(f"Traces/stone_pick_static/actions/{gt_a}")
        if np.array_equal(actions, loaded_a):
            file_name = gt_a.replace('.npy', '')
            with open(f"Traces/stone_pick_static/ompn_skills/{file_name}", "w") as f:
                f.write("\n".join(ep_skills))


# Compile

In [None]:
import json
import os
import numpy as np
# Path to your JSON file
file_path = "Traces/stone_pick_static/compile_data.json"

# Open and load JSON data
with open(file_path, "r") as f:
    data = json.load(f)

data[0].keys()


In [None]:

#List all the .npy files in Traces/stone_pick_static/actions
action_files = [f for f in os.listdir("Traces/stone_pick_static/actions") if f.endswith('.npy')]

os.makedirs("Traces/stone_pick_static/compile_skills", exist_ok=True)

for ep in data:
    ep_skills = [str(x) for x in ep['predicted_skills_static']]
    actions = ep['actions'] + [0]

    for gt_a in action_files:
        loaded_a = np.load(f"Traces/stone_pick_static/actions/{gt_a}")
        if np.array_equal(actions, loaded_a):
            file_name = gt_a.replace('.npy', '')
            with open(f"Traces/stone_pick_static/compile_skills/{file_name}", "w") as f:
                f.write("\n".join(ep_skills))

# Fix OMPN Trees -> ASOT Format

In [6]:
import re
from typing import Tuple, Dict, Any, List

# =========
# Parsing
# =========

def parse_tree_string(s: str) -> Dict[str, Any]:
    """
    Parses a single bracketed tree string like:
      "(((1 1) (5 3)) (1 3))"
    into an AST where:
      - internal nodes are dicts with {"_children": [child, ...]}
      - token leaves are dicts with {"symbol": "<token>"}
    Every (...) becomes a node; tokens inside are leaves.
    """
    s = s.strip()
    if not s or s[0] != '(':
        raise ValueError("Input must start with '('")

    i = 0
    n = len(s)

    def skip_ws():
        nonlocal i
        while i < n and s[i].isspace():
            i += 1

    def parse_node() -> Dict[str, Any]:
        nonlocal i
        if i >= n or s[i] != '(':
            raise ValueError(f"Expected '(', got {s[i:i+10]!r}")
        i += 1  # consume '('
        skip_ws()

        children: List[Dict[str, Any]] = []

        while i < n and s[i] != ')':
            skip_ws()
            if i < n and s[i] == '(':
                # nested node
                children.append(parse_node())
            else:
                # token (number)
                m = re.match(r'[+-]?\d+', s[i:])
                if not m:
                    raise ValueError(f"Invalid token at: {s[i:i+20]!r}")
                tok = m.group(0)
                i += len(tok)
                children.append({"symbol": tok})
            skip_ws()

        if i >= n or s[i] != ')':
            raise ValueError("Unbalanced parentheses")
        i += 1  # consume ')'
        return {"_children": children}

    root = parse_node()
    skip_ws()
    if i != n:
        raise ValueError(f"Extra text after tree: {s[i:]!r}")
    return root

# =======================================
# Production interning & tree conversion
# =======================================

Signature = Tuple  # (tag, ...), fully immutable/tuplized

def node_signature(node: Dict[str, Any]) -> Signature:
    """
    Builds a canonical, immutable signature for a node based on the ordered
    sequence of its children. Leaves: ('T', <symbol>).
    Internals: ('N', (child_signatures...)).
    """
    if "symbol" in node:
        return ("T", node["symbol"])
    # internal node
    child_sigs = tuple(node_signature(ch) for ch in node["_children"])
    return ("N", child_sigs)

def to_production_tree(
    node: Dict[str, Any],
    intern_table: Dict[Signature, int]
) -> Tuple[Dict[str, Any], Signature]:
    """
    Converts a parsed AST to the requested JSON form, assigning a
    global production id to each unique internal node signature.

    Returns:
      (converted_json, signature)
    """
    if "symbol" in node:
        # leaf
        sig = ("T", node["symbol"])
        return {"symbol": node["symbol"]}, sig

    # internal node: convert children first (post-order)
    out_children = []
    child_sigs = []
    for ch in node["_children"]:
        out_ch, ch_sig = to_production_tree(ch, intern_table)
        out_children.append(out_ch)
        child_sigs.append(ch_sig)

    sig = ("N", tuple(child_sigs))
    if sig not in intern_table:
        intern_table[sig] = len(intern_table)  # next id
    prod_id = intern_table[sig]

    return {"production": prod_id, "children": out_children}, sig

# ==========================
# Public API for many trees
# ==========================

def trees_to_productions(tree_strings: List[str]) -> List[Dict[str, Any]]:
    """
    Given multiple tree strings, returns their converted JSON forms, sharing
    a single interning table so identical subtrees across the forest get the
    same production id.
    """
    intern_table: Dict[Signature, int] = {}
    outputs: List[Dict[str, Any]] = []
    for s in tree_strings:
        ast = parse_tree_string(s)
        out, _ = to_production_tree(ast, intern_table)
        outputs.append(out)
    return outputs

from graphviz import Digraph

def visualize_tree(tree_dict, filename_base, fmt="pdf", root_label=None, label_mapping=None):
    """
    tree_dict: nested dict with keys
      - 'production' → int
      - OR 'symbol' → str
      and optional 'children': [ … ]
    filename_base: e.g. "seq_tree_0"  (no extension)
    fmt: "pdf" (default)
    root_label: if given, use this string as the label for the very top node
    label_mapping: optional dict mapping grammar labels to readable strings
    """
    dot = Digraph(format=fmt)
    dot.attr(rankdir="TB")     # top→bottom
    dot.attr("node", shape="oval", fontsize="12")

    def pretty_map(sym: str) -> str:
        if label_mapping and sym in label_mapping:
            mapped = label_mapping[sym]
            # ensure string and do the prettifying that you wanted
            return str(mapped).replace("_", " ").title()
        return sym  # leave leaves untouched by default

    def recurse(node, parent_id=None, is_root=False):
        nid = str(id(node))

        # label
        if is_root and root_label is not None:
            label = root_label
        else:
            if "production" in node:
                label = f"R{node['production']}"
            else:
                label = pretty_map(node["symbol"])

        dot.node(nid, label)

        if parent_id is not None:
            dot.edge(parent_id, nid)

        for child in node.get("children", []):
            recurse(child, parent_id=nid, is_root=False)

    # kick off recursion, marking this call as the root
    recurse(tree_dict, parent_id=None, is_root=True)

    # write the file; returns the full path with extension
    outpath = dot.render(filename_base, cleanup=True)
    return outpath



trees = [
    "((((1 1) (5 3)) (1 3) (5 3) (8 ((1 1) 1) (4 5 3)) (2 (2 11) ((3 3) 3 3 3 3 3 3 2 3 5 2) (2 3) 3 (2 3) (3 3 2 2 2 5 1))) (1 1 (4 (4 4) 1 4 (4 1 1 4 4) (4 4) 4 4 4 12)))",
    "((((3 2) 5 2) (3 (5 3)) (8 (2 5) 1) (3 (11 (2 4) 5 1) 1 1 (23) (1 5 3))) (2 2 2 2 12))",
    "((((3 2) 5 2) (3 (5 3)) (8 (2 5) 1) (3 (11 (2 4) 5 1) 1 1 (23) (1 5 3))) (2 2 2 2 12))"
]

converted = trees_to_productions(trees)
# Pretty print the first converted tree
import json
print(json.dumps(converted[1], indent=2))

# visualize_tree(converted[0], "example_tree_2", root_label="Root")


{
  "production": 31,
  "children": [
    {
      "production": 29,
      "children": [
        {
          "production": 20,
          "children": [
            {
              "production": 19,
              "children": [
                {
                  "symbol": "3"
                },
                {
                  "symbol": "2"
                }
              ]
            },
            {
              "symbol": "5"
            },
            {
              "symbol": "2"
            }
          ]
        },
        {
          "production": 21,
          "children": [
            {
              "symbol": "3"
            },
            {
              "production": 1,
              "children": [
                {
                  "symbol": "5"
                },
                {
                  "symbol": "3"
                }
              ]
            }
          ]
        },
        {
          "production": 23,
          "children": [
            {
              