In [55]:
import os
import openai
import json
import re
openai.api_key = "" # Add your API key here

In [56]:
# optional. GPT is better
import ollama
ollama.pull("llama3")
ollama.pull("mistral")
ollama.pull("mixtral")
ollama.pull("gemma")
ollama.pull("vicuna")

{'status': 'success'}

In [None]:
ROOT_PATH = os.getcwd()
ROOT_PATH

In [58]:
DEFAULT_JUDGE_LLM = "llama3"
DEFAULT_SOLVER_LLM = "llama3"

LLMS = ["llama2", "llama3", "mistral", "mixtral", "gemma", "vicuna"]

In [59]:
pattern = re.compile(r'\[asy\].*?\[/asy\]', re.DOTALL)

def remove_asy_tags(text):
    """
    Remove the [asy] tags from the solution.
    They are Asymptote code that are used for plotting, which are not relevant here."""
    global pattern
    return re.sub(pattern, '', text)

In [60]:
prompt_examples = [
    # these examples come from the training set of MATH
    # we will use the testing set to evaluate the performance
    {
        "prob": "A point $(x,y)$ is randomly picked from inside the rectangle with vertices  $(0,0)$, $(3,0)$, $(3,2)$, and $(0,2)$.  What is the probability that  $x < y$?",
        "sol": "The point $(x,y)$ satisfies $x < y$ if and only if it belongs to the shaded triangle bounded by the lines $x=y$, $y=2$, and $x=0$, the area of which is 2.  The rectangle has area 6, so the probability in question is $\\dfrac{2}{6} = \\boxed{\\dfrac{1}{3}}$.\n\n[asy]\ndefaultpen(.7);\ndraw((-1,0)--(5,0),Arrow);\ndraw((0,-1)--(0,3),Arrow);\nfor (int i=1; i<4; ++i) {\ndraw((i,-0.1)--(i,0.1));\n}\nfill((0,0)--(0,2)--(2,2)--cycle,gray(0.7));\ndraw((-0.1,1)--(0.1,1));\ndraw((-.1,2)--(0,2));\ndraw((3,0)--(3,2)--(0,2),linewidth(1.0));\ndraw((-0.5,-0.5)--(2.8,2.8),dashed);\n[/asy]",
        "ans": "1/3",
        "correct": True,
    },

    {
        "prob": "Simplify $\\tan \\frac{\\pi}{24} + \\tan \\frac{7 \\pi}{24}.$",
        "sol": "We can write\n\\[\\tan \\frac{\\pi}{24} + \\tan \\frac{7 \\pi}{24} = \\frac{\\sin \\frac{\\pi}{24}}{\\cos \\frac{\\pi}{24}} + \\frac{\\sin \\frac{7 \\pi}{24}}{\\cos \\frac{7 \\pi}{24}} \n= \\frac{\\sin \\frac{\\pi}{24} \\cos \\frac{7 \\pi}{24} + \\cos \\frac{\\pi}{24} \\sin \\frac{7 \\pi}{24}}{\\cos \\frac{\\pi}{24} \\cos \\frac{7 \\pi}{24}}.\\]By the angle addition formula and the product-to-sum formula,\n\\begin{align*}\n\\frac{\\sin \\frac{\\pi}{24} \\cos \\frac{7 \\pi}{24} + \\cos \\frac{\\pi}{24} \\sin \\frac{7 \\pi}{24}}{\\cos \\frac{\\pi}{24} \\cos \\frac{7 \\pi}{24}} &= \\frac{\\sin (\\frac{\\pi}{24} + \\frac{7 \\pi}{24})}{\\frac{1}{2} (\\cos \\frac{\\pi}{3} + \\cos \\frac{\\pi}{4})} \\\\\n&= \\frac{2 \\sin \\frac{\\pi}{3}}{\\cos \\frac{\\pi}{3} + \\cos \\frac{\\pi}{4}} \\\\\n&= \\frac{\\sqrt{3}}{\\frac{1}{2} + \\frac{\\sqrt{2}}{2}} \\\\\n&= \\frac{2 \\sqrt{3}}{1 + \\sqrt{2}} \\\\\n&= \\frac{2 \\sqrt{3} (\\sqrt{2} - 1)}{(\\sqrt{2} + 1)(\\sqrt{2} - 1)} \\\\\n&= \\boxed{2 \\sqrt{6} - 2 \\sqrt{3}}.\n\\end{align*}",
        "ans": "2*sqrt(2)*(sqrt(3)-1)",
        "correct": False,
    },

    {
        "prob": "For real numbers $a,$ $b,$ and $c,$ the matrix\n\\[\\begin{pmatrix} a & b & c \\\\ b & c & a \\\\ c & a & b \\end{pmatrix}\\]is not invertible.  List all possible values of\n\\[\\frac{a}{b + c} + \\frac{b}{a + c} + \\frac{c}{a + b}.\\]",
        "sol": "Since the matrix is not invertible, its determinant is 0, i.e.\n\\[\\begin{vmatrix} a & b & c \\\\ b & c & a \\\\ c & a & b \\end{vmatrix} = 0.\\]The determinant expands as\n\\begin{align*}\n\\begin{vmatrix} a & b & c \\\\ b & c & a \\\\ c & a & b \\end{vmatrix} &= a \\begin{vmatrix} c & a \\\\ a & b \\end{vmatrix} - b \\begin{vmatrix} b & a \\\\ c & b \\end{vmatrix} + c \\begin{vmatrix} b & c \\\\ c & a \\end{vmatrix} \\\\\n&= a(bc - a^2) - b(b^2 - ac) + c(ab - c^2) \\\\\n&= 3abc - a^3 - b^3 - c^3.\n\\end{align*}This factors as\n\\[3abc - a^3 - b^3 - c^3 = -(a + b + c)(a^2 + b^2 + c^2 - ab - ac - bc),\\]so either $a + b + c = 0$ or $a^2 + b^2 + c^2 - ab - ac - bc = 0.$\n\nIf $a + b + c = 0,$ then\n\\[\\frac{a}{b + c} + \\frac{b}{a + c} + \\frac{c}{a + b} = \\frac{a}{-a} + \\frac{b}{-b} + \\frac{c}{-c} = -3.\\]Now, suppose $a^2 + b^2 + c^2 - ab - ac - bc = 0.$  Then\n\\begin{align*}\n(a - b)^2 + (a - c)^2 + (b - c)^2 &= (a^2 - 2ab + b^2) + (a^2 - 2ac + c^2) + (b^2 - 2bc + c^2) \\\\\n&= 2(a^2 + b^2 + c^2 - ab - ac - bc) \\\\\n&= 0.\n\\end{align*}This forces $a = b = c,$ so\n\\[\\frac{a}{b + c} + \\frac{b}{a + c} + \\frac{c}{a + b} = \\frac{3}{2}.\\]Thus, the possible values of\n\\[\\frac{a}{b + c} + \\frac{b}{a + c} + \\frac{c}{a + b}\\]are $\\boxed{\\frac{3}{2}}$ and $\\boxed{-3}.$",
        "ans": "-3 and 1.5",
        "correct": True,
    },
]

