#### Neural next-step prediction | part 5: theorem proving with context
Tutorial on neural theorem proving\
Author: Sean Welleck

----------------

#### High-level goal

In this notebook we will train and evaluate theorem provers that use additional context, $p_\theta(y_t|x_t,c)$, such as the preceding file contents. \
This will enable support for several real-world scenarios that we discuss next.

*Historical note*: the setting described here was first investigated in [LLMstep: proofstep suggestions in Lean](https://mathai2023.github.io/papers/40.pdf).

### Motivation

Real-world proofs often use theorems and definitions that are unique to the proof development. A key
limitation of the models that we have considered  is that they lack context beyond the current proof state. That
is, the model $p_\theta(y_t|x_t)$ only receives the current proof state $x_t$ (transformed into a prompt) as input. 

As a result, the model cannot use newly defined theorems, definitions, and other information from the current proof development,
unless the proof development was seen during training. To see this, we will look at a toy example, followed by a real-world example.

#### Toy Example 

The situation below shows the need for incorporating additional context $c$ beyond the proof state.

Namely, proving the `my_object_sum_nonneg` theorem requires using properties (in pink) of the newly defined `my_object` (in green):

<img src="images/context_1.png" width=500px>


The next-step predictors we have trained will almost certainly fail unless they have seen `my_object_sum_nonneg` in their training set. Not surprisingly, our existing model does not successfully prove this theorem:


In [1]:
PATH_TO_REPL = '/Users/wellecks/projects/ntptutorial/partI_nextstep/ntp-interact/repl'

import sys
sys.path.append('../ntp-interact/')

import proofsearch
import os
import transformers
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # prevents an annoying warning

In [2]:

MODEL = 'l3lab/ntp-mathlib-st-deepseek-coder-1.3b'
model, tokenizer = proofsearch.load_model(MODEL)


header = """import Mathlib

open BigOperators

variable {Ω : Type*}[Fintype Ω]

structure my_object (Ω : Type*)[Fintype Ω] :=
  (f : Ω → ℝ)
  (cool_property : ∀ x : Ω, 0 ≤ f x)

"""
transformers.set_seed(40)

theorem_statement = """theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by {}"""
proofsearch.best_first_search(
    model, tokenizer, header, theorem_statement, 
    max_iters=16,
    num_samples=8,
    temperatures=[0.5],
    verbose=True,
    path_to_repl=PATH_TO_REPL
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	


100%|██████████| 8/8 [00:14<00:00,  1.87s/it]


--- type-checked candidates:
	(-0.384) rw [add_comm]
	(-0.473) simp
--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	rw [add_comm]


100%|██████████| 8/8 [00:08<00:00,  1.00s/it]


--- type-checked candidates:
	(-0.489) simp
--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	simp


100%|██████████| 5/5 [00:03<00:00,  1.39it/s]


--- type-checked candidates:
	(-0.373) apply add_nonneg
--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	simp
	apply add_nonneg


100%|██████████| 6/6 [00:03<00:00,  1.55it/s]


--- type-checked candidates:
	
--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	rw [add_comm]
	simp


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


--- type-checked candidates:
	(-0.356) rw [add_comm]
	(-0.363) apply add_nonneg
--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	rw [add_comm]
	simp
	apply add_nonneg


100%|██████████| 8/8 [00:14<00:00,  1.75s/it]

--- type-checked candidates:
	





{'theorem_statement': 'theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by {}',
 'success': False}

#### Example from [Terence Tao's PFR project]( ) 

Analogous situations occur in almost every real-world theorem proving scenario.

For instance, consider the beginning of this file from Terence Tao's formalization of the [Polynomial Freiman-Ruzsa Conjecture](https://github.com/teorth/pfr/blob/master/PFR) (PFR):
- [Entropy/Basic.lean](https://github.com/teorth/pfr/blob/master/PFR/ForMathlib/Entropy/Basic.lean)

It has a new definition (green), new notation (orange), and new theorems/lemmas (pink) that are used frequently:

<img src="images/context_2.png" width="650px">

*Exercise*: look through other files in [PFR](https://github.com/teorth/pfr/blob/master/PFR), and consider the following questions:
1. How often do theorems use definitions that are newly defined in PFR (e.g., in the current file or other files)?

2. How often do proofs use new definitions, lemmas, or theorems that are defined in PFR?

3. Suppose that PFR is not in the language model's training or fine-tuning set. \
What are the implications of (1), and (2) for using the language model to suggest proofs in PFR?

4. Navigate to the bottom of some files. \
Can you find a theorem or proof that depends on multiple pieces of context? \
Can you find a theorem or proof that depends on context that is "far away" from the theorem declaration?

5. Suppose that the language model was finetuned with a context length of around 4000 tokens. \
What are the implications of (4) for using the language model to suggest proofs in PFR?

### Theorem prover with context

To this end, we will train a model of the form:

\begin{align}
p_\theta(y_t|x_t,c_t),
\end{align}

where $x_t$ is the proof state, $c_t$ is the preceding file contents up to the current tactic, and $y_t$ is a next tactic.

Doing so amounts to generating new instruction tuning data that includes new prompts:

In [2]:
import sys
import jsonlines
sys.path.append('../ntp-training-data/scripts')

from instruction_tuning import prompt_context_state_tactic

example = next(jsonlines.Reader(open('../ntp-training-data/example0.jsonl')).iter())
prompt, completion = prompt_context_state_tactic(example, truncation=1024)

print(prompt)
print(completion)

/- You are proving a theorem in Lean 4.
You are given the following information:
- The file contents up to the current tactic, inside [CTX]...[/CTX]
- The current proof state, inside [STATE]...[/STATE]

Your task is to generate the next tactic in the proof.
Put the next tactic inside [TAC]...[/TAC]
-/
[CTX]
import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.Coprime n) : m.gcd n = 1 := by
  
[/CTX]
[STATE]
m n : ℕ
h : Nat.Coprime m n
⊢ Nat.gcd m n = 1
[/STATE]
[TAC]

rw [Nat.Coprime] at h
[/TAC]


Here is the command for running the script on the extracted data from Mathlib:

In [3]:
# !cd ../ntp-training-data && python scripts/instruction_tuning.py --context-truncation 1024 --name with_context --output-base-dir instructions/with_context --prompt context_state_tactic state_tactic --mathlib-only

If you would like to finetune your own model, you can follow the same procedure as in part 2, which would use code in the [ntp-tune](../ntp-tune/) directory. Finetuning would require changing the `TRAIN_FILE` and `VALID_FILE` filepaths in the [finetune.sh](../ntp-tune/finetune.sh) script to point to the new instruction tuning data. 


We provide the instruction tuning data and a fine-tuned model on HuggingFace:

- [`l3lab/ntp-mathlib-instruct-context`](https://huggingface.co/datasets/l3lab/ntp-mathlib-instruct-context) (Data)
- [`l3lab/ntp-mathlib-context-deepseek-coder-1.3b`](https://huggingface.co/l3lab/ntp-mathlib-context-deepseek-coder-1.3b) (Model)

We evaluated this model on miniF2F and it closed 29.5% (72/244) of proofs.

### Theorem proving with context

Now let's use the model to prove the theorem in the toy example from above:

In [10]:
MODEL = 'l3lab/ntp-mathlib-context-deepseek-coder-1.3b'

model, tokenizer = proofsearch.load_model(MODEL)

header = """import Mathlib

open BigOperators

variable {Ω : Type*}[Fintype Ω]

structure my_object (Ω : Type*)[Fintype Ω] :=
  (f : Ω → ℝ)
  (cool_property : ∀ x : Ω, 0 ≤ f x)

"""

theorem_statement = """theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by {}"""

transformers.set_seed(42)
proofsearch.best_first_search(
    model, tokenizer, header, theorem_statement, 
    max_iters=16,
    num_samples=8,
    temperatures=[0.5],
    verbose=True,
    # -- KEY DIFFERENCE: now we add the header + statement to the prompt
    prompt_fn=proofsearch.prompt_with_context,
    context=header + theorem_statement.replace("{}", ""),
    # -----
    path_to_repl=PATH_TO_REPL
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	
/- You are proving a theorem in Lean 4.
You are given the following information:
- The file contents up to the current tactic, inside [CTX]...[/CTX]
- The current proof state, inside [STATE]...[/STATE]

Your task is to generate the next tactic in the proof.
Put the next tactic inside [TAC]...[/TAC]
-/
[CTX]
import Mathlib

open BigOperators

variable {Ω : Type*}[Fintype Ω]

structure my_object (Ω : Type*)[Fintype Ω] :=
  (f : Ω → ℝ)
  (cool_property : ∀ x : Ω, 0 ≤ f x)

theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
[/CTX]
[STATE]
Ω : Type u_1
inst✝ : Fintype Ω
o1 o2 : my_object Ω
⊢ o1.f + o2.f ≥ 0
[/STATE]
[TAC]



100%|██████████| 8/8 [00:14<00:00,  1.85s/it]


--- type-checked candidates:
	(-0.216) apply add_nonneg
	(-0.206) simp [o1.cool_property, o2.cool_property]
--- current:
	theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by 
	simp [o1.cool_property, o2.cool_property]
/- You are proving a theorem in Lean 4.
You are given the following information:
- The file contents up to the current tactic, inside [CTX]...[/CTX]
- The current proof state, inside [STATE]...[/STATE]

Your task is to generate the next tactic in the proof.
Put the next tactic inside [TAC]...[/TAC]
-/
[CTX]
import Mathlib

open BigOperators

variable {Ω : Type*}[Fintype Ω]

structure my_object (Ω : Type*)[Fintype Ω] :=
  (f : Ω → ℝ)
  (cool_property : ∀ x : Ω, 0 ≤ f x)

theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by simp [o1.cool_property, o2.cool_property]
[/CTX]
[STATE]
Ω : Type u_1
inst✝ : Fintype Ω
o1 o2 : my_object Ω
⊢ 0 ≤ o1.f + o2.f
[/STATE]
[TAC]



100%|██████████| 7/7 [00:14<00:00,  2.10s/it]

--- type-checked candidates:
	(-0.136) exact add_nonneg o1.cool_property o2.cool_property





{'theorem_statement': 'theorem my_object_sum_nonneg (o1 o2: my_object Ω) : o1.f + o2.f ≥ 0 := by {}',
 'proof': ['simp [o1.cool_property, o2.cool_property]',
  'exact add_nonneg o1.cool_property o2.cool_property'],
 'state': {'env': 0},
 'score': 0.3417195677757263,
 'success': True}

#### Evaluating on real examples

Finally, let's attempt to prove a few real (context, theorem) examples from [Math2001](https://github.com/teorth/pfr).



In [None]:
evaluation_theorems = [
    ("""/- Copyright (c) Heather Macbeth, 2023.  All rights reserved. -/
import Mathlib.Data.Real.Basic
import Mathlib
     
open BigOperators

def Superpowered (k : ℕ) : Prop := ∀ n : ℕ, Prime (k ^ k ^ n + 1)

""", 
"""theorem not_superpowered_zero : ¬ Superpowered 0 := by {}"""),

("""import Mathlib.Data.Real.Basic
import Mathlib

open Function
namespace Int


def F : ℕ → ℤ
  | 0 => 1
  | 1 => 1
  | n + 2 => F (n + 1) + F n

def q (x : ℝ) : ℝ := x + 3

""",
"""example : Injective q := by {}"""),

("""import Mathlib.Data.Real.Basic
import Mathlib

open Function
 
def s (x : ℝ) : ℝ := 5 - x

""",
"""example : s ∘ s = id := by {}""")
]



import transformers
transformers.set_seed(40)

results = {True: [], False: []}
for header, theorem in evaluation_theorems:
    result = proofsearch.best_first_search(
        model, tokenizer, header, theorem, 
        max_iters=16,
        temperatures=[0.5],
        num_samples=8,
        verbose=True,
        prompt_fn=proofsearch.prompt_with_context,
        context=header + theorem.replace("{}", ""),
        path_to_repl=PATH_TO_REPL
    )
    print("Success: %s" % result['success'])
    results[result['success']].append(result)

Here are the successfully closed theorems and their generated proofs:

In [7]:
def print_result(result):
    print(result['theorem_statement'].replace('{}', '') + '\n\t' + '\n\t'.join(result['proof']) + '\n')

print("%.3f closed" % (len(results[True])/ (len(results[True])+len(results[False]))))
for result in results[True]:
    print_result(result)

1.000 closed
theorem not_superpowered_zero : ¬ Superpowered 0 := by 
	intro h
	simp [Superpowered] at h
	exact not_prime_one (h 0)

example : Injective q := by 
	intro x y h
	simpa [q] using congr_arg q h

example : s ∘ s = id := by 
	funext x
	simp [s, Function.comp]



#### Next steps

In the final notebook, we will incorporate our model into a VS Code tool that generates next-step suggestions with the model, enabling a form of "human-machine collaboration". 
Building a tool is also helpful for thinking about practical requirements (e.g. runtime, generalizing to different projects).