In [None]:
!pip install langfun --pre

In [None]:
import re
import random
import pandas as pd
import langfun as lf
import pyglove as pg
from typing import Literal, Annotated

In [None]:
openai_key = "<your OpenAI key>"

In [None]:
generate_lm = lf.llms.Gpt4_0613(temperature=0.3, max_tokens=512, timeout=100, stop=["```\n"], api_key=openai_key)
valid_lm = lf.llms.Gpt4_0613(temperature=0.0, max_tokens=512, timeout=100, stop=["```\n"], api_key=openai_key)
verifier_lm = lf.llms.Gpt4_0613(temperature=0.3, max_tokens=512, timeout=100, stop=["```\n"], api_key=openai_key)
answer_lm = lf.llms.Gpt4_0613(temperature=0.0, max_tokens=512, timeout=100, stop=["```\n"], api_key=openai_key)

# Game24 Solver

## Schemas (41 lines of code)

In [None]:
class Step(pg.Object):
  expression: str
  left_numbers: list[int]


class GenerateInput(pg.Object):
  input: list[int]
  expression_blacklist: Annotated[list[str], "Expressions that should be skipped."]

class Generate(pg.Object):
  expression: Annotated[str, "Expression with the two input numbers and basic arithmetic operations (+ - * /). Calculate the answer correctly. The Expression should be different from the blacklist."]
  numbers_from_input: list[int]
  remain_numbers_in_input: list[int]
  result_of_expression: int | float

class ValidationInput(pg.Object):
  input: list[int]
  step: Step

class Validation(pg.Object):
  expression: str
  input_check: Annotated[list[str], "Check whether the numbers in the expression including repetitions are a part of, or the complete list of, the given input."]
  math_correctness_check: str
  left_numbers_check: Annotated[str, "Check whether the left numbers of step are the numbers that are not included in the expression, combined with the result of the expression."]
  judgement: Annotated[Literal["Valid", "Invalid"], "Conclude whether the step pass all validations: math check, input check and the verification of left numbers."]

class VerifierInput(pg.Object):
  input: list[int]
  expression_blacklist: Annotated[list[str], "Expressions that should be skipped in the drafts."]

class Verifier(pg.Object):
  expression_drafts: Annotated[list[str], "Expressions with all input numbers and basic arithmetic operations (+ - * /). Calculate the answer step by step. List ~15 unique drafts."]
  expression_that_obtains_24: str | None

class Game24SolutionInput(pg.Object):
  numbers_in_final_expression: list[int]
  hints: list[str]

class Game24Solution(pg.Object):
  figure_out_solution_step_by_step: list[str]
  final_expression: str | None

## One-shot examples (treated as data)

In [None]:
generate_few_shots = [
    lf.structured.mapping.MappingExample(
        input=GenerateInput(
            input=[1, 4, 8, 11], expression_blacklist=["1 + 4 = 5"]
        ),
        output=Generate(
            expression="1 + 11 = 12",
            numbers_from_input=[1, 11],
            remain_numbers_in_input=[4, 8],
            result_of_expression=12,
        ),
    ),
]

valid_few_shots = [
    lf.structured.mapping.MappingExample(
        input=ValidationInput(
            input=[4, 7, 13],
            step=Step(
                expression="7 * (13 / 4) = 7 * 3 = 21",
                left_numbers=[21],
            ),
        ),
        output=Validation(
            expression="7 * (13 / 4)",
            input_check=[
                "1. Extract the numbers from the expression by removing all operators and parentheses from '7 * (13 / 4)'. We get all numbers of the expression are 7, 13, and 4.",
                "2. The first thing is to check the extracted numbers' size is less than or equal to the input numbers' size. The extracted list [7, 13, 4] has 3 number while the input list [4, 7, 13] has 3 numbers. So the size check passes.",
                "3. Then check if all numbers from extracted list [7, 13, 4] are in the input list [4, 7, 13]. First, [7, 13, 4] has 7 and [4, 7, 13] has 7. After removing 7 from both lists, we get [13, 4] and [4, 13]. Then, [13, 4] has 13 and [4, 13] has 13. After removing 13 from both lists, we get [4] and [4]. As 4 is in both lists, all numbers from extracted list [7, 13, 4] are in the input list [4, 7, 13]. So the check passes.",
            ],
            math_correctness_check=(
                "7 * (13 / 4) = 7 * 3.25 = 22.75. So 21 is"
                " incorrect result of the expression. So the check fails."
            ),

            left_numbers_check=(
                "As all input numbers [4, 7, 13] are included in the"
                " expression, the only left numbers are the result of the"
                " expression 22.75. So the correct left numbers are [22.75]."
                " The left numbers [21] is incorrect. So the check fails."
            ),
            judgement="Invalid",
        ),
    ),
]

