In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoModel,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    BitsAndBytesConfig,
    PreTrainedTokenizer,
)
import json
from pathlib import Path
import os, sys
import torch
from tqdm import tqdm
from datasets import load_dataset
import re
from fractions import Fraction
from decimal import Decimal, InvalidOperation
from typing import Optional, Sequence, Tuple, List
import pprint

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "Qwen/Qwen2.5-Math-7B-Instruct"  # "mistralai/Mathstral-7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 4 files: 100%|██████████| 4/4 [07:33<00:00, 113.33s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


In [4]:
prm_tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-PRM-7B", trust_remote_code=True)
prm = AutoModel.from_pretrained(
    model_name, 
    quantization_config=bnb_config,
    device_map="auto", 
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [14]:
import re, math, random
from fractions import Fraction
from typing import List, Dict, Any, Tuple

SYSTEM_PROMPT = (
    "You are Qwen-Math, a meticulous math tutor. "
    "Solve problems with clear, numbered steps and give only ONE final answer line."
)
FEWSHOT_EXAMPLES: List[Dict[str, Any]] = []

def build_user_prompt(question: str) -> str:
    header = (
        "Solve the following problem. Follow this EXACT format:\n"
        "Step 1: <short reasoning>\n"
        "Step 2: <short reasoning>\n"
        "...\n"
        "Answer: <final numeric answer>\n"
        "Constraints:\n"
        "- Keep each step concise, one idea per line.\n"
        "- Do not include extra explanations after the final Answer line.\n"
    )
    parts = [header]
    if FEWSHOT_EXAMPLES:
        parts.append("Here are solved examples:\n")
        for ex in FEWSHOT_EXAMPLES:
            parts.append(format_fewshot_block(ex))
        parts.append("\nNow solve this new problem in the same format.\n")

    parts.append("Problem: " + question.strip())
    return "\n".join(parts).strip()

def to_chat_prompt(question: str) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": build_user_prompt(question)},
    ]
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

