In [121]:
from pantograph import Server
from lean_interact import *  # noqa: F403
from lean_interact.interface import LeanError, CommandResponse, Message
import litellm
import re
from dotenv import load_dotenv
import os

# Load .env file - call this FIRST, before accessing env vars
load_dotenv()
api_key = os.getenv('ANTHROPIC_API_KEY')
if api_key is None:
    raise Exception("Failed to load env!")

In [None]:
#how to create pantograph server
#server = await Server.create(imports=['Init', 'Mathlib'], project_path=".")

In [None]:
project = LocalProject(directory=".")
config = LeanREPLConfig(project=project, memory_hard_limit_mb = 8192)
server = LeanServer(config)
imports = ["Mathlib", "LeanSpecproof.Verification"]
def to_import_string(imports: list[str]):
    return "\n".join("import " + i for i in imports)
import_string = to_import_string(imports)

# this takes 1 to 2 minutes
pmath = server.run(Command(cmd=import_string))
server.run(PickleEnvironment(env=pmath.env, pickle_to="clean.olean"))

Lake version 5.0.0-src+d8204c9 (Lean version 4.26.0)
Current branch: HEAD
Using cache (Azure) from origin: leanprover-community/mathlib4
No files to download
Decompressing 7727 file(s)
Unpacked in 18613 ms
Completed successfully!
✔ [7743/7745] Built LeanSpecproof.Verification (37s)
✔ [7744/7745] Built LeanSpecproof (6.9s)
Build completed successfully (7745 jobs).


CommandResponse(env=0)

In [152]:
SYSTEM_MESSAGE = """You are a Lean 4 theorem proving assistant. Your goal is to prove theorems while respecting user-specified constraints.

<capabilities>
- You can write Lean 4 code with tactics and proofs
- Assume Mathlib is imported, no importing anything else
- You can use `sorry` temporarily to see proof states at that point
- You can run Lean commands like `#check`, `#eval` to inspect definitions
- You will receive feedback after each attempt showing:
  * Compilation errors (if any)
  * Proof states at each `sorry`
  * Output from `#check` and other commands
  * Verification results (whether constraints were violated)
</capabilities>

<constraints>
- You MUST NOT leave any `sorry` in your final proof
- You MUST follow the forbidden pattern constraint provided by the user
- You have {max_attempts} attempts to produce a valid proof
</constraints>

<output_format>
Structure your response using these XML tags:

<THOUGHTS>
Your reasoning about the problem and proof strategy.
Explain which theorems/tactics you plan to use and why they don't violate constraints.
</THOUGHTS>

<SOLUTION>
theorem challenge ... := by
  -- your proof here
#check useful_lemma  -- optional: check available theorems
</SOLUTION>

If you encounter an issue you cannot resolve (e.g., malformed input, impossible constraint):
<ERROR>
Brief explanation of the issue
</ERROR>
</output_format>

<example>
User gives: "Prove ∀ n : Nat, n + 0 = n without using theorems matching 'add_zero'"

Good response:
<THOUGHTS>
I need to prove n + 0 = n. The forbidden pattern is 'add_zero', so I can't use Nat.add_zero or similar.
I'll use induction on n. Base case: 0 + 0 = 0 follows from definition.
Inductive case: (n+1) + 0 = (n + 0) + 1 by definition, then use IH.
</THOUGHTS>

<SOLUTION>
theorem challenge (n : Nat) : n + 0 = n := by
  induction n with
  | zero => rfl
  | succ n ih => 
    simp [Nat.add_succ]
    exact ih
#check Nat.add_succ
</SOLUTION>
</example>

Remember: You can use `sorry` while drafting to see proof states, but your final submission must have NO sorry statements."""
def create_user_message(theorem_code: str, forbidden_pattern: str, attempt: int, max_attempts: int) -> str:
    return f"""<task>
Prove the following theorem WITHOUT using any theorem whose name contains "{forbidden_pattern}".

This is attempt {attempt}/{max_attempts}.

{theorem_code}
</task>"""


