In [9]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
checkpoint = "google/gemma-3-1b-it"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(device)

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

mps


In [10]:
prompt = """
Evaluate any defined expressions or constants.
Substitute known values into subsequent equations where applicable.
Solve any resulting systems of equations to find the values of unknowns.
If a variable is defined in terms of another (e.g., x=ay+bx=ay+b), solve for the required variable.
Present all steps clearly and logically, showing how each result is derived from the previous one.
Ensure all algebraic manipulations are valid. Simplify expressions where appropriate.
Provide the final answer, along with intermediate steps as needed for clarity. At the end, give just a numerical or symbolic answer.
At the end, clearly separate the final numerical answer using the format:
FinalAnswer: <value>
"""

In [11]:
task = "Solve -20*b + 128*b + 648 = 0 for b."

messages = [{"role": "system", "content": prompt}, {"role": "user", "content": task}]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)

inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.2, top_p=0.9, do_sample=True)
raw_output = tokenizer.decode(outputs[0])

print("RAW OUTPUT:")
print("="*50)
print(raw_output)
print("\n" + "="*50)

RAW OUTPUT:
<bos><bos><start_of_turn>user

Evaluate any defined expressions or constants.
Substitute known values into subsequent equations where applicable.
Solve any resulting systems of equations to find the values of unknowns.
If a variable is defined in terms of another (e.g., x=ay+bx=ay+b), solve for the required variable.
Present all steps clearly and logically, showing how each result is derived from the previous one.
Ensure all algebraic manipulations are valid. Simplify expressions where appropriate.
Provide the final answer, along with intermediate steps as needed for clarity. At the end, give just a numerical or symbolic answer.
At the end, clearly separate the final numerical answer using the format:
FinalAnswer: <value>


Solve -20*b + 128*b + 648 = 0 for b.<end_of_turn>
**Solution:**

We are given the equation -20*b + 128*b + 648 = 0.
First, combine the terms with 'b' on the left side:
(-20 + 128)b + 648 = 0
108b + 648 = 0

Now, isolate the term with 'b':
108b = -648
Div

In [12]:
import re

def parse_llm_output(output_text):
    """
    Parse LLM output to extract the final answer.
    
    Args:
        output_text (str): The raw output from the LLM
        
    Returns:
        dict: Contains 'answer', 'reasoning', and 'success' fields
    """
    try:
        # Remove special tokens and clean the text
        cleaned_text = output_text.replace("<bos>", "").replace("<end_of_turn>", "").strip()
        
        # Split by user/assistant turns to get only the assistant's response
        if "<start_of_turn>" in cleaned_text:
            parts = cleaned_text.split("<start_of_turn>")
            if len(parts) > 1:
                assistant_response = parts[-1].strip()
            else:
                assistant_response = cleaned_text
        else:
            assistant_response = cleaned_text
        
        # Extract the final answer using regex
        final_answer_patterns = [
            r"FinalAnswer:\s*([^\n<]+)",  # Standard format
            r"Final Answer:\s*([^\n<]+)",  # Alternative format
            r"Answer:\s*([^\n<]+)",  # Simple format
            r"Therefore,?\s*[a-zA-Z]?\s*=\s*([^\n<]+)",  # Pattern like "Therefore, b = -6"
            r"The answer is\s*([^\n<]+)",  # Natural language format
        ]
        
        extracted_answer = None
        for pattern in final_answer_patterns:
            match = re.search(pattern, assistant_response, re.IGNORECASE)
            if match:
                extracted_answer = match.group(1).strip()
                break
        
        # If no pattern matched, try to find the last number or expression
        if not extracted_answer:
            # Look for patterns like "b = -6" or "x = 5" at the end
            number_pattern = r"[a-zA-Z]?\s*=\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)"
            matches = re.findall(number_pattern, assistant_response)
            if matches:
                extracted_answer = matches[-1]
        
        # Clean up the extracted answer
        if extracted_answer:
            extracted_answer = extracted_answer.strip(".,!?;")
            # Remove any trailing punctuation or markdown
            extracted_answer = re.sub(r'[*_`]+', '', extracted_answer)
            
        return {
            'answer': extracted_answer,
            'reasoning': assistant_response,
            'success': extracted_answer is not None,
            'raw_output': output_text
        }
        
    except Exception as e:
        return {
            'answer': None,
            'reasoning': output_text,
            'success': False,
            'error': str(e),
            'raw_output': output_text
        }