def generate_solutions(question: str, n: int = 1, temperature: float = 0.2, top_p: float = 0.9, max_new_tokens: int = 512, seed: int = 123) -> List[str]:
    assert n >= 1
    do_sample = (n > 1) or (temperature and temperature > 1e-8)

    prompt = to_chat_prompt(question)
    inputs = tokenizer([prompt], return_tensors="pt", padding=True).to(model.device)

    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    gen_cfg = GenerationConfig(
        do_sample=do_sample,
        temperature=temperature if do_sample else None,
        top_p=top_p if do_sample else None,
        num_return_sequences=n,
        max_new_tokens=max_new_tokens,
        repetition_penalty=1.05,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    with torch.no_grad():
        out = model.generate(**inputs, **gen_cfg.to_dict())

    # decode only the generated part
    gen_only = out[:, inputs["input_ids"].shape[1]:]
    texts = tokenizer.batch_decode(gen_only, skip_special_tokens=True)
    return [t.strip() for t in texts]


# gsm8k grader

In [17]:
from latex2sympy2 import latex2sympy, latex2latex, set_real, variances, var, set_variances
import re, math
from fractions import Fraction
from decimal import Decimal, InvalidOperation
from typing import Optional, Tuple, List

# extract pred answer
ANS_LINE = re.compile(r"^\s*answer\s*:\s*(.+?)\s*$", re.IGNORECASE | re.MULTILINE)
_BOXED = re.compile(r"\\boxed\{([^}]*)\}")
NUM_FINDER = re.compile(r"[-+]?\d+(?:\.\d+)?")
_MIXED_FRAC = re.compile(r"[-+]?\d+\s+\d+/\d+")
_PURE_FRAC  = re.compile(r"[-+]?\d+/\d+")
_DEC_SCI    = re.compile(r"[-+]?(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?(?:[eE][+-]?\d+)?")
# find gold answer
_GSM_GOLD = re.compile(r"####\s*([^\n]+)")
# truncate tail part of the generation
_TRAIL_PUNCT = re.compile(r"[)\].,;:\s]+$")
_LEAD_PUNCT  = re.compile(r"^[([{\s]+")
_UNIT_TAIL = re.compile(
    r"\s*(dollars?|cents?|percent|perc\.?|pts?|points?|years?|year|hrs?|hours?|mins?|minutes?|secs?|seconds?)\s*$",
    re.IGNORECASE,
)

def _strip_wrappers(s: str) -> str:
    s = _LEAD_PUNCT.sub("", s)
    s = _TRAIL_PUNCT.sub("", s)
    return s.strip()

def _simple_moneylike_last(text: str) -> str:
    m = re.findall(r"(-?[$0-9.,]{2,})|(-?[0-9]+)", text)
    if not m:
        return ""
    last = m[-1]
    tok = [x for x in last if x][0]
    tok = tok.strip()
    for rgx in (",", r"\$", r"(?s).*#### ", r"\.$"):
        tok = re.sub(rgx, "", tok)
    return tok.strip()

def _find_last_numberish(text: str) -> str:
    for rx in (_MIXED_FRAC, _PURE_FRAC, _DEC_SCI):
        hits = list(rx.finditer(text))
        if hits:
            tok = hits[-1].group(0)
            return _post_clean_numberish(tok)
    fallback = _simple_moneylike_last(text)
    return fallback or ""

def _post_clean_numberish(s: str) -> str:
    if not s:
        return ""
    t = s.strip()
    for rx in (_MIXED_FRAC, _PURE_FRAC, _DEC_SCI):
        hits = rx.findall(t)
        if hits:
            cand = hits[-1] if isinstance(hits, list) else hits
            t = cand if isinstance(cand, str) else cand[-1]
            break

    t = t.replace("$", "").replace(",", "")
    t = _UNIT_TAIL.sub("", t)
    t = _strip_wrappers(t)
    t = re.sub(r"%\s*$", "", t)
    return t.strip()

def _parse_numeric(s: str):
    if s is None:
        return None
    t = str(s).strip()
    if not t:
        return None
    # 혼합분수
    m = re.fullmatch(r"([+-]?\d+)\s+(\d+)\s*/\s*(\d+)", t)
    if m:
        whole, num, den = int(m.group(1)), int(m.group(2)), int(m.group(3))
        sign = -1 if whole < 0 else 1
        whole = abs(whole)
        try:
            return Fraction(whole * den + num, den) * sign
        except ZeroDivisionError:
            return None
    # 순수 분수
    m = re.fullmatch(r"([+-]?\d+)\s*/\s*(\d+)", t)
    if m:
        try:
            return Fraction(int(m.group(1)), int(m.group(2)))
        except ZeroDivisionError:
            return None
    # 일반 수
    try:
        v = float(Decimal(t))
        return v
    except (InvalidOperation, ValueError):
        return None

def normalize_number(x: str) -> str:
    if x is None:
        return ""
    s = str(x).strip()
    s = _strip_wrappers(s)
    s = s.replace(",", "").replace("$", "").strip()
    # 백분율 처리: 뒤에 %/percent 있으면 제거(수치 비교는 grade 단계에서 가능)
    s = re.sub(r"\s*%\s*$", "", s)
    s = re.sub(r"\s*percent\s*$", "", s, flags=re.IGNORECASE)
    # 혼합분수: "a b/c" -> (a*den + num)/den
    m = re.fullmatch(r"\s*([+-]?\d+)\s+(\d+)\s*/\s*(\d+)\s*\Z", s)
    if m:
        whole, num, den = int(m.group(1)), int(m.group(2)), int(m.group(3))
        sign = -1 if whole < 0 else 1
        whole = abs(whole)
        try:
            frac = Fraction(whole * den + num, den) * sign
            return f"{frac.numerator}/{frac.denominator}"
        except ZeroDivisionError:
            return ""
    # 순수 분수: "a/b" -> 기약
    m = re.fullmatch(r"\s*([+-]?\d+)\s*/\s*(\d+)\s*\Z", s)
    if m:
        try:
            frac = Fraction(int(m.group(1)), int(m.group(2)))
            return f"{frac.numerator}/{frac.denominator}"
        except ZeroDivisionError:
            return ""
    # 소수/정수/과학표기
    try:
        v = float(Decimal(s))
        if math.isfinite(v) and abs(v - round(v)) < 1e-12:
            return str(int(round(v)))
        return f"{v:.12g}"
    except (InvalidOperation, ValueError):
        return s

def gsm_extract_gold(gold_field: str) -> str:
    if not gold_field:
        return ""
    m = _GSM_GOLD.search(gold_field)
    ans = (m.group(1) if m else gold_field).strip()
    ans = _strip_wrappers(ans)
    return ans

def gsm_extract_pred(text: str):
    if not text:
        return ""
    # 1) find "Answer: ..." line
    m = ANS_LINE.search(text)
    if m:
        cand = _post_clean_numberish(m.group(1))
        if cand:
            return cand
    # 2) find "\boxed{...}" qwen style
    m = _BOXED.search(text)
    if m:
        cand = _post_clean_numberish(m.group(1))
        if cand:
            return cand
    # 3) fallback
    last = _find_last_numberish(text)
    return last or ""

def grade_gsm8k_answer(pred: str, gold: str, atol: float = 1e-6) -> bool:
    p_norm = normalize_number(pred)
    g_norm = normalize_number(gold)
    if p_norm == g_norm:
        return True
    
    p_val = _parse_numeric(p_norm)
    g_val = _parse_numeric(g_norm)
    if p_val is None or g_val is None:
        return False
    
    if isinstance(p_val, Fraction) and isinstance(g_val, Fraction):
        return p_val == g_val
    
    try:
        pv = float(p_val) if not isinstance(p_val, float) else p_val
        gv = float(g_val) if not isinstance(g_val, float) else g_val
        return math.isfinite(pv) and math.isfinite(gv) and abs(pv - gv) <= atol
    except Exception:
        return False

def solve_one(question: str, n: int = 1) -> Dict[str, Any]:
    gens = generate_solutions(question, n=n)
    return {
        "generations": gens,
        "pred_answer": [gsm_extract_pred(t) for t in gens]
    }

In [18]:
def evaluate_gsm8k(dataset, limit: int = None, n: int = 1, temperature: float = 0.2, top_p: float = 0.9, seed: int = 123, max_new_tokens: int = 512):
    total = 0
    correct = 0
    logs = []

    for i, ex in enumerate(dataset):
        if limit and i >= limit:
            break
        q = ex["question"]
        gold_raw = ex["answer"]
        gold = gsm_extract_gold(gold_raw)
        print("Gold:", gold)

        gens = generate_solutions(q, n=n, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, seed=seed+i)
        preds = [gsm_extract_pred(t) for t in gens]
        print("Pred:", preds)

        is_correct = grade_gsm8k_answer(preds[0], gold)
        total += 1
        correct += int(is_correct)

        logs.append({
            "idx": i,
            "question": q,
            "gold": gold,
            "gens": gens,
            "preds": preds,
            "correct_first": bool(is_correct),
        })

        if (i + 1) % 20 == 0:
            acc = 100.0 * correct / total
            print(f"[{i+1}] running acc = {acc:.2f}%")

    acc = 100.0 * correct / max(total, 1)
    print(f"Done. Accuracy (first sample) = {acc:.2f}%  on {total} examples.")
    return acc, logs

gsm8k = load_dataset("openai/gsm8k", "main", split="test")
acc, logs = evaluate_gsm8k(gsm8k, limit=5, n=1)

Gold: 18


Both `max_new_tokens` (=512) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=512) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Pred: ['18']
Gold: 3


