# Process the generated queries and code data

In [9]:
import os 
import sys
import numpy as np
import json 
import tqdm
import re 

sys.path.append("../../") # add package root to path

## Process queries

In [13]:
def load_queries(task_queries_file):
    """
    Load task queries from txt file. The first line is the world context, and the rest are task queries line by line:
    ''' 
    
    objects = [table, cabinet, cabinet.drawer0, cabinet.drawer1, cabinet.drawer2, cabinet.drawer3, panda_robot] ; # open cabinet.drawer0
    ...
    '''
    """
    with open(task_queries_file, 'r') as f:
        lines = f.readlines()
    
    # use regex to extract the query in each line:
    # objects = [table, cabinet, cabinet.drawer0, cabinet.drawer1, cabinet.drawer2, cabinet.drawer3, panda_robot] ; # open cabinet.drawer0

    valid_line_pattern = re.compile(r'(?P<context>objects.*);\s*#(?P<query>.*)')
    task_queries = []
    for line in lines:
        match = valid_line_pattern.match(line)
        if match:
            context = match.group('context')
            query = match.group('query')
            task_query = context + "; #" + query
            task_queries.append(task_query)

    return task_queries

In [17]:
task_query_dir = "../../data/task_queries/"

# get number of valid queries in each task query file
num_task_queries = {}
for i in range(0, 99):
    task_query_file = f"table_cabinet_{i}.txt"
    task_query_path = os.path.join(task_query_dir, task_query_file)
    if not os.path.exists(task_query_path):
        continue
    task_queries = load_queries(task_query_path)
    num_task_queries[task_query_file] = len(task_queries)

# get all task queries files without queries from num_task_queries
corrupted_task_query_files = [k for k, v in num_task_queries.items() if v == 0]
print(f"corrupted_task_query_files: {corrupted_task_query_files}")
print(f"Number of corrupted task query files: {len(corrupted_task_query_files)}")


corrupted_task_query_files: ['table_cabinet_5.txt', 'table_cabinet_7.txt', 'table_cabinet_8.txt', 'table_cabinet_12.txt', 'table_cabinet_18.txt', 'table_cabinet_19.txt', 'table_cabinet_20.txt', 'table_cabinet_22.txt', 'table_cabinet_23.txt', 'table_cabinet_24.txt', 'table_cabinet_25.txt', 'table_cabinet_27.txt', 'table_cabinet_32.txt', 'table_cabinet_35.txt', 'table_cabinet_36.txt', 'table_cabinet_37.txt', 'table_cabinet_40.txt', 'table_cabinet_43.txt', 'table_cabinet_53.txt', 'table_cabinet_59.txt', 'table_cabinet_61.txt', 'table_cabinet_64.txt', 'table_cabinet_68.txt', 'table_cabinet_71.txt', 'table_cabinet_73.txt', 'table_cabinet_75.txt', 'table_cabinet_77.txt', 'table_cabinet_78.txt', 'table_cabinet_79.txt', 'table_cabinet_80.txt', 'table_cabinet_81.txt', 'table_cabinet_82.txt', 'table_cabinet_85.txt', 'table_cabinet_88.txt', 'table_cabinet_89.txt', 'table_cabinet_90.txt', 'table_cabinet_92.txt', 'table_cabinet_96.txt', 'table_cabinet_98.txt']
Number of corrupted task query files: 

In [11]:
num_task_queries

{}

## Process code

Code generation stopped at table_cabinet_48

In [2]:
def process_raw_output(raw_path, processed_path):
    """
    Convert raw output json to {query: code} pairs

    Raw output json:
    [{
        "context": context,
        "query": use_query,
        "src_fs": src_fs,
        "code_str": code_str,
        "gvars": list(gvars.keys()),
        "lvars": list(lvars.keys()),
    },
    ...
    ]
    """
    with open(raw_path, 'r') as f:
        raw_data = json.load(f)
    
    processed_data = []
    for data in raw_data:
        context = data['context']
        query = data['query']
        query = context + query

        src_fs = data['src_fs']
        code = data['code_str']
        if len(src_fs) > 0:
            fs_definition_str = '\n'.join([v for k, v in src_fs.items()])
            code = fs_definition_str + '\n' + code
        
        processed_data.append({
            "query": query,
            "code": code
        })

    with open(processed_path, 'w') as f:
        json.dump(processed_data, f, indent=4)

In [5]:
gen_code_dir = "../../data/generated_code"

# convert all raw_table_cabinet_i.json to processed_table_cabinet_i.json
idx_start = 0
idx_end = 47
raw_files_idx_not_exist = []
raw_files_idx_process_failed = []
for i in range(idx_start, idx_end+1):
    raw_path = os.path.join(gen_code_dir, f"raw_table_cabinet_{i}.json")
    processed_path = os.path.join(gen_code_dir, f"processed_table_cabinet_{i}.json")
    if not os.path.exists(raw_path):
        raw_files_idx_not_exist.append(i)
        continue
    try:
        process_raw_output(raw_path, processed_path)
        print(f"Finish processing {raw_path} to {processed_path}")
    except:
        raw_files_idx_process_failed.append(i)
        continue

Finish processing ../../data/generated_code/raw_table_cabinet_0.json to ../../data/generated_code/processed_table_cabinet_0.json
Finish processing ../../data/generated_code/raw_table_cabinet_1.json to ../../data/generated_code/processed_table_cabinet_1.json
Finish processing ../../data/generated_code/raw_table_cabinet_2.json to ../../data/generated_code/processed_table_cabinet_2.json
Finish processing ../../data/generated_code/raw_table_cabinet_3.json to ../../data/generated_code/processed_table_cabinet_3.json
Finish processing ../../data/generated_code/raw_table_cabinet_4.json to ../../data/generated_code/processed_table_cabinet_4.json
Finish processing ../../data/generated_code/raw_table_cabinet_6.json to ../../data/generated_code/processed_table_cabinet_6.json
Finish processing ../../data/generated_code/raw_table_cabinet_9.json to ../../data/generated_code/processed_table_cabinet_9.json
Finish processing ../../data/generated_code/raw_table_cabinet_10.json to ../../data/generated_cod

In [6]:
print(raw_files_idx_not_exist)
print(raw_files_idx_process_failed)

[5, 7, 8, 12, 18, 19, 20, 22, 23, 24, 25, 27, 32, 35, 36, 37, 40, 43]
[]


In [18]:
# calculate the number of processed data across all files
idx_start = 0
idx_end = 99
num_processed_data_list = []
for i in range(idx_start, idx_end+1):
    processed_path = os.path.join(gen_code_dir, f"processed_table_cabinet_{i}.json")
    if not os.path.exists(processed_path):
        continue
    with open(processed_path, 'r') as f:
        processed_data = json.load(f)
    num_processed_data_list.append(len(processed_data))

num_processed_data = np.sum(num_processed_data_list)
print(f"Number of processed data: {num_processed_data}")

Number of processed data: 4303
