In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric
import logging

from src.utils import logging_utils
from src.utils.sweep_utils import read_sweep_results, relation_from_dict

# logging_utils.configure(level=logging.DEBUG)

In [None]:
##############################################
model_name = "gptj"
path = f"../../results/num_train/{model_name}"
##############################################

os.listdir(path)

In [None]:
def parse_for_n(n_icl, relation_path):
    sweep_results = read_sweep_results(
        sweep_dir=f"{relation_path}/{str(n_icl)}",
    )
    relation_results = relation_from_dict(sweep_results[relation])
    faithfulness = np.array([
        trial.layers[0].result.betas[0].recall[0] 
        for trial in relation_results.trials
    ])
    efficacy = np.array([
        trial.layers[0].result.ranks[0].efficacy[0]
        for trial in relation_results.trials
    ])

    return faithfulness, efficacy

def parse_for_relation(relation = "country capital city"):
    relation_path = os.path.join(path, relation.replace(" ", "_"))
    relation_path = os.path.join(relation_path, os.listdir(relation_path)[0])
    n_icl_list = [int(x) for x in os.listdir(relation_path) if x.startswith("args") == False]
    n_icl_list.sort()
    
    faith_means, faith_stds = [], []
    eff_means, eff_stds = [], []

    for n_icl in n_icl_list:
        faithfulness, efficacy = parse_for_n(n_icl, relation_path)
        faith_means.append(np.mean(faithfulness))
        faith_stds.append(np.std(faithfulness))
        eff_means.append(np.mean(efficacy))
        eff_stds.append(np.std(efficacy))
    
    faith_means = np.array(faith_means)
    faith_stds = np.array(faith_stds)
    eff_means = np.array(eff_means)
    eff_stds = np.array(eff_stds)
    
    return n_icl_list, faith_means, faith_stds, eff_means, eff_stds

In [None]:
n_icl_list, faith_means, faith_stds, eff_means, eff_stds = parse_for_relation(relation = "country capital city")

In [None]:
#####################################################################################
plt.rcdefaults()
plt.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "Times New Roman"

SMALL_SIZE = 16
MEDIUM_SIZE = 20
BIGGER_SIZE = 28

plt.rc("font", size=SMALL_SIZE)  # controls default text sizes
plt.rc("axes", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

color_scheme = {
    "recall": "steelblue",
    "efficacy": "darkorange",
}

linewidth = 2
#####################################################################################


def export_legend(legend, filename="legend.pdf"):
    fig = legend.figure
    fig.canvas.draw()
    bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(filename, dpi="figure", bbox_inches=bbox)


def plot_n_icl(
    canvas,
    n_icl_list,
    faith_means,
    faith_stds,
    eff_means,
    eff_stds,
    add_y_label=False,
    export_legend_to_file=None,
):
    canvas.plot(
        n_icl_list,
        faith_means,
        color=color_scheme["recall"],
        label="Faithfulness",
        linewidth=linewidth,
    )
    canvas.fill_between(
        n_icl_list,
        faith_means - faith_stds,
        faith_means + faith_stds,
        alpha=0.1,
        color=color_scheme["recall"],
    )
    canvas.plot(
        n_icl_list,
        eff_means,
        color=color_scheme["efficacy"],
        label="Efficacy",
        linewidth=linewidth,
    )
    canvas.fill_between(
        n_icl_list,
        eff_means - eff_stds,
        eff_means + eff_stds,
        alpha=0.1,
        color=color_scheme["efficacy"],
    )
    canvas.set_xticks(n_icl_list)
    canvas.set_ylim(0, 1)
    if add_y_label:
        canvas.set_ylabel("Score")

    if export_legend_to_file is not None:
        legend = ax.legend(
            ncol=2, bbox_to_anchor=(4.7, -0.3), frameon=False, fontsize=MEDIUM_SIZE
        )
        export_legend(legend, export_legend_to_file)
        legend.remove()


relations = [
    "country capital city",
    "plays pro sport",
    "person occupation",
    "object superclass",
]

fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(21, 5))

first_canvas = True
for ax, relation in zip(axes, relations):
    n_icl_list, faith_means, faith_stds, eff_means, eff_stds = parse_for_relation(
        relation
    )
    plot_n_icl(
        ax,
        n_icl_list,
        faith_means,
        faith_stds,
        eff_means,
        eff_stds,
        add_y_label=first_canvas,
        export_legend_to_file="vaying_n_legend.pdf" if first_canvas else None,
    )
    ax.set_title(relation, fontsize=BIGGER_SIZE)
    first_canvas = False

fig.tight_layout()
plt.savefig(f"figs/varying_n.pdf", bbox_inches="tight")

fig.show()