In [None]:
import sys
sys.path.append("..")

In [None]:
from src import data
data.disable_caching()
dataset = data.load_dataset("counterfact", split="train[5000:10000]")

In [None]:
import json
from pathlib import Path

RESULTS_ROOT = Path("../results")
assert RESULTS_ROOT.exists()

In [None]:
EXPERIMENT_NAME = "icml_eval_fact_gen_gptj"
results_dir = RESULTS_ROOT / EXPERIMENT_NAME / "linear/1"

essence_results_file = results_dir / "essence.json"
with essence_results_file.open("r") as handle:
    essence_results = json.load(handle)

In [None]:
[i for i, x in enumerate(dataset) if "brazil" in x["entity"].lower()]

In [None]:
# Fun examples:
# - 25
# - 0
# - 501
# - 2497 # but model is wrong initially
# - 2518
# - 2521
# - 2597
# - 3109
# - 3316
# - 1008  # not the best, but cute
# - 4515
i = 1114
# i = 48
# i = 78
dataset[i], essence_results["samples"][i]

In [None]:
generation_results_file = results_dir / "generation.json"
with generation_results_file.open("r") as handle:
    generation_results = json.load(handle)

In [None]:
in_context_essence_results_file = results_dir.parent.parent / "prefix" / "essence.json"
with in_context_essence_results_file.open("r") as handle:
    in_context_essence_results = json.load(handle)

In [None]:
in_context_generation_results_file = results_dir.parent.parent / "prefix" / "generation.json"
with in_context_generation_results_file.open("r") as handle:
    in_context_generation_results = json.load(handle)

In [None]:
import random
i = random.choice(list(range(5000)))
# i = 4426
print(i, dataset[i]["entity"], "+", dataset[i]["attribute"])
print()

print("--- in context ---")
print(in_context_essence_results["samples"][i]["generation"].replace("\n", " ").replace("  ", " "))
print()
for g in [
    g.replace("\n", " ").replace("  ", " ")
    for g in in_context_generation_results["samples"][i]["generations"]
]:
    print(g)

print("--- REMEDI ---")
print(essence_results["samples"][i]["generation"].replace("\n", " ").replace("  ", " "))
print()
for g in [
    g.replace("\n", " ").replace("  ", " ")
    for g in generation_results["samples"][i]["generations"]
]:
    print(g)

In [None]:
def clean_result(result):
    # Take just the first few sentences / the first thought.
    if "\n\n" in result:
        result = result.replace("\n\n", " ")

    if "Kremlin" in result:
        limit = 2
    else:
        limit = 2 if "Inc. " not in result else 3
    result = ". ".join(result.split(". ")[:limit])

    # Sometimes CounterFact does not capitalize the entity, do so for presentation.
    result = result[0].upper() + result[1:]

    return result

# For teaser:
# 1033
# 95
# 1323

# Just really good:
# 4845

# Both succeed:
# 887
# 3729
# 695
# 4989
# 1064
# 3150
# 4865 **
# 1059 **
# 4100 **
# 2521

# Ours failed:
# 3177
# 2741
# 3192 ** 

examples = [
    (25, "essence", None, (False, True)),
#     (501, "essence", 0),  # Honda
#     (2518, "essence", None),
#     (2521, "generation", 2),
#     (3109, "essence", None),
    (3316, "essence", None, (False, True)),
#     (4515, "essence", None),
    (1114, "essence", None, (False, True)),
    # Both succeed.
#     (4865, "essence", None, (True, True)),
#     (2878, "generation", (2, 0), (True, True)),
#     (4143, "essence", None, (True, True)),
    (1839, "generation", (0, 2), (True, True)),
    # Ours fails.
    (3192, "generation", (1, 2), (True, False)),
]
rows = []
for index, source, positions, corrects in examples:
    if source == "essence":
        result = essence_results["samples"][index]["generation"]
        in_context_result = in_context_essence_results["samples"][index]["generation"]
    else:
        assert source == "generation"
        assert positions is not None
        in_context_result = in_context_generation_results["samples"][index]["generations"][positions[0]]
        in_context_result = in_context_result.replace(dataset[index]["context"], "").lstrip(". ")
        result = generation_results["samples"][index]["generations"][positions[1]]

    result = clean_result(result)
    in_context_result = clean_result(in_context_result)

    sample = dataset[index]

    entity = sample["entity"]
    attribute = sample["prompt"] 
    if attribute.startswith(entity):
        attribute = attribute.replace(entity, "").strip(", ")
    attribute = f"{attribute} {sample['target_mediated']}"
    
    if index == 1839:
        # hard to format this one automatically
        attribute = "greatest strength is basketball"

    row = (
        index,
        corrects,
        entity,
        attribute,
        in_context_result,
        result,
    )
    rows.append(row)
rows

In [None]:
formatted_rows = []
for (_, corrects, entity, attribute, in_context, edited) in rows:
    row = list(row)

    entity = entity[0].upper() + entity[1:]

    prefix = entity + " is"
    if in_context.startswith(prefix):
        in_context = (r"\underline{\textbf{" + f"{prefix}" + "}}") + in_context[len(prefix):]
#         suffix = in_context[len(prefix):]
#         in_context = in_context.replace(suffix, r"\textcolor{red}{" + f"{suffix}" + "}")
    if edited.startswith(prefix):
        edited = (r"\underline{\textbf{" + f"{prefix}" + "}}") + edited[len(prefix):]
#         suffix = edited[len(prefix):]        
#         edited = edited.replace(suffix, r"\textcolor{blue}{" + f"{suffix}" + "}")

    if not in_context.endswith("."):
        in_context += "."
    if not edited.endswith("."):
        edited += "."

    in_context_marker = r"\correctmarker " if corrects[0] else r"\wrongmarker"
    edited_marker =  r"\correctmarker " if corrects[1] else f"\wrongmarker "

    in_context = in_context_marker + in_context
    edited = edited_marker + edited

    formatted_row = " & ".join([
        entity,
        attribute,
        in_context,
        edited,
    ])
    formatted_rows.append(formatted_row)

table_str = (r" \\" + "\n").join(formatted_rows) + r" \\"
print(table_str)

In [None]:
from src import models

In [None]:
# device = "cuda:1"
# mt = models.load_model("EleutherAI/gpt-j-6B", fp16=True, device=device)

In [None]:
# inputs = mt.tokenizer("Gianni Versace S.p.A.'s headquarters is", return_tensors="pt").to(device)
# outputs = mt.model.generate(**inputs, pad_token_id=mt.tokenizer.eos_token_id, max_length=50)
# mt.tokenizer.batch_decode(outputs)

In [None]:
# from src import editors
# editor = editors.load_editor(mt, "linear", 1,
#                              editors_dir="../results/icml_editors_counterfact_gptj_linear",
#                              device=device)

In [None]:
# with editors.apply(editor, device=device, alpha=0) as edited_mt:
#     outputs = edited_mt.model.generate({
#         "entity": "the London Bridge",
#         "prompt": "To cross the London Bridge, one should travel to",
#         "context": "The London Bridge is located in the deserts of Arizona",
#         "attribute": "is located in the deserts of Arizona",
#     }, max_new_tokens=25)
# mt.tokenizer.batch_decode(outputs)