Both `max_new_tokens` (=512) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Pred: ['3']
Gold: 70000


Both `max_new_tokens` (=512) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Pred: ['70000']
Gold: 540


Both `max_new_tokens` (=512) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Pred: ['540']
Gold: 20
Pred: ['20']
Done. Accuracy (first sample) = 100.00%  on 5 examples.


In [None]:
def evaluate_gsm8k(
    dataset,
    limit: int = None,
    n: int = 1,
    temperature: float = 0.2,
    top_p: float = 0.9,
    seed: int = 123,
    max_new_tokens: int = 512,
    save_incorrect_path: str | None = None,  # 추가: 틀린 샘플 저장 경로
) -> Tuple[float, List[Dict[str, Any]], List[Dict[str, Any]]]:

    total = 0
    correct = 0
    logs: List[Dict[str, Any]] = []
    incorrect_samples: List[Dict[str, Any]] = []

    for i, ex in enumerate(dataset):
        if limit and i >= limit:
            break
        q = ex["question"]
        gold_raw = ex["answer"]
        gold = gsm_extract_gold(gold_raw)
        print("Gold:", gold)

        gens = generate_solutions(
            q, n=n, temperature=temperature, top_p=top_p,
            max_new_tokens=max_new_tokens, seed=seed + i
        )
        preds = [gsm_extract_pred(t) for t in gens]
        print("Pred:", preds)

        is_correct = grade_gsm8k_answer(preds[0], gold)
        total += 1
        correct += int(is_correct)

        # 로그 축적
        logs.append({
            "idx": i,
            "question": q,
            "gold": gold,
            "gens": gens,
            "preds": preds,
            "correct_first": bool(is_correct),
        })

        # 오답 수집
        if not is_correct:
            # 비교/디버깅에 유용한 정규화 값 및 차이
            p_norm = normalize_number(preds[0]) if preds and preds[0] is not None else ""
            g_norm = normalize_number(gold)

            # 수치 차이(가능한 경우)
            def _to_float(x):
                try:
                    return float(x.replace(" ", ""))
                except Exception:
                    return None
            p_val = _to_float(p_norm)
            g_val = _to_float(g_norm)
            diff = (p_val - g_val) if (p_val is not None and g_val is not None and math.isfinite(p_val) and math.isfinite(g_val)) else None

            incorrect_samples.append({
                "idx": i,
                "question": q,
                "gold": gold,
                "gold_norm": g_norm,
                "pred_chosen": preds[0] if preds else "",
                "pred_chosen_norm": p_norm,
                "preds_all": preds,          # 추출된 후보 정답들
                "gens_all": gens,            # 원문 전체(길 수 있음)
                "numeric_diff": diff,        # 가능하면 수치 차이
            })

        if (i + 1) % 20 == 0:
            acc = 100.0 * correct / total
            print(f"[{i+1}] running acc = {acc:.2f}%")

    acc = 100.0 * correct / max(total, 1)
    print(f"Done. Accuracy (first sample) = {acc:.2f}%  on {total} examples.")

    # JSON 저장(옵션)
    if save_incorrect_path:
        os.makedirs(os.path.dirname(save_incorrect_path) or ".", exist_ok=True)
        with open(save_incorrect_path, "w", encoding="utf-8") as f:
            json.dump(incorrect_samples, f, ensure_ascii=False, indent=2)
        print(f"Saved {len(incorrect_samples)} incorrect samples to: {save_incorrect_path}")

    return acc, logs, incorrect_samples

