In [None]:
import sys
sys.path.append("..")

In [None]:
from src import data
data.disable_caching()
dataset = data.load_dataset("counterfact", split="train[5000:10000]")

In [None]:
import json
from pathlib import Path

RESULTS_ROOT = Path("../results")
assert RESULTS_ROOT.exists()

In [None]:
EXPERIMENT_NAME = "icml_eval_fact_gen_gptj"
results_dir = RESULTS_ROOT / EXPERIMENT_NAME / "linear/1"

essence_results_file = results_dir / "essence.json"
with essence_results_file.open("r") as handle:
    essence_results = json.load(handle)

In [None]:
[i for i, x in enumerate(dataset) if "wiener" in x["entity"].lower()]

In [None]:
# Fun examples:
# - 25
# - 0
# - 501
# - 2497 # but model is wrong initially
# - 2518
# - 2521
# - 2597
# - 3109
# - 3316
# - 1008  # not the best, but cute
# - 4515
i = 1114
# i = 48
# i = 78
dataset[i], essence_results["samples"][i]

In [None]:
generation_results_file = results_dir / "generation.json"
with generation_results_file.open("r") as handle:
    generation_results = json.load(handle)

In [None]:
import random
i = random.choice(list(range(5000)))
print(dataset[i]["entity"], "+", dataset[i]["attribute"])
print()
print(essence_results["samples"][i]["generation"].replace("\n", " ").replace("  ", " "))
print()
for g in [
    g.replace("\n", " ").replace("  ", " ")
    for g in generation_results["samples"][i]["generations"]
]:
    print(g)

In [None]:
in_context_essence_results_file = results_dir.parent.parent / "prefix" / "essence.json"
with in_context_essence_results_file.open("r") as handle:
    in_context_essence_results = json.load(handle)

In [None]:
in_context_generation_results_file = results_dir.parent.parent / "prefix" / "generation.json"
with in_context_generation_results_file.open("r") as handle:
    in_context_generation_results = json.load(handle)

In [None]:
import random
i = random.choice(list(range(5000)))
print(dataset[i]["entity"], "+", dataset[i]["attribute"])
print()
print(in_context_essence_results["samples"][i]["generation"].replace("\n", " ").replace("  ", " "))
print()
for g in [
    g.replace("\n", " ").replace("  ", " ")
    for g in in_context_generation_results["samples"][i]["generations"]
]:
    print(g)

In [None]:
def clean_result(result):
    # Take just the first few sentences / the first thought.
    if "\n\n" in result:
        result = result.split("\n\n")[0]
    
    limit = 2 if "Inc. " not in result else 3
    result = ". ".join(result.split(". ")[:limit])

    # Sometimes CounterFact does not capitalize the entity, do so for presentation.
    result = result[0].upper() + result[1:]
    
    return result

examples = [
    (25, "essence", None),
    (501, "essence", 0),
#     (2518, "essence", None),
#     (2521, "generation", 2),
#     (3109, "essence", None),
    (3316, "essence", None),
#     (4515, "essence", None),
    (1114, "essence", None)
]
rows = []
for index, source, position in examples:
    if source == "essence":
        result = essence_results["samples"][index]["generation"]
        in_context_result = in_context_essence_results["samples"][index]["generation"]
    else:
        assert source == "generation"
        assert position is not None
        result = generation_results["samples"][index]["generations"][position]
        in_context_result = in_context_generation_results["samples"][index]["generations"][position]

    result = clean_result(result)
    in_context_result = clean_result(in_context_result)

    sample = dataset[index]

    entity = sample["entity"]
    attribute = sample["prompt"] 
    if attribute.startswith(entity):
        attribute = attribute.replace(entity, "").strip(", ")
    attribute = f"{attribute} {sample['target_mediated']}"

    row = (
        index,
        entity,
        attribute,
        in_context_result,
        result,
    )
    rows.append(row)
rows

In [None]:
formatted_rows = []
for (_, entity, attribute, in_context, edited) in rows:
    row = list(row)

    entity = entity[0].upper() + entity[1:]
    prefix = entity + " is"
    if in_context.startswith(prefix):
        in_context = (r"\underline{\textbf{" + f"{prefix}" + "}}") + in_context[len(prefix):]
#         suffix = in_context[len(prefix):]
#         in_context = in_context.replace(suffix, r"\textcolor{red}{" + f"{suffix}" + "}")
    if edited.startswith(prefix):
        edited = (r"\underline{\textbf{" + f"{prefix}" + "}}") + edited[len(prefix):]
#         suffix = edited[len(prefix):]        
#         edited = edited.replace(suffix, r"\textcolor{blue}{" + f"{suffix}" + "}")

    if not in_context.endswith("."):
        in_context += "."
    if not edited.endswith("."):
        edited += "."

    attribute = attribute

    formatted_row = " & ".join([
        entity,
        attribute,
        in_context,
        edited,
    ])
    formatted_rows.append(formatted_row)

table_str = (r" \\" + "\n").join(formatted_rows) + r" \\"
print(table_str)