In [None]:
import sys
sys.path.append("..")
sys.path.append("../../context-mediation")

In [None]:
!nvidia-smi

In [None]:
import transformers

device = "cuda"
config = "EleutherAI/gpt-j-6B"

model = transformers.AutoModelForCausalLM.from_pretrained(config, revision="float16", low_cpu_mem_usage=True)
model.to(device)

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from src import data

counterfact = data.load_dataset("counterfact", split="train")

In [None]:
counterfact[0]

In [None]:
from collections import defaultdict

samples_by_relation = defaultdict(list)
for sample in counterfact:
    relation_id = sample["source"]["requested_rewrite"]["relation_id"]
    samples_by_relation[relation_id].append(sample)

In [None]:
len(samples_by_relation), {r: len(s) for r, s in samples_by_relation.items()}

In [None]:
from src.utils import tokenizer_utils

import baukit
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

preds_by_relation = defaultdict(list)
logp_by_relation = defaultdict(list)
with counterfact.formatted_as("torch"):
    loader = DataLoader(counterfact, batch_size=32, shuffle=False)
    for batch in tqdm(loader):
        inputs = tokenizer(
            batch["prompt"],
            return_tensors="pt",
            padding="longest",
        ).to(device)
        with torch.inference_mode():
            outputs = model(**inputs)

        batch_idx = torch.arange(len(batch["prompt"]))
        prompt_idx = inputs.attention_mask.sum(dim=-1) - 1
        topk = torch.log_softmax(outputs.logits, dim=-1)[batch_idx, prompt_idx].topk(dim=-1, k=5)
        ids = topk.indices
        logps = topk.values
        tokens = tokenizer_utils.batch_convert_ids_to_tokens(ids.tolist(), tokenizer)

        for rid, preds, logp in zip(
            batch["source"]["requested_rewrite"]["relation_id"],
            tokens,
            logps.cpu(),
        ):
            preds_by_relation[rid].append(preds)
            logp_by_relation[rid].append(logp)

In [None]:
preds_by_relation["P264"][0]

In [None]:
from collections import defaultdict

import estimate

layer = 15
n_train = 5
n_test = 150


relation_templates = {
    "P103": "{} speaks the language of",
    "P140": "{} follows the religion of",
    "P740": "{} originates from",
    "P190": "{} has a sister city named \"",
    "P178": "{} was originally developed by",
    "P176": "{} is developed by the company",
    "P413": "In their sport, {} plays the position of",
    "P39": "{} has the occupation of",
    "P407": "{} is written in the language of",
    "P101": "{} is associated with the field of",
    "P1412": "{} speaks the language of",
    "P27": "{} is originally from the country of",
    "P106": "{} has the occupation of",
    "P136": "{} is associated with the genre of",
    "P159": "{} is based in the city of",
    "P276": "{} is located in the city of",
    "P36": "{} has the capitol city of",
}


accs_by_relation = defaultdict(dict)
for rid, samples in samples_by_relation.items():
    print(f"---- {rid} ----")

    # Only use examples for which the model encodes the correct relation.
    samples = [
        s
        for i, s in sorted(
            enumerate(samples),
            key=lambda x: logp_by_relation[rid][x[0]][0].item(),
            reverse=True,
        )
        if any(
            pred.strip("ĠĊ ").lower()
            in
            s["source"]["requested_rewrite"]["target_true"]["str"].strip().lower()
            for pred in preds_by_relation[rid][i][:1]
            if pred.strip("ĠĊ ").lower()
        )
    ]
    print(f"{len(samples)} known samples")

    # When picking training examples, choose a diverse set of labels.
    trains = []
    tests = []
    seen = set()
    for sample in samples:
        label = sample["source"]["requested_rewrite"]["target_true"]["str"]
        if len(trains) >= n_train:
            if len(tests) < n_test:
                tests.append(sample)
            continue
        elif label not in seen:
            trains.append(sample)
            seen.add(label)
        else:
            tests.append(sample)

    # Pick best relation text via heuristic
    if rid in relation_templates:
        relation = relation_templates[rid]
    else:
        #         continue
        # Pick by heuristic
        rs0 = [
            sample["source"]["requested_rewrite"]["prompt"]
            for sample in samples
        ]
        # Always prefer one that puts the subject first:
        rs1 = [r for r in rs0 if r.startswith("{}")]
        # Then, prefer one with no special punctuation:
        rs2 = [r for r in rs1 if not any(x in r for x in "?,:;.")]
        # Then prefer the longest one:
        rs3 = sorted(rs2, key=lambda r: len(r), reverse=True)
        relation = None
        for rs in (rs3, rs2, rs1, rs0):
            if rs:
                relation = rs[0]
                break
        assert relation is not None

    batch = [
        (
            train["source"]["requested_rewrite"]["subject"],
            train["source"]["requested_rewrite"]["target_true"]["str"]
        )
        for train in trains
    ]
    print(relation)
    print(batch)
    print(len(tests), "test examples")
