In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from experiments.summarize import main

In [None]:
COLS = [
    "post_score",
    "post_rewrite_acc",
    "post_paraphrase_acc",
    "post_neighborhood_acc",
]
OPTIM = [1, 1, 1, 1]
LIM = [20, 50, 50, 20]

In [None]:
BASELINE_NAMES = ["GPT-2 XL", "GPT-2 M", "GPT-2 L", "GPT-J"]


def execute(RUN_DIR, RUN_DATA, FIRST_N):
    data = {}
    for k, (d, alt_para) in RUN_DATA.items():
        cur = main(
            dir_name=RUN_DIR / d, runs=["run_000"], first_n_cases=FIRST_N, abs_path=True
        )
        assert len(cur) == 1
        data[k] = cur[0]

    m = []
    for k, v in data.items():
        m.append(
            [k]
            + [
                v[
                    z
                    if all(k != z for z in BASELINE_NAMES) or z == "time"
                    else "pre_" + z[len("post_") :]
                ]
                for z in COLS
            ]
        )

    m_np = np.array([[col[0] for col in row[1:]] for row in m[1:]])
    m_amax = np.argmax(m_np, axis=0)
    m_amin = np.argmin(m_np, axis=0)

    res = []

    for i, row in enumerate(m):
        lstr = [row[0]]
        for j, el in enumerate(row[1:]):
            mean, std = np.round(el[0], 1), el[1]
            interval = 1.96 * std / np.sqrt(FIRST_N)

            mean, interval = str(mean), f"$\pm${np.round(interval, 1)}"
            bmark = m_amax if OPTIM[j] == 1 else m_amin
            res_str = f"{mean} ({interval})" if not np.isnan(std) else f"{mean}"
            if bmark[j] + 1 == i:
                lstr.append("\\goodmetric{" + res_str + "}")
            elif not any(lstr[0] in z for z in BASELINE_NAMES) and (
                (OPTIM[j] == 1 and float(mean) < LIM[j])
                or (OPTIM[j] == 0 and float(mean) > LIM[j])
            ):
                lstr.append("\\badmetric{" + res_str + "}")
            else:
                lstr.append(res_str)

        res.append(
            " & ".join(lstr)
            + "\\\\"
            + ("\\midrule" if any(lstr[0] == z for z in BASELINE_NAMES) else "")
        )

    return "\n".join(res)

In [None]:
gap = "\n\\midrule\\midrule\n"

dir2j = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA_MROME/zsre/gpt-j")
data2j = {
    #     "GPT-J": ("ROME", False),
    "FT-W": ("FT", False),
    "MEND": ("MEND", False),
    "ROME": ("ROME", False),
}
first2j = 10000

print(execute(dir2j, data2j, first2j))

In [None]:
gap = "\n\\midrule\\midrule\n"

dir2medium = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-medium")
data2medium = {
    "GPT-2 M": ("ROME", False),
    "FT+L": ("FT_L", False),
    "ROME": ("ROME", False),
}
first2medium = 10000

dir2l = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-large")
data2l = {
    "GPT-2 L": ("ROME", False),
    "FT+L": ("FT_L", False),
    "ROME": ("ROME", False),
}
first2l = 10000

dir2xl = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/zsre/gpt2-xl")
data2xl = {
    "GPT-2 XL": ("FT", True),
    "FT": ("FT", True),
    "FT+L": ("FT_L", True),
    "KE": ("KE", False),
    "KE-zsRE": ("KE_zsRE", False),
    "MEND": ("MEND", False),
    "MEND-CF": ("MEND_zsRE", False),
    "ROME": ("ROME", False),
}
first2xl = 10000

print(
    execute(dir2medium, data2medium, first2medium)
    + gap
    + execute(dir2l, data2l, first2l)
    + gap
    + execute(dir2xl, data2xl, first2xl)
)