In [None]:
import json
import re

In [76]:
# Load few-shot data
file_location = "/Users/thyag/Desktop/codes/chain-of-thought/dataset/arithmetic reasoning/gsm8k/few-shot/gsm8k-few-shot-gpt-4-1-nano-2025-04-14.json"
with open(file_location, "r") as f:
    few_shot_data = json.load(f)

In [None]:
from openai import OpenAI
import json
import re
import os
from dotenv import load_dotenv

load_dotenv()

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def normalize_numeric(value_str):
    try:
        return str(float(value_str))
    except ValueError:
        return value_str

def openai_extract_final_answer(text):
    """
    Use the OpenAI API to extract the main answer from the given text.
    """
    print(f"[DEBUG] OpenAI API called with text: '{text[:100]}...' (length: {len(text)})")
    try:
        response = client.responses.create(
             model="gpt-4o-mini",
             instructions="Extract the final answer from the text. Only provide the final answer and nothing else. Do not include any reasoning or explanation. Do not include any symbols or formatting. If the answer is a number, return it as a number. If the answer is a letter (A, B, C, D), return it as a letter. If the answer is true/false, return it as true/false. If the answer is text, return it as text.",
             input=text,
        )
        result = response.output_text
        print(f"[DEBUG] OpenAI API returned: '{result}'")
        return result
    except Exception as e:
        print(f"[DEBUG] OpenAI API error: {e}")
        return "API_ERROR"
