$\textbf{References}$:

1. Base model, https://huggingface.co/unsloth/llama-3-8b-bnb-4bit

2. GSM8k dataset, https://github.com/openai/grade-school-math

3. MATH dataset, https://github.com/hendrycks/math

4. Xwin-Math team, 2023,  https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math

$\textbf{Content}$:

In this notebook, we find that using the base model, "unsloth/llama-3-8b-bnb-4bit", and the MATH dataset, the accuracy is 7.0%. This notebook is modified and ran using google colab. The evaluation functions and related codes are from Xwin-Math team, https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math.

<a name="start"></a>
### Packages and Libraries

In [None]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps trl peft accelerate bitsandbytes
!pip install "xformers<0.0.26"
!pip install antlr4-python3-runtime

In [None]:
import json
import random
import re
from copy import deepcopy
from math import isclose
from typing import Any, List, Union

import torch
from pyparsing import *
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from tqdm import tqdm
from transformers import StoppingCriteria
from unsloth import FastLanguageModel

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


### Load the finetuned model in 4bit

In [None]:
max_seq_length = 512
dtype = None
load_in_4bit = True


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/llama-3-8b-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)

model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "right"

config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

==((====))==  Unsloth: Fast Llama patching release 2024.5
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/172 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/464 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
# @title Configuration


alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Response:
{}"""

config = {
    "extract_pattern": ["####", "The answer is: ", "The answer is"],
    "policy": {
        "GSM8K": {
            "function": "simple_answer_check",
            "extract_policy": "flex",
            "eval_policy": "strict",
        },
        "MATH": {
            "function": "latex_answer_check",
            "extract_policy": "flex",
            "eval_policy": "aggressive",
        },
        "default": {
            "function": "simple_answer_check",
            "extract_policy": "flex",
            "eval_policy": "strict",
        },
    },
}

### Utility functions
We define two functions `load_json` and `load_jsonl` for loading json and jsonl files.

In [None]:
def load_json(file_path):
    """Load a JSON file and return its contents."""
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_jsonl(file_path):
    """Load a JSON Lines (jsonl) file and return its contents as a list."""
    data_list = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            data_list.append(json.loads(line))
    return data_list

### Evaluation function
We use the evaluation toolkit developed by the Xwin-Math team available at [https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math/eval](https://github.com/Xwin-LM/Xwin-LM/tree/main/Xwin-Math/eval). In particular, it contains:
* `extract_answer`: Extract the final answer from a generated solution.
* `simple_answer_check`: Check correctness for simple math solutions that do not involve LaTex. It is mainly used for GSM8k test dataset.
* `latex_answer_check`: Check correctness for complicated math solutions that uses LaTeX. It is mainly used for MATH test dataset.
* `eval_for_single_item`: Given a problem either in the GSM8k or MATH test datasets, this function select either `simple_answer_check` or `latex_answer_check` to check the correctness of the generated solution against the ground truth answer.

In [None]:
# @title extract_answer

# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Written by Weiqi Wang
# ---------------------------------------------------------
"""
This cell defines `extract_answer` function which is used to extract the final answer from a solution.
"""


def extract_answer(
    model_generated_answer: str, split: List[str], extract_policy: str
) -> str:
    """
    Extract answer from response.
    """
    model_generated_answer_split_list = []
    for split_item in split:
        if model_generated_answer.find(split_item) != -1:
            model_generated_answer_split_list.append(
                model_generated_answer.split(split_item)[-1]
            )

    shortest = None
    if len(model_generated_answer_split_list) != 0:
        for ans in model_generated_answer_split_list:
            if shortest is None or len(ans) < len(shortest):
                shortest = ans

        return shortest

    if extract_policy == "flex":
        boxes = search_for_boxes(model_generated_answer)
        if len(boxes) != 0:
            return boxes[-1]

        numbers = search_for_numbers(model_generated_answer)
        if len(numbers) != 0:
            return numbers[-1]

    return None


def remove_prefix_and_suffix(string: str) -> str:
    """
    Remove unnecessary prefixes and suffixes from the input strings
    """
    string = string.strip(" \n").rstrip(".").strip(" \n").rstrip(".")
    if string.startswith("\\(") and string.endswith("\\)"):
        string = string[len("\\(") : -len("\\)")]
    if string.startswith("\\[") and string.endswith("\\]"):
        string = string[len("\\[") : -len("\\]")]
    return string.strip(" \n").rstrip(".").strip(" \n").rstrip(".")


def string_normalization(string: str) -> str:
    """
    Normalize strings, convert to lowercase and remove or replace special symbols
    """
    return (
        string.replace("\\$", "")
        .replace("$", "")
        .replace("\\%", "")
        .replace("%", "")
        .replace("^\\circ", "")
        .replace("^{\\circ}", "")
        .replace("\u03c0", "\\pi")
        .replace("{,}", "")
        .replace("\u00b0", "")
        .lower()
    )


def search_for_boxes(input):
    element = originalTextFor(nestedExpr("{", "}"))
    parser = Literal("\\boxed") + element | Literal("\\mbox") + element
    results = parser.searchString(input)
    return [_[1][1:-1] for _ in results]


def search_for_numbers(input):
    integer = Word("-" + nums, nums)
    fraction = Combine(Word("-" + nums, nums) + "/" + Word(nums))
    decimal = Combine(Optional(Word("-" + nums, nums)) + "." + Word(nums))
    scientific = Combine(Word("-" + nums, nums) + "e" + Word("-" + nums, nums))
    latex = Combine(Suppress("$") + SkipTo("$") + Suppress("$"))
    number_with_comma = Combine(
        Optional(Word("+-", exact=1))
        + Word(nums, max=3)
        + OneOrMore(Suppress(",") + Word(nums, min=3, max=3))
        + Suppress(~Literal(","))
    )

    parser = latex | scientific | fraction | decimal | number_with_comma | integer

    return [_[0] for _ in parser.searchString(input)]


def remove_text_box_only(input):
    tex_expr = Literal("\\text") + nestedExpr("{", "}") + Optional(
        "^" + Char(nums)
    ) | Literal("\\mbox") + nestedExpr("{", "}") + Optional("^" + Char(nums))
    return "".join(tex_expr.suppress().transformString(input))


def remove_boxes_keep_content(input_string):
    box_types = ["\\box", "\\boxed", "\\mbox", "\\text"]
    for box in box_types:
        while True:
            start = input_string.find(box)
            if start == -1:  # box type not found
                break
            brace_start = input_string.find("{", start)
            if brace_start == -1:  # "{" not found after box type
                break
            brace_count = 0
            for i in range(brace_start, len(input_string)):
                if input_string[i] == "{":
                    brace_count += 1
                elif input_string[i] == "}":
                    brace_count -= 1
                if brace_count == 0:  # matching "}" found
                    brace_end = i
                    break
            else:  # matching "}" not found
                break
            # remove box type but keep the content inside
            input_string = (
                input_string[:start]
                + input_string[brace_start + 1 : brace_end]
                + input_string[brace_end + 1 :]
            )
    return input_string


def remove_equals(input_string):
    if "=" in input_string and len(input_string.split("=")) == 2:
        left, right = input_string.split("=")
        if (
            right.strip(" \n").rstrip(".").strip(" \n").rstrip(".") == "0"
            and len(left.strip(" \n").rstrip(".").strip(" \n").rstrip(".")) >= 2
        ):
            return left
        else:
            return right
    return input_string

In [None]:
# @title simple_answer_check

# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Written by Weiqi Wang
# ---------------------------------------------------------

"""
This cell defines `simple_answer_check` function which is used to for the GSM8k test dataset
"""


def simple_answer_check(
    model_generated_answer, ground_truth, extract_policy, eval_policy, split
):
    model_generated_answer = extract_answer(
        model_generated_answer, split=split, extract_policy=extract_policy
    )
    if model_generated_answer is None:
        return False
    model_generated_answer = (
        model_generated_answer.rstrip(".")
        .rstrip(" ")
        .rstrip("\n")
        .rstrip(".")
        .rstrip(" ")
        .rstrip("\n")
        .replace("$", "")
    )

    try:
        ground_truth_number_list = get_simple_numbers(ground_truth)
        model_generated_answer_number_list = get_simple_numbers(model_generated_answer)
    except Exception as e:
        return False

    if (
        len(ground_truth_number_list) == 0
        and len(model_generated_answer_number_list) == 0
    ):
        return (
            model_generated_answer.replace("\n", "").replace(" ", "").lower()
            == ground_truth.replace("\n", "").replace(" ", "").lower()
        )

    if (
        len(model_generated_answer_number_list) == 0
        or len(ground_truth_number_list) == 0
    ):
        return False

    if eval_policy == "strict" and len(ground_truth_number_list) != len(
        model_generated_answer_number_list
    ):
        return False

    while (
        len(ground_truth_number_list) != 0
        and len(model_generated_answer_number_list) != 0
    ):
        matched = False
        for ground_truth_number in ground_truth_number_list:
            for model_generated_answer_number in model_generated_answer_number_list:
                if compare_numbers(model_generated_answer_number, ground_truth_number):
                    model_generated_answer_number_list.remove(
                        model_generated_answer_number
                    )
                    ground_truth_number_list.remove(ground_truth_number)
                    matched = True
                    break
            if matched:
                break
        if matched:
            continue
        else:
            break

    if eval_policy == "strict":
        if (
            len(ground_truth_number_list) == 0
            and len(model_generated_answer_number_list) == 0
        ):
            return True
        return False
    elif eval_policy == "model_include_gt":
        if len(ground_truth_number_list) == 0:
            return True
        return False
    elif eval_policy == "gt_include_model":
        if len(model_generated_answer_number_list) == 0:
            return True
        return False


def cast_to_number(string):
    string = string.replace(":", "/")
    if "." not in string and "/" not in string:
        return int(string)
    if "/" not in string and abs(int(float(string)) - float(string)) < 1e-10:
        return int(float(string))
    if "/" in string:
        return eval(string)
    return float(string)


def get_simple_numbers(string):
    # pattern_for_percentage = r"(\d+(.\d+)*%)"
    pattern_for_integer = r"-?\d+"
    pattern_for_comma = r"(-?\d{1,3}(,\d{3})+(.\d+)?)"
    pattern_for_float_in_fraction = r"(-?\d+(.\d+)?\s*[/:]\s*\d+(.\d+)?)"
    pattern_for_float_or_fraction = r"-?\d+\s*[\./:]\s*\d+"
    pattern_for_sci = r"(-?\d(.\d+)?[eE][+-]?\d+)"

    pattern_list = [
        pattern_for_sci,
        pattern_for_float_in_fraction,
        pattern_for_comma,
        pattern_for_float_or_fraction,
        pattern_for_integer,
    ]

    result_list = []

    for pattern in pattern_list:
        curr_pattern_result = re.findall(pattern, string)
        for item in curr_pattern_result:
            if isinstance(item, tuple) and item[0] != "":
                item = item[0]
            else:
                item = item
            item = item.replace(",", "")
            result_list.append(cast_to_number(item))
        string = re.sub(pattern, "", string)

    return result_list


def compare_numbers(number_1, number_2):
    if isinstance(number_1, int) and isinstance(number_2, int):
        return number_1 == number_2

    return math.isclose(number_1, number_2, rel_tol=1e-3)

In [None]:
# @title latex_answer_check

# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Based on ToRA (https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py)
# Modified by Weiqi Wang
# ---------------------------------------------------------

"""
This cell defines `latex_answer_check` function which is used to for the MATH test dataset
"""


def latex_answer_check(model_ans, gt, split, extract_policy: str, eval_policy: str):
    # Step 1: Extract answer from response
    if split is not None:
        model_ans = extract_answer(model_ans, split, extract_policy=extract_policy)
    if model_ans is None:
        return False

    # Step 2: Remove boxes and perform literal check
    # Compare strings character by character after simple processing including remove $%.
    # First we remove the boxes in the string but keeps the content
    # \boxed{\frac{13}{4}} --> \frac{13}{4}
    model_ans_norm = string_normalization(model_ans)
    model_ans_norm_wo_boxes = remove_boxes_keep_content(model_ans_norm)
    gt_norm = string_normalization(gt)
    gt_norm_wo_boxes = remove_boxes_keep_content(gt_norm)

    literal_check_result = literal_check(
        remove_prefix_and_suffix(model_ans_norm_wo_boxes),
        remove_prefix_and_suffix(gt_norm_wo_boxes),
    )
    if literal_check_result is not None:
        return literal_check_result

    # Step 3: Attempt to parse -- single
    # Treat a string as a single number/extract a single number from a string and then compare.
    #
    # If we can accept a few mistakes, we try to extract numbers from the answers and compare them
    if eval_policy == "aggressive":
        # We wan't to use raw model_ans to keep the $$
        # $13$ meters --> $13$ --> 13
        model_ans_num_lst = search_for_numbers(model_ans)

        # We want the original answer has $$
        # This way we are able to consider the answer as a whole
        # We don't want \frac{13}{4} --> [13, 4] to be considered as 2 numbers
        if gt[0] != "$" or gt[-1] != "$":
            gt_num_lst = search_for_numbers("$" + gt + "$")
        else:
            gt_num_lst = search_for_numbers(gt)

        # We want to judge only those answers that contain only one number that represents the full meaning of the original string.
        # If the string still has LaTeX components or variables in addition to this number, then we believe that this number may not represent the meaning of the answer.
        # Here we must be really really careful.
        # x \\leq -5 vs. x \\geq -5
        # (-\\infty, 5) vs. (5, +\\infty)
        # TODO: We may have better methods to check if the numbers are simple enough
        if (
            len(model_ans_num_lst) == 1
            and len(gt_num_lst) == 1
            and not has_structure(model_ans.replace(model_ans_num_lst[0], ""))
            and not has_structure(gt.replace(gt_num_lst[0], ""))
        ):
            model_num = remove_prefix_and_suffix(
                remove_boxes_keep_content(remove_text_box_only(model_ans_num_lst[0]))
            )
            gt_num = remove_prefix_and_suffix(
                remove_boxes_keep_content(remove_text_box_only(gt_num_lst[0]))
            )
            parse_result = number_check(model_num, gt_num)

            # As an additional method of judgment, even if it returns False we can't say that the answer is wrong, it could be caused by an unreasonable extraction of numbers
            if parse_result is True:
                return True

    # Here we do the same thing to the whole string
    model_wo_text = remove_prefix_and_suffix(model_ans_norm)
    gt_wo_text = remove_prefix_and_suffix(gt_norm)
    parse_result = number_check(model_wo_text, gt_wo_text)
    if parse_result is not None:
        return parse_result

    # If none of the above ways can determine whether the answer is correct or incorrect, then return incorrect
    return False


def has_numbers(input_string: str) -> bool:
    """
    Checks if a string contains a number.
    """
    return any(char.isdigit() for char in input_string)


def has_structure(input_string: str) -> bool:
    """
    Checks if a string contains structured content.
    """
    if (
        "(" in input_string
        or ")" in input_string
        or "[" in input_string
        or "]" in input_string
        or "\\" in input_string
        or "<" in input_string
        or ">" in input_string
        or "," in input_string
        or "x" in input_string
        or "y" in input_string
        or "z" in input_string
    ):
        return True
    return False


def sympy_parse(input_string: str) -> Any:
    """
    Parsing strings into mathematical expressions using sympy
    """
    for f in [parse_latex, parse_expr]:
        try:
            return f(input_string)
        except:
            pass
    return input_string


def symbolic_equal(a: str, b: str) -> Union[bool, None]:
    """
    Check if two strings are symbolic equal.
    """
    a = sympy_parse(a)
    b = sympy_parse(b)

    try:
        if simplify(a - b) == 0:
            return True
    except:
        pass

    try:
        if isclose(N(a), float(N(a)), rel_tol=1e-9) and isclose(
            N(a), float(N(a)), rel_tol=1e-9
        ):
            return False
    except:
        pass

    try:
        if isclose(N(a), N(b), rel_tol=1e-3):
            return True
    except:
        pass
    return None


def convert_to_int(input_string: str) -> Union[int, None]:
    """
    Try to convert a string into int. Return `None` if an error occurs.
    """
    try:
        float_s = float(input_string)
        int_s = int(float_s)

        # If a floating-point number is converted to an integer that is very close to itself, then we consider it to be an integer.
        if isclose(int_s, float_s, rel_tol=1e-9):
            return int_s
        return None
    except:
        return None


def convert_to_float(input_string: str) -> Union[float, None]:
    """
    Try to convert a string into float. Return `None` if an error occurs.
    """
    try:
        float_s = float(input_string)
        return float_s
    except:
        return None


def numerical_equal(a: str, b: str) -> Union[bool, None]:
    """
    Check if two strings are numerical equal.
    """
    a_int = convert_to_int(a)
    b_int = convert_to_int(b)

    if a_int is not None and b_int is not None:
        return a_int == b_int

    a_float = convert_to_float(a)
    b_float = convert_to_float(b)

    if a_float is not None and b_float is not None:
        return isclose(a_float, b_float, rel_tol=1e-3)

    return None


def literal_check(model_generated_answer: str, ground_truth: str) -> Union[bool, None]:
    """
    Check if two strings are the same character by character
    """
    model_remove = (
        deepcopy(model_generated_answer)
        .replace(",", " ")
        .replace(" ", "")
        .replace(" ", "")
    )
    gt_remove = (
        deepcopy(ground_truth).replace(",", " ").replace(" ", "").replace(" ", "")
    )

    if model_remove == gt_remove:
        return True

    if (
        has_numbers(model_generated_answer) == False
        and has_numbers(ground_truth) == False
    ):
        model_generated_answer = model_remove.strip("[]() ")
        ground_truth = gt_remove.strip("[]() ")
        if model_generated_answer == ground_truth:
            return True

    return None


def number_check(model_generated_answer: str, ground_truth: str) -> None:
    """
    Check if two strings have the same mathematical meaning.
    """
    if "," in model_generated_answer or "," in ground_truth:
        return None

    model_generated_answer = remove_prefix_and_suffix(
        remove_equals(model_generated_answer)
    )
    ground_truth = remove_prefix_and_suffix(remove_equals(ground_truth))

    numerical_equal_result = numerical_equal(model_generated_answer, ground_truth)
    if numerical_equal_result is not None:
        return numerical_equal_result

    symbolic_equal_result = symbolic_equal(model_generated_answer, ground_truth)

    if symbolic_equal_result is not None:
        return symbolic_equal_result

    return None

In [None]:
# @title eval_for_single_item

# ---------------------------------------------------------
# Xwin-Math
# Copyright (c) 2023 Xwin-Math Team
# Licensed under The MIT License [see LICENSE for details]
# Based on ToRA (https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py)
# Modified by Weiqi Wang
# ---------------------------------------------------------


def eval_for_single_item(data_item, config):
    """
    Check the correctness for single item.

    Parameters:
        data_item: Data to be checked.
            data format:
            {
                "question": "Joan found 70 seashells on the beach . she gave Sam some of her seashells . She has 27 seashell . How many seashells did she give to Sam ? ",
                "type": "MAWPS",
                "subtype": "addsub",
                "answer": "43",
                "level": 0,
                "completion": [
                    "\n\nTo find out how many seashells Joan gave to Sam, we need to subtract the number of seashells she has left from the total number of seashells she found.\n\nJoan found 70 seashells and has 27 seashells left.\n\nSo, we subtract 27 from 70:\n\n70 - 27 = 43\n\nJoan gave 43 seashells to Sam.\n\nThe answer is: 43."
                ]
            }
        config: The policy and settings for evaluations. Please refer to eval/eval_config.json for more details.

    Returns:
        correctness: A list that contains the correctness for all prompts.
    """
    correctness = []
    for steps in data_item["completion"]:
        if data_item["type"] in config["policy"]:
            config_for_type = config["policy"][data_item["type"]]
        else:
            config_for_type = config["policy"]["default"]

        function = eval(config_for_type["function"])
        result = function(
            steps,
            data_item["answer"],
            eval_policy=config_for_type["eval_policy"],
            extract_policy=config_for_type["extract_policy"],
            split=config["extract_pattern"],
        )
        correctness.append(result)
    return correctness

<a name="dataset"></a>
### Load test dataset
The MATH test dataset is available at [https://github.com/meta-math/MetaMath/tree/main/data/test](https://github.com/meta-math/MetaMath/tree/main/data/test).

In [None]:
# Upload these the dataset to Google Colab
math_data_list = load_json("/content/math.json")

# Set a seed for reproducibility
seed_value = 42
random.seed(seed_value)

# Randomly choose 500 questions from each dataset
math_500_data_list = random.sample(math_data_list, 500)

<a name="accuracy"></a>
### Compute accuracy
We compute the accuracy of the generated answers to problems from MATH test dataset.

In [None]:
"""
The following functions is used to compute the testing accuracy
"""


def generate_and_evaluate_batch(
    model, tokenizer, problems, batch_size=4, eval_config=None, temperature=0.9, top_p=1
):
    all_outputs = []
    correct_count = 0

    for i in tqdm(
        range(0, len(problems), batch_size), desc="Generating and Evaluating"
    ):
        batch_problems = problems[i : i + batch_size]

        inputs = tokenizer(
            [
                alpaca_prompt.format(problem["question"], "Let's think step by step")
                for problem in batch_problems
            ],
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to("cuda")

        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            use_cache=True,
            temperature=temperature,
            top_p=top_p,
            stopping_criteria=[EosListStoppingCriteria()],
        )

        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for problem, generated_output in zip(batch_problems, decoded_outputs):
            problem["completion"] = [generated_output]

            evaluation_result = eval_for_single_item(problem, eval_config)
            if any(evaluation_result):
                correct_count += 1

            problem["correctness"] = evaluation_result

            all_outputs.append(problem)

    accuracy = correct_count / len(problems) if len(problems) > 0 else 0

    return all_outputs, accuracy


class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence=[14711, 30151, 25]):
        self.eos_sequence = eos_sequence

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        last_ids = input_ids[:, -len(self.eos_sequence) :].tolist()
        return self.eos_sequence in last_ids

Accuracy on the 500 randomly selected questions from MATH dataset.

In [None]:
# Testing our model on 500 questions from math.json

responses_batch, accuracy = generate_and_evaluate_batch(
    model, tokenizer, math_500_data_list, 4, config, temperature=0.1, top_p=0.4
)
print()
print()
print(f"Accuracy: {accuracy * 100:.2f}%")
print("Generated Responses:")
for response in responses_batch:
    print(response)

Generating and Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   1%|          | 1/125 [01:04<2:12:33, 64.14s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   2%|▏         | 2/125 [02:05<2:07:54, 62.40s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   2%|▏         | 3/125 [02:25<1:27:56, 43.25s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   3%|▎         | 4/125 [03:27<1:42:16, 50.72s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   4%|▍         | 5/125 [04:29<1:49:18, 54.66s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Generating and Evaluating:   5%|▍         | 6/125 [05:31<1:53:15, 57.11s/it]Setting `pad_token_id` to `eos_token_id`



Accuracy: 7.00%
Generated Responses:
{'question': 'If $a$ and $b$ are real numbers, $a^2b^3=\\frac{32}{27}$, and $\\frac{a}{b^3}=\\frac{27}{4}$, what is $a+b$?', 'solution': 'Rearranging the second equation, we have that $b^3=\\frac{4}{27}a$. If we substitute this into the original equation, we get $\\frac{4}{27}a^3=\\frac{32}{27}$; after multiplying each side by $\\frac{27}{4}$ and taking the cube root, we see that $a=2$. Substituting $a$ into the first equation, we get that $b^3=\\frac{8}{27}$ or $b=\\frac23$. Thus, $a+b=2+\\frac23=\\boxed{\\frac83}$.', 'answer': '\\frac83', 'type': 'MATH', 'subtype': 'Algebra', 'level': 4, 'completion': ["Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nIf $a$ and $b$ are real numbers, $a^2b^3=\\frac{32}{27}$, and $\\frac{a}{b^3}=\\frac{27}{4}$, what is $a+b$?\n\n### Response:\nLet's think step by step, starting with th




In conclusion, we find that using the base model and the MATH dataset, the accuracy is 7.0%.