Visit Unsloth docs for all [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

### Unsloth

Use `PatchFastRL` before all functions to patch GRPO and other RL algorithms!

In [None]:
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

Load up `Llama 3.1 8B Instruct`, and set parameters

In [None]:
from unsloth import is_bfloat16_supported
import torch
import time

max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 128 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "kings-crown/Isabelle_FVELer_SFT",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

In [None]:
class Checker(object):
    """A modified version of the Draft, Sketch, Prove proof-checking client.
    (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)

    This checker supports Isabelle2022 via the new version of PISA
    (https://albertqjiang.github.io/Portal-to-ISAbelle/).

    It supports checking a miniF2F-style proof via `check`.

    Finally, it replaces `sledgehammer` with a call to `normalhammer`.
    """
    def __init__(self, working_dir, isa_path, theory_file_path, port=9000):
        sys.path.append(os.environ.get('PISA_PATH', ''))
        try:
            from pisa_client import initialise_env
            self.initialise_env = initialise_env
        except ImportError:
            print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")

        self.working_dir = working_dir
        self.isa_path = isa_path
        self.theory_file_path = theory_file_path
        self.port = port

    def _initialize(self):
        """Initialize the PISA environment."""
        env = self.initialise_env(
            self.port,
            isa_path=self.isa_path,
            theory_file_path=self.theory_file_path,
            working_directory=self.working_dir
        )
        return env

    def _exit(self, env):
        """Exit the environment and clean up resources."""
        try:
            env.post('exit')
        except Exception:
            pass
        os.system("ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
        os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")

    def _parse_output(self, obs):
        """Parse the sledgehammer output, returning the relevant part."""
        return obs.split('<hammer>')[0] if '<hammer>' in obs else ''

    def _run_step(self, step, i, tls_name, env):
        """Run a single proof step."""
        try:
            obs, reward, done, metadata = env.step_to_top_level_state(
                action=step,
                tls_name=tls_name,
                new_name=f'default_{i}'
            )
            return obs, reward, done, metadata, None
        except Exception as e:
            return '', 0, False, None, str(e)

    def _run_sledgehammer(self, step, i, tls_name, env):
        """Run sledgehammer or fallback heuristics on a step."""
        heuristics = [
            'by auto', 'by simp', 'by blast', 'by fastforce',
            'by force', 'by eval', 'by presburger', 'by sos',
            'by arith', 'by linarith', 'by (auto simp: field_simps)'
        ]
        for heuristic in heuristics:
            step_ = step.replace('normalhammer', heuristic)
            obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
            if error is None:
                obs = f'{heuristic} <hammer> {obs}'
                return obs, reward, done, metadata, error
        return self._run_step(step.replace("normalhammer", "sledgehammer"), i, tls_name, env)

    def check(self, statement_and_proof):
        """Check the given proof."""
        env = self._initialize()
        env.initialise()

        theory = self.wrap_theorem(statement_and_proof)
        steps = self.get_parsed(env, theory)

        result = self._check(env, steps)
        self._exit(env)

        # Output the result
        #print("\n==== Success: %s" % result['success'])
        #print("--- Complete proof:\n%s" % result['theorem_and_proof'])
        return result

    def _check(self, env, steps):
        """Run the proof steps and collect results."""
        success, reason, done = False, '', False
        step_results = []
        tls_name = 'default'

        for i, step in enumerate(steps):
            time0 = time.time()
            if 'normalhammer' in step or 'sledgehammer' in step:
                obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env)
            else:
                obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env)

            step_time = time.time() - time0
            step_results.append({
                'index': i, 'step': step, 
                'output': self._parse_output(obs), 
                'step_time': step_time
            })

            if error:
                reason = error
                break
            tls_name = f'default_{i}'

        success = done and reward == 1.0
        return {
            'success': success,
            'reason': reason,
            'num_steps': len(steps),
            'last_step': len(step_results),
            'step_results': step_results,
            'theorem_and_proof': self.reconstruct(step_results) if success else ''
        }

    @staticmethod
    def reconstruct(step_results):
        """Reconstruct the complete proof."""
        return '\n'.join(
            step_result['output'].strip() if step_result['output'] else step_result['step'].strip()
            for step_result in step_results[1:]
        )

    @staticmethod
    def wrap_theorem(theorem):
        """Wrap the theorem in a theory file."""
        return (
            'theory Interactive imports HOL.HOL Complex_Main '
            '"HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" '
            '"Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" '
            '"HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem
        )

    @staticmethod
    def get_parsed(env, theory):
        """Parse the theory and extract proof steps."""
        raw_steps = env.post(f"<parse text> ${theory}")
        steps = [s.strip() for s in raw_steps.split('<SEP>') if s.strip() and s != '$']
        processed_steps = []
        for i, step in enumerate(steps):
            if step.lower() == "then" and (i == 0 or steps[i - 1].startswith("proof")):
                continue
            processed_steps.append(step)
        return processed_steps

