# Load Model

In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoModel,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    BitsAndBytesConfig,
    PreTrainedTokenizer,
)
import json
from pathlib import Path
import os, sys
import torch
from tqdm import tqdm
from datasets import load_dataset
import re
from fractions import Fraction
from decimal import Decimal, InvalidOperation
from typing import Optional, Sequence, Tuple, List
import pprint

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "Qwen/Qwen2.5-Math-7B-Instruct"  # "mistralai/Mathstral-7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:24<00:00,  6.07s/it]


# Utils

## Step Parser

In [63]:
import re
from typing import List, Optional, Tuple, Dict, Union, Iterable
import torch

def build_chat_messages(question: str,tokenizer,dataset: str, shots: Optional[List[tuple[str, str, str]]] = None,) -> str:
    system_prompt = (
        "You are an **expert mathematical‑reasoning assistant**.\n\n"
        "## Format rules\n"
        "1. Begin *every* reasoning line with the exact prefix `Step k:` where `k = 1, 2, …`. No other prefix is allowed.\n"
        "2. Show *all* intermediate calculations using standard symbols (×, ÷, ±, √).\n"
        "3. Conclude with **one** line of the form `Answer: <final numeric result>` and **stop immediately** - no explanations, no closing remarks.\n"
        "4. Each step must be concise *yet mathematically rigorous*.\n"
        "5. Avoid markdown bullet lists or narrative words such as ‘First’,  ‘Next’, ‘Finally’.\n\n"
        "Follow these rules exactly - evaluations are case- and format‑sensitive.\n"
        "Respond *only* in the specified format."
    )
    default_shots: List[tuple[str, str, str]] = [
        (
            "gsm8k, math",
            "Problem: What is the next number in the sequence 2, 4, 8, 16?",
            "Step 1: Identify the pattern – each term is multiplied by 2.\n"
            "Step 2: 16 × 2 = 32\n"
            "Answer: 32",
        ),
        (
            "gsm8k, math",
            "Problem: Solve for x: 3x + 7 = 22",
            "Step 1: Subtract 7 from both sides: 3x = 15\n"
            "Step 2: Divide by 3: x = 5\n"
            "Answer: 5",
        ),
        (
            "olympiad, omni",
            "Problem: Determine whether v₁ = [1,2] and v₂ = [3,6] are linearly independent.",
            "Step 1: Observe v₂ = 3 · v₁, so v₂ is a scalar multiple of v₁.\n"
            "Step 2: Therefore the vectors are linearly dependent.\n"
            "Answer: Dependent",
        ),
    ]

    if shots is None:
        shots = default_shots

    messages = [{"role": "system", "content": system_prompt}]
    for tag, q, a in shots:
        if dataset.lower() in tag.lower():
            messages.append({"role": "user", "content": q})
            messages.append({"role": "assistant", "content": a})

    messages.append({"role": "user", "content": f"Problem: {question}"})
    return tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=False
    )

# class StepParser:
#     """Extract reasoning steps from raw LLM output – robust to messy formats."""
#     _STEP_RE      = re.compile(r"^\s*Step\s*(\d+)\s*[:\-]", re.I)
#     _ENUM_RE      = re.compile(r"^\s*(\d+)[.)]\s+")
#     _NARRATIVE_RE = re.compile(r"^\s*(First|Second|Third|Fourth|Fifth|Next|Then|After that|Therefore|However|Since|Thus|Finally|Lastly)\b[,:]?", re.I)

#     def _split(self, text: str) -> List[str]:
#         lines = text.splitlines()
#         blocks: List[str] = []
#         buf: List[str] = []

#         def flush():
#             if buf:
#                 blocks.append(" ".join(buf).strip())
#                 buf.clear()

#         for ln in lines:
#             if any(p.match(ln) for p in (self._STEP_RE, self._ENUM_RE, self._NARRATIVE_RE)):
#                 flush()
#             buf.append(ln.strip())
#         flush()
#         return [b for b in blocks if b]

#     # public method ------------------------------------------------------
#     def parse(self, text: str) -> List[str]:
#         steps = self._split(text)
#         if len(steps) >= 2:
#             return steps
#         # first fallback: coarse split on blank lines (\n\n)
#         paras = re.split(r"\n\s*\n", text)
#         steps = [p.strip() for p in paras if p.strip()]
#         if len(steps) >= 2:
#             return steps
#         # second fallback: line‑by‑line
#         return [ln.strip() for ln in text.splitlines() if ln.strip()]