def evaluate_solution(predicted_answer, ground_truth):
    """
    Evaluate if the predicted answer matches the ground truth.
    
    Args:
        predicted_answer (str): The extracted answer from LLM
        ground_truth (str): The correct answer
        
    Returns:
        bool: True if answers match, False otherwise
    """
    if not predicted_answer or not ground_truth:
        return False
    
    # Clean both answers
    def clean_answer(ans):
        if ans is None:
            return None
        # Remove whitespace and convert to string
        ans = str(ans).strip()
        # Remove common formatting
        ans = ans.replace(" ", "").replace(",", "")
        return ans.lower()
    
    pred_clean = clean_answer(predicted_answer)
    truth_clean = clean_answer(ground_truth)
    
    if pred_clean == truth_clean:
        return True
    
    # Try to convert to numbers for comparison
    try:
        pred_num = float(pred_clean)
        truth_num = float(truth_clean)
        # Check if numbers are close (handle floating point precision)
        return abs(pred_num - truth_num) < 1e-6
    except (ValueError, TypeError):
        # If conversion fails, do string comparison
        return pred_clean == truth_clean

print("Testing parser with sample output...")
test_output = '''<bos><bos><start_of_turn>user
Solve -20*b + 128*b + 648 = 0 for b.<end_of_turn>
**Solution:**

We are given the equation -20*b + 128*b + 648 = 0.
First, combine the terms with 'b' on the left side:
(-20 + 128)b + 648 = 0
108b + 648 = 0

Now, isolate the term with 'b':
108b = -648
Divide both sides by 108:
b = -648 / 108
b = -6

Therefore, b = -6.

FinalAnswer: -6
<end_of_turn>'''

result = parse_llm_output(test_output)
print(f"Extracted answer: {result['answer']}")
print(f"Success: {result['success']}")
print(f"Matches ground truth '-6': {evaluate_solution(result['answer'], '-6')}")


Testing parser with sample output...
Extracted answer: -6
Success: True
Matches ground truth '-6': True


In [13]:
def solve_math_problem(problem, expected_answer=None):
    """
    Solve a math problem using the LLM and parse the result.
    
    Args:
        problem (str): The math problem to solve
        expected_answer (str, optional): The expected answer for evaluation
        
    Returns:
        dict: Contains the parsed result and evaluation metrics
    """
    # Prepare the input
    messages = [{"role": "system", "content": prompt}, {"role": "user", "content": problem}]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Generate response
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.2, top_p=0.9, do_sample=True)
    raw_output = tokenizer.decode(outputs[0])
    
    # Parse the output
    parsed_result = parse_llm_output(raw_output)
    
    # Evaluate if expected answer is provided
    if expected_answer is not None:
        parsed_result['correct'] = evaluate_solution(parsed_result['answer'], expected_answer)
        parsed_result['expected_answer'] = expected_answer
    
    return parsed_result

print("="*50)
print("SOLVING: Solve -20*b + 128*b + 648 = 0 for b.")
print("="*50)

result = solve_math_problem("Solve -20*b + 128*b + 648 = 0 for b.", expected_answer="-6")

print(f"Extracted Answer: {result['answer']}")
print(f"Expected Answer: {result['expected_answer']}")
print(f"Correct: {result['correct']}")
print(f"Parsing Success: {result['success']}")
print("\n" + "="*50)
print("REASONING:")
print("="*50)
print(result['reasoning'][:500] + "..." if len(result['reasoning']) > 500 else result['reasoning'])


SOLVING: Solve -20*b + 128*b + 648 = 0 for b.
Extracted Answer: -6
Expected Answer: -6
Correct: True
Parsing Success: True

REASONING:
user

Evaluate any defined expressions or constants.
Substitute known values into subsequent equations where applicable.
Solve any resulting systems of equations to find the values of unknowns.
If a variable is defined in terms of another (e.g., x=ay+bx=ay+b), solve for the required variable.
Present all steps clearly and logically, showing how each result is derived from the previous one.
Ensure all algebraic manipulations are valid. Simplify expressions where appropriate.
Provide the final ans...


