In [40]:
jsonl_datasets = ['GSM8K', 'date', 'MultiArith', 'ASDiv', 'SVAMP', 'AQUA', 'StrategyQA','p_GSM8K', 'commonsenseQA','SPARTQA']
json_datasets = ['logical_deduction_seven_objects', 'reasoning_about_colored_objects']
all_datasets = jsonl_datasets + json_datasets

In [27]:
import csv
import re
import json
import os

def parse_csv_file(file_path):
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                id_str = str(id_).strip()  # Ensure ID is treated as a string
                if id_str:
                    qa_pairs.append((id_str, question, answer_text))
                else:
                    print(f"Skipping a row due to empty 'id': {row}")
            else:
                # Handle cases without 'id' by skipping
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs

def read_ground_truth(path):
    ground_truth = {}
    file_extension = os.path.splitext(path)[1].lower()
    
    try:
        with open(path, 'r', encoding='utf-8') as f:
            # If the file is a JSON array (.json), load it directly
            if file_extension == '.json':
                data_list = json.load(f)
            # If the file is JSONL, read line-by-line
            elif file_extension == '.jsonl':
                data_list = [json.loads(line) for line in f]
            else:
                raise ValueError("Unsupported file format. Please provide a .json or .jsonl file.")
        
        # Process each entry in data_list (which is a list of dicts)
        for data in data_list:
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                id_str = str(id_).strip()  # Ensure ID is treated as a string
                answer_str = str(answer).lower().strip()
                # GSM ONLY: Assuming the answer is split by '####' and the relevant part is after it
                if '####' in answer_str:
                    answer_str = answer_str.split('####')[1].strip()
                # Extract numbers only, as per original code
                numbers_only = re.sub(r'[^0-9]', '', answer_str)
                ground_truth[id_str] = numbers_only
            else:
                print(f"Invalid ground truth entry (missing 'id' or 'answer'): {data}")

    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
    except ValueError as ve:
        print(ve)
    
    return ground_truth

def extract_final_answer(answer_text):
    """
    Extracts the final answer from the answer_text.
    Assumes the final answer is within the last pair of curly braces {}.
    """
    final_answer_match = re.search(r'\{([^}]+)\}(?=[^}]*$)', answer_text, re.DOTALL)
    if final_answer_match:
        extracted = final_answer_match.group(1)
        # Remove commas, dollar signs, and trim whitespace
        final_answer = re.sub(r'[^\d.]', '', extracted).strip().lower()
        # Handle boolean answers if applicable
        if "no" in final_answer or "false" in final_answer:
            final_answer = "false"
        elif "yes" in final_answer or "true" in final_answer:
            final_answer = "true"
        return final_answer
    else:
        return ""

def compute_accuracy(qa_pairs, ground_truth):
    correct_answers = 0
    total_answers = 0
    mismatches = []  # Optional: To store mismatched cases for further inspection

    for id_, question, answer_text in qa_pairs:
        final_answer = extract_final_answer(answer_text)
        gt_answer = ground_truth.get(id_)

        if gt_answer is None:
            print(f"Ground truth not available for ID '{id_}'. Skipping.")
            continue

        is_correct = final_answer == gt_answer
        if is_correct:
            correct_answers += 1
        else:
            mismatches.append((id_, question, final_answer, gt_answer))
        total_answers += 1

    accuracy_percentage = (correct_answers / total_answers * 100) if total_answers > 0 else 0
    return correct_answers, total_answers, accuracy_percentage, mismatches

def main():
    for dataset in json_datasets:
        input_csv = f'/Users/log/Github/textual_grounding/logan/results/final/VanillaCoT/{dataset}/llama3.170b/zero_shot_vanilla_cot_None_{dataset}_llama3.170b.csv' 
        
        if dataset in jsonl_datasets: 
            ground_truth_file = f'/Users/log/Github/textual_grounding/data/{dataset}/test.jsonl'  
        else:
            ground_truth_file = f'/Users/log/Github/textual_grounding/data/{dataset}/test.json'

        print(f"\nProcessing dataset: {dataset}")

        # Check if input files exist
        if not os.path.isfile(input_csv):
            print(f"Input CSV file not found: {input_csv}")
            continue  # Continue to the next dataset instead of exiting
        if not os.path.isfile(ground_truth_file):
            print(f"Ground truth JSONL file not found: {ground_truth_file}")
            continue  # Continue to the next dataset instead of exiting

        # Parse the input CSV file to extract IDs, questions, and answers
        qa_pairs = parse_csv_file(input_csv)
        print(f"Total QA Pairs Parsed: {len(qa_pairs)}")  # Debug: Print the number of QA pairs parsed

        # Read the ground truth answers
        ground_truth = read_ground_truth(ground_truth_file)
        print(f"Total Ground Truth Entries: {len(ground_truth)}")  # Debug: Print the number of ground truth entries

        # Check if any QA pairs were found
        if not qa_pairs:
            print("No question-answer pairs were found in the input file.")
            continue  # Continue to the next dataset

        # Compute accuracy
        correct, total, accuracy, mismatches = compute_accuracy(qa_pairs, ground_truth)

        # Print the accuracy
        print(f"\n--- {dataset} Summary ---")
        print(f"Accuracy: {accuracy:.2f}% ({correct}/{total} correct)")

        # Optional: Print mismatched cases for debugging
        # if mismatches:
        #     print("\n--- Mismatched Cases ---")
        #     for id_, question, predicted, actual in mismatches:
        #         print(f"ID: {id_}")
        #         print(f"Question: {question}")
        #         print(f"Predicted Answer: {predicted}")
        #         print(f"Ground Truth Answer: {actual}\n")

if __name__ == "__main__":
    main()



Processing dataset: logical_deduction_seven_objects
Input CSV file not found: /Users/log/Github/textual_grounding/logan/results/final/VanillaCoT/logical_deduction_seven_objects/llama3.170b/zero_shot_vanilla_cot_None_logical_deduction_seven_objects_llama3.170b.csv

Processing dataset: reasoning_about_colored_objects
Total QA Pairs Parsed: 75
Total Ground Truth Entries: 250

--- reasoning_about_colored_objects Summary ---
Accuracy: 100.00% (75/75 correct)


## Tin Eval

In [155]:
import sys
import importlib

# Add the directory containing eval_mmlu.py to the Python path
sys.path.append('/Users/log/Github/textual_grounding/utils') 

# Import eval_mmlu as a module
import eval_mmlu
import utils
import mmlu
# import eval_da_and_cot

# Reload the eval_mmlu module to ensure the latest changes are loaded (optional)
importlib.reload(eval_mmlu)
importlib.reload(utils)
importlib.reload(mmlu)
# importlib.reload(eval_da_and_cot)

# Now, you can access the evaluate_model function from eval_mmlu
llm_model = "gpt-4o-2024-08-06"
data_mode = "longest"
answer_mode = "cot"
dataset = "AQUA"

# Call the function from the module
for dataset in json_datasets:
    eval_mmlu.evaluate_model(llm_model, data_mode, answer_mode, dataset)
    # eval_da_and_cot.evaluate_model(llm_model, data_mode, answer_mode, dataset)
# print(eval_mmlu.evaluate_model(llm_model, data_mode, answer_mode, 'AQUA'))


Dataset:  logical_deduction_seven_objects
216 250
Accuracy:  0.864
------------------------------------
Dataset:  reasoning_about_colored_objects
249 250
Accuracy:  0.996
------------------------------------