class StepParser:
    # 1) "Step n:" / "Step One:" / "Step IV‑" …
    _STEP_RE = re.compile(
        r"^\s*Step\s*(?:\d+|[IVXLCDM]+|One|Two|Three|Four|Five|Six|Seven|Eight|Nine|Ten)\b\s*[:\-.]",
        re.I,)
    # 2) 1. / 1) / 1‑   OR  I. / II) / III‑
    _ENUM_RE = re.compile(r"^\s*(?:\d+|[IVXLCDM]+)[.)-]\s+", re.I)
    # 3) Narrative adverbs
    _NARRATIVE_RE = re.compile(
        r"^\s*(First|Firstly|Second|Secondly|Third|Thirdly|Fourth|Fifth|Sixth|Seventh|Eighth|Ninth|Tenth|Next|Then|After that|Therefore|However|Since|Thus|Finally|Lastly|In conclusion)\b[,:]?",
        re.I,)
    # 4) Bullet mark (dash/asterisk/dot) + space
    _BULLET_RE = re.compile(r"^\s*[-*•]\s+")
    _JUNK_RE = re.compile(r"^(?:<\|endoftext\|>\s*)+")
    _DELIMITERS = (_STEP_RE, _ENUM_RE, _NARRATIVE_RE, _BULLET_RE)

    # ------------------------------------------------------------------
    @classmethod
    def _is_delimiter(cls, line: str) -> bool:
        return any(r.match(line) for r in cls._DELIMITERS)
    
    @classmethod
    def _clean_line(cls, line: str) -> str:
        """Strip junk tokens and surrounding whitespace from *line*."""
        return cls._JUNK_RE.sub("", line).strip()

    @classmethod
    def _split(cls, text: str) -> List[str]:
        """Split *text* into tentative step blocks using delimiter lines."""
        blocks: List[str] = []
        buf: List[str] = []

        def flush():
            if buf:
                blocks.append(" ".join(buf).strip())
                buf.clear()

        for raw_ln in text.splitlines():
            ln = cls._clean_line(raw_ln)
            if cls._is_delimiter(ln):
                flush()
            buf.append(ln)
        flush()
        return [b for b in blocks if b]
    
    @classmethod
    def _strip_leading_marker(cls, text: str) -> str:
        """Remove one leading delimiter + junk tokens from *text* if present."""
        txt = cls._clean_line(text)
        for p in cls._DELIMITERS:
            m = p.match(txt)
            if m:
                return txt[m.end():].lstrip()
        return txt

    @classmethod
    def parse(cls, text: str) -> List[str]:
        """Return **raw** steps (leading markers still present)."""
        steps = cls._split(text)
        if len(steps) >= 2:
            return steps
        # Fallback ①: blank‑line split
        sans_junk = "\n".join(cls._clean_line(l) for l in text.splitlines())
        paras = re.split(r"\n\s*\n", sans_junk)
        steps = [p.strip() for p in paras if p.strip()]
        if len(steps) >= 2:
            return steps
        # Fallback ②: line‑by‑line (already cleaned)
        return [ln for ln in sans_junk.splitlines() if ln.strip()]
    
    @classmethod
    def parse_clean(cls, obj: Union[str, Iterable[str]]) -> List[str]:
        """Return list of *clean* steps (leading markers removed)."""
        if isinstance(obj, str):
            raw_steps = cls.parse(obj)
        else:
            raw_steps = list(obj)
        return [cls._strip_leading_marker(s) for s in raw_steps]


## Answer Extraction

In [30]:
from datasets import load_dataset
import re, ast
from fractions import Fraction
from decimal import Decimal, InvalidOperation
from typing import Optional, Sequence, Tuple, List
import pandas as pd

