In [None]:
import sys
sys.path.append("..")

In [None]:
import json
from pathlib import Path

RESULTS_ROOT = Path("../results")
assert RESULTS_ROOT.exists()

First, load our own results:

In [None]:
EXPERIMENT_NAME = "icml_eval_fact_gen_gptj"

rows_by_method = {}
for method in (r"\ourmethod", "prefix", "replace"):
    row = {
        "method": method
    }

    if method == r"\ourmethod":
        results_dir = RESULTS_ROOT / EXPERIMENT_NAME / "linear/1"
    else:
        results_dir = RESULTS_ROOT / EXPERIMENT_NAME / method

    for benchmark_name, keys in (
#         ("efficacy", ("score", "magnitude")),
        (
            "paraphrase",
            (
                "score",
#                 "magnitude",
            ),
        ),
        (
            "generation",
            (
#                 "fluency",
                "consistency",
            ),
        ),
        ("essence", ("essence",)),
    ):
        results_file = results_dir / f"{benchmark_name}_metrics.json"
        print(f"reading {results_file}")
        with results_file.open("r") as handle:
            results = json.load(handle)

        for key in keys:
            row[f"{benchmark_name}_{key}"] = results[key]

    row["neighborhood_score"] = {"mean": 1.0, "std": 0.0}

    rows_by_method[method] = row

Make table for representation-editing methods.

In [None]:
formatted_rows = []
for method in ("prefix", "replace", r"\ourmethod"):
    row = rows_by_method[method]
    formatted_row = [method.capitalize()]
    for key in (
#         "efficacy_score",
#         "efficacy_magnitude",
        "paraphrase_score",
#         "paraphrase_magnitude",
#         "generation_fluency",
        "neighborhood_score",
        "generation_consistency",
        "essence_essence",
    ):
        metric = row[key]

        mean = metric["mean"] * 100
        std = metric["std"] * 100

        interval = (1.96 * std) / 5000

        formatted = f"${mean:.1f}" + r" \pm " + f"{interval:.2f}$".lstrip("0")
        formatted_row.append(formatted)
    formatted_rows.append(formatted_row)

table = ""
for formatted_row in formatted_rows:
    if formatted_row == "ROME":
        table += r"\midrule" + "\n"
    table += " & ".join(formatted_row) + r" \\" + "\n"
print(table)

Make table for model-editing methods.

In [None]:
import json

from src import data, metrics

# Load references from our own eval.
essence_references_file = Path("../results/icml_eval_fact_gen_gptj/essence_references.txt")
with essence_references_file.open("r") as handle:
    references = handle.readlines()
references = [[r.strip(" \n")] for r in references if r.strip(" \n")]
print(len(references))

# Load the counterfact vectorizer.
tfidf_vectorizer = data.load_counterfact_tfidf_vectorizer()

# Load their generations.
for results_dir in (
    "../../rome/results/FT-essence/run_000",
    "../../rome/results/ROME-essence/run_000",
):
    case_files = sorted(Path(results_dir).glob("case*.json"))

    cases = []
    for case_file in tqdm(case_files):
        with case_file.open("r") as handle:
            case = json.load(handle)
        cases.append(case)
    
    cases = sorted(cases, key=lambda case: case["case_id"])[:5000]

    generations = []
    for case in cases:
        generations.append([case["post"]["generation"]])

    score = metrics.average_tfidf_similarity(generations, references, tfidf_vectorizer)
    print(results_dir, score)
    break

In [None]:
import json

rows = []
for results_dir in (
    "../../rome/results/FT/run_000",
    "../../rome/results/ROME/run_000",
):
    summary_file = Path(results_dir) / "summary.json"
    with summary_file.open("r") as handle:
        summary = json.load(handle)

    row = [str(results_dir).split("/")[-2]]
    for metric in (
#         "post_rewrite_success",
#         "post_rewrite_diff",
        "post_paraphrase_success",
#         "post_paraphrase_diff",
#         "post_ngram_entropy",
        "post_neighborhood_success",
        "post_reference_score",
    ):
        mean, std = summary[metric]
        interval = 1.96 * std / summary["num_cases"]
        row.append(
            f"{mean:.1f}"
            + r" \pm "
            + f"{interval:.2f}".lstrip("0")
        )
    row.append("")
    rows.append(row)
string = (r" \\" + "\n").join([
    " & ".join(f"${m}$" if i > 0 and m else m for i, m in enumerate(row))
    for row in rows
]) + r" \\"
print(string)