def extract_final_answer(text):
    """
    Extract the final answer from text in various formats.
    Works with numerical answers, multiple choice, true/false, and text answers.
    For ambiguous cases, including empty input, falls back to using the OpenAI API.
    """
    print(f"\n[DEBUG] Starting extract_final_answer with text: '{text[:100] if text else 'None'}...'")
    
    if text and text.strip():
        print("[DEBUG] Text is not empty, proceeding with regex patterns")
        original_text = text.strip()
        
        # Keep original text for final answer extraction
        cleaned_text = text.strip()
        cleaned_text = re.sub(r'[\*\_`]', '', cleaned_text)
        cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
        print(f"[DEBUG] Cleaned text: '{cleaned_text[:100]}...'")

        # HIGHEST PRIORITY: GSM8K format #### answer
        print("[DEBUG] Trying GSM8K #### pattern...")
        gsm8k_match = re.search(r'####\s*([\d,]+(?:\.\d+)?)', original_text)
        if gsm8k_match:
            result = normalize_numeric(gsm8k_match.group(1).replace(',', ''))
            print(f"[DEBUG] GSM8K #### pattern matched: '{result}'")
            return result
        print("[DEBUG] No GSM8K #### pattern found")

        # Second priority: Boxed answers (common in mathematical texts)
        print("[DEBUG] Trying boxed answer pattern...")
        boxed_match = re.search(r'\\?boxed\{?\$?([\d,]+(?:\.\d+)?)\$?\}?', original_text, re.IGNORECASE)
        if boxed_match:
            result = normalize_numeric(boxed_match.group(1).replace(',', ''))
            print(f"[DEBUG] Boxed pattern matched: '{result}'")
            return result
        print("[DEBUG] No boxed pattern found")

        # Third priority: Explicit "Answer:" patterns - use original text
        print("[DEBUG] Trying explicit answer patterns on original text...")
        answer_patterns = [
            r'\*\*Answer:\*\*\s*[^0-9]*?([\d,]+(?:\.\d+)?)',  # **Answer:** ... number
            r'Answer:\s*[^0-9]*?([\d,]+(?:\.\d+)?)',  # Answer: ... number
        ]
        
        for i, pattern in enumerate(answer_patterns):
            matches = list(re.finditer(pattern, original_text, re.IGNORECASE))
            if matches:
                # Take the last match (most likely to be the final answer)
                last_match = matches[-1]
                result = normalize_numeric(last_match.group(1).replace(',', ''))
                print(f"[DEBUG] Explicit answer pattern {i} matched: '{pattern}' -> '{result}' from '{last_match.group(0)}'")
                return result
        print("[DEBUG] No explicit answer patterns matched")

        # Try on cleaned text for other patterns
        print("[DEBUG] Trying other patterns on cleaned text...")
        other_patterns = [
            r'answer\s*(?:is|[:=])\s*\$?([\d,]+(?:\.\d+)?)',  # Basic answer pattern
            r'final\s*(?:answer|result)\s*(?:[:=])\s*\$?([\d,]+(?:\.\d+)?)',
            r'([\d,]+(?:\.\d+)?)\s+is\s+the\s+final\s+answer',
            r'therefore[^0-9]*?([\d,]+(?:\.\d+)?)(?!\s*[/=*+\-])',  # Avoid calculations
            r'thus[^0-9]*?([\d,]+(?:\.\d+)?)(?!\s*[/=*+\-])',
            r'hence[^0-9]*?([\d,]+(?:\.\d+)?)(?!\s*[/=*+\-])',
        ]

        for i, pattern in enumerate(other_patterns):
            matches = list(re.finditer(pattern, cleaned_text, re.IGNORECASE))
            if matches:
                # Take the last match
                last_match = matches[-1]
                result = normalize_numeric(last_match.group(1).replace(',', ''))
                print(f"[DEBUG] Other pattern {i} matched: '{pattern}' -> '{result}' from '{last_match.group(0)}'")
                return result
        print("[DEBUG] No other patterns matched")

        print("[DEBUG] Trying multiple choice pattern...")
        mc_match = re.search(r'answer\s*(?:is)?\s*([A-D])[\.:\)]', cleaned_text, re.IGNORECASE)
        if mc_match:
            result = mc_match.group(1).upper()
            print(f"[DEBUG] Multiple choice matched: '{result}'")
            return result
        print("[DEBUG] No multiple choice pattern matched")

        print("[DEBUG] Trying yes/no pattern...")
        yn_match = re.search(r'answer\s*(?:is)?\s*(yes|no|true|false)', cleaned_text, re.IGNORECASE)
        if yn_match:
            result = yn_match.group(1).lower()
            print(f"[DEBUG] Yes/no pattern matched: '{result}'")
            return result
        print("[DEBUG] No yes/no pattern matched")

        print("[DEBUG] Trying text pattern...")
        text_match = re.search(r'answer\s*(?:is|[:=])\s*"([^"]+)"', cleaned_text, re.IGNORECASE)
        if text_match:
            result = text_match.group(1).strip()
            print(f"[DEBUG] Text pattern matched: '{result}'")
            return result
        print("[DEBUG] No text pattern matched")

        print("[DEBUG] Trying line-by-line analysis...")
        lines = [line.strip() for line in original_text.split('\n') if line.strip()]
        if lines:
            # Look through lines from end to beginning for final answer
            for line_idx in range(len(lines)-1, -1, -1):
                line = lines[line_idx]
                print(f"[DEBUG] Checking line {line_idx}: '{line}'")
                
                # Check if this line contains "Answer:"
                if re.search(r'\*\*Answer:\*\*|\bAnswer:', line, re.IGNORECASE):
                    print(f"[DEBUG] Found answer line: '{line}'")
                    # Extract number from this line
                    num_match = re.search(r'([\d,]+(?:\.\d+)?)', line)
                    if num_match:
                        result = normalize_numeric(num_match.group(1).replace(',', ''))
                        print(f"[DEBUG] Number found in answer line: '{result}'")
                        return result
                
                # Check for other conclusion patterns
                for prefix in ['Therefore,', 'Thus,', 'Hence,']:
                    if line.startswith(prefix):
                        print(f"[DEBUG] Found prefix '{prefix}' in line: '{line}'")
                        num_match = re.search(r'([\d,]+(?:\.\d+)?)', line)
                        if num_match:
                            result = normalize_numeric(num_match.group(1).replace(',', ''))
                            print(f"[DEBUG] Number found in conclusion line: '{result}'")
                            return result
        else:
            print("[DEBUG] No lines found after splitting")
    else:
        print("[DEBUG] Text is empty or None, skipping regex patterns")

    print("[DEBUG] All regex patterns failed, falling back to OpenAI API")
    try:
        input_for_api = text or "" or "."
        print(f"[DEBUG] Calling OpenAI API with input: '{input_for_api[:50]}...'")
        api_result = openai_extract_final_answer(input_for_api)
        print(f"[DEBUG] Final result from API: '{api_result}'")
        return api_result
    except Exception as e:
        print(f"[DEBUG] Exception in API fallback: {e}")
        return "API_ERROR"

In [58]:
import os
from dotenv import load_dotenv

load_dotenv()

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def normalize_numeric(value_str):
    try:
        return float(value_str)
    except ValueError:
        return value_str

def openai_extract_final_answer(text):
    """
    Use the OpenAI API to extract the main answer from the given text.
    """
    try:
        response = client.responses.create(
            model="gpt-4o-mini",
            instructions="Extract the final answer from the text. Only provide the final answer and nothing else. Do not include any reasoning or explanation. Do not include any symbols or formatting. If the answer is a number, return it as a number. If the answer is a letter (A, B, C, D), return it as a letter. If the answer is true/false, return it as true/false. If the answer is text, return it as text. Example: If the answer is 42 pages/hours, return 42.",
            input=text,
        )
        return response.output_text
    except Exception:
        return "API_ERROR"