In [61]:
def judge_correctness(prob:str, sol:str, ans:str,
                      prompt_examples:str=prompt_examples,
                      llm:str="gpt-3.5-turbo"):
    """
    parameters:
    - prob: the problem
    - sol: the solution
    - ans: the student's answer
    - prompt_examples: a list of dictionaries, each containing a problem, a solution, the student's answer, and whether the student's answer is correct
    - llm: the language model to use.

    return:
        True if the student's answer is correct, False otherwise
        If the LLM's response is invalid, raise an error (rarely happens for GPT but more common for Llama)
    """

    instruct_query = """
Imagine that you are a high school math teacher correcting students' homework. You will be given three things: (1) a math problem in LaTeX, (2) a complete solution to this problem, in which the correct answer is wrapped in a box, i.e., \\boxed{answer} (note that there may be multiple boxes), and (3) the student's answer to this problem. These three items will all be wrapped in three backticks (```). Your task is to determine whether the student's answer is correct. If it is, you should respond with "Yes"; otherwise, you should respond with "No". DO NOT return anything other than "Yes" or "No".

The student's answer is considered correct if it is mathematically equivalent to the final answer in the solution. The correct answer can be a real number, a symbolic formula, or any other mathematical items (e.g., a set, an interval, etc.). For example, if the final answer in the solution is $\\boxed{4}$, then the student's answer is considered correct if it is $4$, $4.0$, $4.00$, $\\frac{8}{2}$, etc, and is considered incorrect if it is $4.1$, $4.01$, $\\frac{1}{4}$, etc; if the final answer in the solution is $\\boxed{4x+5}$, then the student's answer is considered correct if it is $4*x+5$, $4\\times x+5.0$, $(8x+10)/2$, etc, and is considered incorrect if it is $4x+5.1$, $4y+5$, $4X+5$, etc.

Here are some examples:
"""

    for i, ex in enumerate(prompt_examples, start=1):
        instruct_query += f"Example {i}:\nQ: (1) Problem: ```{ex['prob']}``` (2) Solution: ```{remove_asy_tags(ex['sol'])}``` (3) Student's answer: ```{ex['ans']}```.\n A: {'Yes' if ex['correct'] else 'No'}\n\n"

    user_query = (
        "Now, I want you to determine whether the student's answer is correct for the following problem.  Please make sure to ONLY answer in `YES` or `NO` without any mistake: \nQ: (1) Problem: ```"
        + prob
        + "``` (2) Solution: ```"
        + sol
        + "``` (3) Student's answer: ```"
        + ans
        + "```.\nA: "
    )

    # print(f"{instruct_query}\n{user_query}")

    if llm in ['gpt-3.5',"gpt-3.5-turbo", "gpt-4"]:
        response = openai.ChatCompletion.create(
            model=llm,
            messages=[{"role": "system", "content": instruct_query}, {"role": "user", "content": user_query}],
            temperature=0
            ).choices[0].message["content"].lower()

    elif llm in ['llama','llama2','llama3','codellama', 'vicuna']:
        response = ollama.chat(model=llm,
                               messages=[{"role": "system", "content": instruct_query}, {"role": "user", "content": user_query}])[
            "message"
        ]["content"].lower()
    else:
        raise NotImplementedError("Calling this LLM is not implemented")

    # print(f"Response from the Judge LLM: {response}")

    if "yes" in response and "no" not in response:
        return True
    elif "no" in response and "yes" not in response:
        return False
    else:
        raise ValueError("Invalid response from the LLM")

