In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric
import logging

from src.utils import logging_utils
from src.utils.sweep_utils import read_sweep_results, relation_from_dict

# logging_utils.configure(level=logging.DEBUG)

In [None]:
##############################################
model_name = "gptj"
path = f"../../results/sweep_prompt/{model_name}"
##############################################

os.listdir(path)

In [None]:
def parse_for_n(n_icl, relation_path):
    sweep_results = read_sweep_results(
        sweep_dir=f"{relation_path}/{str(n_icl)}",
    )
    relation = list(sweep_results.keys())[0]
    relation_results = relation_from_dict(sweep_results[relation])
    prompt_template = relation_results.trials[0].prompt_template
    faithfulness = np.array([
        trial.layers[0].result.betas[0].recall[0] 
        for trial in relation_results.trials
    ])
    efficacy = np.array([
        trial.layers[0].result.ranks[0].efficacy[0]
        for trial in relation_results.trials
    ])

    return prompt_template, faithfulness, efficacy

def parse_for_relation(relation = "country capital city"):
    relation_path = os.path.join(path, relation.replace(" ", "_"))
    relation_path = os.path.join(relation_path, os.listdir(relation_path)[0])
    n_icl_list = [int(x) for x in os.listdir(relation_path) if x.startswith("args") == False]
    n_icl_list.sort()
    
    faith_means, faith_stds = [], []
    eff_means, eff_stds = [], []
    prompt_templates = []

    for n_icl in n_icl_list:
        prompt_template, faithfulness, efficacy = parse_for_n(n_icl, relation_path)
        faith_means.append(np.mean(faithfulness))
        faith_stds.append(np.std(faithfulness))
        eff_means.append(np.mean(efficacy))
        eff_stds.append(np.std(efficacy))
        prompt_templates.append(prompt_template)
    
    faith_means = np.array(faith_means)
    faith_stds = np.array(faith_stds)
    eff_means = np.array(eff_means)
    eff_stds = np.array(eff_stds)
    
    return prompt_templates, faith_means, faith_stds, eff_means, eff_stds

In [None]:
prompt_templates, faith_means, faith_stds, eff_means, eff_stds = parse_for_relation(relation = "country capital city")

In [None]:
for relation in os.listdir(path):
    print("-------------------------")
    print(f"{relation=}")
    print("-------------------------")
    prompt_templates, faith_means, faith_stds, eff_means, eff_stds = parse_for_relation(
        relation=relation
    )
    for i in range(len(prompt_templates)):
        line = "& \\texttt{" + prompt_templates[i].replace('{}', '\{\}') + "} "
        line += f"& ${faith_means[i]:.2f} \pm {faith_stds[i]:.2f}$ "
        line += f"& ${eff_means[i]:.2f} \pm {eff_stds[i]:.2f}$ "
        line += "\\\\"
        print(line)
    print("\\hline")
