In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login
import os
import re
import sys
from io import StringIO
from dotenv import load_dotenv

load_dotenv()
HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN')
login(HF_AUTH_TOKEN)

ADAPTER_PATH = "/home/guest/AdvancedLLMReasoning/math_tutor_model/math_sft_adapter/v3/final_checkpoint"
BASE_MODEL_ID = "meta-llama/Llama-3.2-1B"

def load_model():
    print("‚è≥ ƒêang load Base Model (4-bit)...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    
    tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
    tokenizer.padding_side = "left"  # left for inference
    
    print(f"ƒêang gh√©p LoRA Adapter t·ª´: {ADAPTER_PATH}")
    model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
    model.eval()
    
    return model, tokenizer

def execute_python_code(code_str):
    """Execute Python code and return the output."""
    try:
        # Capture stdout
        old_stdout = sys.stdout
        sys.stdout = StringIO()
        
        # Execute code
        exec_globals = {}
        exec(code_str, exec_globals)
        
        # Get output
        output = sys.stdout.getvalue()
        sys.stdout = old_stdout
        
        # If no print output, try to get the last expression value
        if not output.strip():
            # Try to get the last variable or expression result
            code_lines = code_str.strip().split('\n')
            if code_lines:
                last_line = code_lines[-1].strip()
                # If last line is not an assignment or import
                if '=' not in last_line and not last_line.startswith('import'):
                    try:
                        result = eval(last_line, exec_globals)
                        output = str(result)
                    except:
                        pass
        
        return output.strip()
    except Exception as e:
        return f"Error: {str(e)}"
    finally:
        sys.stdout = old_stdout

def post_process_solution(generated_text):
    """
    Post-process the generated solution:
    1. Trim to line containing \\boxed{...}
    2. If code is present, execute it and replace result in boxed
    """
    match = re.search(r'^.*\\boxed\{[^}]+\}.*$', generated_text, re.MULTILINE)
    if match:
        trimmed_text = generated_text[:match.end()]
    else:
        trimmed_text = generated_text
    code_match = re.search(r'```python\s*\n(.*?)\n```', trimmed_text, re.DOTALL)
    if code_match:
        code_str = code_match.group(1)
        # Remove <llm></llm> or <llm-code-output></llm-code-output> patterns
        trimmed_text = re.sub(r'<llm>.*?</llm>', '', trimmed_text, flags=re.DOTALL)
        trimmed_text = re.sub(r'<llm-code-output>.*?</llm-code-output>', '', trimmed_text, flags=re.DOTALL)
        
        # Execute code
        result = execute_python_code(code_str)
        
        # Replace result to \boxed{}
        if result:
            boxed_match = re.search(r'\\boxed\{([^}]+)\}', trimmed_text)
            if boxed_match:
                trimmed_text = re.sub(r'\\boxed\{[^}]+\}', f'\\\\boxed{{{result}}}', trimmed_text)
            else:
                trimmed_text += f'\n\nTherefore, the answer is \\boxed{{{result}}}.'
    
    return trimmed_text

def solve_math_problem(model, tokenizer, question, max_length=1024):
    system_prompt = (
            "You are a math reasoning assistant.\n"
            "Solve the problem step by step.\n"
            "You can use Python code if needed.\n"
            "If you write code, put it inside a Python code block:\n"
            "```python\n"
            "...\n"
            "```\n"
            "Output ONLY the final number inside \\boxed{}."
    )
    
    messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question},
    ]
    
    prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
    )
    
    inputs = tokenizer(
            prompt, 
            padding=False, 
            truncation=True, 
            max_length=max_length, 
            add_special_tokens=False,
            return_tensors="pt"
    ).to("cuda")
    
    print("\nü§ñ Model ƒëang suy nghƒ©...\n")
    print("-" * 50)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.convert_tokens_to_ids("<|eot_id|>"),
        )
    
    generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    processed_solution = post_process_solution(generated_text)
    
    return processed_solution