verifier_few_shots = [
    lf.structured.mapping.MappingExample(
        input=VerifierInput(
            input=[1, 2, 11],
            expression_blacklist=[
                "(1 + 11) / 2 = 12 / 2 = 6",
                "1 + 11 - 2 = 12 - 2 = 10",
                "(1 + 2) * 11 = 3 * 11 = 33",
                "(1 + 2) − 11 = 3 - 11 = -8",
                "(1 + 2) / 11 = 3 / 11 = 0.2727",
                "(1 − 11) * 2 = -10 * 2 = -20",
                "(1 − 11) − 2 = -10 - 2 = -12",
            ],
        ),
        output=Verifier(
            expression_drafts=[
                "1 * 2 + 11 = 2 + 11 = 13",
                "1 + 2 + 11 = 3 + 11 = 14",
                "2 * (11 - 1) = 2 * 10 = 20",
                "11 * 2 - 1 = 22 - 1 = 21",
                "2 / 1 * 11 = 2 * 11 = 22",
                "(11 - 1) / 2 = 10 / 2 = 5",
                "2 * (11 + 1) = 2 * 12 = 24",
                "11 - 1 * 2 = 11 - 2 = 9",
                "1 * (2 + 11) = 1 * 13 = 13",
                "1 * (11 - 2) = 1 * 9 = 9",
                "2 * 11 + 1 = 22 + 1 = 23",
                "(1 * 2) / 11 = 2 / 11 = 0.1818",
                "(1 - 11) / 2 = -10 / 2 = -5",
                "11 - (1 + 2) = 11 - 3 = 8",
                "(2 - 1) * 11 = 1 * 11 = 11",
            ],
            expression_that_obtains_24="2 * (11 + 1) = 2 * 12 = 24",
        ),
    ),
]

game24_solution_few_shots = [
    lf.structured.mapping.MappingExample(
        Game24SolutionInput(
            numbers_in_final_expression=[4, 4, 8, 12],
            hints=[
                "4 - 12 = -8",
                "4 * 8 - 8 = 32 - 8 = 24",
            ],
        ),
        Game24Solution(
            figure_out_solution_step_by_step=[
                "1. As hint 1 inludes two input numbers and hint 2 includes the other two input numbers plus the result from hint 1, the plan is to bring hint 1 into hint 2 to form the final expression with the 4 input numbers.",
                "2. As hint 2 leads to the final answer, let's start from that. '4 * 8 - 8 = 32 - 8 = 24' suggests that we can use the numbers [4, 8, 8] to get 24.",
                "3. Then, the hint 1 '4 - 12 = -8' suggests that we can replace a 8, which is a number doesn't come from input in hint 2 expression, with the expression '12 - 4', which are two input numbers.",
                "4. Combining all the information, we can substitute 8 with '12 - 4' in the expression '4 * 8 - 8' to get the final expression '4 * (12 - 4) - 8', where 4 numbers in the expression are all from the input.",
            ],
            final_expression="4 * (12 - 4) - 8",
        ),
    ),
]

## GameOf24 Agent (42 lines of code)

