## Testing compilation and goal state extraction for lean

In [1]:
import os
import subprocess
import re

def build_str_with_proof(proof: str):
    # Keep your original imports - they work in VSCode
    return f"""
        import Game.Levels.Multiplication
        import Game.MyNat.Power
        import Game.Metadata
        import Game.MyNat.Addition
        import Game.Levels.Tutorial
        import Game.Levels.Implication
        import Game.Levels.Algorithm
        import Game.Levels.LessOrEqual
        import Game.Levels.Multiplication
        import Game.MyNat.PeanoAxioms
        import Game.MyNat.LE
        import Game.Tactic.Use
        import Game.Levels.AdvAddition

        namespace MyNat
        {proof}
        end MyNat
    """

In [2]:
def lean_compile(current_proof: str, tactics: str, file_name: str, verbose: bool = False):
    # Determine the absolute path to your project root
    # This should be the directory containing your lakefile.lean
    project_root = os.path.abspath("/Users/arnavmehta/Desktop/LeanTutor/NNG4")
    
    if verbose:
        print(f"Using project root: {project_root}")
    
    if isinstance(tactics, str):
        tactics = [tactics]
        
    clean_tactics = []
    for tactic in tactics:
        if tactic.startswith('```lean'):
            tactic = tactic.split('```lean')[1].split('```')[0].strip()
        clean_tactics.append(tactic)
    
    state = build_str_with_proof(current_proof)
    
    # Use a path relative to the project root
    temp_file_path = os.path.join(project_root, "Game", "LevelsClean", "lean_files", f"{file_name}.lean")
    
    try:
        # Ensure the directory exists
        os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
        
        # Write the Lean code to the file
        with open(temp_file_path, "w", encoding="utf-8") as f:
            f.write(state)
        
        if verbose:
            print(f"Created file at: {temp_file_path}")
        
        # Run lake env lean command from the project root
        result = subprocess.run(
            ['lake', 'env', 'lean', temp_file_path],
            capture_output=True,
            text=True,
            cwd=project_root,  # Set working directory to project root
            timeout=30,
            encoding="utf-8"
        )
        
        if verbose:
            print(f"Return code: {result.returncode}")
            print(f"Stdout: {result.stdout}")
            print(f"Stderr: {result.stderr}")
        
        return result
        
    except subprocess.TimeoutExpired:
        return "Compilation timed out"
    except Exception as e:
        return f"Compilation error: {str(e)}"

In [3]:
def parse_unsolved_state(compiler_dump):
    # Parses the current goal states, if there are any, from the stdout dump
    def is_warning(line):
        warnings = ["error", "warning"]
        for warning in warnings:
            if(re.search(warning, line)):
                return True
        return False
    lines = compiler_dump.stdout.split("\n")
    lines = [i for i in lines if not is_warning(i) and len(i) > 2]
    return lines
    return "".join([i + "\n" for i in lines])

In [4]:
# Test with your original proof
a = lean_compile("""
theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by
        induction b with d hd
        -- First prove base case. Simplify LHS a + 0 = a.
        rw [add_zero]
        rw [zero_add]
        -- Prove LHS and RHS are equal, a = a, completing the base case.
        rfl
""", "temp", "test", verbose=True)

arnav = """ 
    theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by
        induction b with d hd
        -- First prove base case. Simplify LHS a + 0 = a.
        rw [add_zero]
        -- Simplify RHS 0 + a = a
        rw [zero_add]
        -- Prove LHS and RHS are equal, a = a, completing the base case.
        rfl
        -- Now prove the inductive step. Rewrite LHS a + succ (d) = succ (a + d)
        rw [add_succ]
        -- Rewrite RHS succ (d) + a = succ (d + a)
        rw [succ_add]
        -- Rewrite LHS succ (a + d) to succ (d + a) using the inductive hypothesis
        rw [hd]
        -- Prove succ LHS and RHS are equal, (d + a) = succ (d + a), completing the proof
        rfl

"""


Using project root: /Users/arnavmehta/Desktop/LeanTutor/NNG4
Created file at: /Users/arnavmehta/Desktop/LeanTutor/NNG4/Game/LevelsClean/lean_files/test.lean
Return code: 1
Stdout: /Users/arnavmehta/Desktop/LeanTutor/NNG4/Game/LevelsClean/lean_files/test.lean:18:59: error: unsolved goals
case succ
a d : ℕ
hd : a + d = d + a
⊢ a + succ d = succ d + a

Stderr: 


In [5]:
print("Compilation result:", a)
print("Parsed result:", parse_unsolved_state(a))

Compilation result: CompletedProcess(args=['lake', 'env', 'lean', '/Users/arnavmehta/Desktop/LeanTutor/NNG4/Game/LevelsClean/lean_files/test.lean'], returncode=1, stdout='/Users/arnavmehta/Desktop/LeanTutor/NNG4/Game/LevelsClean/lean_files/test.lean:18:59: error: unsolved goals\ncase succ\na d : ℕ\nhd : a + d = d + a\n⊢ a + succ d = succ d + a\n', stderr='')
Parsed result: ['case succ', 'a d : ℕ', 'hd : a + d = d + a', '⊢ a + succ d = succ d + a']


In [6]:
class induction_block:
    def __init__(self, variable: str):
        self.variable = variable
        self.case_zero = ""
        self.case_succ = ""
        self.induction_hyp = ""
    
    # should be callled right after we write the induction step down
    def parse_induction_state(self, lean_code: str): # take in lean code and populate this object with the induction "state"
        lines = parse_unsolved_state(lean_compile(lean_code, "temp", "test", verbose=False))

        case_zero_i = lines.index("case zero")
        case_succ_i = lines.index("case succ")
        if case_zero_i == -1 or case_succ_i == -1:
            raise Exception("Error parsing induction state")
        # self.case_zero = lines[case_zero_i + 1: case_succ_i]
        # self.case_succ = lines[case_succ_i + 1:]
        if case_zero_i == case_succ_i-2:
            self.case_zero = lines[case_zero_i + 1]
        self.case_zero = lines[case_zero_i + 2]
        self.case_succ = lines[case_succ_i + 3]
        self.induction_hyp = lines[case_succ_i + 2]
        

        # self.induction_hyp = [line for line in lines if "hd" in line][0]

    
    def to_lean(self):
        return f"induction {self.variable} with d hd"

    def __str__(self):
        return f"{self.variable} \n {self.case_zero} \n {self.case_succ} \n {self.induction_hyp}"
        


## TESTING INDUCTION BLOCK

In [27]:
testing_induction = induction_block("b")
testing_induction.parse_compiler("""
theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by
    induction b with d hd

""")

print(testing_induction)
print(testing_induction.to_lean())

b 
 ⊢ a + 0 = 0 + a 
 ⊢ a + succ d = succ d + a 
 hd : a + d = d + a
induction b with d hd


In [57]:
class global_location:
    def __init__(self, theorem_statement_lean: str):
        self.theorem_statement = theorem_statement_lean[:-6]
        self.current_state = parse_unsolved_state(lean_compile(theorem_statement_lean, "temp", "test", verbose=False))[1]

    def update_current_state(self, lean_code: str):
        lines = parse_unsolved_state(lean_compile(lean_code, "temp", "test", verbose=False))
        print("GLOBAL LOCATION ", lines)
        if len(lines) == 0:
            self.current_state = "Proof complete"
        else:
            self.current_state = lines[1]
    
    def __str__(self):
        return f"{self.theorem_statement} \n {self.current_state}"

## TESTING GLOBAL STATE

In [8]:
g_state = global_location("theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by")
print(g_state)

theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a 
 ⊢ a + b = b + a


