In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric
import logging

from src.utils import logging_utils
from src.utils.sweep_utils import read_sweep_results, relation_from_dict

# logging_utils.configure(level=logging.DEBUG)

In [None]:
##############################################
model_name = "gptj"
path = f"../../results/lre_stats/{model_name}"
##############################################

os.listdir(path)

In [None]:
relation = "country capital city"
relation_path = os.path.join(path, relation.replace(" ", "_"))
relation_path = os.path.join(relation_path, os.listdir(relation_path)[0])
sweep_results = read_sweep_results(
    sweep_dir=f"{relation_path}/{str(8)}",
)
relation = list(sweep_results.keys())[0]
relation_results = relation_from_dict(sweep_results[relation])

sweep_results[relation]["trials"][0]["layers"][0]["result"]

In [None]:
def parse_for_n(n_icl, relation_path):
    sweep_results = read_sweep_results(
        sweep_dir=f"{relation_path}/{str(n_icl)}",
    )
    relation = list(sweep_results.keys())[0]
    relation_results = relation_from_dict(sweep_results[relation])
    
    weight_norms = [
        trial.layers[0].result.lre_stats["|weight|"] for trial in relation_results.trials
    ]
    bias_norms = [
        trial.layers[0].result.lre_stats["|bias|"] for trial in relation_results.trials
    ]

    faithfulness = np.array([
        trial.layers[0].result.betas[0].recall[0] 
        for trial in relation_results.trials
    ])
    efficacy = np.array([
        trial.layers[0].result.ranks[0].efficacy[0]
        for trial in relation_results.trials
    ])

    return faithfulness, efficacy, weight_norms, bias_norms

def parse_for_relation(relation = "country capital city"):
    relation_path = os.path.join(path, relation.replace(" ", "_"))
    relation_path = os.path.join(relation_path, os.listdir(relation_path)[0])
    n_icl_list = [int(x) for x in os.listdir(relation_path) if x.startswith("args") == False]
    n_icl_list.sort()
    
    faithfulness = []
    causality = []
    weight_norms = []
    bias_norms = []

    for n_icl in n_icl_list:
        faith, eff, w, b = parse_for_n(n_icl, relation_path)
        faithfulness.append(faith)
        causality.append(eff)
        weight_norms.append(np.array(w))
        bias_norms.append(np.array(b))

    return {
        "relation": relation,
        "faithfulness": np.array(faithfulness),
        "causality": np.array(causality),
        "|weight|": np.array(weight_norms),
        "|bias|": np.array(bias_norms),
    }

In [None]:
relations = list(os.listdir(path))
lre_stats = {}

for relation in relations:
    lre_stats[relation] = parse_for_relation(relation)

In [None]:
faithfulness = []
causality = []
weight_means = []
weight_stds = []

for relation in relations:
    faith = lre_stats[relation]["faithfulness"]
    eff = lre_stats[relation]["causality"]

    if len(faith) == 0 or len(eff) == 0:
        continue

    faithfulness.append(faith.mean())
    causality.append(eff.mean())
    weight_means.append(np.mean(lre_stats[relation]["|weight|"]))
    weight_stds.append(np.std(lre_stats[relation]["|weight|"]))

In [None]:
plt.plot(faithfulness, weight_stds, "o")