In [1]:
import json
import os
import re

In [29]:
def parse_pddl_sections(text, section_name):
    lines = text.split('\n')
    start_index = -1
    end_index = -1
    predicates = []
    
    # Find the start of the section
    for i, line in enumerate(lines):
        if line.strip().startswith(f"(:{section_name}"):
            start_index = i
            break
    
    # If the start was found, look for the end
    if start_index != -1:
        for i, line in enumerate(lines[start_index+1:], start=start_index+1):
            if line.strip() == ")":
                end_index = i
                break
    
    # Extract predicates if both start and end were found
    if start_index != -1 and end_index != -1:
        predicates = [line.strip().strip('()') for line in lines[start_index+1:end_index] if line.strip() and line.strip() != "(and"]
    
    return predicates

def parse_pddl(text):
    initial_state = parse_pddl_sections(text, "init")
    goal_state = parse_pddl_sections(text, "goal")
    # Further clean goal_state to remove any trailing parentheses in elements
    goal_state = [predicate.rstrip(')') for predicate in goal_state]
    return initial_state, goal_state

In [48]:
# step = 2
# step = 4
step = 6

f = open(f"data/step_{step}.json")

data = json.load(f)

for elem in data:
    path, gt_actions, _ = elem
    filename = os.path.basename(path)
    real_path = f"data/{filename}"
    
    with open(real_path, 'r') as file:
        file_contents = file.read()
    # print(file_contents)
    init_state, goal_state = parse_pddl(file_contents)

    gt_actions = [action.strip("()") for action in gt_actions.strip().split("\n")]
    # print(gt_actions)

    parsed_file = {
        "init_state": init_state,
        "goal_state": goal_state,
        "action_seq": gt_actions
    }

    new_path = f"data_parsed/step_{step}/{filename}"
    
    with open(new_path, 'w') as json_file:
        json.dump(parsed_file, json_file, indent=4)

In [33]:
from env import BlocksWorld
from world_model import BlocksWorldModelInit
import random
from collections import defaultdict

dummy_model = BlocksWorldModelInit()
env = BlocksWorld()
transitions = defaultdict(list)

num_levels = 100
ep_len = 8
dir = "data_parsed/step_6"
save_path = "data_parsed/test_cases.json"
files = os.listdir(dir)

for _ in range(num_levels):
    rand_file_path = f"{dir}/{random.choice(files)}"
    f = open(rand_file_path)
    data = json.load(f)
    state = data["init_state"]
    
    for _ in range(ep_len):

        rand_action = random.choice(dummy_model.suggest_actions(state))
        action_type = rand_action.split(" ")[0]
        next_state = env.state_transition(set(state), rand_action)

        trans = {
            "state" : list(state),
            "action" : rand_action,
            "next_state" : list(next_state)
        }

        transitions[action_type].append(trans)
        state = next_state

for key, value in transitions.items():
    print(key, len(value))

tc_per_action_type = 100

for key in transitions: 
    transitions[key] = transitions[key][:tc_per_action_type]

with open(save_path, "w") as json_file:
    json.dump(transitions, json_file, indent=1)

pick-up 157
stack 253
unstack 243
put-down 147


unstack 170
stack 115
pick-up 130
put-down 85
