In [None]:
import sys
sys.path.append("../..")

from experiments.aliases import REMEDI_EDITOR_LAYER

In [None]:
from pathlib import Path

MODEL = "gptj"

RESULTS_ROOT = Path("../../results")
assert RESULTS_ROOT.exists()

In [None]:
import json


def load_json(file):
    with file.open("r") as handle:
        return json.load(handle)


remedi_layer = REMEDI_EDITOR_LAYER[MODEL]["mcrae"]
experiment_name = f"post_icml_eval_ent_mcrae_{MODEL}"
results_dir = RESULTS_ROOT / experiment_name

results = load_json(results_dir / "linear" / str(remedi_layer) / "entailment.json")

In [None]:
from remedi import data

dataset = data.load_dataset("mcrae", split="train[5000:10000]")

In [None]:
from collections import Counter
features = [x["attribute"] for x in dataset]
Counter(features).most_common(50)

In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# We'll remove a couple disfluent or weird outliers.
BANNED_FEATURES = frozenset({
    "is used long ago"
})

sns.set()
sns.set_style({'font.family':'serif', 'font.serif':['Times New Roman']})

attribute = "herded by shepherds"
indices = [i for i, x in enumerate(dataset) if x["attribute"] == attribute]

logp_lm_by_co_feature = defaultdict(list)
logp_human_by_co_feature = {}
for index in indices:
    for co_feature in results["samples"][index]["co_features"]:
        feature = co_feature["feature"]
        if feature in BANNED_FEATURES:
            continue
        logp_human_by_co_feature[feature] = co_feature["logp_ref"]
        logp_lm_by_co_feature[feature].append(co_feature["logp_post"])
logp_lm_by_co_feature = {key: np.mean(values) for key, values in logp_lm_by_co_feature.items()}

features = sorted(logp_human_by_co_feature)
xs = [logp_human_by_co_feature[feature] for feature in features]
ys = [logp_lm_by_co_feature[feature] for feature in features]

_, ax = plt.subplots()
for i, feature in enumerate(features):
    ax.annotate(feature, (xs[i], ys[i]))
ax.scatter(xs, ys)
plt.title(attribute)
plt.xlabel("Human Co-Occurrence logp")
plt.ylabel("Post-Edit LM logp")