# olympiad grader

In [None]:
import re
import json
import sympy as sp
from sympy import simplify, Eq, sympify, Pow
from sympy.parsing.latex import parse_latex
import sys
import math

class AutoScoringJudge:
    def __init__(self):
        # Map of special symbols to their replacements
        self.special_signal_map = {
            "\\left": "",
            "\\right": "",
            "∶": ":",
            "，": ",",
            "$": "",
            "\\approx": "=",
            "\\simeq": "=",
            "\\sim": "=",
            "^\\prime": "'",
            "^{\\prime}": "'",
            "^\\circ": "",
            "%": "",
        }
        self.pi = parse_latex("\\pi")
        self.precision = 1e-8  # Default precision for comparison

    def split_by_comma(self, expr: str):
        # Splits expressions by commas outside of brackets
        in_bracket_num = 0
        splitted_expr = []
        start_idx = 0
        for i, char in enumerate(expr):
            if char in ["(", "["]:
                in_bracket_num += 1
            elif char in [")", "]"]:
                in_bracket_num -= 1
            elif char == "," and in_bracket_num == 0:
                splitted_expr.append(expr[start_idx:i].strip())
                start_idx = i + 1

        if start_idx < len(expr):
            splitted_expr.append(expr[start_idx:].strip())   
        
        return splitted_expr

    def trans_plus_minus_sign(self, expr_list: list):
        # Translates plus-minus signs into separate expressions
        new_expr_list = []
        for expr in expr_list:
            if "\\pm" in expr:
                new_expr_list.append(expr.replace("\\pm", "+"))
                new_expr_list.append(expr.replace("\\pm", "-"))
            else:
                new_expr_list.append(expr)
        
        return new_expr_list
    
    def judge(self, expression1, expression2, precision=1e-8):
        # Judge if two expressions are equal (expression1 is considered as the Ground Truth)
        # Default precision is a list for supporting multiple expressions
        precision = precision if isinstance(precision, list) else [precision]

        try:
            expression1, expression2 = self.preprocess(expression1, expression2)
        except:
            return False
        if expression1 == expression2:
            # print("Exactly equal")
            return True
        
        # Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
        expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
        expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
        
        expression1 = self.split_by_comma(expression1)
        expression2 = self.split_by_comma(expression2)

        temp_list1 = self.trans_plus_minus_sign(expression1)
        temp_list2 = self.trans_plus_minus_sign(expression2)

        # Set up a list for allowed errors
        if len(precision) <= 1:
            precision = precision * len(temp_list1)
        
        if len(temp_list1) != len(temp_list2):
            return False

        # Check if elements in both lists can be paired and are equal
        idx = -1
        while len(temp_list1) != 0:
            idx = (idx + 1) % len(temp_list1)

            item1 = temp_list1[idx]
            self.precision = precision[idx]

            for item2 in temp_list2:
                if self.is_equal(item1, item2):
                    temp_list1.remove(item1)
                    temp_list2.remove(item2)
                    precision.remove(self.precision)
                    break
            else:
                # If no match was found, return False
                return False

        # If all elements are matched, return True
        return True
    
    def is_interval(self, expr):
        # Checks if an expression is an interval
        return expr.startswith(("(", "[")) and expr.endswith((")", "]"))

    def sympy_sub_pi(self, expression_sympy):
        # Replaces the symbol for pi in sympy expressions with its numerical value
        return expression_sympy.subs(self.pi, math.pi)
    
    def is_equal(self, expression1, expression2):
        # Default first expression is ground truth. Check if expressions are equal in different aspects
        if expression1 == expression2 and expression1 != "" and expression2 != "":
            # print("Equivalent natively")
            return True

        # First check if both are intervals
        if self.is_interval(expression1) and self.is_interval(expression2):
            try:
                if self.interval_equal(expression1, expression2):
                    # print("Interval equivalent")
                    return True
            except:
                return False

        # Then check for numerical equality
        try:
            if self.numerical_equal(expression1, expression2):
                # print("Numerically equivalent")
                return True
        except:
            pass
        
        # Then check if expressions are mathematically equal
        try:
            if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
                # print("Expression equivalent")
                return True
        except:
            pass
            
        # Lastly, check for equation equality
        try:
            if self.equation_equal(expression1, expression2):
                # print("Equation equivalent")
                return True
        except:
            pass
            
        return False

    def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
        # Check if two numerical values are equal within an allowed error range
        # Includes possible percentage cases
        reference = float(expression1)
        prediction = float(expression2)
        
        if include_percentage:
            gt_result = [reference / 100, reference, reference * 100]
        else:
            gt_result = [reference]
        
        for item in gt_result:
            if abs(item - prediction) <= self.precision * 1.01:
                return True
        return False
    

    def expression_equal(self, exp1, exp2):
        # Check if two expressions are mathematically equivalent
        # Extract expression and use sympy for equivalence checking
        def extract_expression(expression):
            if "=" in expression:
                expression = expression.split("=")[1]
            return expression.strip()
        
        exp1 = extract_expression(exp1)
        exp2 = extract_expression(exp2)

        expr1_sym = sympify(parse_latex(exp1))
        expr2_sym = sympify(parse_latex(exp2))

        if expr1_sym == expr2_sym:
            return True
        else:
            expr1_sym = self.sympy_sub_pi(expr1_sym)
            expr2_sym = self.sympy_sub_pi(expr2_sym)

            if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
                return False
            elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
                try:
                    if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
                        print(f"These two numbers cannot be calculated by the current computer for: \"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
                        return False

                    if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
                        return True
                    else:
                        return False
                except:
                    return False
            else:
                try:
                    simplified_expr = simplify(expr1_sym - expr2_sym)

                    num_value = simplified_expr.evalf()
                    
                    return abs(num_value) < 1e-3
                except:
                    return False

    def equation_equal(self, expression1, expression2):
        # Check if two equations are mathematically equivalent
        # Simplify equations and use sympy for equivalence checking
        def simplify_equation(latex_eq):
            lhs, rhs = latex_eq.split('=')

            lhs_expr = parse_latex(lhs)
            rhs_expr = parse_latex(rhs)

            equation = Eq(lhs_expr, rhs_expr)

            simplified_eq = simplify(equation.lhs - equation.rhs)

            return simplified_eq

        expr1_sym = simplify_equation(expression1)
        expr2_sym = simplify_equation(expression2)

        division_result_1 = simplify(expr1_sym / expr2_sym)
        division_result_2 = simplify(expr2_sym / expr1_sym)

        if (division_result_1.is_Integer and division_result_1 != 0) or (division_result_2.is_Integer and division_result_2 != 0):
            return True
        else:
            return False

    def interval_equal(self, expression1, expression2):
        # Check if two intervals are mathematically equivalent
        def compare_two_interval(inter1, inter2):
            if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
                return False
            
            inter1 = inter1.strip('[]()')
            inter2 = inter2.strip('[]()')

            items_1 = inter1.split(',')
            items_2 = inter2.split(',')

            for item_1, item_2 in zip(items_1, items_2):
                if not self.expression_equal(item_1, item_2):
                    return False
            return True
            
        interval1 = expression1
        interval2 = expression2

        if interval1 == interval2:
            return True
        else:
            inter_list1 = interval1.split("\\cup")
            inter_list2 = interval2.split("\\cup")
            
            if len(inter_list1) != len(inter_list2):
                return False
            else:
                for inter1, inter2 in zip(inter_list1, inter_list2):
                    if not compare_two_interval(inter1, inter2):
                        return False
                return True

    def preprocess(self, expression1, expression2):
        # Preprocess expressions to extract and replace special symbols
        def extract_boxed_content(latex_str):
            boxed_matches = re.finditer(r'\\boxed{', latex_str)
            results = ""

            for match in boxed_matches:
                start_index = match.end()
                end_index = start_index
                stack = 1

                while stack > 0 and end_index < len(latex_str):
                    if latex_str[end_index] == '{':
                        stack += 1
                    elif latex_str[end_index] == '}':
                        stack -= 1
                    end_index += 1

                if stack == 0:
                    content = latex_str[start_index:end_index - 1]
                    results += content + ","
                else:
                    raise ValueError("Mismatched braces in LaTeX string.")

            if results == "":
                last_line_ans = latex_str.strip().split("\n")[-1]
                dollar_pattern = r"\$(.*?)\$"
                answers = re.findall(dollar_pattern, last_line_ans)

                if answers:
                    for ans in answers:
                        results += ans + ","
                else:
                    results = latex_str
                
            return results
        
        def sepcial_symbol_replace(expression):
            if "\\in " in expression:
                expression = expression.split("\\in ")[1]
            
            for signal in self.special_signal_map:
                expression = expression.replace(signal, self.special_signal_map[signal])

            expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~，。")

            pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
            expression = re.sub(pattern, r'\1', expression)

            return expression
        
        exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
        exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)

        return exp1, exp2
    
    def can_compute_power(self, expr):
        # Checks if a power expression can be computed
        if isinstance(expr, Pow):
            base, exp = expr.as_base_exp()
            if base.is_number and exp.is_number:
                MAX_EXP = 1000  # Adjust based on computing environment
                if abs(exp.evalf()) > MAX_EXP:
                    return False
                else:
                    return True
            else:
                return False
        else:
            return True  # Not a power expression, can compute

if __name__ == "__main__":
    example_path = "./scoring_examples.json"  # Path to the examples file
    with open(example_path, "r", encoding="utf-8") as f:
        examples = json.load(f)
    
    scorer = AutoScoringJudge()
    
    exp1 = "10^{10^{10^{10}}}"
    exp2 = "10^{10}"
    precision = 1e-4

    res = scorer.judge(exp1, exp2, precision)
    print(res)


    for exp in examples:
        ground_truth = exp["GT"]
        for exp_ans in exp["Ans"]:
            model_output = exp_ans["answer"]

            if "precision" in exp_ans:
                precision = float(exp_ans["precision"]) if type(exp_ans["precision"]) == str else exp_ans["precision"]
            else:
                precision = 1e-8
                
            result = scorer.judge(ground_truth, model_output, precision)

            if result != exp_ans["equal"]:
                print(f"ERROR! Ground_truth: {ground_truth}; Model_output: {model_output}; Expected result: {result}")