#     continue
#     print([t["source"]["requested_rewrite"]["prompt"] for t in trains])
#     continue

    operator_a, _ = estimate.relation_operator_from_batch(
        model=model,
        tokenizer=tokenizer,
        relation=relation,
        samples=batch,
        layer=layer,
        device=device,
    )

    correct_by_k = defaultdict(int)
    for test in tqdm(tests, desc="test"):
        rr = test["source"]["requested_rewrite"]
        subject = rr["subject"]
        expected = rr["target_true"]["str"]

        preds = operator_a(subject, device=device, return_top_k=5)
        for k in (1, 3, 5):
            actuals = [p.strip("ĠĊ ").lower() for p in preds[:k]]
            correct_by_k[k] += any(
                expected.lower().strip().startswith(actual)
                for actual in actuals
            )

    for k in (1, 3, 5):
        accuracy = correct_by_k[k] / len(tests)
        print(f"top-{k} accuracy: {accuracy:.2f}")
        accs_by_relation[rid][k] = accuracy

Cases to analyze:
- P176: "is developed by"; how come we get zero accuracy?
- P413: Which prompt is a good one?
- P190: Sister city...why so bad?
- P136: What is going on with this relation? What is it supposed to be?
- P1412: Why is this different from the other language speaking one?
- P140, P103: What is the top-1 token?

**Nasty**: Copy paste above to compute baselines.

In [None]:
from collections import Counter

baseline_by_relation = {}
for rid, samples in samples_by_relation.items():
    print(f"---- {rid} ----")

    # Only use examples for which the model encodes the correct relation.
    samples = [
        s
        for i, s in sorted(
            enumerate(samples),
            key=lambda x: logp_by_relation[rid][x[0]][0].item(),
            reverse=True,
        )
        if any(
            pred.strip("ĠĊ ").lower()
            in
            s["source"]["requested_rewrite"]["target_true"]["str"].lower()
            for pred in preds_by_relation[rid][i][:1]
            if pred.strip("ĠĊ ").lower()
        )
    ]
    print(f"{len(samples)} known samples")

    # When picking training examples, choose a diverse set of labels.
    trains = []
    tests = []
    seen = set()
    for sample in samples:
        label = sample["source"]["requested_rewrite"]["target_true"]["str"]
        if len(trains) >= n_train:
            if len(tests) < n_test:
                tests.append(sample)
            continue
        elif label not in seen:
            trains.append(sample)
            seen.add(label)
        else:
            tests.append(sample)

    # Pick best relation text via heuristic
    if rid in relation_templates:
        relation = relation_templates[rid]
    else:
        #         continue
        # Pick by heuristic
        rs0 = [
            sample["source"]["requested_rewrite"]["prompt"]
            for sample in samples
        ]
        # Always prefer one that puts the subject first:
        rs1 = [r for r in rs0 if r.startswith("{}")]
        # Then, prefer one with no special punctuation:
        rs2 = [r for r in rs1 if not any(x in r for x in "?,:;.")]
        # Then prefer the longest one:
        rs3 = sorted(rs2, key=lambda r: len(r), reverse=True)
        relation = None
        for rs in (rs3, rs2, rs1, rs0):
            if rs:
                relation = rs[0]
                break
        assert relation is not None

    batch = [
        (
            train["source"]["requested_rewrite"]["subject"],
            train["source"]["requested_rewrite"]["target_true"]["str"]
        )
        for train in trains
    ]
    print(relation)
    print(batch)
    print(len(tests), "test examples")

    label = Counter([sample["source"]["requested_rewrite"]["target_true"]["str"] for sample in samples]).most_common()[0][0]
    print("Majority=", label)
    correct = 0
    for test in tqdm(samples, desc="test"):
        expected = test["source"]["requested_rewrite"]["target_true"]["str"]
        correct += expected.strip().lower() == label.strip().lower()
    
    accuracy = correct / len(samples)
    print(f"accuracy={accuracy:.2f}")
    baseline_by_relation[rid] = accuracy

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


