In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from experiments.summarize import main as summarize_main
from pathlib import Path
import math

In [None]:
RESULTS_DIR = Path("results/iclr")
DATA = {}
KEYS = None
for method_dir in RESULTS_DIR.iterdir():
    method_name = str(method_dir).split("/")[-1]
    print(method_name)
    n_edit_folders = list(method_dir.glob("*_edits_setting_*"))
    for n_edit_folder in n_edit_folders:
        n_edits = str(n_edit_folder.name).split("/")[-1].split("_")[0]
        try:
            res = summarize_main(n_edit_folder.relative_to("results"), ["run_000"])[0]

            DATA[method_name] = DATA.get(method_name, {})
            DATA[method_name][n_edits] = res
            if KEYS is None:
                KEYS = list(res.keys())
        except:
            pass

print({k: list(v.keys()) for k, v in DATA.items()})

In [None]:
plt.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "Times New Roman"

SMALL_SIZE = 14
MEDIUM_SIZE = 15
BIGGER_SIZE = 16

plt.rc("font", size=SMALL_SIZE)  # controls default text sizes
plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
TITLES = {
    "post_score": "Score (S)",
    "post_rewrite_success": "Efficacy Succ. (ES)",
    "post_paraphrase_success": "Generalization Succ. (PS)",
    "post_neighborhood_success": "Specificity Succ. (NS)",
    "post_rewrite_acc": "Efficacy Acc (EA)",
    "post_paraphrase_acc": "Generalization Acc. (PA)",
    "post_neighborhood_acc": "Specificity Acc. (NA)",
    "post_reference_score": "Consistency (RS)",
}

SHOW_KEYS = list(TITLES.keys())

In [None]:
SHOW_KEYS = KEYS
SHOW_KEYS.pop(SHOW_KEYS.index("run_dir"))
TITLES = {k: k for k in SHOW_KEYS}

In [None]:
w = 4
h = math.ceil(len(KEYS) / w)
plt.figure(figsize=(w * 3.5, h * 2.5))

assert all(k in KEYS for k in SHOW_KEYS)
for i, key in enumerate(SHOW_KEYS):
    plt.subplot(h, w, i + 1)
    for method, results in sorted([(k, v) for k, v in DATA.items() if "_fix" not in k]):
        try:
            n_edits = list(map(int, results.keys()))
            values = [
                f[0] if (type(f := results[str(n)][key]) is tuple) else f
                for n in n_edits
            ]
            plt.plot(n_edits, values, marker="o", markersize=4, label=method)
            plt.xlabel("# Edits")
            # plt.ylabel("metric value")
            plt.title(TITLES[key])
            plt.legend()
        except:
            pass
plt.tight_layout()
plt.savefig("tmp.pdf")
plt.show()