In [43]:
model, tokenizer = load_model()

‚è≥ ƒêang load Base Model (4-bit)...


ƒêang gh√©p LoRA Adapter t·ª´: /home/guest/AdvancedLLMReasoning/math_tutor_model/math_sft_adapter/v3/final_checkpoint


In [None]:
question = "Solve the equation: 2x + 3 = 7"  # V√≠ d·ª• c√¢u h·ªèi
solution = solve_math_problem(model, tokenizer, question)
print(solution)


ü§ñ Model ƒëang suy nghƒ©...

--------------------------------------------------
Let's use sympy to solve the equation.
```python
import sympy as sp

# define the symbols
x = sp.symbols('x')

# define the equation
equation = sp.Eq(2*x + 3, 7)

# solve the equation
solution = sp.solve(equation, x)

# print the solution
print(solution)
```

So the solution is $\boxed{[2]}$.


: 

In [20]:
from datasets import load_dataset
gsm8k_ds = load_dataset("openai/gsm8k", "main", split="test")
math_ds = load_dataset("nlile/hendrycks-MATH-benchmark", split="test")

Generating train split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12000/12000 [00:00<00:00, 765244.30 examples/s]
Generating test split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [00:00<00:00, 303935.07 examples/s]


In [26]:
import re
def extract_answer(text):
    if "\\boxed{" in text:
        idx = text.rfind("\\boxed{")
        content = ""
        count = 0
        started = False
        for char in text[idx:]:
            if char == "{":
                count += 1
                started = True
                if count == 1: continue 
            elif char == "}":
                count -= 1
            if started:
                if count == 0: break
                content += char
        return content.strip()
    
    match = re.search(r'[Tt]he answer is[:\s]+(-?[\d,\.]+)', text)
    if match:
        return match.group(1)
        
    return None

In [28]:
limit = 50
gsm8k_answer = []
math_answer = []
for i in range(limit):
    gsm8k_truth = gsm8k_ds[i]['answer'].split("####")[-1].strip()
    gsm8k_answer.append(gsm8k_truth)
    
    math_truth = extract_answer(math_ds[i]['solution'])
    math_answer.append(math_truth)

In [29]:
gsm8k_answer

['18',
 '3',
 '70000',
 '540',
 '20',
 '64',
 '260',
 '160',
 '45',
 '460',
 '366',
 '694',
 '13',
 '18',
 '60',
 '125',
 '230',
 '57500',
 '7',
 '6',
 '15',
 '14',
 '7',
 '8',
 '26',
 '2',
 '243',
 '16',
 '25',
 '104',
 '109',
 '80',
 '35',
 '70',
 '23',
 '9',
 '75',
 '2',
 '10',
 '18',
 '8',
 '200',
 '26',
 '48',
 '20',
 '104',
 '163',
 '800',
 '8',
 '30']

In [30]:
math_answer

['\\left( 3, \\frac{\\pi}{2} \\right)',
 'p - q',
 '\\frac{14}{3}',
 '9',
 '\\text{Evelyn}',
 '42',
 '27',
 '90^\\circ',
 '3\\sqrt{13}',
 '4',
 '2220',
 '\\frac{3}{56}',
 '284',
 '5',
 '\\sqrt{51}',
 '6 - 5i',
 '-50',
 '\\pi',
 '28',
 '3',
 '6+9i',
 '13535',
 '5',
 'x=5',
 '10',
 '1,-2',
 '144',
 '78',
 '-2 + 7i',
 '225',
 '52_8',
 '11\\sqrt2',
 '720',
 '\\frac{243}{625}',
 '-125',
 '3',
 '3, 5, 7',
 '72',
 '2000',
 '23',
 '12',
 '17',
 '4',
 '70 \\sqrt{2}',
 '1.25',
 '2',
 '6',
 '5',
 '\\frac{3}{2}',
 '83']