class AnswerExtractor:
    """
    Robust gold/ pred-answer extractor for GSM8K · Math · Omni · OlympiadBench.
    """
    def __init__(self):
        self.answer_keywords = (
            "answer", "final answer", "therefore", "result",
            "the final answer is"
        )
        self.number_patterns = [
            r'(\d+\.\d+)', r'(\d+\/\d+)', r'(\d+)',
            r'(\d+\+\d+)', r'(\d+\-\d+)', r'(\d+\*\d+)'
        ]

    def extract_gold_answer(self, text: str,dataset: str | None = None) -> Optional[str]:
        if text is None:
            return None
        if isinstance(text, (list, tuple)):
            return self._flatten_and_clean(text)

        ds  = (dataset or "").lower()
        txt = str(text).strip()
        # 1) GSM8K  #### 정답
        if ds == "gsm8k" or re.search(r"\n####\s*[^\n]+", txt):
            m = re.search(r"\n####\s*([^\n]+)", txt)
            return self._strip(m.group(1)) if m else None
        # 2) OlympiadBench  [ '$…$', '$…$' ]
        if ds =="olympiad" or ( txt.startswith('[') and txt.endswith(']')):
            try:
                parsed = ast.literal_eval(txt)
                if isinstance(parsed, (list, tuple)):
                    return self._flatten_and_clean(parsed)
                return self._strip(str(parsed))
            except (SyntaxError, ValueError):
                return self._strip(txt)
        # 3) Omni  (그대로)
        if ds == "omni":
            return self._strip(txt)
        # 4) Math 또는 그 밖 → 공통 pred-extract 로 처리
        return self.extract_pred_answer(txt)

    def extract_pred_answer(self, text: str) -> Optional[str]:
        # 1) 마지막 balanced \boxed{…} / \fbox{…}
        boxed = self._extract_last_boxed(text)
        if boxed:
            return self._strip(boxed)
        # 2) Answer: … / Therefore: … (다음 줄까지 포함)
        ans_line = self._extract_answer_line(text)
        if ans_line:
            return self._strip(ans_line)
        # 3) 마지막 줄에서 inline LaTeX
        last_expr = self._extract_last_latex(text)
        if last_expr:
            return self._strip(last_expr)
        # 4) 마지막 숫자/수식
        num = self._extract_last_number(text)
        return self._strip(num) if num else None

    # ─────────────────── Internals ───────────────────
    def _extract_last_boxed(self, text: str) -> Optional[str]:
        start_pat = re.compile(r'(\\boxed|\\fbox)\s*\{')
        starts = list(start_pat.finditer(text))
        if not starts:
            return None
        i = starts[-1].end()
        depth = 1
        while i < len(text) and depth:
            ch = text[i]
            if ch == '{':
                depth += 1
            elif ch == '}':
                depth -= 1
            i += 1
        return text[starts[-1].end(): i-1].strip() if depth == 0 else None

    def _extract_answer_line(self, text: str) -> Optional[str]:
        lines = text.splitlines()
        for i, ln in enumerate(lines):
            low = ln.lower()
            if any(k in low for k in self.answer_keywords):
                # 콜론 뒤 같은 줄?
                m = re.search(r'[:\-]\s*(.+)$', ln)
                if m and m.group(1).strip():
                    return m.group(1).strip()
                # 아니면 다음 비어있지 않은 줄
                j = i + 1
                while j < len(lines) and not lines[j].strip():
                    j += 1
                if j < len(lines):
                    return lines[j].strip()
        return None

    def _extract_last_latex(self, text: str) -> Optional[str]:
        for ln in reversed(text.splitlines()):
            ln = ln.strip()
            if not ln:
                continue
            # $ … $  \[…\]  \(…\)
            m = re.findall(r'\$(.*?)\$|\\\((.*?)\\\)|\\\[(.*?)\\\]', ln)
            if m:
                return [seg for seg in m[-1] if seg][0]
            # bare \frac  \sqrt …
            m2 = re.search(r'(\\[a-zA-Z]+(?:\{[^{}]+\})+)', ln)
            if m2:
                return m2.group(1)
            break
        return None

    def _extract_last_number(self, text: str) -> Optional[str]:
        for ln in reversed(text.splitlines()):
            ln = ln.strip()
            if not ln:
                continue
            for pat in self.number_patterns:
                m = re.findall(pat, ln)
                if m:
                    return m[-1]
        return None

    def _strip(self, s: str | None) -> str | None:
        if s is None:
            return None
        s = s.strip()
        s = re.sub(r'^\$+\s*', '', s)            # leading $
        s = re.sub(r'\s*\$+$', '', s)            # trailing $
        s = re.sub(r'^\\\(|\\\)$', '', s)        # \( … \)
        s = re.sub(r'^\\\[|\\\]$', '', s)        # \[ … \]
        s = s.replace('\\\\', '\\')              # \\ → \
        return s.strip(" ,;:")

    def _flatten_and_clean(self, seq: List[str]) -> str:
        return ", ".join(self._strip(x) for x in seq)


