In [None]:
import sys
sys.path.append("..")

In [None]:
from pathlib import Path

editor_type = "linear"
results_dir = Path("/raid/lingo/dez/web/unitname/context-mediation") / editor_type
results_dir.mkdir(exist_ok=True, parents=True)

outputs_dir = Path("../results/editors/linear")
layers = [int(path.name) for path in outputs_dir.glob("*")][:-1]
editor_type, layers

# Load Results

In [None]:
import json

from src import editors

from tqdm.auto import tqdm

def load_results(layer):
    file = outputs_dir / str(layer) / "eval.json"
    with file.open("r") as handle:
        data = json.load(handle)
    return data["results"]

results_by_layer = {
    layer: load_results(layer) for layer in layers
}

In [None]:
# Maybe pre-filter results so that only cases where model was originally correct is considered.
def filter_results(rs_by_l):
    return {
        l: [
            r
            for r in rs
            if r["sample"]["target_unmediated"] in {tok.strip("Ġ ") for tok in r["before_top_tokens"][:5]}
            and r["after_target_mediated_score"] - r["before_target_mediated_score"] > 1e-6
        ]
        for l, rs in rs_by_l.items()
    }

filtered_by_layer = filter_results(results_by_layer)

In [None]:
filtered_by_layer[0][0]

# Average Delta Mediated Prob

In [None]:
average_delta_by_layer = {
    layer: sum(
        result["after_target_mediated_score"] - result["before_target_mediated_score"]
        for result in results
    ) / len(results)
    for layer, results in filtered_by_layer.items()
}
average_delta_by_layer

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def heatmap_by_layer(values_by_layer, title):
    plt.figure(figsize=(15, 2))
    plt.title(title)
    data = np.array([[values_by_layer[layer] for layer in sorted(values_by_layer)]])
    sns.heatmap(data)
    plt.yticks([]);
    plt.xlabel("gpt2-xl layer")

heatmap_by_layer(average_delta_by_layer, "Avg. Relative Change in P(mediated)")
plt.savefig(str(results_dir / "effect_size.png"))

# Average Mediated - Unmediated Delta

In [None]:
average_delta_diff_by_layer = {}
for layer, results in filtered_by_layer.items():
    diffs = []
    for result in results:
        before = result["before_target_mediated_score"] - result["before_target_unmediated_score"]
        after = result["after_target_mediated_score"] - result["after_target_unmediated_score"]
        if after - before < 1e-6 or before < 1e-6:
            diff = 0
        else:
            diff = after - before
        diffs.append(diff)
    average_delta_diff_by_layer[layer] = sum(diffs) / len(diffs)
average_delta_diff_by_layer

In [None]:
heatmap_by_layer(average_delta_diff_by_layer, "Avg. Relative Change in P(Mediated) - P(Unmediated)")

# Accuracy P(Med) > P(Unmediated)

In [None]:
accuracy_by_layer = {}
for layer, results in filtered_by_layer.items():
    correct = 0
    for result in results:
        if result["after_target_mediated_score"] > result["after_target_unmediated_score"]:
            correct += 1
    accuracy_by_layer[layer] = correct / len(results)

In [None]:
heatmap_by_layer(accuracy_by_layer, "Accuracy")
plt.savefig(str(results_dir / "accuracy.png"))

# HTML

In [None]:
results = [
    [results_by_layer[layer][index] for layer in results_by_layer]
    for index in range(len(results_by_layer[0]))
]
html = [
    "<!doctype html>",
    "<html>",
    "<body>",
    f'<img alt="effect_size" src="https://unitname.csail.mit.edu/context-mediation/{editor_type}/effect_size.png">',
    f'<img alt="accuracy" src="https://unitname.csail.mit.edu/context-mediation/{editor_type}/accuracy.png">',
]
for sample_num, rs in enumerate(results):
    kv = lambda key: f"<b>{key}: </b>{rs[0]['sample'][key]}"
    html += [
        "<div>",
        f"<h2>Sample {sample_num}</h2>",
        "<div>Inputs:</div>",
        "<ul>",
        f"<li>{kv('entity')}</li>",
        f"<li>{kv('context')}</li>",
        f"<li>{kv('attribute')}</li>",
        f"<li>{kv('prompt')}</li>",
        f"<li>{kv('target_mediated')}</li>",
        f"<li>{kv('target_unmediated')}</li>",
        "</ul>",
    ]
    html += [
        "<div>Model outputs:</div>",
        "<ul>",
    ]
    html += [
        "<li>",
        "<b>original: </b>",
        f"{rs[0]['before_generations']}"
        "</li>",
    ]
    for layer, result in enumerate(rs):
        html += [
            f"<li>after edit <b>layer {layer}</b>: {result['after_generations']}</li>"
        ]
    html += [
        "</ul>",
        "<hr>",
    ]

html += [
    "</body>",
    "</html>",
]

html_file = results_dir / "editing_results.html"
html_file.parent.mkdir(exist_ok=True, parents=True)
with html_file.open("w") as handle:
    handle.write("\n".join(html))

# Playground

In [None]:
entity = "Stanford University"
attribute = "was founded in the city of Madrid"
context = f"{entity} {attribute}"
prompt = f"{entity}, located in the country of"

with editors.apply(editor, alpha=.5) as edited_mt:
    outputs = edited_mt.model.generate({
        "entity": entity,
        "prompt": prompt,
        "attribute": attribute,
        "context": context,
    }, max_new_tokens=20)
mt.tokenizer.batch_decode(outputs)