In [9]:
class base_case_location:
    def __init__(self):
        self.current_state = ""
    
    def update_current_state(self, lean_code: str):
        lines = parse_unsolved_state(lean_compile(lean_code, "temp", "test", verbose=False))
        print("BASE CASE COMPILER ", lines)
        if "case zero" not in lines:
            self.current_state = "Proof complete"
        else:
            if "case succ" in lines:
                case_zero_i = lines.index("case zero")
                case_succ_i = lines.index("case succ")

                if case_zero_i == -1 or case_succ_i == -1:
                    raise Exception("Error parsing induction state")
                if case_zero_i == case_succ_i - 2:
                    self.current_state = lines[case_zero_i + 1]
                else:
                    self.current_state = lines[case_zero_i + 2]
            else:
                case_zero_i = lines.index("case zero")
                if len(lines) >= case_zero_i + 2:
                    self.current_state = lines[case_zero_i + 1]
                else:
                    self.current_state = lines[lines.index("case zero") + 2]
        
    

In [10]:
class inductive_step_location:
    def __init__(self):
        self.current_state = ""
    
    def update_current_state(self, lean_code: str):
        lines = parse_unsolved_state(lean_compile(lean_code, "temp", "test", verbose=False))

        print("INDUCTIVE STEP COMPILER ", lines)
        if "case succ" not in lines:
            self.current_state = "Proof complete"
        else:
            case_succ_i = lines.index("case succ")
            self.current_state = lines[case_succ_i + 3]

In [11]:
class rw_block:
   def __init__(self, theorem: str):
       self.backarrow = False
       self.theorem = theorem
       self.inputs = []  # List to store explicit inputs to the theorem
       self.nth = None   # Store the nth occurrence to rewrite, if specified
  
   def to_lean(self):
       # Format the theorem with inputs if any
       if self.inputs:
           # Join inputs with spaces
           inputs_str = " ".join(self.inputs)
           theorem_with_inputs = f"{self.theorem} {inputs_str}"
       else:
           theorem_with_inputs = self.theorem
       
       # Determine if we're using regular rewrite or nth_rewrite
       if self.nth is not None:
           # nth_rewrite with backarrow if needed
           if self.backarrow:
               return f"nth_rewrite {self.nth} [← {theorem_with_inputs}]"
           return f"nth_rewrite {self.nth} [{theorem_with_inputs}]"
       else:
           # Regular rewrite with backarrow if needed
           if self.backarrow:
               return f"rw [← {theorem_with_inputs}]"
           return f"rw [{theorem_with_inputs}]"
   
   def add_input(self, input_value: str):
       """Add an explicit input to the theorem"""
       self.inputs.append(input_value)
   
   def set_nth_rewrite(self, n: int):
       """Set to use nth_rewrite instead of regular rewrite"""
       self.nth = n
       return self  # Return self to allow method chaining

## THE DECISIONS WE NEED TO MAKE
NL:
theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by
        -- Induct on b, with d = 0 as the base case and the inductive hypothesis succ (a) + d = succ (a + d). There are now two proof goals, prove base case: succ (a) + 0 = succ (a + 0) and inductive step: succ (a) + succ (d) = succ (a + succ (d))

        -- First prove base case. Simplify LHS a + 0 = a.

        -- Simplify RHS 0 + a = a

        -- Prove LHS and RHS are equal, a = a, completing the base case.

        -- Now prove the inductive step. Rewrite LHS a + succ (d) = succ (a + d)

        -- Rewrite RHS succ (d) + a = succ (d + a)

        -- Rewrite LHS succ (a + d) to succ (d + a) using the inductive hypothesis

        -- Prove succ LHS and RHS are equal, (d + a) = succ (d + a), completing the proof


In [12]:
from openai import OpenAI
import os

def get_completion(prompt):
    client = OpenAI(api_key= '')

    return client.chat.completions.create(
        model="o1",
        messages=[{"role": "user", "content": prompt}],
    ).choices[0].message.content


