# Reanalyze Test Results with Enhanced Equivalence Checking

This notebook reanalyzes the test results from GSM8K and MATH datasets using an improved answer comparison method that handles:
- List bracket normalization `[5]` → `5`
- Complex numbers `I` ↔ `i`
- Power notation `**` ↔ `^`
- Implicit multiplication `2*k` ↔ `2k`
- LaTeX sqrt and fractions
- GSM8K comma formatting

We'll compare the original results with the new equivalence checking to see how many additional cases become correct.

In [8]:
import json
import re
from typing import Any, Dict, List
import pandas as pd

# Load the JSON files
with open('result_sft_gsm8k_v0.json', 'r') as f:
    gsm8k_v0_data = json.load(f)

with open('result_sft_gsm8k_v3.json', 'r') as f:
    gsm8k_v3_data = json.load(f)

with open('result_sft_math_v3.json', 'r') as f:
    math_v3_data = json.load(f)

print("Loaded files:")
print(f"GSM8K v0: {gsm8k_v0_data['summary']}")
print(f"GSM8K v3: {gsm8k_v3_data['summary']}")
print(f"MATH v3: {math_v3_data['summary']}")

Loaded files:
GSM8K v0: {'model': 'sft', 'dataset': 'gsm8k', 'data_version': 'v0', 'total': 1319, 'correct': 412, 'accuracy': 0.312357846853677}
GSM8K v3: {'model': 'sft', 'dataset': 'gsm8k', 'data_version': 'v3', 'total': 1319, 'correct': 661, 'accuracy': 0.5011372251705838}
MATH v3: {'model': 'sft', 'dataset': 'math', 'data_version': 'v3', 'total': 500, 'correct': 85, 'accuracy': 0.17}


In [9]:
import sys
sys.path.insert(0, '/home/guest/AdvancedLLMReasoning')

from math_tutor_model.math_equivalence import is_equiv

def enhanced_is_equiv(predicted, truth, dataset='math'):
    """
    Enhanced equivalence checking - matches the logic in test.py
    """
    # Handle None or empty predictions
    if predicted is None or predicted == '':
        return False
    
    # Convert to string
    predicted = str(predicted)
    truth = str(truth)
    
    # For GSM8K, remove commas from truth (as in test.py line 187)
    if dataset == 'gsm8k':
        truth = truth.replace(',', '')
    
    # Use the improved is_equiv from math_equivalence.py
    return is_equiv(predicted, truth)

# Test the function
print("Testing enhanced_is_equiv:")
test_cases = [
    ("3*sqrt(13)", "3\\sqrt{13}", True),
    ("6+9*I", "6+9i", True),
    ("x**3+3*x-6", "x^3+3x-6", True),
    ("276000.0", "276000", True),
]

for pred, truth, expected in test_cases:
    result = enhanced_is_equiv(pred, truth)
    status = "✓" if result == expected else "✗"
    print(f"{status} {pred} == {truth}: {result}")

Testing enhanced_is_equiv:
✓ 3*sqrt(13) == 3\sqrt{13}: True
✓ 6+9*I == 6+9i: True
✓ x**3+3*x-6 == x^3+3x-6: True
✓ 276000.0 == 276000: True


In [10]:
def reanalyze_dataset(data, dataset_name='math'):
    """
    Reanalyze a dataset with enhanced equivalence checking.
    Updates the 'correct' field directly, just like test.py does.
    """
    details = data['details']
    original_correct = 0
    new_correct = 0
    newly_correct = []
    newly_incorrect = []
    
    for i, item in enumerate(details):
        predicted = item.get('predicted_answer') or item.get('predicted')
        truth = item.get('truth')
        original_status = item.get('correct', False)
        
        # Compute new correctness
        new_status = enhanced_is_equiv(predicted, truth, dataset=dataset_name)
        
        # Track statistics
        if original_status:
            original_correct += 1
        if new_status:
            new_correct += 1
        
        # Track changes
        if new_status and not original_status:
            newly_correct.append({
                'index': i,
                'predicted': predicted,
                'truth': truth,
                'question': item.get('question', 'N/A')[:200]
            })
        elif not new_status and original_status:
            newly_incorrect.append({
                'index': i,
                'predicted': predicted,
                'truth': truth,
                'question': item.get('question', 'N/A')[:200]
            })
        
        # Update the correct field (same as test.py)
        item['correct'] = new_status
    
    return {
        'original_correct': original_correct,
        'new_correct': new_correct,
        'total': len(details),
        'newly_correct': newly_correct,
        'newly_incorrect': newly_incorrect,
        'improvement': new_correct - original_correct
    }

print("Reanalyzing datasets...")
print("=" * 80)

