In [None]:
import json
import glob


oracle_trajectories_train = []
oracle_results = {
    "/plancraft/outputs/oracle_symb/train/0/*.json": [],
    "/plancraft/outputs/oracle_symb/val/0/*.json": [],
}
x, c = 0,0
for path in oracle_results.keys():
    for f in glob.glob(path):
        with open(f, "r") as file:
            traj = json.load(file)
        if (
            len(traj["model_trace"]["inventory_history"]) == len(traj["model_trace"]["action_history"])
        ):
            # ignore impossible actions
            oracle_results[path].append(traj)

In [None]:
from plancraft.models.prompts import get_system_prompt
from collections import defaultdict

SYSTEM_PROMPT = get_system_prompt(actions=["move", "smelt"])

SYSTEM_PROMPT

In [None]:
import os
from plancraft.environments.actions import convert_from_slot_index
from plancraft.models.utils import objective_and_inventory_to_str

def convert_action_to_text(action: dict):
    slot_from = convert_from_slot_index(action['slot_from'])
    slot_to = convert_from_slot_index(action['slot_to'])
    action_type = action["action_type"]
    return f"{action_type}: from {slot_from} to {slot_to} with quantity {action['quantity']}"


# convert action and inventory to dialogue history
def convert_trajectory_to_base_dialogue(traj: dict):
    dialogue = [{"role": "system", "content": SYSTEM_PROMPT}]
    objective = traj["model_trace"]["objective"]
    for action, inventory in zip(
        traj["model_trace"]["action_history"],
        traj["model_trace"]["inventory_history"],
    ):
        dialogue.append(
            {
                "role": "user",
                "content": objective_and_inventory_to_str(objective, inventory),
            }
        )
        dialogue.append(
            {
                "role": "assistant",
                "content": convert_action_to_text(action),
            }
        )
    example = {
        "messages": dialogue,
        "example_id": traj["example_id"],
    }
    return example

In [None]:
text_data = defaultdict(list)
for path, trajs in oracle_results.items():
    split = path.split("/")[-3]
    for traj in trajs:
        text_example = convert_trajectory_to_base_dialogue(traj)
        text_data[split].append(text_example)

In [None]:
for split in text_data:
    for example in text_data[split]:
        # save under data/oracle/{split}/oa/{example_id}.json
        example_id = example["example_id"]
        example_path = os.path.join("../data/oracle", split, "oa", f"{example_id}.json")
        os.makedirs(os.path.dirname(example_path), exist_ok=True)
        with open(example_path, "w") as f:
            f.write(json.dumps(example, indent=2))