In [1]:
import pandas as pd

In [2]:
def calculate_accuracy(data):
    if 'parsed_answer' not in data.columns or 'correct_answer' not in data.columns:
        raise ValueError("DataFrame must contain 'parsed_answer' and 'correct_answer' columns")
    
    correct_predictions = (data['parsed_answer'] == data['correct_answer']).sum()
    total_predictions = data.shape[0]
    
    accuracy = correct_predictions / total_predictions
    
    return accuracy

In [9]:
all_results = os.listdir('../results/story_far_all_prompts/')

# LLAMA3 

In [5]:
import os

In [28]:
llama3_results = sorted([res for res in all_results if 'llama3' in res])

In [29]:
llama3_results

['story_analogies_far_llama3_prompt_templates-story_analogies-1_basic_prompt_not_forced-1.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-1_basic_prompt_not_forced-2.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-1_basic_prompt_not_forced-3.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-2_basic_prompt_forced-1.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-2_basic_prompt_forced-2.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-2_basic_prompt_forced-3.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-3_cot-1.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-3_cot-2.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-3_cot-3.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-4_cot_structured-1.csv',
 'story_analogies_far_llama3_prompt_templates-story_analogies-4_cot_structured-2.csv',
 'story_analogies_far_llama3_promp

In [30]:
import re

def extract_prompt_info(filenames):
    extracted_info = []
    for filename in filenames:
        match = re.search(r'story_analogies-\d+_(.*?)-(\d+)\.csv', filename)
        if match:
            prompt_type = match.group(1)
            run_number = match.group(2)
            extracted_info.append((prompt_type, run_number))
    return extracted_info

In [34]:
results_final = []

for res_file in llama3_results:
    data = pd.read_csv('../results/story_far_all_prompts/' + res_file)[['parsed_answer', 'correct_answer']]
    (prompt_type, prompt_number) = extract_prompt_info([res_file])[0]

    accuracy = calculate_accuracy(data)
    results_final.append({'prompt_type': prompt_type, 'prompt_number': prompt_number, 'accuracy': accuracy})

In [36]:
final_results_df = pd.DataFrame(results_final)
final_results_df

Unnamed: 0,prompt_type,prompt_number,accuracy
0,basic_prompt_not_forced,1,0.638889
1,basic_prompt_not_forced,2,0.722222
2,basic_prompt_not_forced,3,0.666667
3,basic_prompt_forced,1,0.527778
4,basic_prompt_forced,2,0.722222
5,basic_prompt_forced,3,0.694444
6,cot,1,0.638889
7,cot,2,0.638889
8,cot,3,0.638889
9,cot_structured,1,0.666667


In [38]:
final_results_df.groupby(by='prompt_type').agg({'accuracy': 'mean'})

Unnamed: 0_level_0,accuracy
prompt_type,Unnamed: 1_level_1
basic_prompt_forced,0.648148
basic_prompt_not_forced,0.675926
cot,0.638889
cot_structured,0.666667


# Majority voting

In [47]:
from collections import Counter

In [48]:
def majority_vote(*args):
    return [Counter(votes).most_common(1)[0][0] for votes in zip(*args)]

In [51]:
results_majority = []

for i in range(0, len(llama3_results), 3):
    prompt_1 = llama3_results[i]
    prompt_2 = llama3_results[i + 1]
    prompt_3 = llama3_results[i + 2]
    print(prompt_1, prompt_2, prompt_3)

    data_1 = pd.read_csv('../results/story_far_all_prompts/' + prompt_1)[['parsed_answer', 'correct_answer']]
    data_2 = pd.read_csv('../results/story_far_all_prompts/' + prompt_2)[['parsed_answer', 'correct_answer']]
    data_3 = pd.read_csv('../results/story_far_all_prompts/' + prompt_3)[['parsed_answer', 'correct_answer']]

    data_1['parsed_answer'] = majority_vote(data_1['parsed_answer'], data_2['parsed_answer'], data_3['parsed_answer'])

    (prompt_type, prompt_number) = extract_prompt_info([prompt_1])[0]

    accuracy = calculate_accuracy(data_1)
    results_majority.append({'prompt_type': prompt_type, 'accuracy': accuracy})

story_analogies_far_llama3_prompt_templates-story_analogies-1_basic_prompt_not_forced-1.csv story_analogies_far_llama3_prompt_templates-story_analogies-1_basic_prompt_not_forced-2.csv story_analogies_far_llama3_prompt_templates-story_analogies-1_basic_prompt_not_forced-3.csv
story_analogies_far_llama3_prompt_templates-story_analogies-2_basic_prompt_forced-1.csv story_analogies_far_llama3_prompt_templates-story_analogies-2_basic_prompt_forced-2.csv story_analogies_far_llama3_prompt_templates-story_analogies-2_basic_prompt_forced-3.csv
story_analogies_far_llama3_prompt_templates-story_analogies-3_cot-1.csv story_analogies_far_llama3_prompt_templates-story_analogies-3_cot-2.csv story_analogies_far_llama3_prompt_templates-story_analogies-3_cot-3.csv
story_analogies_far_llama3_prompt_templates-story_analogies-4_cot_structured-1.csv story_analogies_far_llama3_prompt_templates-story_analogies-4_cot_structured-2.csv story_analogies_far_llama3_prompt_templates-story_analogies-4_cot_structured-3

In [52]:
pd.DataFrame(results_majority)

Unnamed: 0,prompt_type,accuracy
0,basic_prompt_not_forced,0.666667
1,basic_prompt_forced,0.694444
2,cot,0.666667
3,cot_structured,0.638889
