In [None]:
import sys
sys.path.append("../..")

In [None]:
from pathlib import Path

from experiments.aliases import REMEDI_EDITOR_LAYER, REMEDI_ENTITY_CLS_LAYER

MODEL = "gptj"
RESULTS_ROOT = Path("../../results")
assert RESULTS_ROOT.exists()

CORRECT = r"\correctmarker"
WRONG = r"\wrongmarker"

In [None]:
from remedi import data
data.disable_caching()

counterfact = data.load_dataset("counterfact", split="train[5000:10000]")
biosbias = data.load_dataset("biosbias", split="train[5000:10000]")

In [None]:
import json

def load_json(file):
    with Path(file).open("r") as handle:
        return json.load(handle)

# CounterFact

In [None]:
EXPERIMENT_NAME = "post_icml_eval_gen_counterfact_gptj"
LAYER = REMEDI_EDITOR_LAYER[MODEL]["counterfact"]
results_dir = RESULTS_ROOT / EXPERIMENT_NAME / "linear" / str(LAYER)
assert results_dir.exists()

essence_results = load_json(results_dir / "essence.json")
generation_results = load_json(results_dir / "generation.json")

ic_essence_results = load_json(results_dir.parent.parent / "prefix" / "essence.json")
ic_generation_results = load_json(results_dir.parent.parent / "prefix" / "generation.json")

Generate an HTML page for us to find diverse and representative examples:

In [None]:
def strip(x):
    return x.split("\n\n")[0]


html = [
    "<!doctype html>",
    "<html>",
    "<head>",
    "<style>",
"""\
table {
    text-align: left;
    border-collapse: collapse;
    vertical-align: top;
}

th {
    font-weight: bold;
    border-top: 2px solid black;
    border-bottom: 1px solid black;
    padding-right: 5em;
}

tr:last-of-type {
    border-bottom: 2px solid black;
}

h2 {
    margin-top: 2em;
}

h4 {
    font-weight: normal;
    text-decoration: underline;
}
"""
    "</style>",
    "</head>",
    "<body>",
]
for i in range(len(counterfact)):
    task = f'{counterfact[i]["entity"]} + {counterfact[i]["attribute"]}'    
    essence_pref = strip(ic_essence_results["samples"][i]["generation"])
    essence_ours = strip(essence_results["samples"][i]["generation"])
    
    # TODO(evandez): Match these.
    generations = []
    for j in range(len(ic_generation_results["samples"][i]["generations"])):
        gen_pref = strip(ic_generation_results["samples"][i]["generations"][j]).replace(counterfact[i]["context"] + ".", "").strip()
        gen_ours = strip(generation_results["samples"][i]["generations"][j])
        generations.append((f"generation {j}", gen_pref, gen_ours))

    html += [
        f"<h2>Sample {i}</h2>",
        f"<p><u>{task}</u></p>",
        "<table>",
        "<thead>",
        "<tr>",
        "<th>source</th>",
        "<th>prefix</th>",
        "<th>remedi</th>",
        "</tr>",
        "</thead>",
        "<tbody>",

    ]
    for label, pref, ours in (
        ("essence", essence_pref, essence_ours),
        *generations,
    ):
        html += [
            "<tr>",
            f"<td>{label}</td>",
            f"<td>{pref}</td>",
            f"<td>{ours}</td>",
            "</tr>",
        ]
    html += [
        "</tbody>",
        "</table>",
    ]
        

html += [
    "</body>",
    "</html>",
]

html_file = Path(f"{MODEL}_counterfact_examples.html")
with html_file.open("w") as handle:
    handle.write("\n".join(html))

In [None]:
def clean_result(result, limit=2):
    # Take just the first few sentences / the first thought.
    if "\n\n" in result:
        result = result.replace("\n\n", " ")

    result = ". ".join(result.split(". ")[:limit]) + "."

    # Sometimes CounterFact does not capitalize the entity, do so for presentation.
    result = result[0].upper() + result[1:]
    
    return result


def mark(text, correct):
    marker = CORRECT if correct else WRONG
    return f"{marker} {text}"


examples = [
    # Ours succeeds.
    (25, "generation", (0, 1), (False, True)),
    (476, "essence", None, (False, True)),
    (430, "essence", None, (False, True)),
    (1287, "essence", None, (False, True)),
    (3423, "essence", None, (False, True)),
    (3884, "essence", None, (False, True)),
    
    # Both succeed.
#     (4032, "essence", None, (True, True)),
    (4175, "generation", (0, 0), (True, True)),
    (4189, "essence", None, (True, True)),
    # Ours fails.
    (3793, "generation", (2, 2), (True, False)),
]

for index, source, positions, (prefix_correct, remedi_correct) in examples:
    if source == "essence":
        result = essence_results["samples"][index]["generation"]
        ic_result = ic_essence_results["samples"][index]["generation"]
    else:
        assert source == "generation"
        assert positions is not None
        ic_result = ic_generation_results["samples"][index]["generations"][positions[0]]
        ic_result = ic_result.replace(counterfact[index]["context"], "").lstrip(". ")
        result = generation_results["samples"][index]["generations"][positions[1]]

    result = clean_result(result)
    ic_result = clean_result(ic_result)

    sample = counterfact[index]

    entity = sample["entity"]
    attribute = sample["prompt"] 
    if attribute.startswith(entity):
        attribute = attribute.replace(entity, "").strip(", ")
    attribute = f"{attribute} {sample['target_mediated']}"

    if index == 1839:
        # hard to format this one automatically
        attribute = "greatest strength is basketball"

    row = (
        entity,
        attribute,
        mark(ic_result, prefix_correct),
        mark(result, remedi_correct),
    )
    print(" & ".join(row) + r" \\")