In [153]:
def format_lean_output(response: CommandResponse | LeanError, base_code: str, header_lines: int) -> str:
    if isinstance(response, LeanError):
        return f"<fatal>\n Fatal Lean Error: {response.message}\n</fatal>"
    parts = []
    errors: list[Message] = []
    warnings: list[Message] = []
    infos: list[Message] = []
    lines = base_code.split("\n")
    for msg in response.messages:
        if msg.severity == 'error':
            errors.append(msg)
        if msg.severity == "info":
            infos.append(msg)
        if msg.severity == "warning":
            warnings.append(msg)
    def format_sl(msg: Message):
        ln = msg.start_pos.line - header_lines
        if ln >= len(lines):
            return [msg.data]
        else:
            return [f"Line {ln}: " +
                    f"{lines[ln - 1].lstrip()} →",
                    msg.data]
    # Extract errors
    if errors:
        parts.append("<errors>")
        for err in errors:
            parts += format_sl(err)
        parts.append("</errors>")
    if warnings:
        parts.append("<warnings>")
        for warning in warnings:
            parts += format_sl(warning)
        parts.append("</warnings>")
    if infos:
        parts.append("<info>")
        for info in infos:
            parts += format_sl(info)
        parts.append("</info>")
        
    
    # Extract sorry states
    if response.sorries:
        parts.append("<proof_states>")
        for i, sorry in enumerate(response.sorries, 1):
            parts.append(f"Sorry #{i} at line {sorry.start_pos.line - header_lines}:")
            parts.append(sorry.goal)
        parts.append("</proof_states>")
    
    return "\n".join(parts)

In [None]:
async def prove_theorem_loop(
    theorem_code: str,
    forbidden_pattern: str,
    max_attempts: int = 5,
    model: str = "anthropic/claude-sonnet-4-20250514",
    imports: list[str] = imports,
    verbose: bool = False
) -> dict:
    """
    Main loop for theorem proving with Claude
    
    Returns dict with:
        - success: bool
        - final_code: str (if success)
        - attempt_count: int
        - messages: list (full conversation history)
    """

    # Initialize messages
    messages = [
        {"role": "system", "content": SYSTEM_MESSAGE.format(max_attempts=max_attempts)},
        {"role": "user", "content": create_user_message(theorem_code, forbidden_pattern, 1, max_attempts)}
    ]
    
    # Header for Lean code (import statements)
    
    for attempt in range(1, max_attempts + 1):
        if (verbose):
            print(f"\n=== Attempt {attempt}/{max_attempts} ===")
        
        # Get Claude's response
        response = await litellm.acompletion(
            model=model,
            messages=messages,
            temperature=0.1
        )
        assistant_msg = response.choices[0].message.content
        messages.append({"role": "assistant", "content": assistant_msg})
        
        if (verbose):
            print(f"Claude response:\n{assistant_msg}\n")
        
        # Check if Claude gave up
        if "<ERROR>" in assistant_msg:
            return {
                "success": False,
                "error": "Claude reported an error",
                "attempt_count": attempt,
                "messages": messages
            }
        
        # Extract code from <SOLUTION> tags
        solution_match = re.search(r'<SOLUTION>\s*(.*?)\s*</SOLUTION>', assistant_msg, re.DOTALL)
        if not solution_match:
            feedback = "<feedback>\n<error>No <SOLUTION> tag found in response</error>\n</feedback>"
            messages.append({"role": "user", "content": feedback})
            continue
        
        lean_code = solution_match.group(1).strip()
        final_code = (to_import_string(imports) + "\n" + lean_code + "\n" +
                      f"#verify_solution challenge forbids_pattern {forbidden_pattern}")
        
        # Run through Lean
        lean_response = server.run(Command(cmd=final_code))
        
        # Format the output
        formatted_output = format_lean_output(lean_response, lean_code, len(imports))
        
        if verbose:
            print(f"Lean output:\n{formatted_output}\n")
        
        # Check for success
        valid = lean_response.lean_code_is_valid(allow_sorry=False)
        
        if valid:
            # Success!
            return {
                "success": True,
                "final_code": lean_code,
                "attempt_count": attempt,
                "messages": messages
            }
        
        # Prepare feedback for next iteration
        feedback = f"<feedback>\nAttempt {attempt}/{max_attempts} failed.\n{formatted_output}\n</feedback>"
        messages.append({"role": "user", "content": feedback})
    
    # All attempts exhausted
    return {
        "success": False,
        "error": "Maximum attempts reached",
        "final_code": lean_code,
        "attempt_count": max_attempts,
        "messages": messages
    }


