In [None]:
import os
os.chdir("..")
from vh_eval import *
import json
with open("/home/lch/desktop/ToolkenGPT-main/data/vh/legal_test_v2.json") as f:
    vh_file_list = json.load(f)

In [None]:
import re
from copy import deepcopy
from collections import Counter
from tqdm import *

In [None]:
name_equivalence = utils.load_name_equivalence()

In [None]:
func_dict = json.load(open("/home/lch/desktop/ToolkenGPT-main/data/vh/func_dict.json"))

In [None]:
def check_results(script_path, graph_path, name_equivalence, script, verbose=False, prefix_match=False):
    graph = json.load(open(graph_path))
    graph_init = EnvironmentGraph(graph["init_graph"])
    graph_final = EnvironmentGraph(graph["final_graph"])
    executor = ScriptExecutor(graph_init, name_equivalence)
    init_state = EnvironmentState(graph_init, name_equivalence)
    final_state = EnvironmentState(graph_final, name_equivalence)
    # desc = get_desc(graph_file_name=graph_path, script_file_name=script_path)
    obj_list = set([n["class_name"] for n in graph["init_graph"]["nodes"]])
    if len(obj_list) == 0:
        return False, "FAIL-EMPTY-STATE"

    if verbose:
        print("input script len: ", len(script))
    for cur_length in range(1 if prefix_match else len(script), len(script) + 1):
        # print(stop_idx)
        # try:
        if verbose:
            print("current length: ", cur_length)
        try:
            lines = []
            for line_idx, line in enumerate(script[:cur_length]):
                lines.append(parse_script_line(line, line_idx + 1))
            # execute:
            
            actions = re.findall("\[.*?\]", script[cur_length - 1])
            objs = re.findall("<.*?>", script[cur_length - 1])

            if any([a not in func_dict for a in actions]):
                if verbose:
                    print("action not found", script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
                return False, "FAIL-ACTION-NOT-FOUND"

            if any([o[1:-1] not in obj_list for o in objs]):
                if verbose:
                    print("obj not found", script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
                return False, "FAIL-OBJ-NOT-FOUND"

            s = Script(lines)
            c = verify_script(deepcopy(executor), s, init_state, final_state)
            
            if c['state']:
                # print("success")
                # print(script[:stop_idx])
                return True, ""
            
            elif "failed" in c['desc']:
                # print(c['desc'])
                # print(script[:stop_idx])
                # fail += 1
                # print(c['desc'])
                """
                if "cannot be found" in c['desc']:
                    if verbose:
                        print(c['desc'], script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
                    return False, "FAIL-OBJ-NOT-FOUND"
                if "internal failed" in c['desc'] and "[FIND]" in script[stop_idx - 1] and "None" in c["desc"]:
                    if verbose:
                        print(c['desc'], script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
                    return False, "FAIL-OBJ-NOT-FOUND"
                """                

                # print(c['desc'])
                if verbose:
                    print(c['desc'], script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
                return False, "FAIL-INTERNAL-ERROR-!"
        
        #    break
        # if error


        except Exception as e:
            
            if "Wrong number of parameters" in str(e):
                if verbose:
                    print(e, script, get_desc(graph_file_name=graph_path, script_file_name=script_path))

                return False, "FAIL-NUM-PARAMS"
                    
            if "Cannot parse action" in str(e):
                if verbose:
                    print(script)
                    print(e, script[stop_idx-1], get_desc(graph_file_name=graph_path, script_file_name=script_path))
                return False, "FAIL-PARSE-ACTION"
            if "Unknown" in str(e):
                if verbose:
                    print(e, script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
                return False, "FAIL-UNKOWN"
            
            if verbose:
                print(e, script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
            # print(script[:stop_idx])
            return False, "FAIL-EXCEPTION-!"
        
    if "not correct" in c['desc']:
        # print("not correct")
        # nc += 1
        # print(script[:stop_idx])
        if verbose:
            print(c['desc'], script, get_desc(graph_file_name=graph_path, script_file_name=script_path))
        return False, "FAIL-WRONG-STATE"
    
    return False, "FAIL-MISSED-!"

In [None]:
with open("") as f:
    embedding_outputs = [json.loads(line) for line in f.readlines()]

In [None]:
correct = 0
correct_relax = 0
fail = Counter()
fail_relax = Counter()
for test_idx in trange(len(embedding_outputs)):
    # print("="*20)
    # test_idx = 1
    script_path = vh_file_list[test_idx][0]
    graph_path = vh_file_list[test_idx][1]
    # correct_count = 0
    # baseline_script = baseline_outputs[test_idx]["generation"].replace(">", "> (1)").split("\\n")[:15]
    embedding_script = "\n".join(embedding_outputs[test_idx]["func_calls"]).replace("]<", "] <").replace(">", "> (1)").split("\n")[:8]
    
    for idx, e in enumerate(embedding_script):
        if "[END]" in e:
            embedding_script = embedding_script[:idx]
            break
    
    # print(embedding_script)
    # print(test_idx, len(embedding_script))
    cor, reason = check_results(script_path, graph_path, name_equivalence, embedding_script, verbose=False, prefix_match=False)
    cor_relax, reason_relax = check_results(script_path, graph_path, name_equivalence, embedding_script, verbose=False, prefix_match=True)
    if not cor_relax:
        fail_relax[reason_relax] += 1
    else:
        correct_relax += 1

    if not cor:
    #     check_results(script_path, graph_path, name_equivalence, embedding_script, verbose=True)
        fail[reason] += 1
    else:
        correct += 1
    # break

In [None]:
print("correct", correct, len(embedding_outputs), correct/len(embedding_outputs))
for k, v in fail.items():
    print(k, v, len(embedding_outputs), v/len(embedding_outputs))

print("="*60)

print("correct", correct_relax, len(embedding_outputs), correct_relax/len(embedding_outputs))
for k, v in fail_relax.items():
    print(k, v, len(embedding_outputs), v/len(embedding_outputs))