In [14]:
def evaluate_model_on_dataset(dataset, max_problems=None, categories=None, verbose=True):
    """
    Evaluate the model on a HuggingFace dataset with category-wise statistics.
    
    Args:
        dataset: HuggingFace dataset with 'question', 'answer', and 'category' fields
        max_problems (int, optional): Maximum number of problems to evaluate
        categories (list, optional): Specific categories to evaluate (if None, evaluate all)
        verbose (bool): Whether to print detailed progress
        
    Returns:
        dict: Evaluation results with overall and category-wise statistics
    """
    import pandas as pd
    from collections import defaultdict
    
    # Filter by categories if specified
    if categories:
        dataset = dataset.filter(lambda x: x['category'] in categories)
    
    # Limit the number of problems if specified
    if max_problems:
        dataset = dataset.select(range(min(max_problems, len(dataset))))
    
    # Initialize tracking variables
    results = []
    category_stats = defaultdict(lambda: {
        'total': 0, 'correct': 0, 'parsed': 0, 'errors': 0
    })
    
    total_problems = len(dataset)
    print(f"Evaluating model on {total_problems} problems...")
    
    # Process each problem
    for i, example in enumerate(dataset):
        problem = example['question']
        expected_answer = example['answer']
        category = example['category']
        
        if verbose:
            print(f"\nProblem {i+1}/{total_problems} [{category}]: {problem[:50]}...")
        
        # Update category total count
        category_stats[category]['total'] += 1
        
        try:
            result = solve_math_problem(problem, expected_answer)
            
            if result['success']:
                category_stats[category]['parsed'] += 1
                if result['correct']:
                    category_stats[category]['correct'] += 1
                    if verbose:
                        print(f"✅ CORRECT: {result['answer']}")
                else:
                    if verbose:
                        print(f"❌ WRONG: Got {result['answer']}, Expected {expected_answer}")
            else:
                if verbose:
                    print(f"⚠️ PARSING FAILED")
            
            results.append({
                'problem': problem,
                'expected_answer': expected_answer,
                'predicted_answer': result['answer'],
                'correct': result.get('correct', False),
                'parsed_successfully': result['success'],
                'category': category,
                'reasoning': result['reasoning']
            })
            
        except Exception as e:
            category_stats[category]['errors'] += 1
            if verbose:
                print(f"❌ ERROR: {str(e)}")
            results.append({
                'problem': problem,
                'expected_answer': expected_answer,
                'predicted_answer': None,
                'correct': False,
                'parsed_successfully': False,
                'category': category,
                'error': str(e)
            })
    
    # Calculate overall metrics
    total_correct = sum(stats['correct'] for stats in category_stats.values())
    total_parsed = sum(stats['parsed'] for stats in category_stats.values())
    total_errors = sum(stats['errors'] for stats in category_stats.values())
    
    overall_parsing_accuracy = total_parsed / total_problems if total_problems > 0 else 0
    overall_solving_accuracy = total_correct / total_problems if total_problems > 0 else 0
    overall_conditional_accuracy = total_correct / total_parsed if total_parsed > 0 else 0
    
    # Calculate category-wise metrics
    category_metrics = {}
    for category, stats in category_stats.items():
        parsing_acc = stats['parsed'] / stats['total'] if stats['total'] > 0 else 0
        solving_acc = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        conditional_acc = stats['correct'] / stats['parsed'] if stats['parsed'] > 0 else 0
        
        category_metrics[category] = {
            'total_problems': stats['total'],
            'correct_answers': stats['correct'],
            'parsed_successfully': stats['parsed'],
            'errors': stats['errors'],
            'parsing_accuracy': parsing_acc,
            'solving_accuracy': solving_acc,
            'conditional_accuracy': conditional_acc
        }
    
    # Create summary
    summary = {
        'overall_metrics': {
            'total_problems': total_problems,
            'parsed_successfully': total_parsed,
            'correct_answers': total_correct,
            'errors': total_errors,
            'parsing_accuracy': overall_parsing_accuracy,
            'solving_accuracy': overall_solving_accuracy,
            'conditional_accuracy': overall_conditional_accuracy
        },
        'category_metrics': category_metrics,
        'detailed_results': results
    }
    
    # Print summary
    if verbose:
        print(f"\n{'='*80}")
        print("EVALUATION SUMMARY")
        print(f"{'='*80}")
        print(f"Overall Results:")
        print(f"  Total Problems: {total_problems}")
        print(f"  Successfully Parsed: {total_parsed} ({overall_parsing_accuracy:.2%})")
        print(f"  Correct Answers: {total_correct} ({overall_solving_accuracy:.2%})")
        print(f"  Accuracy (given successful parsing): {overall_conditional_accuracy:.2%}")
        print(f"  Errors: {total_errors}")
        
        print(f"\n{'='*80}")
        print("CATEGORY-WISE RESULTS")
        print(f"{'='*80}")
        
        # Create a nice table for category results
        category_data = []
        for category, metrics in category_metrics.items():
            category_data.append([
                category,
                metrics['total_problems'],
                metrics['correct_answers'],
                f"{metrics['solving_accuracy']:.2%}",
                f"{metrics['parsing_accuracy']:.2%}",
                f"{metrics['conditional_accuracy']:.2%}"
            ])
        
        # Sort by category name
        category_data.sort(key=lambda x: x[0])
        
        # Print table header
        print(f"{'Category':<15} {'Total':<7} {'Correct':<7} {'Solve%':<7} {'Parse%':<7} {'Cond%':<7}")
        print("-" * 80)
        
        # Print category results
        for row in category_data:
            print(f"{row[0]:<15} {row[1]:<7} {row[2]:<7} {row[3]:<7} {row[4]:<7} {row[5]:<7}")
    
    return summary

