## Analyze Results

This script provides example to analyze results. before running this file, please run collate.ipynb to get combined_results.csv

In [1]:
import matplotlib.pyplot as plt
from pprint import pprint


def read_file(filename):
    data = {}
    with open(filename, 'r') as file:
        file.readline()
        for line in file:
            n_shots, accuracy, run_num, nspw, f1 = map(float, line.strip().split(','))
            key = (n_shots, run_num, nspw)
            if key not in data:
                data[key] = []
            data[key].append((accuracy, f1))
    return data

def aggregate_data(filename, num_windows=1, hundredfy=False):
    # new keys should be n_shots, n_windows 
    data = read_file(filename)
    n_shots_options = list(set(key[0] for key in data))
    new_data = {}
    for num_shots in n_shots_options:
        # TODO: average accuracy AND f1 
        count = 0
        acc_total = 0
        f1_total = 0
        all_nums = []
        all_f1s = []
        coeff = 1 if not hundredfy else 100
        for dp in data:
            if dp[0] == num_shots: #and dp[2] == dp[0] // num_windows:
                acc_total += data[dp][0][0] * coeff
                f1_total += data[dp][0][1] * 100
                count += 1
                all_nums.append(data[dp][0][0] * coeff)
                all_f1s.append(data[dp][0][1] * 100)

        if count != 0:
            new_data[num_shots] = (acc_total / count, all_nums, f1_total / count, all_f1s)
    return new_data



## Main Results - Accuracy

In [2]:
dataset_names = ["banking77", "clinic150", "nlu", "trec", "trecfine"]
max_blocks = [16, 20, 22, 21, 20]
subset_blocks = [6, 7, 7, 7, 7]
block_size = 50

In [10]:
def create_fixed_dict(dataset_name, max_block, model_name):
    return {
        "32k": aggregate_data(f"../long/{model_name}/{dataset_name}-n_selected_blocks={max_block}-block_select=all/combined_results.csv"),
        # "64k": aggregate_data(f"../long/{model_name}/{dataset_name}-n_selected_blocks={max_block*2}-block_select=all/combined_results.csv"),
        # "96k": aggregate_data(f"../long/{model_name}/{dataset_name}-n_selected_blocks={max_block*3}-block_select=all/combined_results.csv"),
    }

fixed_dicts = {dataset_name: create_fixed_dict(dataset_name, max_block, "meta-llama+Llama-3.1-8B") for dataset_name, max_block in zip(dataset_names, max_blocks)}
fixed_dicts_llama2 = {dataset_name: create_fixed_dict(dataset_name, max_block, "togethercomputer+LLaMA-2-7B-32K") for dataset_name, max_block in zip(dataset_names, max_blocks)}

In [11]:
def create_retrieval_dict(dataset_name, subset_block, model_name):
    return {
        "32k": aggregate_data(f"../long/{model_name}/{dataset_name}-retrieval-n_selected_blocks={subset_block * block_size * 1}-block_select=bm25/combined_results.csv"),
        # "64k": aggregate_data(f"../long/{model_name}/{dataset_name}-retrieval-n_selected_blocks={subset_block * block_size * 2}-block_select=bm25/combined_results.csv"),
        # "96k": aggregate_data(f"../long/{model_name}/{dataset_name}-retrieval-n_selected_blocks={subset_block * block_size * 3}-block_select=bm25/combined_results.csv"),
    }

retrieval_dicts = {dataset_name: create_retrieval_dict(dataset_name, subset_block, "meta-llama+Llama-3.1-8B") for dataset_name, subset_block in zip(dataset_names, subset_blocks)}
retrieval_dicts_llama2 = {dataset_name: create_retrieval_dict(dataset_name, subset_block, "togethercomputer+LLaMA-2-7B-32K") for dataset_name, subset_block in zip(dataset_names, subset_blocks)}