# test the function
judge_correctness(
    prob="Find the phase shift of the graph of $y = 2 \\sin \\left( 2x + \\frac{\\pi}{3} \\right).$",
    sol="Since the graph of $y = 2 \\sin \\left( 2x + \\frac{\\pi}{3} \\right)$ is the same as the graph of $y = 2 \\sin 2x$ shifted $\\frac{\\pi}{6}$ units to the left, the phase shift is $\\boxed{-\\frac{\\pi}{6}}.$\n\n[asy]import TrigMacros;\n\nsize(400);\n\nreal g(real x)\n{\n\treturn 2*sin(2*x + pi/3);\n}\n\nreal f(real x)\n{\n\treturn 2*sin(2*x);\n}\n\ndraw(graph(g,-2*pi,2*pi,n=700,join=operator ..),red);\ndraw(graph(f,-2*pi,2*pi,n=700,join=operator ..));\ntrig_axes(-2*pi,2*pi,-3,3,pi/2,1);\nlayer();\nrm_trig_labels(-4,4, 2);\n[/asy]",
    ans="$-pi/6$",
    llm=DEFAULT_JUDGE_LLM
)

True

In [62]:
def llm_solver(problem:str, data_class, solver_llm):
    """
    parameters:
    - problem: the problem to solve
    - llm: the language model to use.

    return:
        the solution to the problem
    """
    instruct_query = f"""Imagine you are a math expert especially in the field of {data_class}, who is solving a math problem in all seriousness during the exam. You will be given a math problem in LaTeX format, and your task is to provide a complete solution to this problem and show all the calculation steps. The problem will be wrapped in three backticks (```). Your solution should be detailed and clear, and it should be written in LaTeX format. You can assume that the problem is well-posed and has a unique solution. Be extremely accurate. No mistakes allowed."""
    user_query = f"""Here is the problem you need to solve:\n\n```\n{problem}\n```\n\n"""
    if solver_llm in ['gpt-3.5',"gpt-3.5-turbo", "gpt-4"]:
        response = openai.ChatCompletion.create(
            model=solver_llm,
            messages=[{"role": "system", "content": instruct_query}, {"role": "user", "content": user_query}],
            temperature=0
            ).choices[0].message["content"].lower()
    elif solver_llm in ['llama','llama2','llama3','codellama', 'mistral', 'mixtral']:
        response = ollama.chat(model=solver_llm,
                               messages=[{"role": "system", "content": instruct_query}, {
                                   "role": "user", "content": user_query
                               }])["message"]["content"]
    return response