def save_evaluation_results(summary, filename="evaluation_results.json"):
    """Save evaluation results to a JSON file."""
    import json
    
    # Create a serializable version (remove non-serializable parts)
    serializable_summary = {
        'overall_metrics': summary['overall_metrics'],
        'category_metrics': summary['category_metrics'],
        'detailed_results': [{
            'problem': r['problem'],
            'expected_answer': r['expected_answer'],
            'predicted_answer': r['predicted_answer'],
            'correct': r['correct'],
            'parsed_successfully': r['parsed_successfully'],
            'category': r['category']
        } for r in summary['detailed_results']]
    }
    
    with open(filename, 'w') as f:
        json.dump(serializable_summary, f, indent=2)
    
    print(f"Results saved to {filename}")


In [None]:
# Example usage with the HuggingFace dataset
from datasets import load_from_disk

# Load the dataset
dataset = load_from_disk("../../data/processed/math_qa_dataset")

print("Dataset loaded successfully!")
print(f"Total examples: {len(dataset)}")
print(f"Example: {dataset[0]}")
print(f"Categories: {set(dataset['category'])}")

# Test with a small sample first
print("\n" + "="*80)
print("TESTING WITH A SMALL SAMPLE (5 problems)")
print("="*80)

sample_results = evaluate_model_on_dataset(dataset, max_problems=5, verbose=True)

print("\n" + "="*80)
print("QUICK CATEGORY TEST (2 problems from 'arithmetic')")
print("="*80)

# Test with specific categories
category_results = evaluate_model_on_dataset(
    dataset, 
    max_problems=2, 
    categories=['arithmetic'], 
    verbose=True
)

# Save results
save_evaluation_results(sample_results, "sample_evaluation_results.json")
print("\nSample evaluation complete!")


Dataset loaded successfully!
Total examples: 1214
Example: {'question': 'What is the greatest common factor of 13975 and 130?', 'answer': '65', 'category': 'numbers'}
Categories: {'arithmetic', 'algebra', 'probability', 'comparison', 'measurement', 'numbers', 'polynomials', 'calculus'}

TESTING WITH A SMALL SAMPLE (5 problems)
Evaluating model on 5 problems...

Problem 1/5 [numbers]: What is the greatest common factor of 13975 and 13...
❌ WRONG: Got 5, Expected 65

Problem 2/5 [numbers]: What is the highest common factor of 20 and 365?...
✅ CORRECT: 5

Problem 3/5 [numbers]: Calculate the greatest common factor of 34945 and ...
❌ WRONG: Got 3, Expected 145

Problem 4/5 [numbers]: Let m(t) = -6*t - 3. Let o be m(-6). Let r = o - 7...
❌ WRONG: Got -18, Expected 2

Problem 5/5 [numbers]: What is the common denominator of 127/12 and 55/36...
❌ WRONG: Got 3001, Expected 12

EVALUATION SUMMARY
Overall Results:
  Total Problems: 5
  Successfully Parsed: 5 (100.00%)
  Correct Answers: 1 (20.00