### $\textbf{References}$:

1. Base model, https://huggingface.co/unsloth/llama-3-8b-bnb-4bit

2. GSM8k dataset, https://github.com/openai/grade-school-math

3. MATH dataset, https://github.com/hendrycks/math

4. Xwin-Math team, 2023,  https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math

### $\textbf{Content}$:

In this notebook, we find that using our model, "acihanckr/erdos_qed_2024", and the GSM8k dataset, the accuracy is 35.2%. This notebook is modified and ran using google colab. The evaluation functions and related codes are from Xwin-Math team, https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math.

In [None]:
%%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps trl peft accelerate bitsandbytes
!pip install "xformers<0.0.26"
!pip install antlr4-python3-runtime

In [None]:
import json
import math
import random
import re
from copy import deepcopy
from math import isclose
from typing import Any, List, Union

import torch
from pyparsing import *
from tqdm import tqdm
from transformers import StoppingCriteria
from unsloth import FastLanguageModel

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


Load the finetuned model in 4bit

In [None]:
max_seq_length = 512
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="acihanckr/erdos_qed_2024",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)

model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "right"

config.json:   0%|          | 0.00/751 [00:00<?, ?B/s]

==((====))==  Unsloth: Fast Llama patching release 2024.5
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/172 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/464 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
# @title Configuration