In [13]:
def create_ours_dict(dataset_name, subset_block, model_name):
    return {
        "32k": aggregate_data(f"../long/{model_name}/{dataset_name}-n_selected_blocks={subset_block * 1}-block_select=bm25/combined_results.csv"),
        # "64k": aggregate_data(f"../long/{model_name}/{dataset_name}-n_selected_blocks={subset_block * 2}-block_select=bm25/combined_results.csv"),
        # "96k": aggregate_data(f"../long/{model_name}/{dataset_name}-n_selected_blocks={subset_block * 3}-block_select=bm25/combined_results.csv"),
    }

ours_dicts = {dataset_name: create_ours_dict(dataset_name, subset_block, "meta-llama+Llama-3.1-8B") for dataset_name, subset_block in zip(dataset_names, subset_blocks)}
ours_dicts_llama2 = {dataset_name: create_ours_dict(dataset_name, subset_block, "togethercomputer+LLaMA-2-7B-32K") for dataset_name, subset_block in zip(dataset_names, subset_blocks)}

In [15]:
length = '32k'
for i, name in enumerate(dataset_names):
    fixed = fixed_dicts_llama2[name][length]
    ret = retrieval_dicts_llama2[name][length]
    our = ours_dicts_llama2[name][length]
        
    fixed_accuracy = list(fixed.values())[0][0]
    ret_accuracy = list(ret.values())[0][0]
    our_accuracy = list(our.values())[0][0]
    print(f"{name}, fixed, {fixed_accuracy:.0f}, ret, {ret_accuracy:.0f}, ours, {our_accuracy:.0f}")


banking77, fixed, 81, ret, 84, ours, 80
clinic150, fixed, 86, ret, 83, ours, 80
nlu, fixed, 85, ret, 86, ours, 84
trec, fixed, 93, ret, 92, ours, 91
trecfine, fixed, 76, ret, 79, ours, 77


In [None]:
def max_num_examples(data):
    key = max(list(data.keys()))
    return (key, data[key])

{1: (2.7600000000000007, [2.8000000000000003, 2.8000000000000003, 2.4, 2.8000000000000003, 2.8000000000000003, 2.8000000000000003, 2.8000000000000003, 2.8000000000000003, 2.8000000000000003, 2.8000000000000003], 0.07577991249908698, [0.07102272727272728, 0.07074637424831977, 0.12025269768276949, 0.07074637424831977, 0.07102272727272728, 0.07074637424831977, 0.07074637424831977, 0.07074637424831977, 0.07074637424831977, 0.07102272727272728]), 5: (40.32000000000001, [38.4, 50.0, 29.2, 38.800000000000004, 27.6, 38.800000000000004, 49.6, 40.400000000000006, 50.8, 39.6], 34.94223275029968, [32.79630023142029, 45.187290941181, 26.76011451412521, 34.1452338511162, 24.503339373469235, 33.74575538217846, 44.07924193980243, 32.669181467882765, 43.03922628597953, 32.49664351584171]), 20: (75.32000000000002, [74.4, 79.2, 76.0, 76.8, 76.4, 74.8, 74.0, 71.2, 76.8, 73.6], 69.7037512309214, [71.1519402599375, 72.97113420490041, 71.35158852691316, 70.03379016366026, 70.54173820407583, 68.87914394407899

In [7]:
# find if any differents are significant
from scipy.stats import ttest_ind


def check_behavior_post_sat(data, satpt):
    options = sorted([int(i) for i in data.keys() if i > satpt])
    satdata = data[satpt][1]
    found_saturation=False
    for i in range(0, len(options)):
        # looking for first p-value > 0.05 
        #print(f"comparing data points {options[i]}={data[options[i]][0]} and {options[j]}={data[options[j]][0]}")
        if (cur_p := ttest_ind(a=satdata, b=data[options[i]][1]).pvalue) < 0.05:
            print(f"Sig difference! {satpt} and {options[i]}")
            print(f"The difference is {data[satpt][0] - data[options[i]][0]} in favor of satpt")
