In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric
import logging
import torch
import json

from src.utils import logging_utils
from src.utils.sweep_utils import read_sweep_results, relation_from_dict


# logging_utils.configure(level=logging.DEBUG)


In [None]:
############################################
sweep_root = "../../results/sweep-full-rank"
model_name = "gptj"
############################################

sweep_path = f"{sweep_root}/{model_name}"

In [None]:
sweep_results = read_sweep_results(sweep_path)

for relation in sweep_results:
    print(relation, end=": ")
    print(len(sweep_results[relation]["trials"]))

In [None]:
def rankwise_performance(relation_result):
    rank_recalls = {beta.rank: [] for beta in relation_result.trials[0].layers[0].result.betas}
    rank_efficacies = {rank.rank: [] for rank in relation_result.trials[0].layers[0].result.ranks}

    for trial in relation_result.trials:
        for beta in trial.layers[0].result.betas:
            rank_recalls[beta.rank].append(beta.recall[0])
        for rank in trial.layers[0].result.ranks:
            rank_efficacies[rank.rank].append(rank.efficacy[0])
    
    ranks = list(rank_recalls.keys())

    return ranks, list(rank_recalls.values()), list(rank_efficacies.values())


In [None]:
import pandas as pd

df = pd.read_csv("../../results/tables/gptj-hparams.csv")

In [None]:
relation_dict = {}
for relation in tqdm(sweep_results):
    relation_dict[relation] = relation_from_dict(sweep_results[relation])

In [None]:
#########################################
tau = 0.8
#########################################

F = None
C = None

for relation in relation_dict:
    res = df[df["relation"] == relation].to_dict(orient="records")[0]
    print(f"{relation} >> faithfulness={res['recall@1']} | efficacy={res['efficacy']}")
    if float(res['recall@1'].split()[0]) > tau:
        ranks, faithfulness, efficacies = rankwise_performance(relation_dict[relation])
        F = torch.Tensor(faithfulness).T if F is None else torch.cat([F, torch.Tensor(faithfulness).T], dim=0)
        C = torch.Tensor(efficacies).T if C is None else torch.cat([C, torch.Tensor(efficacies).T], dim=0)

print(F.shape)

f_mean = F.mean(dim = 0)
c_mean = C.mean(dim = 0)
f_std = F.std(dim = 0)
c_std = C.std(dim = 0)

In [None]:
plt.rcdefaults()
fig_dir = "figs"
#####################################################################################
plt.rcdefaults()
plt.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "Times New Roman"

SMALL_SIZE = 14
MEDIUM_SIZE = 18
BIGGER_SIZE = 22

plt.rc("font", size=SMALL_SIZE)  # controls default text sizes
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc("figure", titlesize=50)  # fontsize of the figure title


faith_color = "steelblue"
cause_color = "darkorange"
#####################################################################################

plt.plot(ranks, f_mean, label = "Faithfulness", color = faith_color, linewidth = 1.8)
plt.fill_between(ranks, f_mean - f_std, f_mean + f_std, alpha = 0.1, color = faith_color)
plt.plot(ranks, c_mean, label = "Causality", color = cause_color, linewidth = 1.8)
plt.fill_between(ranks, c_mean - c_std, c_mean + c_std, alpha = 0.1, color = cause_color)

plt.xscale("log", base = 2)
plt.xlabel("Rank")
plt.ylim(0, 1)
plt.ylabel("Score")
plt.legend(ncol = 2, bbox_to_anchor=(0.5, 1.15), loc='upper center', frameon=False)

plt.savefig(f"{fig_dir}/rank-sweep.pdf", bbox_inches="tight")
plt.show()

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("../../results/tables/gptj-beta-R.csv")
print(df.to_latex( index=False, float_format="%.2f" ))

In [None]:
df = pd.read_csv("../../results/tables/gptj-hparams.csv")
print(df.to_latex( index=False, float_format="%.2f" ))

In [None]:
count = 0
for idx, row in df.iterrows():
    faith = float(row['recall@1'].split()[0])
    count += faith > 0.6
count/len(df)