In [None]:
import sys
sys.path.append("../..")

from experiments.aliases import REMEDI_EDITOR_LAYER

In [None]:
from pathlib import Path

MODEL = "gptj"

RESULTS_ROOT = Path("../../results")
assert RESULTS_ROOT.exists()

In [None]:
import json


def load_json(file):
    with file.open("r") as handle:
        return json.load(handle)


remedi_layer = REMEDI_EDITOR_LAYER[MODEL]["mcrae"]
experiment_name = f"post_icml_eval_ent_mcrae_{MODEL}"
results_dir = RESULTS_ROOT / experiment_name

results = load_json(results_dir / "linear" / str(remedi_layer) / "entailment.json")

In [None]:
from remedi import data

dataset = data.load_dataset("mcrae", split="train[5000:10000]")

In [None]:
from collections import defaultdict

import numpy as np

logp_diffs_co = defaultdict(list)
logp_diffs_orig = defaultdict(list)
logp_diffs_unrel = defaultdict(list)
for result in results["samples"]:
    sid = result["id"]
    for features_key, diffs in (
        ("co_features", logp_diffs_co),
        ("orig_features", logp_diffs_orig),
        ("unrel_features", logp_diffs_unrel),
    ):
        for feature in result[features_key]:
            diffs[sid].append((feature["logp_post"] - feature["logp_pre"]) / abs(feature["logp_pre"]))
#             diffs[sid].append(np.exp(feature["logp_post"]) - np.exp(feature["logp_pre"]))

for key, diffs in (
    ("co", logp_diffs_co),
    ("orig", logp_diffs_orig),
    ("unrel", logp_diffs_unrel),
):
    values = [np.mean(ds) for ds in diffs.values()]
    print(key, np.mean(values))

In [None]:
from collections import Counter
features = [x["attribute"] for x in dataset]
Counter(features).most_common(100)

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# We'll remove a couple disfluent or weird outliers.
BANNED_FEATURES = frozenset({
    "is used long ago"
})

sns.set()
sns.set_style({'font.family':'serif', 'font.serif':['Times New Roman']})

attribute = "comes in bunches"
indices = [i for i, x in enumerate(dataset) if x["attribute"] == attribute]

logp_lm_by_co_feature = defaultdict(list)
logp_human_by_co_feature = {}
for index in indices:
    for co_feature in results["samples"][index]["co_features"]:
        feature = co_feature["feature"]
        if feature in BANNED_FEATURES:
            continue
        logp_human_by_co_feature[feature] = co_feature["logp_ref"]
        logp_lm_by_co_feature[feature].append(co_feature["logp_post"])
logp_lm_by_co_feature = {key: np.mean(values) for key, values in logp_lm_by_co_feature.items()}

features = sorted(logp_human_by_co_feature)
xs = [logp_human_by_co_feature[feature] for feature in features]
ys = [logp_lm_by_co_feature[feature] for feature in features]

_, ax = plt.subplots()
for i, feature in enumerate(features):
    ax.annotate(feature, (xs[i], ys[i]))
ax.scatter(xs, ys)
plt.title(attribute)
plt.xlabel("Human Co-Occurrence logp")
plt.ylabel("Post-Edit LM logp")

In [None]:
from collections import OrderedDict

index = 5
n_samples = 15

sample = dataset[index]
result = results["samples"][index]

title = f'{sample["entity"].capitalize()} + {sample["attribute"]}'

data = OrderedDict()
for key, label in (
    ("co", "Entailed"),
    ("orig", "Original"),
    ("unrel", "Unrelated"),
):
    features = result[f"{key}_features"][:n_samples]
    before = [feature["logp_pre"] for feature in features]
    after = [feature["logp_post"] for feature in features]
    labels = [feature["feature"] for feature in features]
    data[label] = {
        "before": before,
        "after": after,
        "labels": labels,
    }

An alternative visualization...

In [None]:
from collections import OrderedDict

# attribute = "is found on couches"
# attribute = "lives by the ocean"
# attribute = "has flashing lights"
# attribute = "spins webs"
attribute = "is used for chopping wood"

unrel_cutoff_logp = -30
limit = 25

title = r"Effect of REMEDI($\it{" + attribute.replace(" ", r"\ ") + "}$)"

indices = [i for i, x in enumerate(dataset) if x["attribute"] == attribute]