def extract_final_answer(text):
    """
    Extract the final answer from text using regex patterns.
    Falls back to OpenAI API if no pattern matches or matched result is empty.
    """
    def is_nonempty_string(s):
        return isinstance(s, str) and bool(s.strip())

    if text and text.strip():
        original_text = text.strip()
        cleaned_text = re.sub(r'[\*\_`]', '', original_text)
        cleaned_text = re.sub(r'\s+', ' ', cleaned_text)

        gsm8k_match = re.search(r'####\s*([\d,]+(?:\.\d+)?)', original_text)
        if gsm8k_match:
            result = normalize_numeric(gsm8k_match.group(1).replace(',', ''))
            if is_nonempty_string(result) or isinstance(result, (int, float)):
                return result

        boxed_match = re.search(r'\\?boxed\{?\$?([\d,]+(?:\.\d+)?)\$?\}?', original_text, re.IGNORECASE)
        if boxed_match:
            result = normalize_numeric(boxed_match.group(1).replace(',', ''))
            if is_nonempty_string(result) or isinstance(result, (int, float)):
                return result

        bold_general_match = re.search(r'\*\*([^\*]+)\*\*', original_text)
        if bold_general_match:
            inner = bold_general_match.group(1)
            num_match = re.search(r'([\d,]+(?:\.\d+)?)', inner)
            if num_match:
                result = normalize_numeric(num_match.group(1).replace(',', ''))
                if is_nonempty_string(result) or isinstance(result, (int, float)):
                    return result

        answer_patterns = [
            r'\*\*Answer:\*\*\s*[^0-9]*?([\d,]+(?:\.\d+)?)',
            r'Answer:\s*[^0-9]*?([\d,]+(?:\.\d+)?)',
        ]
        for pattern in answer_patterns:
            matches = list(re.finditer(pattern, original_text, re.IGNORECASE))
            if matches:
                result = normalize_numeric(matches[-1].group(1).replace(',', ''))
                if is_nonempty_string(result) or isinstance(result, (int, float)):
                    return result

        other_patterns = [
            r'answer\s*(?:is|[:=])\s*\$?([\d,]+(?:\.\d+)?)',
            r'final\s*(?:answer|result)\s*(?:[:=])\s*\$?([\d,]+(?:\.\d+)?)',
            r'([\d,]+(?:\.\d+)?)\s+is\s+the\s+final\s+answer',
            r'therefore[^0-9]*?([\d,]+(?:\.\d+)?)(?!\s*[/=*+\-])',
            r'thus[^0-9]*?([\d,]+(?:\.\d+)?)(?!\s*[/=*+\-])',
            r'hence[^0-9]*?([\d,]+(?:\.\d+)?)(?!\s*[/=*+\-])',
        ]
        for pattern in other_patterns:
            matches = list(re.finditer(pattern, cleaned_text, re.IGNORECASE))
            if matches:
                result = normalize_numeric(matches[-1].group(1).replace(',', ''))
                if is_nonempty_string(result) or isinstance(result, (int, float)):
                    return result

        mc_match = re.search(r'answer\s*(?:is)?\s*([A-D])[\.:\)]', cleaned_text, re.IGNORECASE)
        if mc_match:
            result = mc_match.group(1).upper()
            if is_nonempty_string(result):
                return result

        yn_match = re.search(r'answer\s*(?:is)?\s*(yes|no|true|false)', cleaned_text, re.IGNORECASE)
        if yn_match:
            result = yn_match.group(1).lower()
            if is_nonempty_string(result):
                return result

        text_match = re.search(r'answer\s*(?:is|[:=])\s*"([^"]+)"', cleaned_text, re.IGNORECASE)
        if text_match:
            result = text_match.group(1).strip()
            if result:
                return result

        lines = [line.strip() for line in original_text.split('\n') if line.strip()]
        for line in reversed(lines):
            if re.search(r'\*\*Answer:\*\*|\bAnswer:', line, re.IGNORECASE):
                num_match = re.search(r'([\d,]+(?:\.\d+)?)', line)
                if num_match:
                    result = normalize_numeric(num_match.group(1).replace(',', ''))
                    if is_nonempty_string(result) or isinstance(result, (int, float)):
                        return result
            for prefix in ['Therefore,', 'Thus,', 'Hence,']:
                if line.startswith(prefix):
                    num_match = re.search(r'([\d,]+(?:\.\d+)?)', line)
                    if num_match:
                        result = normalize_numeric(num_match.group(1).replace(',', ''))
                        if is_nonempty_string(result) or isinstance(result, (int, float)):
                            return result

    return normalize_numeric(openai_extract_final_answer(text or ""))


