In [1]:
import re
import json
import numpy as np
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
from dataclasses import asdict, dataclass
from autoformalism_with_llms import prompt
from matplotlib.ticker import PercentFormatter
from autoformalism_with_llms.dataset import MiniF2FMATH
from transformers import AutoModelForCausalLM, AutoTokenizer
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction

from autoformalism_with_llms.dataset import MathQuestion

In [2]:
#model_name = "meta-llama/Meta-Llama-3-70B-Instruct"

In [3]:
@dataclass
class Args:
    name: str
    model: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
    temperature: float = 0.1
    max_tokens: int = 2048
    top_p: float = 1.0

@dataclass(frozen=True)
class FEWSHOTIDS:
    """IDs of the few-shot learning examples used in the paper"""
    algebra: tuple[str, ...] = (
        "245",
        "76",
        "478",
        "338",
        "422",
        "43",
        "756",
        "149",
        "48",
        "410",
    )

    numbertheory: tuple[str, ...] = (
        "709",
        "461",
        "466",
        "257",
        "34",
        "780",
        "233",
        "764",
        "345",
        "227",
    )


In [4]:
def load_model_and_tokenizer(model_name):
    """Loads the model and tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

In [5]:
def get_question_response(prompt_str, model, tokenizer, **kwargs):
    """Generates a response for a given prompt."""
    temperature = kwargs.get("temperature", 0.2)
    max_tokens = kwargs.get("max_tokens", 512)
    top_p = kwargs.get("top_p", 1.0)

    # Tokenize input
    inputs = tokenizer(prompt_str, return_tensors="pt")

    # Generate response
    output = model.generate(
        inputs["input_ids"],
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )

    # Decode response
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

In [6]:
def make_fewshot_prompt(dataset, question_ids):
    """Creates a few-shot prompt using the dataset and question IDs."""
    questions = [dataset.get_question(qid) for qid in question_ids]
    messages = [system_message()]
    messages.extend(prompt.informal_to_formal_messages(questions))
    return messages

In [7]:
def system_message():
    """Returns the system message for the prompt."""
    return {
        "role": "system",
        "content": (
            "Translate the following natural language math problem to the "
            "Isabelle theorem proving language. Do not provide a proof of the "
            "statement. Use diligence when translating the problem and make "
            "certain you capture all the necessary assumptions as hypotheses."
        ),
    }

In [8]:
def convert_messages_to_llama3(messages: list[dict]) -> str:
    """Convert a list of messages to a llama3 string.

    See:
        https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/

    Args:
        messages (list[dict]): A list of messages.

    Returns:
        str: The llama3 string.
    """
    HEADER_START = "<|start_header_id|>"
    HEADER_END = "<|end_header_id|>"
    role_template = HEADER_START + "{role}" + HEADER_END + "\n\n"
    llama3 = []
    llama3.append("<|begin_of_text|>")
    for message in messages:
        msg = role_template.format(role=message["role"])
        msg += message["content"]
        msg += "<|eot_id|>"
        llama3.append(msg)

    llama3.append(role_template.format(role="assistant"))
    return "".join(llama3)


def informal_to_formal_messages(questions: list[MathQuestion]) -> list[dict]:
    """Convert the list of MathQuestions to a message string."""
    messages = []
    for question in questions:
        example = make_example(question)
        messages.append(get_natural_language_message(question))
        messages.append(get_formal_language_message(question))
    return messages


def get_natural_language_message(question: MathQuestion, role: str = "user") -> dict:
    """Convert a MathQuestion object to an OpenAI message dictionary.

    This message is the natural language message, i.e. the informal statement of a
    math problem.

    Args:
        question (MathQuestion): The MathQuestion object which contains informal
            and formal statements of a question.
        role (str, optional): The role of the speaker. Defaults to "user".

    Returns:
        dict: The message dictionary.
    """
    example = make_example(question)
    return {"role": role, "content": example["natural_question"]}


def get_formal_language_message(
    question: MathQuestion, role: str = "assistant"
) -> dict:
    """Convert a MathQuestion object to an OpenAI message dictionary.

    This message is the formal language message, i.e. the formal statement of a
    math problem in a theorem prover language.

    Args:
        question (MathQuestion): The MathQuestion object which contains informal
            and formal statements of a question.
        role (str, optional): The role of the speaker. Defaults to "assistant".

    Returns:
        dict: The message dictionary.
    """
    example = make_example(question)
    return {"role": role, "content": example["formal_question"]}


def make_example(question: MathQuestion) -> dict[str, str]:
    """Convert a MathQuestion object to a single example for translation.


    Args:
        question (MathQuestion): The MathQuestion object which contains informal
            and formal statements of a question.

    Returns:
        dict[str, str]: A dictionary containing the natural language question and
            the formal question. This can be used to contrust few shot learning
            examples for the translation task.
    """
    question_prompt = question_with_answer_prompt(
        question.informal_statement, question.informal_solution
    )
    theorem_prompt = remove_content_after_theorem_shows(question.formal_statement)
    theorem_prompt = remove_content_before_theorem(theorem_prompt)
    theorem_prompt = remove_theorem_name(theorem_prompt)
    theorem_prompt = theorem_prompt.strip()

    return {
        "natural_question": question_prompt,
        "formal_question": theorem_prompt,
    }


def make_question(question: MathQuestion) -> str:
    """Convert a MathQuestion object to a question string."""
    return make_example(question)["natural_question"]


def question_with_answer_prompt(question: str, solution: str) -> str:
    r"""Convert the question and solution strings to a natural language string.

    Args:
        question (str): The question string.
        solution (str): The solution string.

    Returns:
        str: The natural language string.

    """
    final_answer = get_boxed_answer(solution)
    return f"{question} The final answer is ${final_answer}$."


def remove_content_after_theorem_shows(formal_statement: str) -> str:
    """Remove the content after the shows statement in the theorem.

    Note:
        This is not applicable to metamath or hollight datasets.
    """
    for line_number, line in enumerate(formal_statement.splitlines()):
        if re.search(r"^\s*shows", line):
            return "\n".join(formal_statement.splitlines()[: line_number + 1])
    return formal_statement


def remove_content_before_theorem(formal_statement: str) -> str:
    """Removes all the content before the theorem statement."""
    for line_number, line in enumerate(formal_statement.splitlines()):
        if re.search(r"^\s*theorem", line):
            return "\n".join(formal_statement.splitlines()[line_number:])
    return formal_statement


def remove_theorem_name(formal_statement: str) -> str:
    """Removes the theorem name from the formal statement."""
    return re.sub(r"(.*theorem).*(?:|$)", r"\1", formal_statement, re.M)


def get_boxed_answer(question: str) -> str | None:
    r"""Extract the boxed answer from the string.

    We assume the question has a latex boxed answer in the form `\boxed{answer}`.

    Args:
        question (str): The question string.

    Returns:
        str: The boxed answer string.

    """
    phrase = r"\boxed{"
    try:
        index = question.index(phrase) + len(phrase)
    except ValueError:
        return None
    open_count = 1  # since we start after \boxed{ we have one open brace
    close_count = 0
    end_index = None
    for i, c in enumerate(question[index:]):
        if c == "{":
            open_count += 1
        elif c == "}":
            close_count += 1
        if open_count == close_count:
            end_index = i
            break
    if end_index is None:
        return None
    return question[index : index + end_index]

In [9]:
def run_experiment(dataset, fewshot_ids, log_dir, model, tokenizer, **kwargs):
    """Runs the experiment on the dataset."""
    messages = make_fewshot_prompt(dataset, fewshot_ids)

    for question in dataset:
        if question.question_number in fewshot_ids:
            continue

        fname = Path(log_dir) / f"{question.question_number}.json"
        if fname.exists():
            continue

        try:
            prompt_str = get_prompt_str(question, messages)
            response = get_question_response(prompt_str, model, tokenizer, **kwargs)

            # Extract only the last assistant's response
            if "assistant<|end_header_id|>" in response:
                response = response.split("assistant<|end_header_id|>")[-1].strip()

            data = {
                "response": response,  # Save only the isolated assistant response
                "metadata": asdict(question),
                "prompt": prompt_str,
            }
            with open(fname, "w") as f:
                json.dump(data, f)
        except Exception as e:
            print(f"Error processing {question.question_number}: {e}")



def get_prompt_str(question, messages):
    """Formats the full prompt string."""
    _messages = messages + [prompt.get_natural_language_message(question)]
    prompt_str = convert_messages_to_llama3(_messages)
    return prompt_str

In [10]:
def main():
    args = Args(name="DeepSeek_R1_Qwen_32B_baseline")

    print("Loading model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer(args.model)
    print("Loading dataset...")
    dataset = MiniF2FMATH()
    algebra = dataset.get_subject("algebra")
    algebra_ids = FEWSHOTIDS.algebra
    numtheory = dataset.get_subject("numbertheory")
    numtheory_ids = FEWSHOTIDS.numbertheory

    algebra_data = ("algebra", algebra, algebra_ids)
    numtheory_data = ("numbertheory", numtheory, numtheory_ids)

    for data in (algebra_data, numtheory_data):
        dataset_name, dataset, ids = data
        log_dir = Path("artifacts") / args.name / dataset_name
        log_dir.mkdir(parents=True, exist_ok=True)
        params = asdict(args)

        # Save parameters
        with open(log_dir / "params.json", "w") as f:
            json.dump(params, f)

        # Remove "model" from params to avoid conflict
        params.pop("name", None)  # Optional: Remove name too if not needed
        params.pop("model", None)

        # Pass the remaining params
        run_experiment(dataset, ids, log_dir, model, tokenizer, **params)

    print("Experiment complete!")

In [None]:
if __name__ == "__main__":
    main()