corr = []
orig = []
unrel = []
for index in indices:
    for feature in results["samples"][index]["co_features"]:
        logp_pre = feature["logp_pre"]
        logp_post = feature["logp_post"]
        logp_ref = feature["logp_ref"]
        label = feature["feature"]
        corr.append([logp_pre, logp_post, logp_ref, label])

    for feature in results["samples"][index]["orig_features"]:
        logp_pre = feature["logp_pre"]
        logp_post = feature["logp_post"]
        logp_ref = feature["logp_ref"]
        label = feature["feature"]
        orig.append([logp_pre, logp_post, logp_ref, label])

    for feature in results["samples"][index]["unrel_features"]:
        logp_pre = feature["logp_pre"]
        logp_post = feature["logp_post"]
        label = feature["feature"]
        if logp_pre < unrel_cutoff_logp:
            continue
        unrel.append([logp_pre, logp_post, float("-inf"), label])

def apply_limit(values):
    chosen = sorted(
        values,
#         key=lambda x: abs(x[1] - x[0])
        key=lambda x: x[0]
#         key=lambda x: (x[2], x[0])
#         key=lambda x: x[1]
        , reverse=True)[:limit]
    return np.array([x[:2] for x in chosen]), [x[-1] for x in chosen]

corr, corr_labels = apply_limit(corr[:limit])
unrel, unrel_labels = apply_limit(unrel[:limit])
orig, orig_labels = apply_limit(orig[:limit])

print(corr[:, 1] - corr[:, 0])
data = OrderedDict() 

# Version where labels are excluded.
# corr_labels = [""] * len(corr)
# orig_labels = [""] * len(orig)
# unrel_labels = [""] * len(unrel)

# Version where deltas are shown.
data["Correlated Features ($f^{(c)}$)"] = {
    "before": corr[:, 0],
    "after": corr[:, 1],
    "labels": corr_labels,
}
data["Original Features of Edited Concept ($f^{(o)}$)"] = {
    "before": orig[:, 0],
    "after": orig[:, 1],
    "labels": orig_labels
}
data["Random Features"] = {
    "before": unrel[:, 0],
    "after": unrel[:, 1],
    "labels": unrel_labels,
}

# Version where only diffs are shown.
# data["Correlated"] = {
#     "before": corr[:, 1] - corr[:, 0],
#     "after": corr[:, 1] - corr[:, 0],
#     "labels": [""] * len(corr),
# }
# data["Original"] = {
#     "before": orig[:, 1] - orig[:, 0],
#     "after": orig[:, 1] - orig[:, 0],
#     "labels": [""] * len(orig)
# }
# data["Unrelated"] = {
#     "before": unrel[:, 1] - unrel[:, 0],
#     "after": unrel[:, 1] - unrel[:, 0],
#     "labels": [""] * len(unrel),
# }

In [None]:
import random

import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

ymin = -40
ytext = -21

def plot_subplot(ax, before, after, labels):
    for i, (b, a, label) in enumerate(zip(before, after, labels)):
        x = i - .5
        color = 'darkblue' if b < a else "red"
        ax.plot([x, x], [b, a], marker='o', color=color)
        if abs(a - b) >= 1.5: 
            ax.annotate(
                "",
                xy=(x, a),
                xytext=(x, b),
                arrowprops=dict(arrowstyle="->", lw=1.5, color=color),
            )
        if (
            label
            and len(label) < 25
#             and i < 3 * len(labels) // 4
            and i < 18
            and i % 2 == 0
        ):   
            ax.plot([x, x], [min(a, b) - .5, ytext], linestyle=':', color='darkgrey')
            ax.text(
                x,
                ytext,
                label,
                fontsize=14,
                rotation=-45,
                bbox=dict(facecolor='white', edgecolor='darkgrey', boxstyle='round,pad=0.2'),
                va="top",
                ha="left",
            )

        # Remove x-axis tick labels and y-axis ticks for subplots other than the leftmost one
        ax.set_xticks([])
        ax.set_xticklabels([])
        ax.set_yticks(np.arange(0, ymin - 1, -5))
        
        # Last label is a bit ugly, remove it.
        labels = ax.get_yticklabels()
        if labels:
            labels[-1] = ""
            ax.set_yticklabels(labels)

# Create the subplots
fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(18, 5))
plt.suptitle(title, fontsize=18)

# Plot data in subplots
for ax, (key, values) in zip(axes, data.items()):
    plot_subplot(ax, values['before'], values['after'], values['labels'])
    ax.set_title(key, fontsize=16)
    ax.set_xlabel('')

axes[0].set_ylabel('$\log p_{LM}(f \mid c)$', fontsize=16)
# axes[1].set_ylabel('$\log p_M(f^{(o)})$', fontsize=16)
# axes[2].set_ylabel('$\log p_M(f)$', fontsize=16)

plt.tight_layout()
plt.savefig('mcrae.pdf', dpi=300)
plt.show()