In [None]:
def clip(text, limit=20):
    return " ".join(text.split()[:limit])

examples = [
    (27, "generation", (2, 2), "Repeats indefinitely"),
    (45, "essence", None, "Destroys essence"),
    (106, "essence", None, "Incoherence"),
    (1013, "essence", None, "Changes unrelated facts")
]
for index, source, positions, kind in examples:
    if source == "essence":
        result = essence_results["samples"][index]["generation"]
    else:
        assert source == "generation"
        assert positions is not None
        result = generation_results["samples"][index]["generations"][positions[1]]

    result = clean_result(result)

    sample = counterfact[index]

    entity = sample["entity"]
    context = sample["context"]
    
    if result.count(".") < 2:
        result = clip(result, limit=10)

    row = (
        kind,
        context,
        result,
    )
    print(" & ".join(row) + r" \\")

# Bios

In [None]:
EXPERIMENT_NAME = "post_icml_eval_gen_biosbias_gptj"
LAYER = REMEDI_EDITOR_LAYER[MODEL]["biosbias"]
results_dir = RESULTS_ROOT / EXPERIMENT_NAME
assert results_dir.exists()

prefix_results = load_json(results_dir / "contextual/baseline.json")
remedi_results = load_json(results_dir / "contextual/linear" / str(LAYER) / "error_correction.json")

In [None]:
LAYER

In [None]:
prefix_results["samples"][0]

In [None]:

html = [
    "<!doctype html>",
    "<html>",
    "<head>",
    "<style>",
"""\
table {
    text-align: left;
    border-collapse: collapse;
    vertical-align: top;
}

th {
    font-weight: bold;
    border-top: 2px solid black;
    border-bottom: 1px solid black;
    padding-right: 5em;
}

tr:last-of-type {
    border-bottom: 2px solid black;
}

h2 {
    margin-top: 2em;
}

h4 {
    font-weight: normal;
    text-decoration: underline;
}
"""
    "</style>",
    "</head>",
    "<body>",
]
for i in range(len(counterfact)):
    context = biosbias[i]["context"]
    prompt = biosbias[i]["prompt"]
    gen_pref = prompt + strip(prefix_results["samples"][i]["generation"])
    gen_ours = prompt + strip(remedi_results["samples"][i]["generation"])

    html += [
        f"<h2>Sample {i}</h2>",
        f"<p>{context}</p>",
        "<table>",
        "<thead>",
        "<tr>",
        "<th>source</th>",
        "<th>prefix</th>",
        "<th>remedi</th>",
        "</tr>",
        "</thead>",
        "<tbody>",

    ]
    for label, pref, ours in (
        ("generation", gen_pref, gen_ours),
    ):
        html += [
            "<tr>",
            f"<td>{label}</td>",
            f"<td>{pref}</td>",
            f"<td>{ours}</td>",
            "</tr>",
        ]
    html += [
        "</tbody>",
        "</table>",
    ]
        

html += [
    "</body>",
    "</html>",
]

html_file = Path(f"{MODEL}_biosbias_examples.html")
with html_file.open("w") as handle:
    handle.write("\n".join(html))

Let's show some failures too:

In [None]:
import spacy

nlp = spacy.load("en_core_web_sm")


def clean_context(entity, context):
    about_prefix = f"About {entity}: "
    if context.startswith(about_prefix):
        context = context[len(about_prefix):]
    return context


def clean_generation(prompt, gen):
    gen = gen.split("\n")[0]
    gen = " ".join(str(s) for s in tuple(nlp(gen).sents)[:2])
    return f"{prompt} {gen.strip()}"


def mark(text, correct):
    marker = CORRECT if correct else WRONG
    return f"{marker} {text}"


CHOSEN = (
    (2739, (False, True)),

#     2246,
    (2293, (False, True)),
    (2300, (False, True)),

    (1276, (False, True)),
    (1823, (False, True)),

    (140, (True, True)),
)

for index, (prefix_correct, remedi_correct) in CHOSEN:
    sample = biosbias[index]
    entity = sample["entity"]
    context = sample["context"]
    prompt = sample["prompt"]
    gen_prefix = prefix_results["samples"][index]["generation"]
    gen_remedi = remedi_results["samples"][index]["generation"]
    row = [
        entity,
        clean_context(entity, context),
        mark(clean_generation(prompt, gen_prefix), prefix_correct),
        mark(clean_generation(prompt, gen_remedi), remedi_correct),
    ]
    print(" & ".join(row) + r" \\")

Show some failure modes for decontextual case, where they are most common.

In [None]:
EXPERIMENT_NAME = "post_icml_eval_gen_biosbias_gptj"
LAYER = REMEDI_EDITOR_LAYER[MODEL]["biosbias"]
results_dir = RESULTS_ROOT / EXPERIMENT_NAME
assert results_dir.exists()

remedi_results = load_json(results_dir / "decontextual/linear" / str(LAYER) / "error_correction.json")

In [None]:
CHOSEN = (
    (2021, "Repeats indefinitely"),
    (2027, "Incoherence"),
    (2028, "Incorrect edit"),
    (2218, "Partial edit"),
)

for index, kind in CHOSEN:
    sample = biosbias[index]
    entity = sample["entity"]
    context = sample["context"]
    prompt = sample["prompt"]
    gen_remedi = remedi_results["samples"][index]["generation"]
    row = [
        kind,
        context,
        clean_generation(prompt, gen_remedi),
    ]
    print(" & ".join(row) + r" \\")    