# test the function
llm_solver(
    problem="Find the phase shift of the graph of $y = 2 \\sin \\left( 2x + \\frac{\\pi}{3} \\right).$",
    data_class="algebra",
    solver_llm=DEFAULT_JUDGE_LLM
)

"To find the phase shift, we need to identify the constant term within the argument of the sine function. In this case, it's $\\frac{\\pi}{3}$.\n\nThe general form of a sinusoidal function is $y = A \\sin(B(x - C))$, where $A$ is the amplitude, $B$ is the frequency, and $C$ is the phase shift.\n\nComparing our given function with the general form, we can see that:\n\n* $A = 2$\n* $B = 2$\n* $C = \\frac{\\pi}{3}$\n\nThe phase shift, $C$, is the constant term within the argument of the sine function. Therefore, the phase shift of the graph of $y = 2 \\sin \\left( 2x + \\frac{\\pi}{3} \\right)$ is $\\boxed{\\frac{\\pi}{3}}$.\n\nHere's the solution in LaTeX format:\n\n```latex\nThe phase shift, C, is the constant term within the argument of the sine function.\nTherefore, the phase shift of the graph of $y = 2 \\sin \\left( 2x + \\frac{\\pi}{3} \\right)$ is $\\frac{\\pi}{3}$.\n```"

In [63]:
import random

def test(solver,
         data_class:str,
         dataset:list,
         solver_llm=None,
         levels:list=[1,2,3,4,5],
         judging_llm:str='gpt-3.5-turbo',
         sample_size=None,
         test_mode=False,
    ):
    """
    parameters:
    - solver: the solver to be tested, returns a **string** of the answer
    - data_class: the class of the data, must be in ['algebra','counting_and_probability','geometry','intermediate_algebra','number_theory','prealgebra','precalculus']
    - args_solver: the arguments for the solver (if any)
    - levels: the levels of the data, default is all 5 levels
    - target_dir: the directory of the data (TODO: change to your own!)

    return:
    a list of the accuracy at all levels, whose first position is dummy. e.g., [0,0.9,0.8,0.7,0.6,0.5]
    """
    assert data_class in ['algebra','counting_and_probability','geometry','intermediate_algebra','number_theory','prealgebra','precalculus'], "Invalid data class"
    assert all([level in [1,2,3,4,5] for level in levels]), "Levels must be in 1 to 5"

    print(f'Testing dataset {data_class} with levels {str(sorted(set(levels)))}')

    prob_num=[0, 0, 0, 0, 0] # num of problems at each level
    correct_num=[0, 0, 0, 0, 0] # num of correctly-solved problems at each level

    # retry files that had errors in the first round
    retry_files = []

    # Randomly choose unique elements
    selected_dataset = random.sample(dataset, sample_size)

    for i in range(len(selected_dataset)):
        if test_mode:
            print(f"Testing problem {selected_dataset[i]['filename']} of {data_class}")

        level = selected_dataset[i]["level"]
        if level not in levels:
            continue
        prob_num[level-1]+=1
        prob = selected_dataset[i]["problem"]
        sol = remove_asy_tags(selected_dataset[i]["solution"])
        ans=solver(prob,data_class,solver_llm)

        try:
            is_correct=judge_correctness(prob, sol, ans, llm=judging_llm)
            if is_correct:
                correct_num[level-1]+=1
        except ValueError:
            print(f"Error in judging correctness for problem {files[i]}. Will retry later.")
            retry_files.append(i)
            prob_num[level-1]-=1

        if test_mode:
            print(f"Problem numbers at each level attempted to solve so far: {prob_num}")
            print(f"Correctly-solved problem numbers at each level so far: {correct_num}")

    # Retry the files that have errors
    if len(retry_files)>0:
        print(f"Retrying {len(retry_files)} files")
        for i in range(len(retry_files)):
            if test_mode:
                print(f"Again testing problem {retry_files[i]} of {data_class}")

            level = selected_dataset[retry_files[i]]["level"]
            prob_num[level-1]+=1
            prob = selected_dataset[retry_files[i]]["problem"]
            sol = remove_asy_tags(selected_dataset[retry_files[i]]["solution"])
            try:
                is_correct=judge_correctness(prob, sol, ans, llm=judging_llm)
                if is_correct:
                    correct_num[level-1]+=1
            except ValueError:
                print(f"Error in judging correctness for problem {retry_files[i]}")
                prob_num[level-1]-=1

        if test_mode:
            print(f"Problem numbers at each level attempted to solve so far: {prob_num}")
            print(f"Correctly-solved problem numbers at each level so far: {correct_num}")

    return list(map(lambda x,y: x/y if y!=0 else 0, correct_num, prob_num))

