In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Load the model and tokenizer
model_name = "deepseek-ai/DeepSeek-Prover-V1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set up the pipeline for text generation
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

def test_prompt(prompt):
    """
    Test the model with a given prompt and print the output.
    """
    print(f"Testing prompt:\n{prompt}\n")
    output = generator(prompt, max_length=100, num_return_sequences=1, do_sample=True)
    generated_text = output[0]["generated_text"]
    print(f"Generated Output:\n{generated_text}\n")
    return generated_text

# Sample prompts for testing
prompts = [
    "Please provide the single tactic needed to complete the following step in a Lean proof. Make sure the tactic is enclosed in a Markdown code block using triple backticks (` ``` `). Output only the tactic and nothing else.",
    "Output only the tactic needed for the proof step. Enclose the tactic in a code block.",
    "Provide the tactic necessary for the proof step in a code block format.",
]

# Run tests on all prompts
for prompt in prompts:
    test_prompt(prompt)

In [1]:
import re

LEAN4_DEFAULT_HEADER = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"

def extract_tactic(output: str) -> str:
    # Regular expression to find text within triple backticks
    match = re.search(r'```(.*?)```', output, re.DOTALL)
    if match:
        return match.group(1).strip()
    # Fallback to find text within single backticks if no triple backticks
    match = re.search(r'`([^`]*)`', output)
    if match:
        return match.group(1).strip()
    # If no code block is found, return a fallback message or handle it as needed
    return "Tactic not found in expected format."

def cot_prompt(data):
    return "Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n```lean4\n{header}{informal_prefix}{formal_statement}".format(
        header=data.get('header', LEAN4_DEFAULT_HEADER),
        informal_prefix=data.get('informal_prefix', str()),
        formal_statement=data['formal_statement'],
    )

def post_process_output(output):
    _find_idx = output.find("```")
    return output[:_find_idx] if _find_idx >= 0 else output