# Prompt for formatting the data
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Response:
{}"""

# Configuration for answer-checking
config = {
    "extract_pattern": ["####", "The answer is: ", "The answer is"],
    "policy": {
        "GSM8K": {
            "function": "simple_answer_check",
            "extract_policy": "flex",
            "eval_policy": "strict",
        },
        "MATH": {
            "function": "latex_answer_check",
            "extract_policy": "flex",
            "eval_policy": "aggressive",
        },
        "default": {
            "function": "simple_answer_check",
            "extract_policy": "flex",
            "eval_policy": "strict",
        },
    },
}

In [None]:
# @title Utility functions
def load_json(file_path):
    """Load a JSON file and return its contents."""
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_jsonl(file_path):
    """Load a JSON Lines (jsonl) file and return its contents as a list."""
    data_list = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            data_list.append(json.loads(line))
    return data_list

### $\textbf{Evaluation functions}$:

Toolkit developed by Xwin-Math (Xwin-Math team, 2023,  https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math) for evaluation purposes. In particular, it contains:
* `extract_answer`: Extract the final answer from a generated solution.
* `simple_answer_check`: Check correctness for simple math solutions that do not involve LaTex. It is mainly used for GSM8k test dataset.
* `eval_for_single_item`: Given a problem either in the GSM8k or MATH test datasets, this function select either `simple_answer_check` or `latex_answer_check` to check the correctness of the generated solution against the ground truth answer.

In [None]:
# @title extract_answer
# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Written by Weiqi Wang
# ---------------------------------------------------------

from pyparsing import *
from typing import List


def extract_answer(
    model_generated_answer: str, split: List[str], extract_policy: str
) -> str:
    """
    Extract answer from response.
    """
    model_generated_answer_split_list = []
    for split_item in split:
        if model_generated_answer.find(split_item) != -1:
            model_generated_answer_split_list.append(
                model_generated_answer.split(split_item)[-1]
            )

    shortest = None
    if len(model_generated_answer_split_list) != 0:
        for ans in model_generated_answer_split_list:
            if shortest is None or len(ans) < len(shortest):
                shortest = ans

        return shortest

    if extract_policy == "flex":
        boxes = search_for_boxes(model_generated_answer)
        if len(boxes) != 0:
            return boxes[-1]

        numbers = search_for_numbers(model_generated_answer)
        if len(numbers) != 0:
            return numbers[-1]

    return None


def remove_prefix_and_suffix(string: str) -> str:
    """
    Remove unnecessary prefixes and suffixes from the input strings
    """
    string = string.strip(" \n").rstrip(".").strip(" \n").rstrip(".")
    if string.startswith("\\(") and string.endswith("\\)"):
        string = string[len("\\(") : -len("\\)")]
    if string.startswith("\\[") and string.endswith("\\]"):
        string = string[len("\\[") : -len("\\]")]
    return string.strip(" \n").rstrip(".").strip(" \n").rstrip(".")


def string_normalization(string: str) -> str:
    """
    Normalize strings, convert to lowercase and remove or replace special symbols
    """
    return (
        string.replace("\\$", "")
        .replace("$", "")
        .replace("\\%", "")
        .replace("%", "")
        .replace("^\\circ", "")
        .replace("^{\\circ}", "")
        .replace("\u03c0", "\\pi")
        .replace("{,}", "")
        .replace("\u00b0", "")
        .lower()
    )


def search_for_boxes(input):
    element = originalTextFor(nestedExpr("{", "}"))
    parser = Literal("\\boxed") + element | Literal("\\mbox") + element
    results = parser.searchString(input)
    return [_[1][1:-1] for _ in results]


def search_for_numbers(input):
    integer = Word("-" + nums, nums)
    fraction = Combine(Word("-" + nums, nums) + "/" + Word(nums))
    decimal = Combine(Optional(Word("-" + nums, nums)) + "." + Word(nums))
    scientific = Combine(Word("-" + nums, nums) + "e" + Word("-" + nums, nums))
    latex = Combine(Suppress("$") + SkipTo("$") + Suppress("$"))
    number_with_comma = Combine(
        Optional(Word("+-", exact=1))
        + Word(nums, max=3)
        + OneOrMore(Suppress(",") + Word(nums, min=3, max=3))
        + Suppress(~Literal(","))
    )

    parser = latex | scientific | fraction | decimal | number_with_comma | integer

    return [_[0] for _ in parser.searchString(input)]


def remove_text_box_only(input):
    tex_expr = Literal("\\text") + nestedExpr("{", "}") + Optional(
        "^" + Char(nums)
    ) | Literal("\\mbox") + nestedExpr("{", "}") + Optional("^" + Char(nums))
    return "".join(tex_expr.suppress().transformString(input))


def remove_boxes_keep_content(input_string):
    box_types = ["\\box", "\\boxed", "\\mbox", "\\text"]
    for box in box_types:
        while True:
            start = input_string.find(box)
            if start == -1:  # box type not found
                break
            brace_start = input_string.find("{", start)
            if brace_start == -1:  # "{" not found after box type
                break
            brace_count = 0
            for i in range(brace_start, len(input_string)):
                if input_string[i] == "{":
                    brace_count += 1
                elif input_string[i] == "}":
                    brace_count -= 1
                if brace_count == 0:  # matching "}" found
                    brace_end = i
                    break
            else:  # matching "}" not found
                break
            # remove box type but keep the content inside
            input_string = (
                input_string[:start]
                + input_string[brace_start + 1 : brace_end]
                + input_string[brace_end + 1 :]
            )
    return input_string


def remove_equals(input_string):
    if "=" in input_string and len(input_string.split("=")) == 2:
        left, right = input_string.split("=")
        if (
            right.strip(" \n").rstrip(".").strip(" \n").rstrip(".") == "0"
            and len(left.strip(" \n").rstrip(".").strip(" \n").rstrip(".")) >= 2
        ):
            return left
        else:
            return right
    return input_string

In [None]:
# @title simple_answer_check

# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Written by Weiqi Wang
# ---------------------------------------------------------

# Simple answer check

import re
import math


def cast_to_number(string):
    string = string.replace(":", "/")
    if "." not in string and "/" not in string:
        return int(string)
    if "/" not in string and abs(int(float(string)) - float(string)) < 1e-10:
        return int(float(string))
    if "/" in string:
        return eval(string)
    return float(string)


def get_simple_numbers(string):
    # pattern_for_percentage = r"(\d+(.\d+)*%)"
    pattern_for_integer = r"-?\d+"
    pattern_for_comma = r"(-?\d{1,3}(,\d{3})+(.\d+)?)"
    pattern_for_float_in_fraction = r"(-?\d+(.\d+)?\s*[/:]\s*\d+(.\d+)?)"
    pattern_for_float_or_fraction = r"-?\d+\s*[\./:]\s*\d+"
    pattern_for_sci = r"(-?\d(.\d+)?[eE][+-]?\d+)"

    pattern_list = [
        pattern_for_sci,
        pattern_for_float_in_fraction,
        pattern_for_comma,
        pattern_for_float_or_fraction,
        pattern_for_integer,
    ]

    result_list = []

    for pattern in pattern_list:
        curr_pattern_result = re.findall(pattern, string)
        for item in curr_pattern_result:
            if isinstance(item, tuple) and item[0] != "":
                item = item[0]
            else:
                item = item
            item = item.replace(",", "")
            result_list.append(cast_to_number(item))
        string = re.sub(pattern, "", string)

    return result_list


def compare_numbers(number_1, number_2):
    if isinstance(number_1, int) and isinstance(number_2, int):
        return number_1 == number_2

    return math.isclose(number_1, number_2, rel_tol=1e-3)


def simple_answer_check(
    model_generated_answer, ground_truth, extract_policy, eval_policy, split
):
    model_generated_answer = extract_answer(
        model_generated_answer, split=split, extract_policy=extract_policy
    )
    if model_generated_answer is None:
        return False
    model_generated_answer = (
        model_generated_answer.rstrip(".")
        .rstrip(" ")
        .rstrip("\n")
        .rstrip(".")
        .rstrip(" ")
        .rstrip("\n")
        .replace("$", "")
    )

    try:
        ground_truth_number_list = get_simple_numbers(ground_truth)
        model_generated_answer_number_list = get_simple_numbers(model_generated_answer)
    except Exception as e:
        return False

    if (
        len(ground_truth_number_list) == 0
        and len(model_generated_answer_number_list) == 0
    ):
        return (
            model_generated_answer.replace("\n", "").replace(" ", "").lower()
            == ground_truth.replace("\n", "").replace(" ", "").lower()
        )

    if (
        len(model_generated_answer_number_list) == 0
        or len(ground_truth_number_list) == 0
    ):
        return False

    if eval_policy == "strict" and len(ground_truth_number_list) != len(
        model_generated_answer_number_list
    ):
        return False

    while (
        len(ground_truth_number_list) != 0
        and len(model_generated_answer_number_list) != 0
    ):
        matched = False
        for ground_truth_number in ground_truth_number_list:
            for model_generated_answer_number in model_generated_answer_number_list:
                if compare_numbers(model_generated_answer_number, ground_truth_number):
                    model_generated_answer_number_list.remove(
                        model_generated_answer_number
                    )
                    ground_truth_number_list.remove(ground_truth_number)
                    matched = True
                    break
            if matched:
                break
        if matched:
            continue
        else:
            break

    if eval_policy == "strict":
        if (
            len(ground_truth_number_list) == 0
            and len(model_generated_answer_number_list) == 0
        ):
            return True
        return False
    elif eval_policy == "model_include_gt":
        if len(ground_truth_number_list) == 0:
            return True
        return False
    elif eval_policy == "gt_include_model":
        if len(model_generated_answer_number_list) == 0:
            return True
        return False

In [None]:
# @title eval_for_single_item
# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Based on ToRA (https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py)
# Modified by Weiqi Wang
# ---------------------------------------------------------


def eval_for_single_item(data_item, config):
    """
    Check the correctness for single item.

    Parameters:
        data_item: Data to be checked.
            data format:
            {
                "question": "Joan found 70 seashells on the beach . she gave Sam some of her seashells . She has 27 seashell . How many seashells did she give to Sam ? ",
                "type": "MAWPS",
                "subtype": "addsub",
                "answer": "43",
                "level": 0,
                "completion": [
                    "\n\nTo find out how many seashells Joan gave to Sam, we need to subtract the number of seashells she has left from the total number of seashells she found.\n\nJoan found 70 seashells and has 27 seashells left.\n\nSo, we subtract 27 from 70:\n\n70 - 27 = 43\n\nJoan gave 43 seashells to Sam.\n\nThe answer is: 43."
                ]
            }
        config: The policy and settings for evaluations. Please refer to eval/eval_config.json for more details.

    Returns:
        correctness: A list that contains the correctness for all prompts.
    """
    correctness = []
    for steps in data_item["completion"]:
        if data_item["type"] in config["policy"]:
            config_for_type = config["policy"][data_item["type"]]
        else:
            config_for_type = config["policy"]["default"]

        function = eval(config_for_type["function"])
        result = function(
            steps,
            data_item["answer"],
            eval_policy=config_for_type["eval_policy"],
            extract_policy=config_for_type["extract_policy"],
            split=config["extract_pattern"],
        )
        correctness.append(result)
    return correctness

### $\textbf{Load Test Dataset}$:

Uploading the GSM8k dataset for the testing: https://github.com/openai/grade-school-math

In [None]:
data_list = load_json("/content/gsm8k.json")

# Set a seed for reproducibility
seed_value = 42
random.seed(seed_value)

### $\textbf{Hyperparameter Tuning}$:

The following function is used to compute the accuracy of our model on a dataset.

In [None]:
def generate_and_evaluate_batch(
    model, tokenizer, problems, batch_size=4, eval_config=None, temperature=0.9, top_p=1
):
    all_outputs = []
    correct_count = 0

    for i in tqdm(
        range(0, len(problems), batch_size), desc="Generating and Evaluating"
    ):
        batch_problems = problems[i : i + batch_size]

        inputs = tokenizer(
            [
                alpaca_prompt.format(problem["question"], "Let's think step by step")
                for problem in batch_problems
            ],
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to("cuda")

        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            use_cache=True,
            temperature=temperature,
            top_p=top_p,
            stopping_criteria=[EosListStoppingCriteria()],
        )

        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for problem, generated_output in zip(batch_problems, decoded_outputs):
            problem["completion"] = [generated_output]

            evaluation_result = eval_for_single_item(problem, eval_config)
            if any(evaluation_result):
                correct_count += 1

            problem["correctness"] = evaluation_result

            all_outputs.append(problem)

    accuracy = correct_count / len(problems) if len(problems) > 0 else 0

    return all_outputs, accuracy


class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence=[14711, 30151, 25]):
        self.eos_sequence = eos_sequence

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        last_ids = input_ids[:, -len(self.eos_sequence) :].tolist()
        return self.eos_sequence in last_ids

The model generation depends on many hyperparameters. We first do a simple investigation on the accuracy rates by choosing different values of `temperature` and `top_p`.

In [None]:
acc_dict = {}

for i in range(0, 2):
    for temperature in [0.1, 0.4, 0.7, 1]:
        for top_p in [0.1, 0.4, 0.7, 1]:
            responses_batch, accuracy = generate_and_evaluate_batch(
                model,
                tokenizer,
                random.sample(data_list, 4),
                4,
                config,
                temperature,
                top_p,
            )
            acc_dict[(i, temperature, top_p)] = accuracy
print(acc_dict)

Generating and Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating: 100%|██████████| 1/1 [01:04<00:00, 64.21s/it]
Generating and Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating: 100%|██████████| 1/1 [01:01<00:00, 61.70s/it]
Generating and Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating: 100%|██████████| 1/1 [01:02<00:00, 62.12s/it]
Generating and Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating: 100%|██████████| 1/1 [01:01<00:00, 61.25s/it]
Generating and Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluati

{(0, 0.1, 0.1): 0.25, (0, 0.1, 0.4): 0.5, (0, 0.1, 0.7): 0.25, (0, 0.1, 1): 0.25, (0, 0.4, 0.1): 0.25, (0, 0.4, 0.4): 0.5, (0, 0.4, 0.7): 0.25, (0, 0.4, 1): 0.5, (0, 0.7, 0.1): 0.25, (0, 0.7, 0.4): 0.25, (0, 0.7, 0.7): 0.5, (0, 0.7, 1): 0.25, (0, 1, 0.1): 0.75, (0, 1, 0.4): 0.5, (0, 1, 0.7): 0.0, (0, 1, 1): 0.25, (1, 0.1, 0.1): 0.0, (1, 0.1, 0.4): 0.5, (1, 0.1, 0.7): 0.25, (1, 0.1, 1): 0.5, (1, 0.4, 0.1): 0.5, (1, 0.4, 0.4): 0.5, (1, 0.4, 0.7): 0.5, (1, 0.4, 1): 0.5, (1, 0.7, 0.1): 0.25, (1, 0.7, 0.4): 0.0, (1, 0.7, 0.7): 0.5, (1, 0.7, 1): 0.0, (1, 1, 0.1): 0.25, (1, 1, 0.4): 0.5, (1, 1, 0.7): 0.5, (1, 1, 1): 0.0}





In [None]:
for temperature in [0.1, 0.4, 0.7, 1]:
    for top_p in [0.1, 0.4, 0.7, 1]:
        average_accuracy = (
            acc_dict[(0, temperature, top_p)] + acc_dict[(1, temperature, top_p)]
        ) / 2
        print(f"For temperature {temperature} and top_p {top_p}: {average_accuracy}")

For temperature 0.1 and top_p 0.1: 0.125
For temperature 0.1 and top_p 0.4: 0.5
For temperature 0.1 and top_p 0.7: 0.25
For temperature 0.1 and top_p 1: 0.375
For temperature 0.4 and top_p 0.1: 0.375
For temperature 0.4 and top_p 0.4: 0.5
For temperature 0.4 and top_p 0.7: 0.375
For temperature 0.4 and top_p 1: 0.5
For temperature 0.7 and top_p 0.1: 0.25
For temperature 0.7 and top_p 0.4: 0.125
For temperature 0.7 and top_p 0.7: 0.5
For temperature 0.7 and top_p 1: 0.125
For temperature 1 and top_p 0.1: 0.5
For temperature 1 and top_p 0.4: 0.5
For temperature 1 and top_p 0.7: 0.25
For temperature 1 and top_p 1: 0.125


### $\textbf{Evaluation}$:

Upon our inverstigation, we choose temperature as 0.1 and top_p as 0.4. Next, we generate solutions and answers for 500 randomly chosen problems from GSM8k dataset and check the accuracy.

In [None]:
responses_batch, accuracy = generate_and_evaluate_batch(
    model,
    tokenizer,
    random.sample(data_list, 500),
    4,
    config,
    temperature=0.1,
    top_p=0.4,
)
print()
print()
print(f"Accuracy: {accuracy * 100:.2f}%")
print("Generated Responses:")
for response in responses_batch:
    print(response)

Generating and Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   1%|          | 1/125 [00:09<19:51,  9.61s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   2%|▏         | 2/125 [01:11<1:23:18, 40.64s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   2%|▏         | 3/125 [02:13<1:42:03, 50.19s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   3%|▎         | 4/125 [02:56<1:35:47, 47.50s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   4%|▍         | 5/125 [03:59<1:45:44, 52.87s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   5%|▍         | 6/125 [05:01<1:51:23, 56.16s/it]Setting `pad_token_id` to `eos_token_id`:1



Accuracy: 35.20%
Generated Responses:
{'question': 'Elise is learning to write and decides to keep re-writing the alphabet until she knows it. She writes it in full twice, writes half of it once, then re-writes everything she has already written. How many letters has Elise written in total?', 'solution': 'Elise has written the alphabet twice which is a total of 26 * 2 = 52 letters.\nShe then writes half the alphabet, which is 26 / 2 = 13 letters.\nSo far, this is a total of 52 + 13 = 65 letters.\nWriting this again means she has doubled the number of letters she has written, so she has written a total of 65 * 2 = 130 letters.\n', 'answer': '130', 'type': 'GSM8K', 'subtype': '', 'level': 0, 'completion': ["Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nElise is learning to write and decides to keep re-writing the alphabet until she knows it. She writes it

In conclusion, we find that using our model and the GSM8k dataset, the accuracy is 35.2%.