In [None]:
def game_24_solver(example, num_iter=50, num_verification=2):
  next_step_blacklist=[]
  for i in range(num_iter):
    # Generate next step.
    p = lf.query(GenerateInput(input=example, expression_blacklist=next_step_blacklist), Generate, lm=generate_lm, examples=generate_few_shots, default=None)
    if p is None:
      continue
    next_step_blacklist.append(p.expression)
    if isinstance(p.result_of_expression, float):
      continue
    next_step = Step(expression=p.expression, left_numbers=p.remain_numbers_in_input + [p.result_of_expression])

    # Validate the next step.
    validation = lf.query(ValidationInput(input=example, step=next_step), Validation, lm=valid_lm, examples=valid_few_shots, default=None)
    if validation and validation.judgement == "Invalid":
      continue

    expression_blacklist = []
    for _ in range(num_verification):
      random.shuffle(expression_blacklist)
      num_blacklist = random.randint(0, min(15, len(expression_blacklist)))
      trimed_blacklist = expression_blacklist[:num_blacklist]
      # Verify how possible the next step may lead to 24.
      verification = lf.query(VerifierInput(input=next_step.left_numbers, expression_blacklist=trimed_blacklist), Verifier, lm=verifier_lm, examples=verifier_few_shots, default=None)
      if verification is None:
        continue

      # If the next step can leads to 24, output the solution.
      if verification.expression_that_obtains_24:
        final_step = Step(expression=verification.expression_that_obtains_24, left_numbers=[24])
        # Valid the final step.
        validation = lf.query(ValidationInput(input=next_step.left_numbers, step=final_step), Validation, lm=valid_lm, examples=valid_few_shots, default=None)

        if isinstance(validation, Validation) and validation.judgement == "Valid":
          # Generate the final expression from steps.
          ans = lf.query(Game24SolutionInput(numbers_in_final_expression=example, hints=[next_step.expression, final_step.expression]), Game24Solution, lm=answer_lm, examples=game24_solution_few_shots, default=None)
          if isinstance(ans, Game24Solution):
            return ans

      expression_blacklist.extend(verification.expression_drafts)

  return None

# Evaluation

## Load Eval Data

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/princeton-nlp/tree-of-thought-llm/master/src/tot/data/24/24.csv')
examples = [list(map(int, e.split(' '))) for e in df[900:1000]['Puzzles'].tolist()]

## Utils

In [None]:
def answer(output: pg.Dict, example) -> int:
  """Evaluate the numerical answer of the final expression from the LLM agent."""
  # Make sure the expression use the exact 4 given numbers.
  if output.final_expression:
    # Sometimes the expression includes " = 24" which should be removed.
    exp = output.final_expression.split('=')[0]
    numbers = re.findall(r'-?\d+', exp)
    sorted_list = sorted([int(number) for number in numbers])
    if sorted_list != sorted(list(example)):
      return 0
  else:
    return 0

  try:
    return lf.coding.evaluate(exp)
  except Exception:
    return 0

## Run the Eval

In [None]:
# Num of workers working in parallel.
max_workers = 10

success = 0
total = 0
parse_error = 0
failed_to_solve = 0
wrong_solution = 0
failed_to_solve_list = []
for e, ans, error in lf.concurrent_map(game_24_solver, examples, max_workers=max_workers, show_progress=True):
  total += 1

  if ans is None:
    failed_to_solve += 1
    print(f'[FAILED TO SOLVE] {e}')
    failed_to_solve_list.append(e)
    continue

  if answer(ans, e) == 24:
    success += 1
    print(f'[SUCCESS] {e} {ans.final_expression}')
  else:
    wrong_solution += 1
    print(f'[FAILED] {e} {ans}')

  acc = success / total
  parse_error_rate = parse_error / total
  failed_to_solve_rate = failed_to_solve / total
  wrong_solution_rate = wrong_solution / total
  print(f'acc: {acc}, total: {total}, parse_error_rate: {parse_error_rate}, failed_to_solve_rate: {failed_to_solve_rate}, wrong_solution_rate: {wrong_solution_rate}')

acc = success / total
parse_error_rate = parse_error / total
failed_to_solve_rate = failed_to_solve / total
wrong_solution_rate = wrong_solution / total
print(f'acc: {acc}, total: {total}, parse_error_rate: {parse_error_rate}, failed_to_solve_rate: {failed_to_solve_rate}, wrong_solution_rate: {wrong_solution_rate}')

print(f'failed_to_solve_list: {failed_to_solve_list}')