In [1]:
import pandas as pd
import logging
import re 

In [None]:
# Define the folder and file path for evaluation results
folder = 'mgsm/results_deepseekr1'  # <-- Change to the specific folder you want
filename = f'{folder}/final_result.csv'

In [16]:
# Load the results dataframe
df = pd.read_csv(filename)

# Set tolerance for numerical comparison
tolerance = 1e-6

In [17]:
def extract_answer(text):
    """
    Extract a numerical answer from the given text, handling boxed formats, 
    different localizations (e.g., commas, periods), and fallback strategies.

    Args:
        text (str): The text containing the numerical answer.
    
    Returns:
        float or None: The extracted numerical answer, or None if not found.
    """
    def parse_locale_number(number_str):
        """
        Handle different number formats, such as '1.234,56' vs '1,234.56'.
        """
        if '.' in number_str and ',' in number_str:
            if number_str.find('.') < number_str.find(','):
                return number_str.replace('.', '').replace(',', '.')
            else:
                return number_str.replace(',', '')
        elif '.' in number_str and not ',' in number_str:
            parts = number_str.split('.')
            if all(len(p) == 3 for p in parts[1:]):  # likely thousands separator
                return number_str.replace('.', '')
        elif ',' in number_str and not '.' in number_str:
            return number_str.replace(',', '')
        return number_str  # fallback

    try:
        # Extract from boxed expressions like \boxed{123}
        matches = re.findall(r'\\?boxed\{.*?([\d.,]+).*?\}', text)
        if matches:
            number_str = matches[-1]
            number_str = re.sub(r'[^\d.,-]', '', number_str)
            number_str = parse_locale_number(number_str)
            return float(number_str)

        # Fallback: extract last number appearing in text
        numbers = re.findall(r'[-+]?\d{1,3}(?:[.,]\d{3})*(?:[.,]\d+)?|\d+', text)
        if numbers:
            number_str = re.sub(r'[^\d.,-]', '', numbers[-1])
            number_str = parse_locale_number(number_str)
            return float(number_str)

    except Exception as e:
        logging.info(f"Error extracting answer from text: {text}\nException: {e}")
        return None

    logging.info(f"No valid answer found in text: {text}")
    return None

def check_correct(pred, actual, tol=tolerance):
    """
    Check if the predicted number matches the actual answer within a tolerance.

    Args:
        pred (float): Predicted number.
        actual (float): Actual answer.
        tol (float): Tolerance for equality check.
    
    Returns:
        bool: True if within tolerance, else False.
    """
    if pd.notnull(pred):
        return abs(actual - pred) < tol
    return False

In [18]:
# Extract numerical answers from model outputs
df["gen_answer2"] = df["gen_text"].apply(extract_answer)

# Check correctness based on re-extracted answers
df["is_correct2"] = df.apply(lambda row: check_correct(row["gen_answer2"], row["answer_number"]), axis=1)

# Show comparison of original and rechecked results
df[["q_no", "lang", "answer_number", "gen_answer", "is_correct", "gen_answer2", "is_correct2"]]

In [None]:
def transform_comparison_stats(comparison_stats):
    """
    Transform comparison statistics into nicely formatted correctness and incorrectness tables.

    Args:
        comparison_stats (pd.DataFrame): Dataframe containing aggregated statistics.
    
    Returns:
        (pd.DataFrame, pd.DataFrame): Correctness and Incorrectness summary tables.
    """
    total_initial_correct = comparison_stats['original_correct'].sum()
    total_revision_correct = comparison_stats['revised_correct'].sum()
    total_initial_incorrect = comparison_stats['original_incorrect'].sum()
    total_revision_incorrect = comparison_stats['revised_incorrect'].sum()

    tables = {}
    for table_name, cols in zip(['correctness', 'incorrectness'], 
                                [['original_correct', 'revised_correct'], ['original_incorrect', 'revised_incorrect']]):
        data = {'type': ['initial', 'recheck']}
        
        for lang in comparison_stats['lang']:
            data[lang] = [
                comparison_stats.loc[comparison_stats['lang'] == lang, cols[0]].values[0],
                comparison_stats.loc[comparison_stats['lang'] == lang, cols[1]].values[0]
            ]
        
        # Add total sums across languages
        data['total'] = [
            sum(comparison_stats[cols[0]]),
            sum(comparison_stats[cols[1]])
        ]
        
        tables[table_name] = pd.DataFrame(data)

    return tables['correctness'], tables['incorrectness']

In [21]:
comparison_stats = df.groupby('lang')[['is_correct', 'is_correct2']].agg(
    original_correct=('is_correct', lambda x: sum(x == True)),
    original_incorrect=('is_correct', lambda x: sum(x == False)),
    revised_correct=('is_correct2', lambda x: sum(x == True)),
    revised_incorrect=('is_correct2', lambda x: sum(x == False))
).reset_index()

# Transform to readable tables
correctness_table, incorrectness_table = transform_comparison_stats(comparison_stats)

print('correctness_table')
display(correctness_table)

print('incorrectness_table')
display(incorrectness_table)

In [22]:
mask = df["is_correct2"] == False  # Only focus on rows that were incorrect initially

# Apply extraction on self-revision texts
df.loc[mask, "revised_answer2"] = df.loc[mask, "self_revision"].apply(extract_answer)

# Recheck correctness after revision extraction
df.loc[mask, "is_correct_rev2"] = df.loc[mask].apply(
    lambda row: check_correct(row["revised_answer2"], row["answer_number"], tol=tolerance),
    axis=1
)

In [23]:
# If already correct, mark revision as correct
df.loc[~mask, "revised_answer2"] = None
df.loc[~mask, "is_correct_rev2"] = True

# Create a dataframe to inspect revision rechecking
df_revision_check = df[mask][[
    "q_no", "lang", "answer_number", 
    "gen_answer2", "is_correct2", 
    "revised_answer2", "is_correct_rev2"
]]

df_revision_check

In [None]:
comparison_stats = df.groupby('lang')[['is_correct2', 'is_correct_rev2']].agg(
    original_correct=('is_correct2', lambda x: sum(x == True)),
    original_incorrect=('is_correct2', lambda x: sum(x == False)),
    revised_correct=('is_correct_rev2', lambda x: sum(x == True)),
    revised_incorrect=('is_correct_rev2', lambda x: sum(x == False))
).reset_index()

correctness_table, incorrectness_table = transform_comparison_stats(comparison_stats)

print('correctness_table')
display(correctness_table)

print('incorrectness_table')
display(incorrectness_table)

In [27]:
X = 'X'  # Prefix for saving
df.to_csv(f'{folder}/{X}final_result.csv', index=False)