In [1]:
import ast
import json
import math
import os


# this function is used to round the result to 2 decimal places
# e.g. 52.3523 -> 52.35, 52.0011 -> 52, 0.00000233 -> 0.0000023
def custom_round(x, decimal_places=2):
    str_x = f"{x:.10f}"
    before_decimal = str_x.split(".")[0]
    after_decimal = str_x.split(".")[1]
    leading_zeros = len(after_decimal) - len(after_decimal.lstrip("0"))

    if leading_zeros >= 1 and before_decimal == "0":
        return round(x, leading_zeros + 2)
    else:
        return round(x, decimal_places)


# this function converts a number in scientific notation to decimal notation
def scito_decimal(sci_str):
    def split_exponent(number_str):
        parts = number_str.split("e")
        coefficient = parts[0]
        exponent = int(parts[1]) if len(parts) == 2 else 0
        return coefficient, exponent

    def multiplyby_10(number_str, exponent):
        if exponent == 0:
            return number_str

        if exponent > 0:
            index = number_str.index(".") if "." in number_str else len(number_str)
            number_str = number_str.replace(".", "")
            new_index = index + exponent
            number_str += "0" * (new_index - len(number_str))
            if new_index < len(number_str):
                number_str = number_str[:new_index] + "." + number_str[new_index:]
            return number_str

        if exponent < 0:
            index = number_str.index(".") if "." in number_str else len(number_str)
            number_str = number_str.replace(".", "")
            new_index = index + exponent
            number_str = "0" * (-new_index) + number_str
            number_str = "0." + number_str
            return number_str

    coefficient, exponent = split_exponent(sci_str)
    decimal_str = multiplyby_10(coefficient, exponent)

    # remove trailing zeros
    if "." in decimal_str:
        decimal_str = decimal_str.rstrip("0")

    return decimal_str


# normalize the result to 2 decimal places and remove trailing zeros
def normalize(res, round_to=2):
    # we round the result to 2 decimal places
    res = custom_round(res, round_to)
    res = str(res)
    if "." in res:
        while res[-1] == "0":
            res = res[:-1]
        res = res.strip(".")

    # scientific notation
    if "e" in res:
        res = scito_decimal(res)

    return res


# 1. add
def add_(*args):
    return normalize(sum(args))


# 2. subtract
def subtract_(*args):
    res = args[0]
    for arg in args[1:]:
        res -= arg
    return normalize(res)


# 3. multiply
def multiply_(*args):
    res = args[0]
    for arg in args[1:]:
        res *= arg
    return normalize(res)


# 4. divide
def divide_(*args):
    res = args[0]
    for arg in args[1:]:
        res /= arg
    return normalize(res)


# 5. power
def power_(*args):
    res = args[0]
    for arg in args[1:]:
        res **= arg
    return normalize(res)


# 6. square root
def sqrt_(*args):
    res = args[0]
    return normalize(math.sqrt(res))


# 7. 10th log
def log_(*args):
    # if only one argument is passed, it is 10th log
    if len(args) == 1:
        res = args[0]
        return normalize(math.log10(res))
    # if two arguments are passed, it is log with base as the second argument
    elif len(args) == 2:
        res = args[0]
        base = args[1]
        return normalize(math.log(res, base))
    else:
        raise Exception("Invalid number of arguments passed to log function")


# 8. natural log
def ln_(*args):
    res = args[0]
    return normalize(math.log(res))


# 9. choose
def choose_(*args):
    n = int(args[0])
    r = int(args[1])
    return normalize(math.comb(n, r))


# 10. permutation
def permutate_(*args):
    n = int(args[0])
    r = int(args[1])
    return normalize(math.perm(n, r))


# 11. greatest common divisor
def gcd_(*args):
    res = args[0]
    for arg in args[1:]:
        res = math.gcd(int(res), int(arg))
    return normalize(res)


# 12. least common multiple
def lcm_(*args):
    res = int(args[0])
    for arg in args[1:]:
        res = res * int(arg) // math.gcd(res, int(arg))
    return normalize(res)


# 13. remainder
def remainder_(*args):
    dividend = args[0]
    divisor = args[1]
    return normalize(dividend % divisor)


OPERATORS = {
    "add": add_,
    "subtract": subtract_,
    "multiply": multiply_,
    "divide": divide_,
    "power": power_,
    "sqrt": sqrt_,
    "log": log_,
    "ln": ln_,
    "choose": choose_,
    "permutate": permutate_,
    "gcd": gcd_,
    "lcm": lcm_,
    "remainder": remainder_,
}


import re
def check_answer(answer):
    try: 
        toolcall = re.search(r'<(.*?)>\((.*?)\)', answer).group(1)
        answer_partial = re.search(r'<(.*?)<eoe>', answer).group(1)
        arguments = re.search(r'\((.*?)\)', answer_partial).group(1)
        result = re.search(r'=(.*?)<eoe>', answer).group(1).split("=")[-1].strip()
    except:
        return False
    OPERATORS = {
    "add": add_,
    "subtract": subtract_,
    "multiply": multiply_,
    "divide": divide_,
    "power": power_,
    "sqrt": sqrt_,
    "log": log_,
    "ln": ln_,
    "choose": choose_,
    "permutate": permutate_,
    "gcd": gcd_,
    "lcm": lcm_,
    "remainder": remainder_,
    }
    try:
        ground_truth = OPERATORS[toolcall](*[float(x) for x in arguments.split(",")])
    except:
        return False
    ground_truth = custom_round(float(ground_truth))
    result = custom_round(float(result))
    print("ground_truth: ", ground_truth)
    print("result: ", result)
    
    return ground_truth == result




In [2]:
dir_path = "../data/funcqa/training_data"
total_line = 0
total_correct = 0

for file_path in os.listdir(dir_path):
    if file_path.endswith(".jsonl"):
        with open(os.path.join(dir_path, file_path), "r") as f:
            lines = f.readlines()
        with open(os.path.join(dir_path, file_path), "w") as f:
            for line in lines:
                if line.startswith("{"):
                    sample = json.loads(line)
                    total_line += 1
                    if check_answer(sample["answer"]):
                        total_correct += 1
                        f.write(line)

print(f"Total correct: {total_correct}/{total_line} = {total_correct/total_line*100:.2f}%")

ground_truth:  4.0
result:  4.0
ground_truth:  1.0
result:  1.0
ground_truth:  5.0
result:  5.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  4.0
result:  4.0
ground_truth:  3.0
result:  3.0
ground_truth:  18.0
result:  18.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  3.0
result:  3.0
ground_truth:  2.0
result:  2.0
ground_truth:  8.0
result:  8.0
ground_truth:  38.0
result:  38.0
ground_truth:  4.0
result:  4.0
ground_truth:  1.0
result:  1.0
ground_truth:  9.0
result:  9.0
ground_truth:  1.0
result:  1.0
ground_truth:  2.0
result:  2.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
ground_truth:  1.0
result:  1.0
grou