In [2]:
import json
from models.symbolic_solvers.clingo_solver import ClingoSolver

In [50]:
def _parse_asp(asp_raw: str) -> str:
    if "###ASP_START###" in asp_raw:
        return _parse_asp_start_end(asp_raw)

    possible_ends = ["\n\n#", "\n\n`", "\n`", "\n#"]

    # find first % symbol index
    start = asp_raw.find("%")
    asp_raw = asp_raw[start:]
    for end in possible_ends:
        if end in asp_raw:
            asp_raw = asp_raw.split(end)[0]
            break
    asp_raw = asp_raw.replace("`", "")

    return asp_raw.strip()

def _parse_asp_start_end(asp_raw: str) -> str:
    START = "###ASP_START###"
    END = "###ASP_END###"

    return asp_raw.split(START)[1].split(END)[0]


def get_answer_letter(answer, options):
    for option in options:
        if answer in option:
            return option[0]

def full_evaluation(result_file):
    with open(result_file, "r") as f:
        all_samples = json.load(f)

    task_type = "validity" if "validity" in result_file else "fill_in"
    solver = lambda x: ClingoSolver.solve(x, task_type == "validity")
    answer_is_correct = []

    for sample in all_samples:
        asp_raw = sample["raw_logic_programs"][0].strip()
        try:
            asp = _parse_asp(asp_raw)
        except:
            print(sample["id"])
            raise Exception("Error parsing ASP")
        try:
            if task_type == "validity":
                answer = "Yes" if solver(asp) else "No"
            else:
                answer = solver(asp)
        except:
            print(sample["id"])
            print(asp)
            print("###################")
            answer = None
        correct_answer = sample["answer"]
        options = sample["options"]
        if answer:
            choice_answer = get_answer_letter(answer, options)
            if choice_answer != correct_answer:

                print(sample["id"], answer, choice_answer, correct_answer)
            answer_is_correct.append(choice_answer == correct_answer)
        else:
            answer_is_correct.append(False)
    print("Accuracy:", sum(answer_is_correct) / len(answer_is_correct))
    print(answer_is_correct)

file = "outputs/logic_programs/graph_fill_in_data_gemini-1.5-pro-preview-0409.json"
full_evaluation(file)

problem_0 red C B
problem_4 green C A
problem_15 green C B
problem_29 yellow D C
problem_37 yellow C A
problem_66 blue B A
problem_68 green B C
problem_72 red A C
problem_78 green C A
problem_80 green A D
problem_83 blue C A
problem_84 yellow D B
problem_111 blue C A
problem_128 blue B D
problem_140 green B A
problem_150 yellow D C
problem_151 purple C B
problem_163 green A B
problem_169 blue C A
problem_173 red B A
problem_180 yellow B A
problem_188 blue C D
problem_189 blue B D
problem_192 yellow B D
problem_197 blue B D
Accuracy: 0.62
[False, True, True, True, False, True, True, False, True, True, True, False, True, True, False, False, False, True, False, False, True, False, True, True, True, True, True, True, True, False, True, False, True, True, True, True, False, False, False, True, True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, False, True, True, False, True, False, True, False, False, True, True, F

In [47]:
import re

def _parse_asp(asp_raw: str):
    possible_ends = ["\n\n#", "\n\n`", "\n`", "\n#"]

    # find first % symbol index
    start = asp_raw.find("%")
    asp_raw = asp_raw[start:]
    for end in possible_ends:
        if end in asp_raw:
            asp_raw = asp_raw.split(end)[0]
            break
    asp_raw = asp_raw.replace("`", "")

    return asp_raw.strip()


def get_answer_letter(answer, options):
    for option in options:
        if answer in option:
            return option[0]

def extract_answer_letter(answer: str) -> str:
    answer = answer.replace("Correct", "").replace("The", "")
    # find first capital letter
    answer = re.search("[A-Z]", answer)
    return answer.group(0)


def full_evaluation_direct(result_file):
    answer_is_correct = []

    with open(result_file, "r") as f:
        all_samples = json.load(f)

    for sample in all_samples:
        output = sample["raw_logic_programs"][0].strip()
        correct_answer = sample["answer"]

        choice_answer = extract_answer_letter(output)
        if choice_answer != correct_answer:
            print(sample["id"], choice_answer, correct_answer, output)
        answer_is_correct.append(choice_answer == correct_answer)

    print("Accuracy:", sum(answer_is_correct) / len(answer_is_correct))
    print(answer_is_correct)

file = "outputs/logic_programs/sudoku_validity_direct_data_gemini-1.5-pro-preview-0409.json"
full_evaluation_direct(file)

problem_0 B A Correct answer: 
B
problem_4 A B Correct answer: 
A
problem_9 A B Correct answer: 
A
problem_11 A B Correct answer: 
A
problem_12 A B Correct answer: 
A
problem_13 A B Correct answer: 
A
problem_14 A B Correct answer: 
A
problem_19 A B Correct answer: 
A
problem_21 A B Correct answer: 
A
problem_22 A B Correct answer: 
A
problem_23 A B Correct answer: 
A
problem_29 A B Correct answer: 
A
problem_32 A B Correct answer: 
A
problem_35 A B Correct answer: 
A
problem_36 A B Correct answer: 
A
problem_38 A B Correct answer:
A
problem_39 A B Correct answer: 
A
problem_40 A B Correct answer: 
A
problem_42 A B Correct answer: 
A
problem_46 A B Correct answer: 
A
problem_48 A B Correct answer: 
A
problem_49 A B Correct answer: 
A
problem_50 A B Correct answer: 
A
problem_53 A B A) Yes 

This Sudoku board is correctly solved. Each row, column, and 3x3 square contains all the digits from 1 to 9 without repetition.
problem_54 A B Correct answer: 
A
problem_55 A B Correct answer: 
A
pr