In [48]:
import json
import re

In [49]:
# Load few-shot data
file_location = "/Users/thyag/Desktop/codes/chain-of-thought/dataset/arithmetic reasoning/gsm8k/base/gsm8k-base-gemini-2-0-flash.json"
with open(file_location, "r") as f:
    few_shot_data = json.load(f)

In [None]:
from openai import OpenAI
import json
import re
import os
from dotenv import load_dotenv

load_dotenv()

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def openai_extract_final_answer(text):
    """
    Use the OpenAI API to extract the main answer from the given text.
    """
    response = client.responses.create(
         model="gpt-4o-mini-2024-08-06",
         instructions="Extract the final answer from the text. Only provide the final answer and nothing else. Do not include any reasoning or explanation.",
         input=text,
    )
    return response.output_text

def extract_final_answer(text):
    """
    Extract the final answer from text in various formats.
    Works with numerical answers, multiple choice, true/false, and text answers.
    For ambiguous cases, falls back to using the OpenAI API.
    """
    if not text:
        return None
        
    text = text.strip()
    
    # Case 1: Common mathematical answer formats with numerical answers
    # Handles boxed answers, LaTeX, and various ways to express "the answer is X"
    patterns = [
        # Mathematical boxed format
        r'answer is\s+\$?\\?boxed\{?\$?([\d.]+)\$?\}?',
        # Common answer phrases 
        r'answer is\s+([\d.]+)',
        r'result is\s+([\d.]+)', 
        r'equals\s+([\d.]+)',
        r'final answer:?\s+([\d.]+)',
        # Dollar amounts
        r'answer is\s+\$\s*([\d,.]+)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return normalize_numeric(match.group(1))
    
    # Case 2: Multiple choice answers (A/B/C/D)
    mc_match = re.search(r'answer is\s+([A-D])[\.:\)]', text, re.IGNORECASE)
    if mc_match:
        return mc_match.group(1).upper()
    
    # Case 3: True/False or Yes/No answers
    yn_match = re.search(r'answer is\s+(yes|no|true|false)', text, re.IGNORECASE)
    if yn_match:
        return yn_match.group(1).lower()
    
    # Case 4: Text-based answers with specific markers
    text_match = re.search(r'answer is[:\s]+"([^"]+)"', text, re.IGNORECASE)
    if text_match:
        return text_match.group(1).strip()
        
    # Case 5: Check last line for answer pattern
    lines = [line for line in text.split('\n') if line.strip()]
    if lines:
        last_line = lines[-1].strip()
        
        # Check if last line starts with answer indicator
        for prefix in ['Answer:', 'Therefore,', 'Thus,']:
            if last_line.startswith(prefix):
                answer_part = last_line[len(prefix):].strip()
                # Try to extract number from this part
                num_match = re.search(r'([\d.]+)', answer_part)
                if num_match:
                    return normalize_numeric(num_match.group(1))
                return answer_part
        
        # Last resort: number at the end of text
        match = re.search(r'([\d.]+)\s*$', last_line)
        if match:
            return normalize_numeric(match.group(1))
    
    # Fallback to OpenAI API extraction for missing values
    try:
        api_result = openai_extract_final_answer(text)
        print("OpenAI API extraction result:", api_result)
        return api_result
    except Exception as e:
        print("OpenAI API extraction failed:", e)
        return None

def normalize_numeric(text):
    """Normalize numeric answers to integer if possible, otherwise float"""
    # Remove trailing periods, commas, etc.
    text = text.rstrip('.,:;')
    
    # Remove commas and other non-numeric characters
    cleaned = re.sub(r'[^\d.]', '', text)
    
    try:
        float_value = float(cleaned)
        if float_value.is_integer():
            return str(int(float_value))
        return str(float_value)
    except ValueError:
        return text

In [51]:
for i, example in enumerate(few_shot_data):
    pred = extract_final_answer(example["answer_text"])
    print(f"Example {i+1}: {pred}")

Example 1: 72
Example 2: 10
Example 3: 
Example 4: 
Example 5: 
Example 6: 
Example 7: 
Example 8: 16
Example 9: 
Example 10: 
Example 11: 
Example 12: 
Example 13: 
Example 14: 
Example 15: 5
Example 16: 
Example 17: 
Example 18: 43
Example 19: 16
Example 20: 


In [44]:
for i, example in enumerate(few_shot_data):
    pred = extract_final_answer(example["original_answer"])
    print(f"Example {i+1}: {pred}")

Example 1: 72
Example 2: 10
Example 3: 5
Example 4: 42
Example 5: 624
Example 6: 35
Example 7: 48
Example 8: 16
Example 9: 41
Example 10: 990
Example 11: 121
Example 12: 5
Example 13: 85
Example 14: 35
Example 15: 5
Example 16: 448000
Example 17: 800
Example 18: 43
Example 19: 16
Example 20: 16