In [155]:
ans = await prove_theorem_loop(theorem_code = "∀ {a b c : ℕ}, a + (b + c) = (a + b) + c",
                         forbidden_pattern = "add_assoc",
                         verbose = True)


=== Attempt 1/5 ===
Claude response:
<THOUGHTS>
I need to prove the associativity of addition for natural numbers without using any theorem containing "add_assoc". This is a fundamental property that I'll need to prove from first principles.

I'll use induction on one of the variables. Let me try induction on `a`. The base case will be `0 + (b + c) = (0 + b) + c`, and the inductive step will be `(a + 1) + (b + c) = ((a + 1) + b) + c`.

For natural numbers, addition is defined recursively:
- `0 + n = n`
- `(m + 1) + n = (m + n) + 1`

I should be able to use `Nat.add_succ` and `Nat.zero_add` and similar basic properties, as long as they don't contain "add_assoc" in their names.

Let me start with induction on `a`.
</THOUGHTS>

<SOLUTION>
theorem challenge {a b c : ℕ} : a + (b + c) = (a + b) + c := by
  induction a with
  | zero => 
    sorry
  | succ a ih => 
    sorry

#check Nat.zero_add
#check Nat.add_succ
#check Nat.succ_add
</SOLUTION>

Lean output:
Line 1: theorem challenge {a b c

In [158]:
print(ans["final_code"])

theorem challenge {a b c : ℕ} : a + (b + c) = (a + b) + c := by
  induction a with
  | zero => 
    rw [Nat.zero_add, Nat.zero_add]
  | succ a ih => 
    rw [Nat.succ_add, Nat.succ_add, ih, Nat.succ_add]

#check Nat.zero_add
#check Nat.add_succ
#check Nat.succ_add


In [144]:
print(ans["messages"][3]["content"])

<feedback>
Attempt 1/5 failed.
<errors>
Line -1: #check Nat.succ_add →
unexpected identifier; expected command
Line 1: theorem challenge {a b c : ℕ} : a + (b + c) = (a + b) + c := by →
failed to synthesize
  HAdd ℕ ℕ ?m.4

Hint: Additional diagnostic information may be available using the `set_option diagnostics true` command.
Line 2: induction a with →
Tactic `induction` failed: major premise type is not an inductive type
  ℕ

Explanation: the `induction` tactic is for constructor-based reasoning as well as for applying custom induction principles with a 'using' clause or a registered '@[induction_eliminator]' theorem. The above type neither is an inductive type nor has a registered theorem.

ℕ : Type u_1
a b c : ℕ
⊢ sorry
unexpected token '#'; expected command
</errors>
<info>
Line 9: #check Nat.succ_add →
Nat.succ_add (n m : Nat) : n.succ + m = (n + m).succ
Nat.zero_add (n : Nat) : 0 + n = n
</info>
</feedback>


In [145]:
solution_match = re.search(r'<SOLUTION>\s*(.*?)\s*</SOLUTION>', ans["messages"][2]["content"], re.DOTALL)

lean_code = solution_match.group(1).strip()
final_code = ("\n".join(imports) + "\n" + lean_code + "\n" +
                f"#verify_solution challenge forbids_pattern add_assoc")

In [146]:
final_code

'Mathlib\nLeanSpecproof.Verification\ntheorem challenge {a b c : ℕ} : a + (b + c) = (a + b) + c := by\n  induction a with\n  | zero => \n    simp [Nat.zero_add]\n  | succ a ih =>\n    rw [Nat.succ_add, Nat.succ_add, Nat.succ_add]\n    rw [ih]\n\n#check Nat.succ_add\n#check Nat.zero_add\n#verify_solution challenge forbids_pattern add_assoc'

In [150]:
to_import_string

<function __main__.to_import_string(imports: list[str])>