# Lean GRPO Tutorial

This notebook demonstrates how to use Lean GRPO to train LLMs for Lean 4 proof generation.

## Setup

First, let's import the necessary modules and set up our environment.

In [None]:
import asyncio
import json
import os

from lean_grpo import (
    LeanGRPOConfig,
    LeanGRPOTrainer,
    LeanInterface,
    MockLeanInterface,
    MockInferenceClient,
    ProofRolloutGenerator,
    VLLMClient,
    REWARD_CONFIG_SHAPED,
)

# For Jupyter async support
import nest_asyncio
nest_asyncio.apply()

## 1. Basic Usage: Mock Mode

Let's start with mock interfaces that don't require Lean 4 or a GPU.

In [None]:
# Create mock interfaces
lean = MockLeanInterface()
inference = MockInferenceClient()

# Create a rollout generator
generator = ProofRolloutGenerator(
    inference_client=inference,
    lean_interface=lean,
    max_steps=5,
)

# Generate a rollout
trajectory = asyncio.run(generator.generate_rollout(
    theorem_name="add_zero",
    theorem_statement="theorem add_zero (n : Nat) : n + 0 = n",
))

print(f"Reward: {trajectory.reward}")
print(f"Steps: {trajectory.num_steps}")
print(f"Complete: {trajectory.is_complete}")
print("\nTactics:")
print(trajectory.get_tactics_text())

## 2. Lean 4 Integration

Now let's use the actual Lean 4 interface (requires Lean 4 installed).

In [None]:
# Check if Lean is available
try:
    lean_real = LeanInterface(lean_cmd="lake", timeout=10.0)
    print("Lean 4 interface initialized successfully!")
except Exception as e:
    print(f"Could not initialize Lean: {e}")
    lean_real = None

In [None]:
if lean_real:
    from lean_grpo.lean_interface import LeanProofState
    
    # Create a proof state
    state = LeanProofState(
        theorem_name="test",
        theorem_statement="(n : Nat) : n = n",
        imports=["Mathlib"],
    )
    
    # Execute a tactic
    result = asyncio.run(lean_real.execute_tactic(state, "intro n"))
    
    print(f"Status: {result.status}")
    print(f"Goals remaining: {result.num_goals}")
    print(f"Is valid: {result.is_valid}")

## 3. Reward Calculation

Calculate rewards for proof trajectories.

In [None]:
from lean_grpo.reward import LeanRewardCalculator

# Create reward calculator
calculator = LeanRewardCalculator(
    lean_interface=lean,
    config=REWARD_CONFIG_SHAPED,
)

# Calculate reward for our trajectory
reward, metrics = asyncio.run(calculator.calculate_reward(trajectory))

print(f"Reward: {reward:.3f}")
print("\nMetrics:")
for key, value in metrics.items():
    print(f"  {key}: {value}")

## 4. Training Setup

Configure and set up the GRPO trainer.

In [None]:
# Configuration
config = LeanGRPOConfig(
    base_model="Qwen/Qwen2.5-0.5B-Instruct",  # Small model for testing
    lora_rank=8,
    num_generations=4,
    learning_rate=5e-6,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    max_proof_steps=10,
    reward_config=REWARD_CONFIG_SHAPED,
    output_dir="outputs/tutorial",
)

print("Configuration:")
print(f"  Base model: {config.base_model}")
print(f"  LoRA rank: {config.lora_rank}")
print(f"  Num generations: {config.num_generations}")
print(f"  Learning rate: {config.learning_rate}")

## 5. Prepare Dataset

Create a dataset of theorems to train on.

In [None]:
# Sample theorems
theorems = [
    {
        "name": "refl",
        "statement": "theorem refl (n : Nat) : n = n",
        "imports": ["Mathlib"],
    },
    {
        "name": "succ_zero",
        "statement": "theorem succ_zero : Nat.succ 0 = 1",
        "imports": ["Mathlib"],
    },
]

print(f"Loaded {len(theorems)} theorems")
for thm in theorems:
    print(f"  - {thm['name']}: {thm['statement']}")

## 6. Initialize Trainer

Set up the trainer (requires GPU).

In [None]:
# Check for GPU
import torch

if not torch.cuda.is_available():
    print("WARNING: No GPU available. Training will not work.")
else:
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    
    # Initialize trainer
    trainer = LeanGRPOTrainer(config, lean, inference)
    trainer.setup()
    
    print("\nTrainer initialized!")
    print(f"Model: {trainer.model.__class__.__name__}")
    print(f"Tokenizer: {trainer.tokenizer.__class__.__name__}")

## 7. Prepare Training Dataset

In [None]:
if torch.cuda.is_available():
    dataset = trainer.prepare_dataset(theorems, trainer.tokenizer)
    print(f"\nDataset prepared with {len(dataset)} examples")
    
    # Show an example
    example = dataset[0]
    print("\nExample prompt:")
    for msg in example['prompt']:
        print(f"  {msg['role']}: {msg['content'][:50]}...")

## 8. Train

Run training (this will take a while on GPU).

In [None]:
if torch.cuda.is_available() and len(dataset) > 0:
    print("Starting training...")
    trainer.train(dataset)
    print("Training complete!")
    
    # Save model
    trainer.save_model()
    print(f"Model saved to {config.output_dir}")

## 9. Generate Proofs

Use the trained model to generate proofs.

In [None]:
if torch.cuda.is_available():
    # Generate a proof
    theorem = "theorem add_one (n : Nat) : n + 1 = Nat.succ n"
    
    print(f"Theorem: {theorem}")
    print("\nGenerating proof...")
    
    proof = trainer.generate_proof(theorem, temperature=0.7)
    
    print("\nGenerated proof:")
    print(proof)

## 10. Advanced: Custom Reward Function

Create a custom reward function.

In [None]:
from lean_grpo.reward import RewardConfig
from lean_grpo.trajectory import ProofTrajectory

def prefer_short_proofs(trajectory: ProofTrajectory) -> float:
    """Reward shorter proofs."""
    if trajectory.is_complete:
        # Bonus for short proofs
        return max(0, 0.5 - trajectory.num_steps * 0.05)
    return 0.0

# Create custom config
custom_config = RewardConfig(
    completion_reward=1.0,
    custom_rewards=[prefer_short_proofs],
    use_lean_validation=True,
)

# Use in trainer
# trainer = LeanGRPOTrainer(config, lean, inference, reward_config=custom_config)

## Conclusion

This tutorial covered:

1. Basic usage with mock interfaces
2. Lean 4 integration
3. Reward calculation
4. Training setup
5. Dataset preparation
6. Model training
7. Proof generation
8. Custom reward functions

For more details, see the documentation and examples in the repository.