In [None]:
import sys
sys.path.append("../..")

In [None]:
from pathlib import Path

from experiments.aliases import REMEDI_EDITOR_LAYER, REMEDI_ENTITY_CLS_LAYER

MODEL = "gptj"
RESULTS_ROOT = Path("../../results")
assert RESULTS_ROOT.exists()

In [None]:
from remedi import data
data.disable_caching()

counterfact = data.load_dataset("counterfact", split="train[5000:10000]")
biosbias = data.load_dataset("biosbias", split="train[5000:10000]")

In [None]:
import json

def load_json(file):
    with Path(file).open("r") as handle:
        return json.load(handle)

# CounterFact

In [None]:
EXPERIMENT_NAME = "post_icml_eval_gen_counterfact_gptj"
LAYER = REMEDI_EDITOR_LAYER[MODEL]["counterfact"]
results_dir = RESULTS_ROOT / EXPERIMENT_NAME / "linear" / str(LAYER)
assert results_dir.exists()

essence_results = load_json(results_dir / "essence.json")
generation_results = load_json(results_dir / "generation.json")

ic_essence_results = load_json(results_dir.parent.parent / "prefix" / "essence.json")
ic_generation_results = load_json(results_dir.parent.parent / "prefix" / "generation.json")

Generate an HTML page for us to find diverse and representative examples:

In [None]:
def strip(x):
    return x.split("\n\n")[0]


html = [
    "<!doctype html>",
    "<html>",
    "<head>",
    "<style>",
"""\
table {
    text-align: left;
    border-collapse: collapse;
    vertical-align: top;
}

th {
    font-weight: bold;
    border-top: 2px solid black;
    border-bottom: 1px solid black;
    padding-right: 5em;
}

tr:last-of-type {
    border-bottom: 2px solid black;
}

h2 {
    margin-top: 2em;
}

h4 {
    font-weight: normal;
    text-decoration: underline;
}
"""
    "</style>",
    "</head>",
    "<body>",
]
for i in range(len(counterfact)):
    task = f'{counterfact[i]["entity"]} + {counterfact[i]["attribute"]}'    
    essence_pref = strip(ic_essence_results["samples"][i]["generation"])
    essence_ours = strip(essence_results["samples"][i]["generation"])
    
    # TODO(evandez): Match these.
    generations = []
    for j in range(len(ic_generation_results["samples"][i]["generations"])):
        gen_pref = strip(ic_generation_results["samples"][i]["generations"][j]).replace(counterfact[i]["context"] + ".", "").strip()
        gen_ours = strip(generation_results["samples"][i]["generations"][j])
        generations.append((f"generation {j}", gen_pref, gen_ours))

    html += [
        f"<h2>Sample {i}</h2>",
        f"<p><u>{task}</u></p>",
        "<table>",
        "<thead>",
        "<tr>",
        "<th>source</th>",
        "<th>prefix</th>",
        "<th>remedi</th>",
        "</tr>",
        "</thead>",
        "<tbody>",

    ]
    for label, pref, ours in (
        ("essence", essence_pref, essence_ours),
        *generations,
    ):
        html += [
            "<tr>",
            f"<td>{label}</td>",
            f"<td>{pref}</td>",
            f"<td>{ours}</td>",
            "</tr>",
        ]
    html += [
        "</tbody>",
        "</table>",
    ]
        

html += [
    "</body>",
    "</html>",
]

html_file = Path(f"{MODEL}_counterfact_examples.html")
with html_file.open("w") as handle:
    handle.write("\n".join(html))

In [None]:
def clean_result(result):
    # Take just the first few sentences / the first thought.
    if "\n\n" in result:
        result = result.replace("\n\n", " ")
    
    if "Kremlin" in result:
        limit = 2
    else:
        limit = 2 if "Inc. " not in result else 3
    result = ". ".join(result.split(". ")[:limit])

    # Sometimes CounterFact does not capitalize the entity, do so for presentation.
    result = result[0].upper() + result[1:]
    
    return result

examples = [
    # Ours succeeds.
    (25, "essence", None, (False, True)),
    (3316, "essence", None, (False, True)),
    (1114, "essence", None, (False, True)),
    (501, "essence", None, (False, False)),  # Honda
    (2518, "essence", None, (False, False)),
    (2521, "generation", (0, 2), (False, False)),
    (3109, "essence", None, (False, False)),
    # Both succeed.
    (4865, "essence", None, (True, True)),
    (2878, "generation", (2, 0), (True, True)),
    (4143, "essence", None, (True, True)),
    (1839, "generation", (0, 2), (True, True)),
    # Ours fails.
    (3192, "generation", (1, 2), (True, False)),
]
rows = []
for index, source, positions, corrects in examples:
    if source == "essence":
        result = essence_results["samples"][index]["generation"]
        ic_result = ic_essence_results["samples"][index]["generation"]
    else:
        assert source == "generation"
        assert positions is not None
        ic_result = ic_generation_results["samples"][index]["generations"][positions[0]]
        ic_result = ic_result.replace(counterfact[index]["context"], "").lstrip(". ")
        result = generation_results["samples"][index]["generations"][positions[1]]

    result = clean_result(result)
    ic_result = clean_result(ic_result)

    sample = counterfact[index]

    entity = sample["entity"]
    attribute = sample["prompt"] 
    if attribute.startswith(entity):
        attribute = attribute.replace(entity, "").strip(", ")
    attribute = f"{attribute} {sample['target_mediated']}"

    if index == 1839:
        # hard to format this one automatically
        attribute = "greatest strength is basketball"

    row = (
        index,
        corrects,
        entity,
        attribute,
        ic_result,
        result,
    )
    rows.append(row)
rows

# Bios

In [None]:
EXPERIMENT_NAME = "post_icml_eval_gen_biosbias_gptj"
LAYER = REMEDI_EDITOR_LAYER[MODEL]["counterfact"]
results_dir = RESULTS_ROOT / EXPERIMENT_NAME / "linear" / str(LAYER)
assert results_dir.exists()

essence_results = load_json(results_dir / "essence.json")
generation_results = load_json(results_dir / "generation.json")

ic_essence_results = load_json(results_dir.parent.parent / "prefix" / "essence.json")
ic_generation_results = load_json(results_dir.parent.parent / "prefix" / "generation.json")