Follow the guide here to set up PISA: https://github.com/haoxiongliu/Portal-to-ISAbelle/tree/dev_lhx

In [None]:
import sys
import os

sys.path.append('../')
os.environ['PISA_PATH'] = '~/Portal-to-ISAbelle/src/main/python'

#import dsp_utils

checker = Checker(
    working_dir='~/Isabelle2022/src/HOL/Examples',
    isa_path='~/Isabelle2022',
    theory_file_path='~/Isabelle2022/src/HOL/Examples/Interactive.thy',
    port=9000
)


### Data Prep
<a name="Data"></a>

We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!

In [None]:
import re
import sys
import os
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """Write a proof in Isabelle that appropriately proves the given statement in natural language.
Make sure to wrap the proof within ``isabelle and ``` tags inside answer.
Respond in the following format:
<reasoning>
[Your explanation or chain of thought]
</reasoning>
<answer>
``` 
isabelle 
[Your formal Isabelle code]
``` 
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str | None:
    """
    Extracts the content inside <answer>...</answer> tags in XML format.
    Handles missing tags, extra whitespace, and multiple matches.
    """
    if not isinstance(text, str) or not text.strip():
        return None  # Return None for empty or invalid input

    match = re.search(r"<answer>\s*(.*?)\s*</answer>", text, re.DOTALL)
    return match.group(1).strip() if match else None


def extract_isabelle_snippet(text: str) -> str | None:

    if not isinstance(text, str) or not text.strip():
        return None  # Handle empty input safely

    if ":" in text and "```isabelle" in text:
        text = text.split("```isabelle", 1)[1]  # Take the part after first occurrence
        text = "```isabelle" + text  # Ensure consistency

    code_match = re.search(r"```isabelle\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
    if code_match:
        return code_match.group(1).strip()  # Return stripped match
   
    inline_match = re.search(r"(lemma.*?proof.*?qed)", text, re.DOTALL | re.IGNORECASE)
    if inline_match:
        return inline_match.group(1).strip()

    return None  


def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('kings-crown/FVELer_PISA_Proven', 'default')[split]
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['natural_language_statement']}
        ],
        'answer': extract_isabelle_snippet(x['formal_proof'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_isabelle_snippet(r) for r in responses]
    #print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def checker_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_snippets = [extract_isabelle_snippet(r) for r in responses]

    #print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_snippets[0]}")

    for content in extracted_snippets:
        checker = Checker(
            working_dir='~/Isabelle2022/src/HOL/Examples',
            isa_path='~/Isabelle2022',
            theory_file_path='~/Isabelle2022/src/HOL/Examples/Interactive.thy',
            port=9000
        )
        #result = checker.check(content)
        rewards = [2.0 if checker.check(content).get("success", False) else 0.0 for content in extracted_snippets]
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.25 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.25 if match else 0.0 for match in matches]

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 1e-5,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 2048,
    num_train_epochs = 3, # Set to 1 for a full training run
    max_steps = 300,
    save_steps = 25,
    max_grad_norm = 0.1,
    report_to = "wandb", # Can use Weights & Biases
    output_dir = "output_RL",
)

And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
sys.path.append('../')
os.environ['PISA_PATH'] = '~/Portal-to-ISAbelle/src/main/python'


trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        soft_format_reward_func,
        strict_format_reward_func,
        correctness_reward_func,
        checker_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

<a name="Save"></a>
### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")