In [1]:
import os
import json
import time
import random
import warnings
import argparse

from numpy.random import seed
from extended_definition import ExtendedDefinition
from logic_parser import LogicParser
from logic_solver import LogicSolver
from func_timeout import func_timeout, FunctionTimedOut
from tqdm import tqdm
from multiprocess import Pool

warnings.filterwarnings('ignore')

seed(0)
random.seed(0)

In [2]:
label = 'final_new'
strategy = 'final'
use_annotated = False
time_limit = 150
num_threads = 20
low_first = True
step_limit = 100
debug_mode = False
enable_round = False
enable_predict = True

In [3]:
def solve_one_problem(text_parser, diagram_parser,order_lst):

    ## Set up the logic parser
    parser = LogicParser(ExtendedDefinition(debug=debug_mode))

    if diagram_parser is not None:
        # Define diagram primitive elements
        parser.logic.point_positions = diagram_parser['point_positions']

        isLetter = lambda ch: ch.upper() and len(ch) == 1
        parser.logic.define_point([_ for _ in parser.logic.point_positions if isLetter(_)])
        if debug_mode:
            print(parser.logic.point_positions)

        lines = diagram_parser['line_instances']  # ['AB', 'AC', 'AD', 'BC', 'BD', 'CD']
        for line in lines:
            line = line.strip()
            if len(line) == 2 and isLetter(line[0]) and isLetter(line[1]):
                parser.logic.define_line(line[0], line[1])

        circles = diagram_parser['circle_instances']  # ['O']
        for point in circles:
            parser.logic.define_circle(point)

        # Parse diagram logic forms
        logic_forms = diagram_parser['diagram_logic_forms']
        logic_forms = sorted(logic_forms, key=lambda x: x.find("Perpendicular") != -1)  # put 'Perpendicular' to the end

        for logic_form in logic_forms:
            if logic_form.strip() != "":
                if debug_mode:
                    print("The diagram logic form is", logic_form)
                try:
                    parse_tree = parser.parse(logic_form) # ['Equals', ['LengthOf', ['Line', 'A', 'C']], '10']
                    parser.dfsParseTree(parse_tree)
                except Exception as e:
                    if debug_mode:
                        print("\033[0;0;41mError:\033[0m", repr(e))

    ## Parse text logic forms
    target = None
    text_logic_forms = text_parser["text_logic_forms"]
    for text in text_logic_forms:
        if debug_mode:
            print("The text logic form is", text)
        if text.find('Find') != -1:
            target = parser.findTarget(parser.parse(text)) # ['Value', 'A', 'C']
        else:
            res = parser.parse(text)
            parser.dfsParseTree(res)

    if debug_mode:
        print("The predicting sequence is", order_lst)

    ## Set up, initialize and run the logic solver
    solver = LogicSolver(parser.logic)
    solver.initSearch()
    answer, steps, step_lst = solver.Search(target=target,
                                            order_list=order_lst,
                                            round_or_step=enable_round,
                                            upper_bound=round_limit if enable_round else step_limit,
                                            enable_low_first=low_first)

    return target, answer, steps, step_lst

In [4]:
def multithread_solve(parameters ):
    text_logic_form, diagram_logic_form,order_lst = parameters
    target, answer, steps, step_lst = None, None, 0, []

    # solve the #index problem
    solve_problem_start = time.time()
    if debug_mode:
        target, answer, steps, step_lst = solve_one_problem(text_logic_form, diagram_logic_form,order_lst)
    else:
        try:
            target, answer, steps, step_lst = func_timeout(time_limit, solve_one_problem,
                                                           kwargs=dict(text_parser=text_logic_form,
                                                                 diagram_parser=diagram_logic_form,
                                                                 order_lst=order_lst
                                                                 ))
        except FunctionTimedOut:
            pass
        except Exception as e:
            if args.debug_mode:
                print("\033[0;0;41mError:\033[0m", repr(e))
    time_interval = time.time() - solve_problem_start

    # solved result
    answer = float(answer) if answer is not None else answer
    entry = {'pid': index, 'target': target, 'guess': answer, 'correctness': 'no',
             'steps': steps, 'step_lst': step_lst, 'time': round(time_interval, 2)}

    # ground truth
    data_json = json.load(open(os.path.join(args.data_path, 'test', str(index), "data.json")))
    value_list = data_json['precise_value']  # [5.0, 12.0, 13.0, 26.0]
    gt_id = ord(data_json['answer']) - 65  # 0

    # validate the predicted answer
    if answer is not None:
        # all choice candidates are valid, and the answer is the closest one to ground truth among choice candidates
        try:
            if all([x is not None for x in value_list]) and \
                    abs(value_list[gt_id] - answer) == min([abs(x - answer) for x in value_list]):
                entry['correctness'] = 'yes'
        except Exception as e:
            if args.debug_mode:
                print("\035[0;0;41mError:\0353[0m", repr(e), value_list, gt_id, answer)
        if entry['correctness'] == 'yes':
            if args.debug_mode:
                print("\033[0;0;42mCorrect_answer:\033[0m ", end="")  # green
        else:
            if args.debug_mode:
                print("\033[0;0;43mWrong_answer:\033[0m ", end="")  # yellow

    if args.debug_mode:
        print(entry)
    return entry