xys = sorted(accs_by_relation.items(), key=lambda kv: kv[-1][3], reverse=True)

banned = {
    # Too few examples
    "P463",
    "P264",
    # Really weird; mostly jazz? What is this one?
    "P136",
}
xys = [(x, y) for x, y in xys if x not in banned]

xs = [rid for rid, _ in xys]
ys = [scores[3] for _, scores in xys]
zs = [baseline_by_relation[rid] for rid in xs]

x = np.arange(len(xs))
width = 0.35  # the width of the bars

fig, ax = plt.subplots(figsize=(15, 3))
ours = ax.bar(x - width / 2, ys, width, label="Estimate J/b")
baseline = ax.bar(x + width / 2, zs, width, label="Majority vote")
ax.set_xticks(x, xs)
ax.legend()
fig.tight_layout()

plt.title("top-3 accuracy by relation")

# What is going on?

In [None]:
targets_by_relation = {}
for rid, samples in samples_by_relation.items():
    targets_by_relation[rid] = {
        sample["source"]["requested_rewrite"]["target_true"]["str"]
        for sample in samples
    }
{rid: targets for rid, targets in targets_by_relation.items()}["P39"]

In [None]:
prompts_by_relation = {}
for rid, samples in samples_by_relation.items():
    prompts_by_relation[rid] = [
        sample["source"]["requested_rewrite"]["prompt"]
        for sample in samples
    ]
{rid: prompts for rid, prompts in prompts_by_relation.items()}["P176"]

In [None]:
# rid = "P176"
# relation = "{} was developed by"
# trained_on = [
#     ('Microsoft Band', 'Microsoft'),
#     ('Kindle Fire', 'Amazon'),
#     ('Sony NEX-5', 'Sony'),
#     ('Chromecast', 'Google'),
#     ('Nintendo Entertainment System', 'Nintendo'),
# ]


rid = "P1412"
relation = "{} primarily spoke the language of"
trained_on = [
    ('Juan Bautista de Anza', 'Spanish'),
    ('Glamourina', 'Ukrainian'),
    ('Francesc Eiximenis', 'Catalan'),
    ('Milo Manara', 'Italian'),
    ('Aleksandar Zograf', 'Serbian'),
]

tests = [
    (
        s["source"]["requested_rewrite"]["subject"],
        s["source"]["requested_rewrite"]["target_true"]["str"],
    )
    for s in samples_by_relation[rid]
]

operator, _ = estimate.relation_operator_from_batch(
    model,
    tokenizer,
    trained_on,
    relation,
    device=device)

# Logit lens for b.
logits = model.lm_head(model.transformer.ln_f(operator.bias[None]))
ids = logits.topk(k=5, dim=-1).indices
tokens = tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
print("LOGIT LENS", tokens)

correct = 0
for subject, target in tests:
    predictions = operator(subject, device=device)
    print(subject, f"preds={predictions}", f"target={target}")
    correct += any(p.strip("Ġ ") in target for p in predictions)
print(f"{correct / len(tests):.2f}")