In [None]:
# Run the fixed function on the example data
for i, example in enumerate(few_shot_data):
    pred = extract_final_answer(example["answer_text"])
    print(f"Prediction: {pred}")

In [77]:
for i, example in enumerate(few_shot_data):
    pred = extract_final_answer(example["generated_answer"])
    print(f"Example {i+1}: {pred}")

Example 1: 48.0
Example 2: 10.0
Example 3: 100.0
Example 4: 42.0
Example 5: 3.0
Example 6: 35.0
Example 7: 2.0
Example 8: 16.0
Example 9: 200.0
Example 10: 990.0
Example 11: 121.0
Example 12: 5.0
Example 13: 60.0
Example 14: 35.0
Example 15: 5.0
Example 16: 450000.0
Example 17: 800.0
Example 18: 43.0
Example 19: 10.0
Example 20: 12.0


In [68]:
import json

def evaluate_accuracy(base_data, fewshot_data):
    assert len(base_data) == len(fewshot_data), "Mismatch in number of examples"

    base_correct = 0
    fewshot_correct = 0
    total = len(base_data)
    base_mismatched_examples = []
    fewshot_mismatched_examples = []

    for i, (base_item, few_item) in enumerate(zip(base_data, fewshot_data)):
        gold_answer = extract_final_answer(base_item["original_answer"])
        base_pred = extract_final_answer(base_item["answer_text"])
        fewshot_pred = extract_final_answer(few_item["generated_answer"])

        try:
            gold_answer = float(gold_answer)
            base_pred = float(base_pred)
            fewshot_pred = float(fewshot_pred)
        except (ValueError, TypeError):
            pass

        if base_pred == gold_answer:
            base_correct += 1
        else:
            base_mismatched_examples.append(i + 1)

        if fewshot_pred == gold_answer:
            fewshot_correct += 1
        else:
            fewshot_mismatched_examples.append(i + 1)

    base_accuracy = base_correct / total
    fewshot_accuracy = fewshot_correct / total

    return {
        "total": total,
        "base_correct": base_correct,
        "fewshot_correct": fewshot_correct,
        "base_accuracy": base_accuracy,
        "fewshot_accuracy": fewshot_accuracy,
        "base_mismatched_examples": base_mismatched_examples,
        "fewshot_mismatched_examples": fewshot_mismatched_examples,
    }

base = "/Users/thyag/Desktop/codes/chain-of-thought/dataset/arithmetic reasoning/gsm8k/base/gsm8k-base-gpt-4-1-nano-2025-04-14.json"
few_shot = "/Users/thyag/Desktop/codes/chain-of-thought/dataset/arithmetic reasoning/gsm8k/few-shot/gsm8k-few-shot-gpt-4-1-nano-2025-04-14.json"
with open(base) as f:
    base_data = json.load(f)

with open(few_shot) as f:
    fewshot_data = json.load(f)

results = evaluate_accuracy(base_data, fewshot_data)
print(results)

{'total': 20, 'base_correct': 16, 'fewshot_correct': 17, 'base_accuracy': 0.8, 'fewshot_accuracy': 0.85, 'base_mismatched_examples': [3, 12, 16, 19], 'fewshot_mismatched_examples': [16, 19, 20]}


In [None]:
import re
import os
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def normalize_float(s):
    try:
        return float(re.sub(r'[^\d.]$', '', s.replace(',', '')))
    except:
        return None

def openai_fallback(text):
    try:
        response = client.responses.create(
            model="gpt-4o-mini",
            instructions=(
                "Extract the final answer from the text. Only provide the answer value, no reasoning or symbols. "
                "If numeric, return as number. If letter (A–D), return letter. If true/false, return boolean. "
                "If short text, return it as is. Example: '42 pages' → 42"
            ),
            input=text,
        )
        return response.output_text.strip()
    except:
        return "API_ERROR"

