In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from experiments.summarize import main

In [None]:
COLS = [
    "post_score",
    "post_rewrite_success",
    "post_rewrite_diff",
    "post_paraphrase_success",
    "post_paraphrase_diff",
    "post_neighborhood_success",
    "post_neighborhood_diff",
    "post_ngram_entropy",
    "post_reference_score",
]
OPTIM = [1, 1, 1, 1, 1, 1, 1, 1, 1]
LIM = [50, 75, 20, 60, 10, 45, -5, 600, 31]

In [None]:
BASELINE_NAMES = ["GPT-2 XL", "GPT-2 M", "GPT-2 L", "GPT-J"]


def execute(RUN_DIR, RUN_DATA, FIRST_N):
    data = {}
    for k, (d, alt_para) in RUN_DATA.items():
        cur = main(
            dir_name=RUN_DIR / d, runs=["run_000"], first_n_cases=FIRST_N, abs_path=True
        )
        assert len(cur) == 1
        data[k] = cur[0]

        # We replaced the paraphrase metrics last minute
        # run_000 should be the original, run_001 with transplanted paraphrases
        if alt_para:
            cur = main(
                dir_name=RUN_DIR / d,
                runs=["run_001"],
                first_n_cases=FIRST_N,
                abs_path=True,
            )
            if len(cur) == 1:
                data[k]["pre_paraphrase_success"] = cur[0]["pre_paraphrase_success"]
                data[k]["post_paraphrase_success"] = cur[0]["post_paraphrase_success"]
                data[k]["pre_score"] = cur[0]["pre_score"]
                data[k]["post_score"] = cur[0]["post_score"]
            else:
                raise
    m = []
    for k, v in data.items():
        m.append(
            [k]
            + [
                v[
                    z
                    if all(k != z for z in BASELINE_NAMES) or z == "time"
                    else "pre_" + z[len("post_") :]
                ]
                for z in COLS
            ]
        )

    m_np = np.array([[col[0] for col in row[1:]] for row in m[1:]])
    m_amax = np.argmax(m_np, axis=0)
    m_amin = np.argmin(m_np, axis=0)

    res = []

    for i, row in enumerate(m):
        lstr = [row[0]]
        for j, el in enumerate(row[1:]):
            mean, std = np.round(el[0], 1), el[1]
            interval = 1.96 * std / np.sqrt(FIRST_N)

            mean, interval = str(mean), f"{np.round(interval, 1)}"
            bmark = m_amax if OPTIM[j] == 1 else m_amin
            res_str = f"{mean} ({interval})" if not np.isnan(std) else f"{mean}"
            if bmark[j] + 1 == i:
                lstr.append("\\goodmetric{" + res_str + "}")
            elif not any(lstr[0] == z for z in BASELINE_NAMES) and (
                (OPTIM[j] == 1 and float(mean) < LIM[j])
                or (OPTIM[j] == 0 and float(mean) > LIM[j])
            ):
                lstr.append("\\badmetric{" + res_str + "}")
            else:
                lstr.append(res_str)

        res.append(
            " & ".join(lstr)
            + "\\\\"
            + ("\\midrule" if any(lstr[0] == z for z in BASELINE_NAMES) else "")
        )

    return "\n".join(res)

In [None]:
gap = "\n\\midrule\\midrule\n"

dir2medium = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/cf/gpt2-medium")
data2medium = {
    "GPT-2 M": ("ROME", False),
    "FT+L": ("FT_L", False),
    "ROME": ("ROME", False),
}
first2medium = 7500

dir2large = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/cf/gpt2-large")
data2large = {
    "GPT-2 L": ("ROME", False),
    "FT+L": ("FT_L", False),
    "ROME": ("ROME", False),
}
first2large = 7500

print(
    execute(dir2medium, data2medium, first2medium)
    + gap
    + execute(dir2large, data2large, first2large)
)

dir2xl = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/cf/gpt2-xl")
data2xl = {
    "GPT-2 XL": ("FT", True),
    "FT": ("FT", True),
    "FT+L": ("FT_L", True),
    "KN": ("KN", False),
    "KE": ("KE", False),
    "KE-CF": ("KE_CF", False),
    "MEND": ("MEND", False),
    "MEND-CF": ("MEND_CF", False),
    "ROME": ("ROME", False),
}
first2xl = 7500

dirj = Path("/share/projects/rewriting-knowledge/OFFICIAL_DATA/cf/gptj")
dataj = {
    "GPT-J": ("FT", True),
    "FT": ("FT", True),
    "FT+L": ("FT_L", True),
    "MEND": ("MEND", False),
    "ROME": ("ROME", False),
}
firstj = 2000

print(execute(dir2xl, data2xl, first2xl) + gap + execute(dirj, dataj, firstj))