In [None]:
extractor = AnswerExtractor()
test_cases = [
    # GSM8K
    ("gsm8k", "Prince fought bravely.\n#### 30 ", "30"),
    ("gsm8k", "Problem...\n#### $-1$ \n", "-1"),
    # Math boxed
    ("math", r"Compute ... thus we have \boxed{\frac{8t}{3}}.", r"\frac{8t}{3}"),
    # Math with Answer:
    ("math", "We conclude.\nAnswer: 52", "52"),
    # Math inline latex
    ("math", "Finally the value is $2\\sqrt{3}$", r"2\sqrt{3}"),
    # Math only number at end
    ("math", "Hence the result is\n40", "40"),
    # Omni simple
    ("omni", "2500", "2500"),
    # Omni latex
    ("omni", r"\frac{5}{6}", r"\frac{5}{6}"),
    # Olympiad list
    ("olympiad", "['$2.0 \\times 10^{6}$', '$6.5 \\times 10^{8}$']", r"2.0 \times 10^{6}, 6.5 \times 10^{8}"),
    # Olympiad single
    ("olympiad", "['$7$']", "7"),
    # Raw python list
    ("", ["42"], "42"),
    # Pred only latex
    ("pred", r"$\frac{1}{27}$", r"\frac{1}{27}"),
]

rows = []
for ds, text, expected in test_cases:
    if ds == "pred":
        got = extractor.extract_pred_answer(text)
    else:
        got = extractor.extract_gold_answer(text, ds if ds else None)
    rows.append({"dataset": ds, "input": text, "expected": expected, "extracted": got, "pass": expected == got})

df = pd.DataFrame(rows)

print("AnswerExtractor test results")
print("=" * 80)
print(df.to_string(index=False))
print("=" * 80)
print(f"\nSummary: {df['pass'].sum()}/{len(df)} tests passed ({df['pass'].mean()*100:.1f}%)")

failed = df[~df['pass']]
if not failed.empty:
    print("\nFailed Cases:")
    print(failed[['dataset', 'input', 'expected', 'extracted']].to_string(index=False))

## Answer Match

In [None]:
from answer_matcher import MathAnswerScorer
import pandas as pd
import pprint

scorer = MathAnswerScorer()
tests = [
    ("1/2", r"\frac{1}{2}", True),
    ("0.5", "1/2", True),
    ("50%", "0.5", True),
    (r"2,3", r"3,2", True),
    (r"(0,1]", r"(0,1]", True),
    (r"{1,2,3}", r"{3,2,1}", True),
    ("30°", "30 °", True),
    ("x+1=0", "x=-1", True),
    (r"\sin(\pi/4)", r"\frac{\sqrt{2}}{2}", True),
    (r"\begin{pmatrix}1&2\\3&4\end{pmatrix}", "{{1,2},{3,4}}", True),
    ("3+4i", "3+4j", True),
    ("1000 m", "1 km", True),
    ("Answer: C", "C", True),
    ("Yes", "yes.", True),
    ("1/2", "2/3", False),
]

results = []
for i,(pred,gold,expect) in enumerate(tests,1):
    try:
        res = scorer.answers_match(pred,gold)
    except Exception as e:
        res = f"Error: {e}"
    results.append({"#":i,"pred":pred,"gold":gold,"expected":expect,"result":res,"pass":res==expect})
df = pd.DataFrame(results)

print("Math Answer Scorer Tests")
print("=" * 80)
print(df.to_string(index=False))
print("=" * 80)

total_tests = len(df)
passed_tests = df['pass'].sum()
print(f"\nSummary: {passed_tests}/{total_tests} tests passed ({passed_tests/total_tests*100:.1f}%)")

failed_tests = df[~df['pass']]
if not failed_tests.empty:
    print("\nFailed Tests:")
    print(failed_tests[['#', 'pred', 'gold', 'expected', 'result']].to_string(index=False))

# Load Datasets

In [1]:
from datasets import get_dataset_config_names, load_dataset, concatenate_datasets

gsm8k = load_dataset("openai/gsm8k", "main", split="test")  # train/test, ['question', 'answer']
math = load_dataset("HuggingFaceTB/MATH", "all", split="test")    # train/test/fewshot, ['problem', 'level', 'type', 'solution']
omni = load_dataset("KbsdJames/Omni-MATH", split = "test")  # only test (4428), ['domain', 'difficulty', 'problem', 'solution', 'answer', 'source']
def load_olympiadbench_english(split: str = "train"):
    all_cfgs = get_dataset_config_names("Hothan/OlympiadBench")
    en_cfgs = [cfg for cfg in all_cfgs if "_en_" in cfg or cfg.endswith("_en")]
    print(f"English configs found: {en_cfgs}")
    ds_list = []
    for cfg in en_cfgs:
        try:
            ds = load_dataset("Hothan/OlympiadBench", cfg, split=split)
            ds_list.append(ds)
        except Exception as e:
            print(f"⚠️  {cfg} 로드 실패: {e}")
    if len(ds_list) == 0:
        raise ValueError("Fail to load English configs")
    full_ds = concatenate_datasets(ds_list)
    return full_ds