def extract_final_answer(text):
    if not text or not text.strip():
        return openai_fallback(text)

    text = text.strip()

    # Pattern 1: Answer after ####
    match = re.search(r'####\s*([-+]?\d*\.?\d+)', text)
    if match:
        val = normalize_float(match.group(1))
        if val is not None:
            return val

    # Pattern 2: \boxed{number}
    match = re.search(r'\\boxed\{?\$?([-+]?\d*\.?\d+)\$?\}?', text)
    if match:
        val = normalize_float(match.group(1))
        if val is not None:
            return val

    # Pattern 3: Bolded numeric answer **123**
    match = re.search(r'\*\*[^*]*?(\d[\d,]*\.?\d*)[^*]*?\*\*', text)
    if match:
        val = normalize_float(match.group(1))
        if val is not None:
            return val

    # Pattern 4: Answer: ... 123
    match = re.search(r'\bAnswer\b[:\s]*[^A-Za-z0-9\-]*([$]?)(\d[\d,]*\.?\d*)', text, re.IGNORECASE)
    if match:
        val = normalize_float(match.group(2))
        if val is not None:
            return val
    
    # Pattern 9: "The answer is 72."
    match = re.search(r'\bthe answer is\b[:\s]*([-+]?\d[\d,]*\.?\d*)[^\d]*$', text, re.IGNORECASE)
    if match:
        val = normalize_float(match.group(1))
        if val is not None:
            return val

    # Pattern 5: Last number in last line containing "Answer"
    lines = text.split('\n')
    for line in reversed(lines):
        if "answer" in line.lower():
            match = re.search(r'(\d[\d,]*\.?\d*)', line)
            if match:
                val = normalize_float(match.group(1))
                if val is not None:
                    return val

    return openai_fallback(text)


In [72]:
# Run the fixed function on the example data
for i, example in enumerate(few_shot_data):
    pred = extract_final_answer(example["generated_answer"])
    print(f"Prediction: {pred}")

Prediction: 72.0
Prediction: 10.0
Prediction: 5.0
Prediction: 42.0
Prediction: 312.0
Prediction: 35.0
Prediction: 48.0
Prediction: 2.0
Prediction: 127.0
Prediction: 4950.0
Prediction: 847.0
Prediction: 20.0
Prediction: 85.0
Prediction: 45
Prediction: 5.0
Prediction: 1.0
Prediction: 800.0
Prediction: 53.0
Prediction: 8.0
Prediction: 12.0


In [75]:
import json

def evaluate_accuracy(base_data, fewshot_data):
    assert len(base_data) == len(fewshot_data), "Mismatch in number of examples"

    base_correct = 0
    fewshot_correct = 0
    total = len(base_data)
    base_mismatched_examples = []
    fewshot_mismatched_examples = []

    for i, (base_item, few_item) in enumerate(zip(base_data, fewshot_data)):
        gold_answer = extract_final_answer(base_item["original_answer"])
        base_pred = extract_final_answer(base_item["generated_answer"])
        fewshot_pred = extract_final_answer(few_item["generated_answer"])

        try:
            gold_answer = float(gold_answer)
            base_pred = float(base_pred)
            fewshot_pred = float(fewshot_pred)
        except (ValueError, TypeError):
            pass

        if base_pred == gold_answer:
            base_correct += 1
        else:
            base_mismatched_examples.append(i + 1)

        if fewshot_pred == gold_answer:
            fewshot_correct += 1
        else:
            fewshot_mismatched_examples.append(i + 1)

    base_accuracy = base_correct / total
    fewshot_accuracy = fewshot_correct / total

    return {
        "total": total,
        "base_correct": base_correct,
        "fewshot_correct": fewshot_correct,
        "base_accuracy": base_accuracy,
        "fewshot_accuracy": fewshot_accuracy,
        "base_mismatched_examples": base_mismatched_examples,
        "fewshot_mismatched_examples": fewshot_mismatched_examples,
    }

base = "/Users/thyag/Desktop/codes/chain-of-thought/dataset/arithmetic reasoning/gsm8k/base/gsm8k-base-gpt-4-1-nano-2025-04-14.json"
few_shot = "/Users/thyag/Desktop/codes/chain-of-thought/dataset/arithmetic reasoning/gsm8k/few-shot/gsm8k-few-shot-gpt-4-1-nano-2025-04-14.json"
with open(base) as f:
    base_data = json.load(f)

with open(few_shot) as f:
    fewshot_data = json.load(f)

results = evaluate_accuracy(base_data, fewshot_data)
print(results)

{'total': 20, 'base_correct': 13, 'fewshot_correct': 11, 'base_accuracy': 0.65, 'fewshot_accuracy': 0.55, 'base_mismatched_examples': [3, 4, 10, 12, 14, 16, 19], 'fewshot_mismatched_examples': [1, 3, 5, 7, 9, 13, 16, 19, 20]}