# test the function
ALGEBRA_DATASET_PATH = os.path.join(ROOT_PATH, 'merged_dataset', 'train', 'algebra', 'merged.json')
ALGEBRA_DATASET = json.load(open(ALGEBRA_DATASET_PATH))
test(solver=llm_solver,
     data_class='algebra',
     dataset=ALGEBRA_DATASET,
     levels=[1,2,3,4,5],
     solver_llm=DEFAULT_JUDGE_LLM,
     judging_llm=DEFAULT_JUDGE_LLM,
     sample_size=2,
     test_mode=True
)

Testing dataset algebra with levels [1, 2, 3, 4, 5]
Testing problem 1469 of algebra
Problem numbers at each level attempted to solve so far: [0, 1, 0, 0, 0]
Correctly-solved problem numbers at each level so far: [0, 1, 0, 0, 0]
Testing problem 402 of algebra
Problem numbers at each level attempted to solve so far: [0, 2, 0, 0, 0]
Correctly-solved problem numbers at each level so far: [0, 2, 0, 0, 0]


[0, 1.0, 0, 0, 0]

TESTING ENDS

For Algebra

In [None]:
ALGEBRA_DATASET_PATH = os.path.join(ROOT_PATH, 'merged_dataset', 'train', 'algebra', 'merged.json')
ALGEBRA_DATASET = json.load(open(ALGEBRA_DATASET_PATH))

for llm in LLMS:
    acc = test(solver=llm_solver,
                data_class='algebra',
                dataset=ALGEBRA_DATASET,
                levels=[1,2,3,4,5],
                solver_llm=llm,
                judging_llm=DEFAULT_JUDGE_LLM,
                sample_size=10
            )
    # Weighted average of accuracy with most weightage to the highest level
    mean_acc = sum([acc[i]*i for i in range(1,len(acc))])/sum(range(1,len(acc)))
    print(f"Weighted Mean Accuracy for {llm}: {mean_acc}")

For Calculus

In [None]:
PRECALCULUS_DATASET_PATH = os.path.join(ROOT_PATH, 'merged_dataset', 'train', 'precalculus', 'merged.json')
PRECALCULUS_DATASET = json.load(open(PRECALCULUS_DATASET_PATH))

for llm in LLMS:
    acc = test(solver=llm_solver,
                data_class='precalculus',
                dataset=PRECALCULUS_DATASET,
                levels=[1,2,3,4,5],
                solver_llm=llm,
                judging_llm=DEFAULT_JUDGE_LLM,
                sample_size=10
            )
    # Weighted average of accuracy with most weightage to the highest level
    mean_acc = sum([acc[i]*i for i in range(1,len(acc))])/sum(range(1,len(acc)))
    print(f"Weighted Mean Accuracy for {llm}: {mean_acc}")