# Accuracy Evaluation Script

In [2]:
import re
import random

random.seed(1234)

def extract_option(text):
    """
    Extract the option identifier from the end of a text.
    
    Args:
        text (str): The input text to parse
        
    Returns:
        tuple: (full_match, option_value) where option_value is just the number/letter
    """
    # Clean the text by removing any trailing whitespace
    cleaned_text = text.strip()

    # Punctuation stripping
    for char in [",", ".", "!", "?", ";", ":", "'"]:
        cleaned_text = cleaned_text.strip(char)
    
    # Define patterns to search for options anywhere in the text
    patterns = [
        r'[Oo]ption\s*(\d+)',     # Option3, Option 3
        r'[Oo]ption\s*([A-Za-z])' # OptionA, Option A
    ]
    
    # Try each pattern
    for pattern in patterns:
        matches = re.findall(pattern, cleaned_text)
        if matches:
            # Return the first option found
            return answer_to_index(matches[0])
    
    # Look for standalone numbers or letters that might be options
    # This is a fallback and may produce false positives
    standalone_pattern = r'\b([A-Za-z]|\d+)\b'
    matches = re.findall(standalone_pattern, cleaned_text)
    if matches:
        # Try to find what looks most like an option (preference for single letters or numbers)
        for match in matches:
            if len(match) == 1 and match.isalpha():  # Single letter
                return answer_to_index(match)
            elif match.isdigit() and 1 <= int(match) <= 10:  # Reasonable option number
                return answer_to_index(match)
    
    return answer_to_index(None)

def answer_to_index(response):
    """
    Return the predicted index of the parsed response, E.g., 0, 1, 2, 3
    """
    try:
        letter_choices = ["a", "b", "c", "d"]
        if response.lower() in letter_choices:
            response = letter_choices.index(response.lower())
        
        if int(response) > 3:
            response = random.randint(0, 3)

        return response
    except Exception:
        return random.randint(0, 3)

In [3]:
# Test with various samples
test_samples = [
    "So, the answer is: Option3.",
    "Therefore, option2 is correct.",
    "The correct answer is Option C.",
    "I believe option 6 is the answer.",
    "Hence, the solution is option A.",
    "The result is 5."
]

print("Testing different formats:")
for sample in test_samples:
    index = extract_option(sample)
    print(f"Text: {sample}")
    print(f"Index: {index}\n")

Testing different formats:
Text: So, the answer is: Option3.
Index: 3

Text: Therefore, option2 is correct.
Index: 2

Text: The correct answer is Option C.
Index: 2

Text: I believe option 6 is the answer.
Index: 3

Text: Hence, the solution is option A.
Index: 0

Text: The result is 5.
Index: 0



In [4]:
import json
import numpy as np

def eval(response_files, out_files):
    """
    Evaluate the accuracy of the model.
    """
    all_accs = []
    for response, out in zip(response_files, out_files):
        data = [json.loads(line) for line in open(response)]
        out_file = open(out, "w")

        accs = []
        correct = []
        incorrect = []
        for item in data:
            prompt = item["prompt"]
            label = item["question_id"]
            gen = item["text"]

            pred = extract_option(gen)

            acc = pred == label
            accs.append(acc)

            if acc:
                correct.append((prompt, gen, label))
            else:
                incorrect.append((prompt, gen, label))

        all_accs.append(accs)
        print(response, len(accs), np.mean(accs), file=out_file) # Send output to file

        # Print correct and incorrect predictions too
        print("Correct predictions:", file=out_file)

        for pred in correct:
            print(f"Prompt:", pred[0], file=out_file)
            print(f"Model output:", pred[1], file=out_file)
            print(f"Actual answer:", pred[2], file=out_file)
            print("", file=out_file)

        print("\n-------------------------------------------", file=out_file)
        print("Incorrect predictions:", file=out_file)
        for pred in incorrect:
            print(f"Prompt:", pred[0], file=out_file)
            print(f"Model output:", pred[1], file=out_file)
            print(f"Actual answer:", pred[2], file=out_file)
            print("", file=out_file)
            
    return all_accs

In [5]:
response_files = ["./playground/data/ai2d_predictions_llava-7b_imagenet-and-llava-trained.jsonl", "./playground/data/ai2d_predictions_llava-7b_base.jsonl"]
out_files = ["results/eval_ai2d_combined.txt", "results/eval_ai2d_base.txt"]
all_accs = eval(response_files, out_files)