# Reanalyze GSM8K v0
print("\n1. GSM8K v0:")
gsm8k_v0_analysis = reanalyze_dataset(gsm8k_v0_data, 'gsm8k')
print(f"Original: {gsm8k_v0_analysis['original_correct']}/{gsm8k_v0_analysis['total']} = {gsm8k_v0_analysis['original_correct']/gsm8k_v0_analysis['total']:.4f}")
print(f"Enhanced: {gsm8k_v0_analysis['new_correct']}/{gsm8k_v0_analysis['total']} = {gsm8k_v0_analysis['new_correct']/gsm8k_v0_analysis['total']:.4f}")
print(f"Improvement: +{gsm8k_v0_analysis['improvement']} cases")

# Reanalyze GSM8K v3
print("\n2. GSM8K v3:")
gsm8k_v3_analysis = reanalyze_dataset(gsm8k_v3_data, 'gsm8k')
print(f"Original: {gsm8k_v3_analysis['original_correct']}/{gsm8k_v3_analysis['total']} = {gsm8k_v3_analysis['original_correct']/gsm8k_v3_analysis['total']:.4f}")
print(f"Enhanced: {gsm8k_v3_analysis['new_correct']}/{gsm8k_v3_analysis['total']} = {gsm8k_v3_analysis['new_correct']/gsm8k_v3_analysis['total']:.4f}")
print(f"Improvement: +{gsm8k_v3_analysis['improvement']} cases")

# Reanalyze MATH v3
print("\n3. MATH v3:")
math_v3_analysis = reanalyze_dataset(math_v3_data, 'math')
print(f"Original: {math_v3_analysis['original_correct']}/{math_v3_analysis['total']} = {math_v3_analysis['original_correct']/math_v3_analysis['total']:.4f}")
print(f"Enhanced: {math_v3_analysis['new_correct']}/{math_v3_analysis['total']} = {math_v3_analysis['new_correct']/math_v3_analysis['total']:.4f}")
print(f"Improvement: +{math_v3_analysis['improvement']} cases")

Reanalyzing datasets...

1. GSM8K v0:
Original: 412/1319 = 0.3124
Enhanced: 416/1319 = 0.3154
Improvement: +4 cases

2. GSM8K v3:
Original: 661/1319 = 0.5011
Enhanced: 666/1319 = 0.5049
Improvement: +5 cases

3. MATH v3:
Original: 85/500 = 0.1700
Enhanced: 97/500 = 0.1940
Improvement: +12 cases




In [11]:
# Create comparison DataFrame
comparison_data = {
    'Dataset': ['GSM8K v0', 'GSM8K v3', 'MATH v3'],
    'Total': [
        gsm8k_v0_analysis['total'],
        gsm8k_v3_analysis['total'],
        math_v3_analysis['total']
    ],
    'Original Correct': [
        gsm8k_v0_analysis['original_correct'],
        gsm8k_v3_analysis['original_correct'],
        math_v3_analysis['original_correct']
    ],
    'Enhanced Correct': [
        gsm8k_v0_analysis['new_correct'],
        gsm8k_v3_analysis['new_correct'],
        math_v3_analysis['new_correct']
    ],
    'Improvement': [
        gsm8k_v0_analysis['improvement'],
        gsm8k_v3_analysis['improvement'],
        math_v3_analysis['improvement']
    ]
}

df_comparison = pd.DataFrame(comparison_data)
df_comparison['Original Accuracy'] = df_comparison['Original Correct'] / df_comparison['Total']
df_comparison['Enhanced Accuracy'] = df_comparison['Enhanced Correct'] / df_comparison['Total']
df_comparison['Accuracy Gain'] = df_comparison['Enhanced Accuracy'] - df_comparison['Original Accuracy']

print("\\nAccuracy Comparison Table:")
print("=" * 100)
print(df_comparison.to_string(index=False))

print("\\n\\nSummary:")
print(f"Total improvement across all datasets: +{sum([gsm8k_v0_analysis['improvement'], gsm8k_v3_analysis['improvement'], math_v3_analysis['improvement']])} cases")

\nAccuracy Comparison Table:
 Dataset  Total  Original Correct  Enhanced Correct  Improvement  Original Accuracy  Enhanced Accuracy  Accuracy Gain
GSM8K v0   1319               412               416            4           0.312358           0.315390       0.003033
GSM8K v3   1319               661               666            5           0.501137           0.504928       0.003791
 MATH v3    500                85                97           12           0.170000           0.194000       0.024000
\n\nSummary:
Total improvement across all datasets: +21 cases


In [12]:
print("Newly Correct Cases:")
print("=" * 100)

datasets = [
    ('GSM8K v0', gsm8k_v0_analysis),
    ('GSM8K v3', gsm8k_v3_analysis),
    ('MATH v3', math_v3_analysis)
]

for dataset_name, analysis in datasets:
    print(f"\\n{dataset_name}: {len(analysis['newly_correct'])} newly correct cases")
    print("-" * 100)
    
    # Show first 5 examples
    for i, case in enumerate(analysis['newly_correct']):
        print(f"\\nExample {i+1} (Index {case['index']}):")
        print(f"Question: {case['question']}")
        print(f"Predicted: {case['predicted']}")
        print(f"Truth: {case['truth']}")
    
    if len(analysis['newly_correct']) > 5:
        print(f"\\n... and {len(analysis['newly_correct']) - 5} more cases")
    
    if len(analysis['newly_correct']) == 0:
        print("No newly correct cases found.")