olymbench = load_olympiadbench_english("train") # ['id', 'question', 'solution', 'final_answer', 'context', 'image_1', 'image_2', 'image_3', 'image_4', 'image_5', 'image_6', 'image_7', 'image_8', 'image_9', 'modality', 'difficulty', 'is_multiple_answer', 'unit', 'answer_type', 'error', 'question_type', 'subfield', 'subject', 'language'] 


English configs found: ['OE_MM_maths_en_COMP', 'OE_MM_physics_en_COMP', 'OE_TO_maths_en_COMP', 'OE_TO_physics_en_COMP', 'TP_MM_maths_en_COMP', 'TP_MM_physics_en_COMP', 'TP_TO_maths_en_COMP', 'TP_TO_physics_en_COMP']


In [8]:
olymbench[0]['solution']

['The largest possible $C$ is $C=\\frac{1}{2}$.\n\nFor $0<C \\leqslant \\frac{1}{2}$, Turbo can simply choose an arbitrary point $P$ (different from its starting point) to avoid. When Turbo is at an arbitrary point $A$ different from $P$, the two arcs $A P$ have total length 1; therefore, the larger of the two the arcs (or either arc in case $A$ is diametrically opposite to $P$ ) must have length $\\geqslant \\frac{1}{2}$. By always choosing this larger arc (or either arc in case $A$ is diametrically opposite to $P$ ), Turbo will manage to avoid the point $P$ forever.\n\nFor $C>\\frac{1}{2}$, we write $C=\\frac{1}{2}+a$ with $a>0$, and we choose the sequence\n\n$$\n\\frac{1}{2}, \\quad \\frac{1+a}{2}, \\quad \\frac{1}{2}, \\quad \\frac{1+a}{2}, \\quad \\frac{1}{2}, \\ldots\n$$\n\nIn other words, $c_{i}=\\frac{1}{2}$ if $i$ is odd and $c_{i}=\\frac{1+a}{2}<C$ when $i$ is even. We claim Turbo must eventually visit all points on the circle. This is clear when it crawls in the same directi

In [15]:
steps = omni[3]['solution'].split(". ")
print(len(steps))
steps[1]

13


'Let \\(a_i = 10000 + i\\epsilon\\) for \\(i = 0, 1, \\ldots, 19\\) and some small \\(\\epsilon > 0\\)'

In [11]:
idx = 42
dataset = "gsm8k"

# gsm8k
gs_question=gsm8k[idx]['question']
gs_answer = gsm8k[idx]['answer']
print("====== GSM8K ======")
print("Problem:", gs_question)
print("Gold Answer:", gs_answer, "\n")

# math
math_question = math[idx]['problem']
math_answer = math[idx]['solution']
print("====== Math ======")
print("Problem:", math_question)
print("Gold Answer:", math_answer, "\n")

# omni
omni_question = omni[idx]['problem']
omni_solution = omni[idx]['solution']
omni_answer = omni[idx]['answer']
print("====== Omni ======")
print("Problem:", omni_question)
print("Solution:", omni_solution)
print("Gold Answer:", omni_answer, "\n")

# olympiad
olm_question = olymbench[idx]['question']
olm_solution = olymbench[idx]['solution']
olm_answer = olymbench[idx]['final_answer']
print("====== Olympiad ======")
print("Problem:", olm_question)
print("Solution:", olm_solution)
print("Gold Answer:", olm_answer)

Problem: Grandma Jones baked 5 apple pies for the fireman's luncheon.  She cut each pie into 8 pieces and set the five pies out on the buffet table for the guests to serve themselves.  At the end of the evening, after the guests had taken and eaten their pieces of pie, there were 14 pieces of pie remaining.  How many pieces were taken by the guests?
Gold Answer: To start the evening, there were 5 pies, each with 8 pieces, which is 5*8=<<5*8=40>>40 pieces of pie.
If only 14 remained, then 40-14=<<40-14=26>>26 pieces of pie had been taken by guests.
#### 26 

Problem: Find the sum of all values of $x$ such that $|x-1| = 7$.
Gold Answer: We must have either $x-1 = 7$ or $x-1=-7$.  If $x-1=7$, we have $x=8$, and if $x-1 = -7$, we have $x= -6$, so the sum of the possible values of $x$ is $8+(-6) = \boxed{2}$. 

Problem: Call a sequence of positive integers $\{a_n\}$ good if for any distinct positive integers $m,n$, one has 
$$\gcd(m,n) \mid a_m^2 + a_n^2 \text{ and } \gcd(a_m,a_n) \mid m^2 

## GSM8K

In [12]:
conv = build_chat_messages(gs_question, tokenizer, "gsm8k")
inputs = tokenizer(conv, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

gen_kwargs = dict(
    max_new_tokens=400,
    temperature=0.3,
    top_p=0.9,
    do_sample=False, 
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

with torch.no_grad():
    output = model.generate(**inputs, **gen_kwargs)

decoded_full = tokenizer.decode(output[0], skip_special_tokens=False)
assistant_text = decoded_full.split("<|im_start|>assistant")[-1]
assistant_text = assistant_text.replace("<|im_end|>", "").strip()
print("Generated Sequence", assistant_text)
assistant_text = assistant_text.split("Answer:")[-1]

gs_pred_ans = extractor.extract_pred_answer(assistant_text)
gs_gold_ans = extractor.extract_gold_answer(gs_answer, "gsm8k")
print("Pred Ans:", gs_pred_ans)
print("Gold Ans:", gs_gold_ans)

gs_corr = scorer.answers_match(gs_pred_ans, gs_gold_ans)
print(gs_corr)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated Sequence Step 1: Calculate the total number of pie pieces: 5 pies × 8 pieces per pie = 40 pieces
Step 2: Determine the number of pieces taken by the guests: 40 pieces - 14 pieces remaining = 26 pieces
Answer: 26
Pred Ans: 26
Gold Ans: 26
True


## MATH

In [13]:
conv = build_chat_messages(math_question, tokenizer, "math")
inputs = tokenizer(conv, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

gen_kwargs = dict(
    max_new_tokens=1024,
    temperature=0.3,
    top_p=0.9,
    do_sample=False, 
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

with torch.no_grad():
    output = model.generate(**inputs, **gen_kwargs)

decoded_full = tokenizer.decode(output[0], skip_special_tokens=False)
assistant_text = decoded_full.split("<|im_start|>assistant")[-1]
assistant_text = assistant_text.replace("<|im_end|>", "").strip()
print("Generated Sequence", assistant_text)
assistant_text = assistant_text.split("Answer:")[-1]

math_pred_ans = extractor.extract_pred_answer(assistant_text)
math_gold_ans = extractor.extract_gold_answer(math_answer, "math")
print("Pred Ans:", math_pred_ans)
print("Gold Ans:", math_gold_ans)

math_corr = scorer.answers_match(math_pred_ans, math_gold_ans)
print(math_corr)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated Sequence Step 1: Solve the absolute value equation by considering two cases.
Case 1: \(x - 1 = 7\)
Step 2: Add 1 to both sides: \(x = 8\)
Case 2: \(x - 1 = -7\)
Step 3: Add 1 to both sides: \(x = -6\)
Step 4: Add the solutions: \(8 + (-6) = 2\)
 Answer: 2
Pred Ans: 2
Gold Ans: 2
True


## OmniMath

In [29]:
conv = build_chat_messages(omni_question, tokenizer, "omni")
inputs = tokenizer(conv, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

gen_kwargs = dict(
    max_new_tokens=1300,
    temperature=0.3,
    top_p=0.9,
    do_sample=False, 
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

with torch.no_grad():
    output = model.generate(**inputs, **gen_kwargs)

decoded_full = tokenizer.decode(output[0], skip_special_tokens=False)
assistant_text = decoded_full.split("<|im_start|>assistant")[-1]
assistant_text = assistant_text.replace("<|im_end|>", "").strip()
print("Generated Sequence", assistant_text)
assistant_text = assistant_text.split("Answer:")[-1]
print("Parsed Sequence", assistant_text)

omni_pred_ans = extractor.extract_pred_answer(assistant_text)
omni_gold_ans = extractor.extract_gold_answer(omni_answer, "omni")
print("Pred Ans:", omni_pred_ans)
print("Gold Ans:", omni_gold_ans)

omni_corr = scorer.answers_match(omni_pred_ans, omni_gold_ans)
print(omni_corr)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated Sequence To determine if there exists a \( k \) such that there are exactly 2019 \( k \)-good positive integers, we need to analyze the properties of the good sequence \(\{a_n\}\).

First, let's consider the conditions given for a good sequence:
1. For any distinct positive integers \( m \) and \( n \), \(\gcd(m, n) \mid a_m^2 + a_n^2\).
2. For any distinct positive integers \( m \) and \( n \), \(\gcd(a_m, a_n) \mid m^2 + n^2\).

We start by examining the simplest case where \( m = 1 \) and \( n = 2 \). Let \( a_1 = a \) and \( a_2 = b \). Then the conditions become:
1. \(\gcd(1, 2) \mid a^2 + b^2 \implies 1 \mid a^2 + b^2\), which is always true.
2. \(\gcd(a, b) \mid 1^2 + 2^2 \implies \gcd(a, b) \mid 5\).

The possible values for \(\gcd(a, b)\) are 1 and 5. This means that \( a \) and \( b \) must be coprime or both multiples of 5.

Next, let's consider the general case. Suppose \( a_n = c \) for all \( n \). Then for any distinct positive integers \( m \) and \( n \):
1. 

## OlympiadBench

In [30]:
conv = build_chat_messages(olm_question, tokenizer, "olympiad")
inputs = tokenizer(conv, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

gen_kwargs = dict(
    max_new_tokens=2400,
    temperature=0.3,
    top_p=0.9,
    do_sample=False, 
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

with torch.no_grad():
    output = model.generate(**inputs, **gen_kwargs)

decoded_full = tokenizer.decode(output[0], skip_special_tokens=False)
assistant_text = decoded_full.split("<|im_start|>assistant")[-1]
assistant_text = assistant_text.replace("<|im_end|>", "").strip()
print("Generated Sequence", assistant_text)
assistant_text = assistant_text.split("Answer:")[-1]

olm_pred_ans = extractor.extract_pred_answer(assistant_text)
olm_gold_ans = extractor.extract_gold_answer(olm_answer, "olympiad")
print("Pred Ans:", olm_pred_ans)
print("Gold Ans:", olm_gold_ans)

olm_corr = scorer.answers_match(olm_pred_ans, olm_gold_ans)
print(olm_corr)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated Sequence To determine the equation of the line through points \(A\) and \(B\), we need to find the slope of the line and then use the point-slope form of the equation of a line.

First, let's identify the coordinates of points \(A\) and \(B\). From the diagram, we can see that point \(O\) is the origin \((0,0)\), point \(A\) is \((15,0)\), and point \(B\) is \((19,4)\).

Next, we calculate the slope \(m\) of the line passing through points \(A\) and \(B\). The formula for the slope between two points \((x_1, y_1)\) and \((x_2, y_2)\) is:
\[
m = \frac{y_2 - y_1}{x_2 - x_1}
\]
Substituting the coordinates of points \(A\) and \(B\):
\[
m = \frac{4 - 0}{19 - 15} = \frac{4}{4} = 1
\]
So, the slope of the line is \(1\).

Now, we use the point-slope form of the equation of a line, which is:
\[
y - y_1 = m(x - x_1)
\]
We can use point \(A\) \((15,0)\) and the slope \(m = 1\):
\[
y - 0 = 1(x - 15)
\]
Simplifying this, we get:
\[
y = x - 15
\]
Therefore, the equation of the line throug

: 

## AIME 2024

In [None]:
# aime = load_dataset("HuggingFaceH4/aime_2024", split="train")
aime2224 = load_dataset("AI-MO/aimo-validation-aime", split="train") # [id, problem, solution, answer, url]

Generating train split: 100%|██████████| 90/90 [00:00<00:00, 9046.16 examples/s]


# Inference in batch

In [66]:
from torch.utils.data import DataLoader

FIELD_MAP: Dict[str, Tuple[str, str]] = { # dataset name  : (question_field, answer_field)
    "gsm8k": ("question", "answer"),
    "math": ("problem", "solution"),
    "omni": ("problem", "answer"),
    "olympiad": ("question", "final_answer"),
}

def get_loader(ds_name: str, split: str, batch_size: int):
    """Return torch DataLoader yielding (idx, question, gold_answer)."""
    if ds_name == "math":
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)
    elif ds_name == "gsm8k":
        ds = load_dataset("openai/gsm8k", "main", split=split)
    elif ds_name == "omni":
        ds = load_dataset("KbsdJames/Omni-MATH", split=split)
    elif ds_name == "olympiad":
        ds = load_olympiadbench_english(split)
    else:
        raise ValueError(f"Unsupported dataset {ds_name}")

    q_key, a_key = FIELD_MAP[ds_name]
    def collate(indices):
        items = [ds[i] for i in indices]
        idxs = indices
        qs = [item[q_key] for item in items]
        golds = [item[a_key] for item in items]
        return idxs, qs, golds

    return DataLoader(range(len(ds)), batch_size=batch_size, shuffle=False, collate_fn=collate), len(ds)

def batched_generate(questions: List[str],tokenizer, model, ds_name: str, gen_kwargs: Dict, parser: StepParser,) -> List[str]:
    """Greedy decode a batch → return raw answer strings (post‑`Answer:`)."""
    prompts = [build_chat_messages(q, tokenizer, ds_name) for q in questions]
    tokenized = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    with torch.no_grad():
        out = model.generate(**tokenized, **gen_kwargs)
    decoded = tokenizer.batch_decode(out, skip_special_tokens=False)
    answers: List[str] = []
    for full in decoded:
        # isolate assistant chunk
        part = full.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
        print("Original Generation", part)
        steps = parser.parse(part)
        cleaned_steps = parser.parse_clean(steps)
        print("Parsed Steps:", len(cleaned_steps))
        print(cleaned_steps)
        try: ans = cleaned_steps[-1].split("Answer:")[-1].strip()
        except: ans = part.split("Answer:")[-1].strip()
        answers.append(ans)
    return answers


In [67]:
import math 

dataset = "olympiad"
loader, total = get_loader(dataset, "test" if dataset != "olympiad" else "train", 3)
max_n = 2

extractor = AnswerExtractor()
scorer = MathAnswerScorer()
parser_obj = StepParser()
gen_kwargs = dict(max_new_tokens=2049, temperature=0.2, top_p=0.9, do_sample=False, pad_token_id=tokenizer.eos_token_id) # greedy decoding

## Evaluation loop ##
correct = 0
seen = 0
pbar = tqdm(loader, total= math.ceil(max_n / 3))
for idxs, qs, golds in pbar:
    # limit sample count
    if seen >= max_n:
        break
    # trim overflow inside batch
    lim = min(len(qs), max_n - seen)
    qs, golds, idxs = qs[:lim], golds[:lim], idxs[:lim]

    preds = batched_generate(qs, tokenizer, model, dataset, gen_kwargs, parser_obj)
    # extract structured answers
    pred_ans = [extractor.extract_pred_answer(p) for p in preds]
    gold_ans = [extractor.extract_gold_answer(g, dataset) for g in golds]
    print("Parsed Answers:", pred_ans)
    print("Parsed Golds:", gold_ans)

    batch_corr = [scorer.answers_match(pa, ga) for pa, ga in zip(pred_ans, gold_ans)]
    correct += sum(batch_corr)
    seen += lim
    pbar.set_postfix(acc=f"{correct/seen:.3%}")

print("\n===========================================")
print(f"Dataset       : {dataset}")
print(f"Samples seen  : {seen}")
print(f"Correct       : {correct}")
print(f"Accuracy      : {correct/seen:.3%}")

English configs found: ['OE_MM_maths_en_COMP', 'OE_MM_physics_en_COMP', 'OE_TO_maths_en_COMP', 'OE_TO_physics_en_COMP', 'TP_MM_maths_en_COMP', 'TP_MM_physics_en_COMP', 'TP_TO_maths_en_COMP', 'TP_TO_physics_en_COMP']


  0%|          | 0/1 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 1/1 [00:33<00:00, 33.32s/it, acc=50.000%]

Original Generation To determine the largest constant \( C > 0 \) such that Turbo can ensure that there is some point on the circle that it will never visit or crawl across, we need to analyze the behavior of Turbo's crawling pattern.

First, let's consider the total distance Turbo crawls. If \( c_i < C \) for all \( i \), then the total distance \( D \) Turbo crawls is:
\[
D = c_1 + c_2 + c_3 + \cdots
\]
Since each \( c_i < C \), we have:
\[
D < C + C + C + \cdots = C \cdot \infty = \infty
\]
However, we need to find a specific value of \( C \) such that Turbo can ensure there is a point on the circle that it will never visit or crawl across.

To do this, let's consider the circumference of the circle, which is 1. If \( C \leq \frac{1}{2} \), then each \( c_i < \frac{1}{2} \). This means that Turbo's total distance \( D \) will always be less than \( \frac{1}{2} \times \infty = \infty \), but more importantly, it will always be less than 1.

If \( D < 1 \), then Turbo's total distance


