#### Neural next-step prediction | part 2: learning
Tutorial on neural theorem proving\
Author: Sean Welleck

----------------

#### High-level goal

Our goal is to train a next-step generator $p_\theta(y_t|x_t)$ with the dataset that we collected in the previous notebook.

To do so, we will fine-tune a pretrained language model with the dataset $\mathcal{D}=\{(x_t,y_t)\}$ using standard supervised fine-tuning:

$$
\arg\max_\theta \sum_{(x_t,y_t)\in \mathcal{D}}-\log p_\theta(y_t|x_t).
$$

That is, we maximize the conditional likelihood of a completion $y_t$ (which contains a next-tactic) given the prompt $x_t$ (which contains a proof state). 

This corresponds to minimizing a cross-entropy loss at each position of the completion, $\sum_{\ell=1}^{{|y_t|}}-\log p_\theta(y_t^\ell|y_t^{<\ell})$.

### Implementation

In the previous notebook, we saw how to use [instruction_tuning.py](../ntp-training-data/scripts/instruction_tuning.py) to format the extracted Mathlib data into (prompt, completion) examples.

We provide formatted fine-tuning data on HuggingFace:

- [`l3lab/ntp-mathlib-instruct-st`](l3lab/ntp-mathlib-instruct-st)

*If you use this data or code, we kindly ask that you cite this neural theorem proving tutorial*.


We will finetune on Mathlib (`ntp-lean-mathlib-tactic-instruct`). First, we download the training and validation set:




In [9]:
!cd ../ntp-tune && bash prepare_data.sh

In [None]:
!head -n 2 ../ntp-tune/data/state_tactic_mathlib_only_train.jsonl

#### 4. Fine-tuning

We can now use an off-the-shelf fine-tuning script. We minimally adapt a standard language-model fine-tuning script from [open-instruct](https://github.com/allenai/open-instruct). 

You can check out the full script at [ntp-tune/finetune.py](../ntp-tune/finetune.py). \
See [ntp-tune/finetune.sh](../ntp-tune/finetune.sh) for a command that finetunes a `deepseek-coder-1.3b-base` model on 4 GPUs. 

Please see the [ntp-tune](../ntp-tune) directory for setup instructions.

#### After training

We have fine-tuned a `deepseek-coder-1.3b-base` model that can be accessed through HuggingFace:
- [`l3lab/ntp-mathlib-st-deepseek-coder-1.3b`](https://huggingface.co/l3lab/ntp-mathlib-st-deepseek-coder-1.3b)

In [2]:
import transformers

MODEL = 'l3lab/ntp-mathlib-st-deepseek-coder-1.3b'
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


You can use your own model by setting `MODEL = "/path/to/checkpoint-{BEST_STEP}"`

Let's generate a next-step suggestion for the proof state from our original example:

```lean
    theorem test_thm (m n : Nat) (h : m.Coprime n) : m.gcd n = 1
```
Recal from the previous notebook that the initial proof state $x_0$ is:

    m n : ℕ
    h : Nat.Coprime m n
    ⊢ Nat.gcd m n = 1

In [3]:
state = """m n : ℕ
h : Nat.Coprime m n
⊢ Nat.gcd m n = 1"""

prompt = """/- You are proving a theorem in Lean 4.
You are given the following information:
- The current proof state, inside [STATE]...[/STATE]

Your task is to generate the next tactic in the proof.
Put the next tactic inside [TAC]...[/TAC]
-/
[STATE]
%s
[/STATE]
[TAC]
""" % state

model.eval()
input_ids = tokenizer(prompt, return_tensors='pt')
out = model.generate(
    input_ids['input_ids'],
    attention_mask=input_ids['attention_mask'],
    max_new_tokens=256,
    pad_token_id=tokenizer.eos_token_id,
)
for item in out:
    text = tokenizer.decode(item[input_ids['input_ids'].shape[1]:], skip_special_tokens=True)
    print(text)

rw [← Nat.gcd_comm, Nat.gcd_eq_one_iff_coprime]
[/TAC]


### Next steps

In the next notebook, we will prove theorems with the trained model by interacting with the Lean proof assistant.

This will let us automatically check whether a generated proof (e.g., one containing the step above) is correct.

Later on, we will train a language model that uses additional file context, then we will build a VS Code plugin that returns next-step suggestions from the language model.