Newly Correct Cases:
\nGSM8K v0: 4 newly correct cases
----------------------------------------------------------------------------------------------------
\nExample 1 (Index 146):
Question: Johnny is picking up the toys on the floor of his room.  He'd dumped a lego boxed set with 500 pieces on the floor, and another one that had 3 times more pieces than the 500 piece one, and another one
Predicted: 2125.0
Truth: 2,125
\nExample 2 (Index 249):
Question: On Monday, Sue ate 4 times as many cookies as her sister. On Tuesday, she ate twice as many cookies as her sister. Her sister ate 5 cookies on Monday and 13 the next day. If 1 cookie has 200 calories,
Predicted: 5600
Truth: 5,600
\nExample 3 (Index 640):
Question: Bill is ordering a new truck. He has decided to purchase a two-ton truck with several added features: a king cab upgrade, a towing package, leather seats, running boards, and the upgraded exterior lig
Predicted: 43500.0
Truth: 43,500
\nExample 4 (Index 1206):
Question: John de

In [13]:
print("Newly Incorrect Cases:")
print("=" * 100)

for dataset_name, analysis in datasets:
    print(f"\\n{dataset_name}: {len(analysis['newly_incorrect'])} newly incorrect cases")
    print("-" * 100)
    
    # Show first 5 examples
    for i, case in enumerate(analysis['newly_incorrect'][:5]):
        print(f"\\nExample {i+1} (Index {case['index']}):")
        print(f"Question: {case['question']}")
        print(f"Predicted: {case['predicted']}")
        print(f"Truth: {case['truth']}")
    
    if len(analysis['newly_incorrect']) > 5:
        print(f"\\n... and {len(analysis['newly_incorrect']) - 5} more cases")
    
    if len(analysis['newly_incorrect']) == 0:
        print("✓ No newly incorrect cases - all improvements are gains!")

Newly Incorrect Cases:
\nGSM8K v0: 0 newly incorrect cases
----------------------------------------------------------------------------------------------------
✓ No newly incorrect cases - all improvements are gains!
\nGSM8K v3: 0 newly incorrect cases
----------------------------------------------------------------------------------------------------
✓ No newly incorrect cases - all improvements are gains!
\nMATH v3: 1 newly incorrect cases
----------------------------------------------------------------------------------------------------
\nExample 1 (Index 31):
Question: Simplify $\sqrt{242}$.
Predicted: 11*sqrt(2)
Truth: 11\sqrt2


In [14]:
# Update summaries (same structure as test.py)
gsm8k_v0_data['summary']['correct'] = gsm8k_v0_analysis['new_correct']
gsm8k_v0_data['summary']['accuracy'] = gsm8k_v0_analysis['new_correct'] / gsm8k_v0_analysis['total']

gsm8k_v3_data['summary']['correct'] = gsm8k_v3_analysis['new_correct']
gsm8k_v3_data['summary']['accuracy'] = gsm8k_v3_analysis['new_correct'] / gsm8k_v3_analysis['total']

math_v3_data['summary']['correct'] = math_v3_analysis['new_correct']
math_v3_data['summary']['accuracy'] = math_v3_analysis['new_correct'] / math_v3_analysis['total']

# Save to new files
with open('result_sft_gsm8k_v0_enhanced.json', 'w') as f:
    json.dump(gsm8k_v0_data, f, indent=2)

with open('result_sft_gsm8k_v3_enhanced.json', 'w') as f:
    json.dump(gsm8k_v3_data, f, indent=2)

with open('result_sft_math_v3_enhanced.json', 'w') as f:
    json.dump(math_v3_data, f, indent=2)

print("Updated results saved to:")
print("- result_sft_gsm8k_v0_enhanced.json")
print("- result_sft_gsm8k_v3_enhanced.json")
print("- result_sft_math_v3_enhanced.json")

print("\nUpdated fields (same format as test.py output):")
print("- 'correct' field updated in each detail entry")
print("- Summary 'correct' and 'accuracy' updated")

Updated results saved to:
- result_sft_gsm8k_v0_enhanced.json
- result_sft_gsm8k_v3_enhanced.json
- result_sft_math_v3_enhanced.json

Updated fields (same format as test.py output):
- 'correct' field updated in each detail entry
- Summary 'correct' and 'accuracy' updated


## 7. Export Updated Results

Save the reanalyzed results to new JSON files.

## 6. Identify Newly Incorrect Cases

Show cases that were marked correct originally but are now incorrect (if any).

## 5. Identify Newly Correct Cases

Show examples of cases that were marked incorrect originally but are now correct.

## 4. Compare Original vs Enhanced Accuracy

## 3. Reanalyze All Datasets

Apply the enhanced comparison function to all test cases.

## 2. Define Enhanced Answer Comparison Function

Import the improved is_equiv function from math_equivalence.py with all normalizations.

## 1. Load and Parse JSON Files