In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric
import logging

from src.utils import logging_utils
from src.utils.sweep_utils import read_sweep_results, relation_from_dict

# logging_utils.configure(level=logging.DEBUG)

In [3]:
##############################################
model_name = "gptj"
path = f"../../results/sweep_prompt/{model_name}"
##############################################

os.listdir(path)

['person_occupation',
 'landmark_in_country',
 'country_capital_city',
 'plays_pro_sport',
 'city_in_country',
 'food_from_country',
 'occupation_gender',
 'name_gender',
 'country_language',
 'object_superclass',
 'landmark_on_continent',
 'country_largest_city',
 'verb_past_tense',
 'name_birthplace',
 'person_sport_position',
 'adjective_superlative']

In [4]:
def parse_for_n(n_icl, relation_path):
    sweep_results = read_sweep_results(
        sweep_dir=f"{relation_path}/{str(n_icl)}",
    )
    relation = list(sweep_results.keys())[0]
    relation_results = relation_from_dict(sweep_results[relation])
    prompt_template = relation_results.trials[0].prompt_template
    faithfulness = np.array([
        trial.layers[0].result.betas[0].recall[0] 
        for trial in relation_results.trials
    ])
    efficacy = np.array([
        trial.layers[0].result.ranks[0].efficacy[0]
        for trial in relation_results.trials
    ])

    return prompt_template, faithfulness, efficacy

def parse_for_relation(relation = "country capital city"):
    relation_path = os.path.join(path, relation.replace(" ", "_"))
    relation_path = os.path.join(relation_path, os.listdir(relation_path)[0])
    n_icl_list = [int(x) for x in os.listdir(relation_path) if x.startswith("args") == False]
    n_icl_list.sort()
    
    faith_means, faith_stds = [], []
    eff_means, eff_stds = [], []
    prompt_templates = []

    for n_icl in n_icl_list:
        prompt_template, faithfulness, efficacy = parse_for_n(n_icl, relation_path)
        faith_means.append(np.mean(faithfulness))
        faith_stds.append(np.std(faithfulness))
        eff_means.append(np.mean(efficacy))
        eff_stds.append(np.std(efficacy))
        prompt_templates.append(prompt_template)
    
    faith_means = np.array(faith_means)
    faith_stds = np.array(faith_stds)
    eff_means = np.array(eff_means)
    eff_stds = np.array(eff_stds)
    
    return prompt_templates, faith_means, faith_stds, eff_means, eff_stds

In [5]:
prompt_templates, faith_means, faith_stds, eff_means, eff_stds = parse_for_relation(relation = "country capital city")

In [21]:
diff_faith = []
diff_eff = []

for relation in os.listdir(path):
    print("-------------------------")
    print(f"{relation=}")
    print("-------------------------")
    prompt_templates, faith_means, faith_stds, eff_means, eff_stds = parse_for_relation(
        relation=relation
    )

    for i in range(len(prompt_templates)):
        line = "& \\texttt{" + prompt_templates[i].replace('{}', '\{\}') + "} "
        line += f"& ${faith_means[i]:.2f} \pm {faith_stds[i]:.2f}$ "
        line += f"& ${eff_means[i]:.2f} \pm {eff_stds[i]:.2f}$ "
        line += "\\\\"
        print(line)

    for i in faith_means:
        diff_faith.append(i - faith_means[0])
    for i in eff_means:
        diff_eff.append(i - eff_means[0])
    
    print("\\hline")

-------------------------
relation='person_occupation'
-------------------------
& \texttt{\{\} works professionally as a} & $0.41 \pm 0.08$ & $0.55 \pm 0.09$ \\
& \texttt{\{\} works as a} & $0.44 \pm 0.11$ & $0.58 \pm 0.07$ \\
& \texttt{By profession, \{\} is a} & $0.46 \pm 0.14$ & $0.58 \pm 0.08$ \\
\hline
-------------------------
relation='landmark_in_country'
-------------------------
& \texttt{What country is \{\} in? It is in} & $0.27 \pm 0.10$ & $0.64 \pm 0.03$ \\
& \texttt{\{\} is in the country of} & $0.31 \pm 0.08$ & $0.66 \pm 0.02$ \\
\hline
-------------------------
relation='country_capital_city'
-------------------------
& \texttt{The capital of \{\} is} & $0.84 \pm 0.09$ & $0.94 \pm 0.04$ \\
& \texttt{The capital of \{\} is the city of} & $0.87 \pm 0.08$ & $0.94 \pm 0.04$ \\
& \texttt{The capital city of \{\} is} & $0.84 \pm 0.08$ & $0.94 \pm 0.04$ \\
& \texttt{What is the capital of \{\}? It is the city of} & $0.87 \pm 0.07$ & $0.92 \pm 0.05$ \\
\hline
----------------

In [26]:
np.array(diff_faith).mean(), np.array(diff_eff).mean()

(0.0183642867583407, 0.002540693583432121)