def get_structured_completion(prompt, structure):
    client = OpenAI(api_key= '')

    completion = client.beta.chat.completions.parse(
    model="gpt-4",
    messages=[
        {"role": "system", "content": "Extract the event information."},
        {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
    ],
    response_format=structure,
)

    return completion.choices[0].message.parsed

In [None]:
base_node = global_location("theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by")
print(base_node)

In [58]:
# THIS ONE WORKS
def autoformalize_induction_proof_whole_addition(nl_statements, theorem):
    theorem_saver = theorem
    """
    Autoformalize an induction proof based on a list of natural language statements,
    using OpenAI to determine the appropriate blocks and theorems.
    
    Args:
        nl_statements: List of natural language statements describing the proof steps
        theorem: The theorem statement to prove
        
    Returns:
        bool: True if the proof is complete, False otherwise
    """
    # Initialize the base node with the theorem statement
    print(theorem)
    base_node = global_location(theorem)
    print(f"Starting proof of: {base_node.theorem_statement}")
    
    # Extract theorem name for output
    theorem_name = theorem.split("theorem ")[1].split(" (")[0] if "theorem " in theorem else "unnamed_theorem"
    
    # Current proof state
    current_proof = f"{base_node.theorem_statement} := by\n"
    
    # Keep track of all lean statements for final output
    lean_statements = []
    
    # Define theorems database with detailed metadata
    theorems_db = [
        {
            "name": "add_succ",
            "description": "For any natural numbers n and m: n + succ(m) = succ(n + m)",
            "pattern": "left side has a sum where the second term is a successor",
            "example": "a + succ(b) → succ(a + b)",
            "left_pattern": "X + succ(Y)",
            "right_pattern": "succ(X + Y)"
        },
        {
            "name": "add_zero",
            "description": "For any natural number n: n + 0 = n",
            "pattern": "left side has a sum where the second term is zero",
            "example": "a + 0 → a",
            "left_pattern": "X + 0",
            "right_pattern": "X"
        },
        {
            "name": "succ_add",
            "description": "For any natural numbers n and m: succ(n) + m = succ(n + m)",
            "pattern": "left side has a sum where the first term is a successor",
            "example": "succ(a) + b → succ(a + b)",
            "left_pattern": "succ(X) + Y",
            "right_pattern": "succ(X + Y)"
        },
        {
            "name": "zero_add",
            "description": "For any natural number n: 0 + n = n",
            "pattern": "left side has a sum where the first term is zero",
            "example": "0 + a → a",
            "left_pattern": "0 + X",
            "right_pattern": "X"
        },
        {
            "name": "add_assoc",
            "description": "For any natural numbers a, b, c: (a + b) + c = a + (b + c)",
            "pattern": "need to rearrange parentheses in a sum of three terms",
            "example": "(a + b) + c → a + (b + c)",
            "left_pattern": "(X + Y) + Z",
            "right_pattern": "X + (Y + Z)"
        },
        {
            "name": "add_comm",
            "description": "For any natural numbers a, b: a + b = b + a",
            "pattern": "need to swap the order of terms in addition",
            "example": "a + b → b + a",
            "left_pattern": "X + Y",
            "right_pattern": "Y + X"
        },
        {
            "name": "rfl",
            "description": "Reflexivity: proves that a = a",
            "pattern": "both sides of the equation are identical",
            "example": "a = a → proved"
        }
    ]
    
    # Process the first statement to determine if we need an induction block
    print(f"\nProcessing statement: {nl_statements[0]}")
    block_type_prompt = f"""
    Here is our current proof state:
    {base_node.current_state}

    You have an option between creating two blocks:
    induction block: where you will input the induction variable
    rewrite block: where you will input the theorem to rewrite with

    What block would you like to create?
    IMPORTANT: Just output the single word "induction" or "rewrite" with no additional text, quotes, or punctuation.
    
    Here is the statement you have to represent using the block:
    {nl_statements[0]}
    """
    
    block_type = get_completion(block_type_prompt).strip().lower()
    print(f"OpenAI suggests block type: {block_type}")
    
    # Initialize induction_block_obj to None
    induction_block_obj = None
    
    if "induction" in block_type:  # More robust check
        # Ask OpenAI for the induction variable
        variable_prompt = f"""
        Here is our current proof state:
        {base_node.current_state}
        We have chosen to create an induction block.

        Please choose the induction variable.
        IMPORTANT: Just output the variable name as a single character or identifier with no additional text, quotes, or punctuation.

        Here is the statement you have to represent using the block:
        {nl_statements[0]}
        """
        
        induction_var = get_completion(variable_prompt).strip()
        print(f"OpenAI suggests induction variable: {induction_var}")
        
        # Create the induction block
        induction_block_obj = induction_block(induction_var)
        lean_statement = induction_block_obj.to_lean()
        current_proof += f"   {lean_statement}\n"
        lean_statements.append(lean_statement)
        
        # Parse the induction state
        induction_block_obj.parse_induction_state(current_proof)
        
        # Update the base node state
        base_node.update_current_state(current_proof)
        
        # Create base case location
        base_case_block = base_case_location()
        base_case_block.update_current_state(current_proof)
        
        # Process statements for the base case
        base_case_index = 1  # Start from the second statement
        while base_case_index < len(nl_statements):
            statement = nl_statements[base_case_index]
            print(f"\nProcessing base case statement: {statement}")
            
            # Add the induction hypothesis to the theorems for this context
            context_theorems = theorems_db.copy()
            
            # Predict the next state based on the natural language statement
            next_state = predict_next_state(statement, base_case_block.current_state, "base case")
            print(f"Predicted next state: {next_state}")
            
            # Try up to 3 times to find a theorem that works
            max_attempts = 3
            for attempt in range(max_attempts):
                print(f"Attempt {attempt+1}/{max_attempts} to find a theorem")
                
                # Get theorem with direction
                theorem_info = determine_theorem_and_direction(
                    base_case_block.current_state, 
                    next_state, 
                    context_theorems, 
                    statement, 
                    attempt
                )
                theorem = theorem_info["theorem"]
                use_backarrow = theorem_info["backarrow"]
                inputs = theorem_info.get("inputs", [])
                nth_rewrite = theorem_info.get("nth_rewrite", None)
                
                print(f"OpenAI suggests theorem for base case: {theorem} (backarrow: {use_backarrow}, inputs: {inputs}, nth: {nth_rewrite})")
                
                # Handle rfl specially
                if "rfl" in theorem.lower():
                    lean_statement = "rfl"
                    current_proof += f"   {lean_statement}\n"
                    lean_statements.append(lean_statement)
                    break  # rfl always works if the states are equal
                else:
                    # Create the rewrite block with backarrow if needed
                    rw_block_obj = rw_block(theorem)
                    rw_block_obj.backarrow = use_backarrow
                    
                    # Add any inputs if specified
                    for input_value in inputs:
                        rw_block_obj.add_input(input_value)
                    
                    # Use nth_rewrite if specified
                    if nth_rewrite is not None:
                        rw_block_obj.set_nth_rewrite(nth_rewrite)
                    
                    # Get the Lean statement
                    lean_statement = rw_block_obj.to_lean()
                    
                    # Create a temporary proof to test if this theorem works
                    temp_proof = current_proof + f"   {lean_statement}\n"
                    
                    # Create a temporary block to test the state
                    temp_block = base_case_location()
                    temp_block.update_current_state(temp_proof)
                    
                    # Check if the theorem application produces the expected state
                    if are_states_equivalent(temp_block.current_state, next_state):
                        print(f"Theorem {theorem} produces the expected state!")
                        current_proof = temp_proof
                        lean_statements.append(lean_statement)
                        break
                    else:
                        print(f"Theorem {theorem} does not produce the expected state.")
                        print(f"Expected: {next_state}")
                        print(f"Actual: {temp_block.current_state}")
                        
                        # Try with the opposite backarrow setting
                        rw_block_obj.backarrow = not use_backarrow
                        lean_statement = rw_block_obj.to_lean()
                        temp_proof = current_proof + f"   {lean_statement}\n"
                        temp_block = base_case_location()
                        temp_block.update_current_state(temp_proof)
                        
                        if are_states_equivalent(temp_block.current_state, next_state):
                            print(f"Theorem {theorem} with opposite backarrow produces the expected state!")
                            current_proof = temp_proof
                            lean_statements.append(lean_statement)
                            break
                        else:
                            print(f"Theorem {theorem} with opposite backarrow also does not work.")
                            # Continue to the next attempt
                
                # If we've tried all attempts and none worked, use the last attempt
                if attempt == max_attempts - 1:
                    print("All attempts failed. Using the last attempt.")
                    rw_block_obj = rw_block(theorem)
                    rw_block_obj.backarrow = use_backarrow
                    
                    # Add any inputs if specified
                    for input_value in inputs:
                        rw_block_obj.add_input(input_value)
                    
                    # Use nth_rewrite if specified
                    if nth_rewrite is not None:
                        rw_block_obj.set_nth_rewrite(nth_rewrite)
                    
                    lean_statement = rw_block_obj.to_lean()
                    current_proof += f"   {lean_statement}\n"
                    lean_statements.append(lean_statement)
            
            # Update the base case state
            base_case_block.update_current_state(current_proof)
            
            # Check if the base case is complete
            if base_case_block.current_state == "Proof complete" or "Goal accomplished" in base_case_block.current_state:
                print("Base case is complete! Moving to inductive step.")
                base_case_index += 1
                break
            
            base_case_index += 1
        
        # Create inductive step location
        induction_case_block = inductive_step_location()
        induction_case_block.update_current_state(current_proof)
        
        # Process statements for the inductive step
        for i in range(base_case_index, len(nl_statements)):
            statement = nl_statements[i]
            print(f"\nProcessing inductive step statement: {statement}")
            
            # Add the induction hypothesis to the theorems for this context
            context_theorems = theorems_db.copy()
            if induction_block_obj:
                context_theorems.append({
                    "name": "hd",
                    "description": f"Induction hypothesis: {induction_block_obj.induction_hyp}",
                    "pattern": "applying the induction hypothesis",
                    "example": f"Use when you need to apply {induction_block_obj.induction_hyp}",
                    "priority": "high",
                    "needs_backarrow": False
                })
            
            # Predict the next state based on the natural language statement
            next_state = predict_next_state(statement, induction_case_block.current_state, "inductive step")
            print(f"Predicted next state: {next_state}")
            
            # Try up to 3 times to find a theorem that works
            max_attempts = 3
            for attempt in range(max_attempts):
                print(f"Attempt {attempt+1}/{max_attempts} to find a theorem")
                
                # Get theorem with direction
                theorem_info = determine_theorem_and_direction(
                    induction_case_block.current_state, 
                    next_state, 
                    context_theorems, 
                    statement,
                    attempt
                )
                theorem = theorem_info["theorem"]
                use_backarrow = theorem_info["backarrow"]
                inputs = theorem_info.get("inputs", [])
                nth_rewrite = theorem_info.get("nth_rewrite", None)
                
                print(f"OpenAI suggests theorem for inductive step: {theorem} (backarrow: {use_backarrow}, inputs: {inputs}, nth: {nth_rewrite})")
                
                # Handle rfl specially
                if "rfl" in theorem.lower():
                    lean_statement = "rfl"
                    current_proof += f"   {lean_statement}\n"
                    lean_statements.append(lean_statement)
                    break  # rfl always works if the states are equal
                else:
                    # Create the rewrite block with backarrow if needed
                    rw_block_obj = rw_block(theorem)
                    rw_block_obj.backarrow = use_backarrow
                    
                    # Add any inputs if specified
                    for input_value in inputs:
                        rw_block_obj.add_input(input_value)
                    
                    # Use nth_rewrite if specified
                    if nth_rewrite is not None:
                        rw_block_obj.set_nth_rewrite(nth_rewrite)
                    
                    # Get the Lean statement
                    lean_statement = rw_block_obj.to_lean()
                    
                    # Create a temporary proof to test if this theorem works
                    temp_proof = current_proof + f"   {lean_statement}\n"
                    
                    # Create a temporary block to test the state
                    temp_block = inductive_step_location()
                    temp_block.update_current_state(temp_proof)
                    
                    # Check if the theorem application produces the expected state
                    if are_states_equivalent(temp_block.current_state, next_state):
                        print(f"Theorem {theorem} produces the expected state!")
                        current_proof = temp_proof
                        lean_statements.append(lean_statement)
                        break
                    else:
                        print(f"Theorem {theorem} does not produce the expected state.")
                        print(f"Expected: {next_state}")
                        print(f"Actual: {temp_block.current_state}")
                        
                        # Try with the opposite backarrow setting
                        rw_block_obj.backarrow = not use_backarrow
                        lean_statement = rw_block_obj.to_lean()
                        temp_proof = current_proof + f"   {lean_statement}\n"
                        temp_block = inductive_step_location()
                        temp_block.update_current_state(temp_proof)
                        
                        if are_states_equivalent(temp_block.current_state, next_state):
                            print(f"Theorem {theorem} with opposite backarrow produces the expected state!")
                            current_proof = temp_proof
                            lean_statements.append(lean_statement)
                            break
                        else:
                            print(f"Theorem {theorem} with opposite backarrow also does not work.")
                            # Continue to the next attempt
                
                # If we've tried all attempts and none worked, use the last attempt
                if attempt == max_attempts - 1:
                    print("All attempts failed. Using the last attempt.")
                    rw_block_obj = rw_block(theorem)
                    rw_block_obj.backarrow = use_backarrow
                    
                    # Add any inputs if specified
                    for input_value in inputs:
                        rw_block_obj.add_input(input_value)
                    
                    # Use nth_rewrite if specified
                    if nth_rewrite is not None:
                        rw_block_obj.set_nth_rewrite(nth_rewrite)
                    
                    lean_statement = rw_block_obj.to_lean()
                    current_proof += f"   {lean_statement}\n"
                    lean_statements.append(lean_statement)
            
            # Update the inductive step state
            induction_case_block.update_current_state(current_proof)
            
            # Also update the base node state
            base_node.update_current_state(current_proof)
            
            # Check if the inductive step is complete
            if induction_case_block.current_state == "Proof complete" or "Goal accomplished" in induction_case_block.current_state:
                print("Inductive step is complete!")
                # Check if the entire proof is complete
                if base_node.current_state == "Proof complete" or "Goal accomplished" in base_node.current_state:
                    print("\nEntire proof is complete!")
                    break
    
    else:  # block_type == "rewrite" or fallback
        # This is a direct proof using rewrite blocks
        for i, statement in enumerate(nl_statements):
            print(f"\nProcessing statement: {statement}")
            
            # Predict the next state based on the natural language statement
            next_state = predict_next_state(statement, base_node.current_state, "direct proof")
            print(f"Predicted next state: {next_state}")
            
            # Try up to 3 times to find a theorem that works
            max_attempts = 3
            for attempt in range(max_attempts):
                print(f"Attempt {attempt+1}/{max_attempts} to find a theorem")
                
                # Get theorem with direction
                theorem_info = determine_theorem_and_direction(
                    base_node.current_state, 
                    next_state, 
                    theorems_db, 
                    statement,
                    attempt
                )
                theorem = theorem_info["theorem"]
                use_backarrow = theorem_info["backarrow"]
                inputs = theorem_info.get("inputs", [])
                nth_rewrite = theorem_info.get("nth_rewrite", None)
                
                print(f"OpenAI suggests theorem: {theorem} (backarrow: {use_backarrow}, inputs: {inputs}, nth: {nth_rewrite})")
                
                # Handle rfl specially
                if "rfl" in theorem.lower():
                    lean_statement = "rfl"
                    current_proof += f"   {lean_statement}\n"
                    lean_statements.append(lean_statement)
                    break  # rfl always works if the states are equal
                else:
                    # Create the rewrite block with backarrow if needed
                    rw_block_obj = rw_block(theorem)
                    rw_block_obj.backarrow = use_backarrow
                    
                    # Add any inputs if specified
                    for input_value in inputs:
                        rw_block_obj.add_input(input_value)
                    
                    # Use nth_rewrite if specified
                    if nth_rewrite is not None:
                        rw_block_obj.set_nth_rewrite(nth_rewrite)
                    
                    # Get the Lean statement
                    lean_statement = rw_block_obj.to_lean()
                    
                    # Create a temporary proof to test if this theorem works
                    temp_proof = current_proof + f"   {lean_statement}\n"

                    print("TEMP PROOF ", temp_proof)
                    
                    # Create a temporary block to test the state
                    print("THEOREM ", theorem)
                    temp_block = global_location(theorem_saver)
                    temp_block.update_current_state(temp_proof)
                    
                    # Check if the theorem application produces the expected state
                    if are_states_equivalent(temp_block.current_state, next_state):
                        print(f"Theorem {theorem} produces the expected state!")
                        current_proof = temp_proof
                        lean_statements.append(lean_statement)
                        break
                    else:
                        print(f"Theorem {theorem} does not produce the expected state.")
                        print(f"Expected: {next_state}")
                        print(f"Actual: {temp_block.current_state}")
                        
                        # Try with the opposite backarrow setting
                        rw_block_obj.backarrow = not use_backarrow
                        lean_statement = rw_block_obj.to_lean()
                        temp_proof = current_proof + f"   {lean_statement}\n"
                        temp_block = global_location(theorem_saver)
                        temp_block.update_current_state(temp_proof)
                        
                        if are_states_equivalent(temp_block.current_state, next_state):
                            print(f"Theorem {theorem} with opposite backarrow produces the expected state!")
                            current_proof = temp_proof
                            lean_statements.append(lean_statement)
                            break
                        else:
                            print(f"Theorem {theorem} with opposite backarrow also does not work.")
                            # Continue to the next attempt
                
                # If we've tried all attempts and none worked, use the last attempt
                if attempt == max_attempts - 1:
                    print("All attempts failed. Using the last attempt.")
                    rw_block_obj = rw_block(theorem)
                    rw_block_obj.backarrow = use_backarrow
                    
                    # Add any inputs if specified
                    for input_value in inputs:
                        rw_block_obj.add_input(input_value)
                    
                    # Use nth_rewrite if specified
                    if nth_rewrite is not None:
                        rw_block_obj.set_nth_rewrite(nth_rewrite)
                    
                    lean_statement = rw_block_obj.to_lean()
                    current_proof += f"   {lean_statement}\n"
                    lean_statements.append(lean_statement)
            
            # Update the base node state
            base_node.update_current_state(current_proof)
            
            # Check if we've completed the proof
            if base_node.current_state == "Proof complete" or "Goal accomplished" in base_node.current_state:
                print("\nProof is complete!")
                break
    
    # Check final state of the proof
    base_node.update_current_state(current_proof)
    is_complete = base_node.current_state == "Proof complete" or "Goal accomplished" in base_node.current_state
    
    print(f"\nFinal proof state: {base_node.current_state}")
    print(f"Proof complete: {is_complete}")
    
    # Format and print the complete Lean proof
    formatted_proof = format_lean_proof(lean_statements)
    print("\n=== COMPLETE LEAN PROOF ===")
    print(formatted_proof)
    
    return is_complete, lean_statements

def predict_next_state(statement, current_state, context_type=""):
    print("CURRENT STATE", current_state)
    """
    Predict the next state based on the natural language statement.
    
    Args:
        statement: Natural language statement describing the proof step
        current_state: Current state of the proof
        context_type: Type of proof context (base case, inductive step, etc.)
        
    Returns:
        str: Predicted next state
    """
    prompt = f"""
    Here is the current proof state in the {context_type}:
    {current_state}
    
    I need to predict what the state will be after applying this step:
    "{statement}"
    
    INSTRUCTIONS:
    1. Analyze the current state and the natural language description
    2. Predict what the mathematical expression will look like after this step
    3. Output ONLY the predicted state with no additional text or explanation
    4. Make sure to include brackets when necessary
    """
    
    response = get_completion(prompt).strip()
    
    # Clean up any potential formatting issues
    predicted_state = response.strip('"').strip("'").strip()
    
    return predicted_state

def are_states_equivalent(state1, state2):
    print("THE STATES ARE ", state1, state2)
    """
    Check if two states are mathematically equivalent, accounting for formatting differences.
    
    Args:
        state1: First state to compare
        state2: Second state to compare
        
    Returns:
        bool: True if the states are equivalent, False otherwise
    """
    # First try direct comparison after normalization
    if normalize_state(state1) == normalize_state(state2):
        return True
    
    # If that fails, try a more sophisticated comparison using GPT
    return check_state_equivalence(state1, state2)

def normalize_state(state):
    """
    Normalize a state string for comparison by removing formatting differences.
    
    Args:
        state: State string to normalize
        
    Returns:
        str: Normalized state string
    """
    # Extract the goal part if it exists
    if "⊢" in state:
        state = state.split("⊢")[1].strip()
    
    # Remove whitespace, parentheses that don't affect meaning, etc.
    normalized = state.strip().replace(" ", "").replace("\n", "").replace("\t", "")
    
    # Remove quotes and other formatting
    normalized = normalized.strip('"').strip("'")
    
    # Handle common variations in notation
    normalized = normalized.replace("(", "").replace(")", "")  # Remove parentheses
    normalized = normalized.replace("succ", "S")  # Standardize successor notation
    
    # Handle special cases
    if "Goal accomplished" in state or "Proof complete" in state:
        return "COMPLETE"
    
    return normalized

def check_state_equivalence(state1, state2):
    print("THE STATES ARE ", state1, state2)
    """
    Use GPT to check if two states are mathematically equivalent.
    
    Args:
        state1: First state to compare
        state2: Second state to compare
        
    Returns:
        bool: True if the states are equivalent, False otherwise
    """
    prompt = f"""
    I need to determine if these two mathematical expressions are exactly the same:
    
    Expression 1: {state1}
    
    Expression 2: {state2}
    
    INSTRUCTIONS:
    1. Analyze both expressions carefully
    2. Determine if they represent exactly the same mathematical statement with just variables changed, accounting for:
       - Different variable names (these are fine as long as they represent the same concept)
       - Different formattings (like whitespace, parentheses, etc.)
       - But if anything else is different even a + 0 vs a, they are not equivalent
       - Even if LHS and RHS are switched, they are not equivalent
    3. Output ONLY "yes" if they are equivalent or "no" if they are not
    4. We are in a peano arithmetic context, so any arithmetic simplication should not be assumed
    """
    
    response = get_completion(prompt).strip().lower()
    
    return "yes" in response or "equivalent" in response or "same" in response

def determine_theorem_and_direction(current_state, next_state, theorems, statement, attempt=1):
    """
    Determine which theorem and direction to use to connect the current state to the next state.
    Also determines if explicit inputs are needed and if nth_rewrite should be used.
    
    Args:
        current_state: Current state of the proof
        next_state: Predicted next state of the proof
        theorems: List of theorems with metadata
        statement: Natural language statement describing the proof step
        attempt: Current attempt number (1, 2, or 3)
        
    Returns:
        dict: Contains the selected theorem, whether to use a backarrow, any inputs, and nth_rewrite info
    """
    # Format theorems for the prompt
    theorem_descriptions = []
    for t in theorems:
        desc = f"- {t['name']}: {t['description']} (Example: {t['example']})"
        if "priority" in t and t["priority"] == "high":
            desc += " [PREFERRED CHOICE when applicable]"
        theorem_descriptions.append(desc)
    
    # Add emphasis on explicit inputs for later attempts
    explicit_inputs_instruction = ""
    if attempt >= 2:
        explicit_inputs_instruction = f"""
        IMPORTANT: This is attempt #{attempt}. Previous attempts failed.
        You MUST provide EXPLICIT INPUTS for the theorem when applicable.
        For example, instead of just using 'add_comm', specify 'add_comm a b' with the exact variables.
        Be very specific about which variables to use based on the current state.
        """
    
    # Create a prompt that asks for the theorem and direction
    prompt = f"""
    I need to determine which theorem and direction to use to transform:
    
    Current state: {current_state}
    
    Into:
    
    Next state: {next_state}
    
    The natural language description of this step is:
    "{statement}"
    
    Available theorems:
    {chr(10).join(theorem_descriptions)}
    {explicit_inputs_instruction}
    
    IMPORTANT RULES:
    1. If the induction hypothesis (hd) can be applied, it should be preferred over other theorems
    2. The "rfl" theorem NEVER needs a backarrow - it's always applied directly
    3. Only use a backarrow (←) when you need to apply a theorem in reverse direction
    4. Be very careful with similar theorems:
       - add_succ: applies to "X + succ(Y)" → "succ(X + Y)"
       - succ_add: applies to "succ(X) + Y" → "succ(X + Y)"
       - add_zero: applies to "X + 0" → "X"
       - zero_add: applies to "0 + X" → "X"
    5. Some theorems need explicit inputs (0, 1, or 2 inputs). Determine if any are needed.
    6. If multiple instances of a pattern exist, determine if we need nth_rewrite and which occurrence to target.
    
    INSTRUCTIONS:
    1. Analyze the current state and next state
    2. Determine which theorem can transform one into the other
    3. Decide if the theorem should be applied forward or backward (with a backarrow ←)
    4. Determine if explicit inputs are needed for the theorem
    5. Determine if we need to target a specific occurrence with nth_rewrite
    6. Output your answer in this exact format:
       THEOREM: [theorem name]
       DIRECTION: [forward/backward]
       INPUTS: [none/input1/input1 input2] (specify the actual values when its ambiguous)
       NTH_REWRITE: [none/occurrence number] (specify which occurrence to target, if needed)
       EXPLANATION: [brief explanation of why this theorem and direction]
    """
    
    response = get_completion(prompt).strip()
    
    # Parse the response to extract theorem and direction
    theorem = "rfl"  # Default fallback
    use_backarrow = False
    inputs = []
    nth_rewrite = None
    explanation = ""
    
    # Look for the theorem line
    for line in response.split('\n'):
        if "THEOREM:" in line.upper():
            theorem_part = line.split(':', 1)[1].strip()
            # Clean up any potential formatting issues
            theorem = theorem_part.strip('"').strip("'").strip(".")
        elif "DIRECTION:" in line.upper():
            direction = line.split(':', 1)[1].strip().lower()
            use_backarrow = "back" in direction or "reverse" in direction or "←" in direction
        elif "INPUTS:" in line.upper():
            inputs_part = line.split(':', 1)[1].strip().lower()
            if inputs_part != "none" and inputs_part:
                inputs = inputs_part.split()
        elif "NTH_REWRITE:" in line.upper():
            nth_part = line.split(':', 1)[1].strip().lower()
            if nth_part != "none" and nth_part:
                try:
                    nth_rewrite = int(nth_part)
                except ValueError:
                    # If it's not a valid integer, ignore it
                    pass
        elif "EXPLANATION:" in line.upper():
            explanation = line.split(':', 1)[1].strip()
    
    # Special case for rfl
    if theorem.lower() == "rfl":
        use_backarrow = False
        inputs = []
        nth_rewrite = None
    
    # For attempts 2 and 3, if no inputs were provided but we're using a theorem that might need them,
    # make an additional request specifically for inputs
    if attempt >= 2 and not inputs and theorem.lower() not in ["rfl"]:
        inputs_prompt = f"""
        We need to provide explicit inputs for the theorem '{theorem}' to transform:
        
        Current state: {current_state}
        
        Into:
        
        Next state: {next_state}
        
        Based on the current state, what specific variables should be used as inputs to this theorem?
        IMPORTANT: Just list the variables separated by spaces, with no additional text.
        For example: "a b" or "n m" or "x y z"
        """
        
        inputs_response = get_completion(inputs_prompt).strip()
        if inputs_response:
            inputs = inputs_response.split()
    
    print(f"Explanation: {explanation}")
    
    return {
        "theorem": theorem,
        "backarrow": use_backarrow,
        "inputs": inputs,
        "nth_rewrite": nth_rewrite
    }

def format_lean_proof(lean_statements):
    """
    Format a list of Lean statements into a properly indented proof.
    
    Args:
        lean_statements: List of Lean statements
        
    Returns:
        str: Formatted Lean proof
    """
    result = []
    indent_level = 0
    
    for statement in lean_statements:
        # Handle indentation for induction cases
        if statement == "·":
            result.append(f"{'  ' * indent_level}{statement}")
            indent_level += 1
            continue
            
        # Add appropriate indentation
        if statement.startswith("theorem"):
            result.append(statement)
            indent_level = 1
        else:
            result.append(f"{'  ' * indent_level}{statement}")
    
    return "\n".join(result)


NL_Complete_succ_add = [
    "-- Initiate induction on b, the base case (b=0) succ(a) + 0 = succ(a + 0)",
 # induction b with n hn
"-- We start by proving the base case using the fact that succ(a+b) = a + succ(b) and setting b = 0 and substituting on the RHS",
  #rw [← add_succ]
"-- We use the fact that c + 0 = c ∀ c ∈ ℕ and set c := succ(a) to get succ(a) = a + succ(0)",
  #rw [add_zero]
"-- Now use the fact that a + succ(b) = succ(a+b) and set a := a and b := 0 to get a + succ(0) = succ(a+0) on the RHS",
  #rw [add_succ]
"-- Now on the RHS we use the fact that c + 0 = c ∀ c ∈ ℕ and set c := a to get succ(a) = succ(a)",
  #rw [add_zero]
"-- since succ(a) = succ(a) we are done with the base case",
  #rfl
"-- Now to prove the induction case, we use the fact that a + succ(b) = succ(a + b) ∀ a, b ∈ ℕ and set a := succ(a) and b := n giving us succ(succ(a) + n) = succ(a+succ(n))",
  #rw [add_succ]
"-- We again use the fact that a + succ(b) = succ(a + b) ∀ a, b ∈ ℕ on the right hand side and set a := a and b := n giving us succ(succ(a) + n) = succ(succ(a+n))",
  #rw [add_succ]
"-- Rewrite the right hand side using the hypothesis giving us succ(succ(a) + n) = succ(succ(a) + n)",
  #rw [← hn]
"-- Hence we are done."
]

NL_Complete_succ_add_1 = [
    #theorem succ_add_logical_deviation_1 (a b : ℕ) : succ a + b = succ (a + b) := by
"-- Initiate induction on b, the base case (b=0) succ(a) + 0 = succ(a + 0)",
#induction b with n hn
"-- We start by proving the base case using the fact that c + 0 = c ∀ c ∈ ℕ and setting c := succ(a) giving us succ(a) = succ(a + 0)",
#rw [add_zero]
"-- Now we can set c = a and use c + 0 = c ∀ c ∈ ℕ again to get succ(a) = succ(a)",
#rw [add_zero]
"-- Since we have succ(a) = succ(a) we are done with the base case",
#rfl
"-- Now to prove the induction case, we use the fact that a + succ(b) = succ(a + b) ∀ a, b ∈ ℕ and set a := succ(a) and b := n giving us succ(succ(a) + n) = succ(a+succ(n))",
#rw [add_succ]
"-- We again use the fact that a + succ(b) = succ(a + b) ∀ a, b ∈ ℕ on the right hand side and set a := a and b := n giving us succ(succ(a) + n) = succ(succ(a+n))",
#rw [add_succ]
"-- Rewrite the right hand side using the hypothesis giving us succ(succ(a + n)) = succ(succ(a) + n)",
#rw [← hn]
"-- Hence we are done.",
#rfl
]

NL_Complete_succ_add_2 = [
    #theorem succ_add_logical_deviation_2 (a b : ℕ) : succ a + b = succ (a + b)  := by
"-- Initiate induction on b, the base case (b=0) succ(a) + 0 = succ(a + 0",
 #induction b with n hn
"-- We start by proving the base case using the fact that succ(a+b) = a + succ(b) and setting b = 0 and substituting on the RHS",
 #rw [← add_succ]
"-- We use the fact that c + 0 = c ∀ c ∈ ℕ and set c := succ(a) to get succ(a) = a + succ(0)",
 #rw [add_zero]
"-- Now use the fact that a + succ(b) = succ(a+b) and set a := a and b := 0 to get a + succ(0) = succ(a+0) on the RHS",
 #rw [add_succ]
"-- Now on the RHS we use the fact that c + 0 = c ∀ c ∈ ℕ and set c := a to get succ(a) = succ(a)",
 #rw [add_zero]
"-- since succ(a) = succ(a) we are done with the base case",
 #rfl
"-- Now to prove the induction case, we use the fact that a + succ(b) = succ(a + b) ∀ a, b ∈ ℕ and set a := succ(a) and b := n giving us succ(succ(a) + n) = succ(a+succ(n))",
 #rw [add_succ]
"-- We again use the fact that a + succ(b) = succ(a + b) ∀ a, b ∈ ℕ on the right hand side and set a := a and b := n giving us succ(succ(a) + n) = succ(succ(a+n))",
 #rw [add_succ]
"-- Rewrite the right hand side using the hypothesis giving us succ(succ(a) + n) = succ(succ(a) + n)",
 # rw [← hn]
"-- Hence we are done.",
 #rfl
]

NL_Complete_succ_add_persona_2 = [
    "-- Initiate induction on b.",
# induction b with n hn
"-- We start by proving the base case using properties of succession, succ(a+0) = a + succ(0) on RHS",
# rw [← add_succ]
"-- Now using properties of addition by 0, we can rewrite succ(a) + 0 to succ(a) on the LHS",
# rw [add_zero]
"-- Now using properties of succession, we can rewrite succ(a) + 0 to succ(a+0) on the RHS",
# rw [add_succ]
"-- Now using properties of addition by 0, we can rewrite a + 0 to a on the RHS",
# rw [add_zero]
"-- since succ(a) = succ(a), we are done with the base case",
# rfl
"-- Now to prove the induction case, we use properties of succession substituting succ(a) + succ(n) = succ(succ(a) + n) on LHS",
# rw [add_succ]
"-- Now again using properties of succession, we substitute succ(a + succ(n)) to succ(succ(a + n)) on the RHS",
# rw [add_succ]
"-- Using the induction hypothesis giving us succ(succ(a) + n) = succ(succ(a) + n) on the LHS",
# rw [← hn]
"-- both sides are equal, hence we are done",
# rfl
]

NL_easy = [

"-- Induct on n, with d = 0 as the base case and the inductive hypothesis 0 + d = d. There are now two proof goals, prove base case: 0 + 0 = 0, and inductive step: 0 + succ (d) = succ (d)",
  #induction n with d hd
"-- First prove base case. Reduce LHS 0 + 0 = 0.",
  #rw [add_zero]
"-- Prove LHS and RHS are equal, 0 = 0, completing base case",
  #rfl
"-- Now prove inductive step. Rewrite 0 + succ d = succ (0 + d)",
 # rw [add_succ]
"-- Simplify RHS succ (0 + d) = succ(d) using the inductive hypothesis.",
  #rw [hd]
"-- Prove LHS and RHS are equal, succ(d) = succ(d), completing the proof",
  #rfl"
]

NL_wrong = [
    "-- Initiate induction on a",
    "-- We start by proving the base case we use properties of addtion by 0 to to simplify LHS to b",
    "-- Now we can use properties of addition by 0 to simplify the RHS to b",
    "-- since b = b we are done with the base case",
    "-- Now to prove the induction case, we use successor properties to make LHS = succ(n + b)",
    "-- Now we can use successor properties to make RHS = succ(b + n)",
    "-- We make RHS = b + succ(n)",
    "-- We make LHS = succ(n) + b",
    "-- cancel out the b's to get succ(n) = succ(n) which is true",
    "-- Hence we are done."
]

# autoformalize_induction_proof_whole_addition(NL_Complete_succ_add_1, "theorem succ_add_logical_deviation_1 (a b : ℕ) : succ a + b = succ (a + b) := by")
# autoformalize_induction_proof_whole_addition(NL_Complete_succ_add_2, "theorem succ_add_logical_deviation_2 (a b : ℕ) : succ a + b = succ (a + b) := by")
# a, b = autoformalize_induction_proof_whole_addition(NL_Complete_succ_add_persona_2, "theorem succ_add_logical_deviation_2 (a b : ℕ) : succ a + b = succ (a + b) := by")

#a, b = autoformalize_induction_proof_whole_addition(NL_easy, "theorem zero_add_test (n : ℕ) : 0 + n = n := by")

# autoformalize_induction_proof_whole_addition(NL_Complete_list_3, "theorem add_comm_logically_eq (a b : ℕ) : a + b = b + a := by")

# a, b = autoformalize_induction_proof_whole_addition(NL_wrong, "theorem wrong_add (a b : ℕ) : a + b = b + a := by")



In [15]:
print(a, b)

True ['induction b with d hd', 'rw [← add_succ a 0]', 'rw [add_zero]', 'rw [add_succ]', 'rw [add_zero a]', 'rfl', 'rw [add_succ]', 'rw [add_succ]', 'rw [hd]', 'rfl']


In [None]:
## TESTING USING STATE

In [64]:
def test_autoformalizer(autoformalizer_fn, theorem_statement, nl_statements, true_lean_statements):
    """
    Test an autoformalizer by comparing states after tactics are applied.
    
    Args:
        autoformalizer_fn: The autoformalizer function to test
        theorem_statement: The theorem statement to prove
        nl_statements: List of natural language statements describing the proof steps
        true_lean_statements: List of true Lean statements for comparison
        
    Returns:
        dict: Test results including success rate, state matches, and detailed comparison
    """

    print(nl_statements, theorem_statement)
    # Run the autoformalizer to get predicted Lean statements
    is_complete, predicted_lean_statements = autoformalizer_fn(nl_statements, theorem_statement)
    
    # Initialize results
    results = {
        "is_complete": is_complete,
        "total_steps": len(true_lean_statements),
        "predicted_steps": len(predicted_lean_statements),
        "state_matches": 0,
        "detailed_comparison": []
    }
    
    # Build up the proof step by step for both true and predicted statements
    true_proof = f"{theorem_statement} := by\n"
    pred_proof = f"{theorem_statement} := by\n"
    
    # Compare states after each step
    max_steps = min(len(true_lean_statements), len(predicted_lean_statements))
    
    for i in range(max_steps):
        # Add the next true statement and compile
        true_proof += f"   {true_lean_statements[i]}\n"
        true_result = lean_compile(true_proof, "", f"true_step_{i}", verbose=False)
        true_state = parse_unsolved_state(true_result)
        
        # Add the next predicted statement and compile
        pred_proof += f"   {predicted_lean_statements[i]}\n"
        pred_result = lean_compile(pred_proof, "", f"pred_step_{i}", verbose=False)
        pred_state = parse_unsolved_state(pred_result)
        
        # Compare the states
        states_match = are_states_equivalent(" ".join(true_state), " ".join(pred_state))
        
        # Record the comparison
        results["detailed_comparison"].append({
            "step": i + 1,
            "true_statement": true_lean_statements[i],
            "predicted_statement": predicted_lean_statements[i],
            "states_match": states_match,
            "true_state": true_state,
            "predicted_state": pred_state
        })
        
        if states_match:
            results["state_matches"] += 1
    
    # Calculate success rate
    results["success_rate"] = results["state_matches"] / results["total_steps"] if results["total_steps"] > 0 else 0
    
    # Check if the final states match (proof completion)
    if len(true_lean_statements) > 0 and len(predicted_lean_statements) > 0:
        # Compile the complete proofs
        final_true_proof = f"{theorem_statement} := by\n" + "\n".join([f"   {stmt}" for stmt in true_lean_statements])
        final_pred_proof = f"{theorem_statement} := by\n" + "\n".join([f"   {stmt}" for stmt in predicted_lean_statements])
        
        true_final_result = lean_compile(final_true_proof, "", "true_final", verbose=False)
        pred_final_result = lean_compile(final_pred_proof, "", "pred_final", verbose=False)
        
        true_final_state = parse_unsolved_state(true_final_result)
        pred_final_state = parse_unsolved_state(pred_final_result)
        
        results["final_states_match"] = are_states_equivalent(" ".join(true_final_state), " ".join(pred_final_state))
        results["true_final_state"] = true_final_state
        results["pred_final_state"] = pred_final_state
    
    return results

def run_test_suite(autoformalizer_fn, test_cases):
    """
    Run a suite of tests on the autoformalizer.
    
    Args:
        autoformalizer_fn: The autoformalizer function to test
        test_cases: List of test cases, each containing theorem_statement, nl_statements, and true_lean_statements
        
    Returns:
        dict: Overall test results and individual test case results
    """
    overall_results = {
        "total_tests": len(test_cases),
        "successful_tests": 0,
        "average_success_rate": 0,
        "test_results": []
    }
    
    total_success_rate = 0
    
    for i, test_case in enumerate(test_cases):
        print(f"\n=== Running Test Case {i+1}/{len(test_cases)} ===")
        
        # Run the test
        result = test_autoformalizer(
            autoformalizer_fn,
            test_case["theorem_statement"],
            test_case["nl_statements"],
            test_case["true_lean_statements"]
        )
        
        # Add test case info to the result
        result["test_case_id"] = i + 1
        result["theorem_name"] = test_case.get("theorem_name", f"Test Case {i+1}")
        
        # Consider a test successful if the final states match
        if result.get("final_states_match", False):
            overall_results["successful_tests"] += 1
        
        # Add to total success rate
        total_success_rate += result["success_rate"]
        
        # Add detailed result to overall results
        overall_results["test_results"].append(result)
        
        # Print summary of this test
        print(f"Test {i+1} - {result['theorem_name']}:")
        print(f"  Success Rate: {result['success_rate'] * 100:.2f}%")
        print(f"  State Matches: {result['state_matches']}/{result['total_steps']}")
        print(f"  Final States Match: {result.get('final_states_match', False)}")
    
    # Calculate average success rate
    overall_results["average_success_rate"] = total_success_rate / len(test_cases) if test_cases else 0
    
    # Print overall summary
    print("\n=== Overall Test Results ===")
    print(f"Total Tests: {overall_results['total_tests']}")
    print(f"Successful Tests: {overall_results['successful_tests']}")
    print(f"Average Success Rate: {overall_results['average_success_rate'] * 100:.2f}%")
    
    return overall_results




In [65]:
test_cases = [
    {
        "theorem_name": "Addition Commutativity",
        "theorem_statement": "theorem add_comm_random (a b : ℕ) : a + b = b + a := by",
        "nl_statements": [
            "We'll prove this by induction on a",
            "By the definition of addition, 0 + b = b",
            "By the definition of addition, b + 0 = b",
            "Therefore, we are done with the base case",
            "By the definition of addition, succ(n) + b = succ(n+b) on the LHS",
            "By the definition of addition again, b + succ(n) = succ(b+n) on the RHS",
            "By the induction hypothesis, n + b = b + n",
            "Hence, we are done with the inductive case"
        ],
        "true_lean_statements": [
            "induction a with d hd",
            "rw [zero_add]",
            "rw [add_zero]",
            "rfl",
            "rw [succ_add]",
            "rw [add_succ]",
            "rw [hd]",
            "rfl"
        ]
    },
    # More test cases...
]

# Run the test suite
results = run_test_suite(autoformalize_induction_proof_whole_addition, test_cases)


=== Running Test Case 1/1 ===
["We'll prove this by induction on a", 'By the definition of addition, 0 + b = b', 'By the definition of addition, b + 0 = b', 'Therefore, we are done with the base case', 'By the definition of addition, succ(n) + b = succ(n+b) on the LHS', 'By the definition of addition again, b + succ(n) = succ(b+n) on the RHS', 'By the induction hypothesis, n + b = b + n', 'Hence, we are done with the inductive case'] theorem add_comm_random (a b : ℕ) : a + b = b + a := by
theorem add_comm_random (a b : ℕ) : a + b = b + a := by
Starting proof of: theorem add_comm_random (a b : ℕ) : a + b = b + a

Processing statement: We'll prove this by induction on a
OpenAI suggests block type: induction
OpenAI suggests induction variable: a
GLOBAL LOCATION  ['case zero', 'b : ℕ', '⊢ 0 + b = b + 0', 'case succ', 'b d : ℕ', 'hd : d + b = b + d', '⊢ succ d + b = b + succ d']
BASE CASE COMPILER  ['case zero', 'b : ℕ', '⊢ 0 + b = b + 0', 'case succ', 'b d : ℕ', 'hd : d + b = b + d', '⊢ s

## TESTING BEST SYSTEM ON 14 theorems

In [37]:
import json
# load json
def output_checker(list, true_list):
    counter = 0
    for i in range(len(list)):
        if list[i] != true_list[i]:
            print(f"Index: {i} \n{list[i]} \n{true_list[i]}")
        else:
            counter += 1
    return counter


theorem = 0
with open('../Datasets/o1_test/test_2.json', 'r') as f:
    theorems = json.load(f)
    for theorem in theorems:
        print(f"Theorem: {theorem['theorem_statement']}")
        print(f"NLs: {theorem['theorem_NLs']}")
        print(f"FLs: {theorem['theorem_FLs']}")
        print()
    print(len(theorems))

Theorem: theorem random_theorem (n : ℕ) : 0 + n = n := by
NLs: ['-- Induct on n', '-- substitute 0 -> 0 + 0 into the RHS giving us 0 + 0 = 0 + 0', '-- 0 + 0 = 0 + 0, completing base case', '-- 0 + succ d -> succ (0 + d) on LHS giving us succ (0 + d) = succ d', '-- 0 + d -> d on LHS -> succ d = succ d', '-- succ d = succ d, QED']
FLs: ['induction n with d hd', 'nth_rewrite 3 [← add_zero 0]', 'rfl', 'rw [add_succ]', 'rw [hd]', 'rfl']

Theorem: theorem random_theorem (a b c : ℕ) : a + b + c = a + c + b := by
NLs: ['-- Apply the associative property of addition to rewrite the LHS of the equation, changing a + b + c to a + (b + c)', '-- Rewrite the LHS of the equation by applying the commutative property of addition to b and c, LHS is now a + (c + b)', '-- Rewrite the RHS using the associative property: a + c + b to a + (c + b).', '-- Prove LHS and RHS are equal, a + (c + b) = a + (c + b), completing the proof']
FLs: ['rw [add_assoc]', 'rw [add_comm b]', 'rw [add_assoc]', 'rfl']

Theorem: t

In [38]:
# count total lines of lean code
count_thm_lines = 0
for i in theorems:
    print(len(i['theorem_NLs']))
    count_thm_lines += len(i['theorem_NLs'])
print(count_thm_lines)

6
4
4
4
4
4
26


In [41]:
total_thm_counter = 0
thm_lines_counter = 0
for theorem in [theorems[-1]]:
    # print("Theorem: ", theorem['theorem_statement'].strip(), "\n", theorem['theorem_NLs'], "\n", theorem['theorem_FLs'])
    theorem_name = theorem['theorem_statement'].strip().split(" ")
    is_correct, lean_output = autoformalize_induction_proof_whole_addition(theorem['theorem_NLs'], theorem['theorem_statement'].strip())
    total_thm_counter += is_correct
    thm_lines_counter += output_checker(lean_output, theorem['theorem_FLs'])
    print(thm_lines_counter)

print(f"Total theorems: {len(theorems)} \nTotal theorems correct: {total_thm_counter}")
print(f"Total lines correct: {thm_lines_counter} \nTotal lines  {count_thm_lines}")


theorem random_theorem (a b c : ℕ) : a + b + c = a + c + b := by
Starting proof of: theorem random_theorem (a b c : ℕ) : a + b + c = a + c + b

Processing statement: -- Apply the associative property of addition to rewrite the LHS: a + b + c to a + (b + c).
OpenAI suggests block type: rewrite

Processing statement: -- Apply the associative property of addition to rewrite the LHS: a + b + c to a + (b + c).
CURRENT STATE ⊢ a + b + c = a + c + b
Predicted next state: ⊢ a + (b + c) = a + c + b
Attempt 1/3 to find a theorem
Explanation: We apply associativity in the forward direction to rewrite (a + b) + c to a + (b + c).
OpenAI suggests theorem: add_assoc (backarrow: False, inputs: [], nth: None)
TEMP PROOF  theorem random_theorem (a b c : ℕ) : a + b + c = a + c + b := by
   rw [add_assoc]

THEOREM  add_assoc
GLOBAL LOCATION  ['a b c : ℕ', '⊢ a + (b + c) = a + c + b']
THE STATES ARE  ⊢ a + (b + c) = a + c + b ⊢ a + (b + c) = a + c + b
Theorem add_assoc produces the expected state!
GLOBAL L

## DONE TEST