In [5]:
text_logic_form_path = "../text_parser/oneprob_text_logic_forms.json"
diagram_logic_form_path = "../diagram_parser/oneprob_diagram_logic_forms.json"
predict_path = "../theorem_predict/results/test/oneprob_pred_seq_test_bart_best.json"

In [6]:
text_logic_table = json.load(open(text_logic_form_path, "r"))
diagram_logic_table = None
if diagram_logic_form_path is not None:
    diagram_logic_table = json.load(open(diagram_logic_form_path, "r"))

predict_table = None
if predict_path is not None:
    predict_table = json.load(open(predict_path, "r"))

In [7]:
print(text_logic_table)
print(diagram_logic_table)
print(predict_table)

{'2409': {'text_logic_forms': ['Find(y)'], 'output_text': '[Find(y)].', 'success': True}}
{'2409': {'log': [], 'point_instances': ['point_0', 'A', 'B', 'C', 'point_4'], 'line_instances': ['point_0C', 'point_0point_4', 'CA', 'CB', 'Cpoint_4', 'point_4B', 'AB'], 'circle_instances': [], 'diagram_logic_forms': ['Perpendicular(Line(C, A), Line(C, B))', 'Equals(LengthOf(Line(C, A)), y)', 'Equals(MeasureOf(Angle(C, B, A)), 60)', 'Equals(LengthOf(Line(A, B)), x)', 'Equals(LengthOf(Line(C, B)), 21)', 'Equals(MeasureOf(Angle(B, A, C)), 30)'], 'point_positions': {'null': [99.48593646591661, 211.24867637326275], 'A': [0.9681032687903439, 1.205912971059746], 'B': [206.28674040780763, 176.5769234855971], 'C': [76.06045424181697, 218.5434201736807]}}}
{'2409': {'id': '2409', 'num_seqs': 5, 'seq': [[], [], [], [], []]}}


In [8]:
str_index = '2409'
text_logic_form, diagram_logic_form, order_lst= None,None,None
if text_logic_table is not None:
    text_logic_form = text_logic_table.get(str_index)

if diagram_logic_table is not None:
    diagram_logic_form = diagram_logic_table.get(str_index)
    
if enable_predict and predict_table is not None:
    if str_index in predict_table:
        order_lst = predict_table[str_index]['seq']
        if isinstance(order_lst[0], list):
            order_lst = order_lst[0]

In [9]:
print(text_logic_form)
print(diagram_logic_form)
print(order_lst)

{'text_logic_forms': ['Find(y)'], 'output_text': '[Find(y)].', 'success': True}
{'log': [], 'point_instances': ['point_0', 'A', 'B', 'C', 'point_4'], 'line_instances': ['point_0C', 'point_0point_4', 'CA', 'CB', 'Cpoint_4', 'point_4B', 'AB'], 'circle_instances': [], 'diagram_logic_forms': ['Perpendicular(Line(C, A), Line(C, B))', 'Equals(LengthOf(Line(C, A)), y)', 'Equals(MeasureOf(Angle(C, B, A)), 60)', 'Equals(LengthOf(Line(A, B)), x)', 'Equals(LengthOf(Line(C, B)), 21)', 'Equals(MeasureOf(Angle(B, A, C)), 30)'], 'point_positions': {'null': [99.48593646591661, 211.24867637326275], 'A': [0.9681032687903439, 1.205912971059746], 'B': [206.28674040780763, 176.5769234855971], 'C': [76.06045424181697, 218.5434201736807]}}
[]


In [10]:
para_lst = []
para_lst.append((text_logic_form, diagram_logic_form,order_lst))

In [11]:
solve_list = []
with Pool(num_threads) as p:
    for answer in p.imap_unordered(multithread_solve, para_lst):
        solve_list.append(answer)

NameError: name 'time' is not defined