In [None]:
import sys
sys.path.append("..")
sys.path.append("../../context-mediation")

In [None]:
import transformers

device = "cuda"
config = "EleutherAI/gpt-j-6B"

model = transformers.AutoModelForCausalLM.from_pretrained(config, revision="float16", low_cpu_mem_usage=True)
model.to(device)

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from src import data

counterfact = data.load_dataset("counterfact", split="train")

In [None]:
counterfact[0]

In [None]:
from collections import defaultdict

samples_by_relation = defaultdict(list)
for sample in counterfact:
    relation_id = sample["source"]["requested_rewrite"]["relation_id"]
    samples_by_relation[relation_id].append(sample)

In [None]:
len(samples_by_relation), {r: len(s) for r, s in samples_by_relation.items()}

In [None]:
from src.utils import tokenizer_utils

import baukit
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

h_layer = 15
z_layer = model.config.n_layer - 1

h_layername = f"transformer.h.{h_layer}"
z_layername = f"transformer.h.{z_layer}"

preds_by_relation = defaultdict(list)
h_by_relation = defaultdict(list)
z_by_relation = defaultdict(list)
with counterfact.formatted_as("torch"):
    loader = DataLoader(counterfact, batch_size=32, shuffle=False)
    for batch in tqdm(loader):
        inputs = tokenizer(
            batch["prompt"],
            return_tensors="pt",
            padding="longest",
            return_offsets_mapping=True
        ).to(device)
        offset_mapping = inputs.pop("offset_mapping")
        with torch.inference_mode():
            with baukit.TraceDict(model, (h_layername, z_layername)) as ret:
                outputs = model(**inputs)

        batch_idx = torch.arange(len(batch["prompt"]))
        prompt_idx = inputs.attention_mask.sum(dim=-1) - 1
        ids = outputs.logits[batch_idx, prompt_idx].topk(dim=-1, k=5).indices
        tokens = tokenizer_utils.batch_convert_ids_to_tokens(ids.tolist(), tokenizer)

        for i, (rid, entity, prompt, preds, h, z) in enumerate(zip(
            batch["source"]["requested_rewrite"]["relation_id"],
            batch["entity"],
            batch["prompt"],
            tokens,
            ret[h_layername].output[0],
            ret[z_layername].output[0],
        )):
            _, entity_j = tokenizer_utils.find_token_range(
                prompt,
                entity, 
                offset_mapping=offset_mapping[i])
            preds_by_relation[rid].append(preds)
            h_by_relation[rid].append(h[entity_j - 1].cpu())
            z_by_relation[rid].append(z[prompt_idx[i].item()].cpu())

In [None]:
preds_by_relation["P264"][0]

In [None]:
from relations import estimate

layer = 15

accs_by_relation = {}
for rid, samples in samples_by_relation.items():
    print(f"---- {rid} ----")
    samples = [
        s
        for i, s in enumerate(samples)
        if any(
            pred.strip("ĠĊ ").lower()
            in
            s["source"]["requested_rewrite"]["target_true"]["str"].lower()
            for pred in preds_by_relation[rid][i][:3]
            if pred.strip("ĠĊ ").lower()
        )
    ]
    print(f"{len(samples)} known samples")

    trains = samples[:5]
    tests = samples[5:55]

    operators = []
    for train in tqdm(trains, desc="train"):
        operator = estimate.estimate_relation_operator(
            model=model,
            tokenizer=tokenizer,
            relation=train["source"]["requested_rewrite"]["prompt"],
            subject=train["source"]["requested_rewrite"]["subject"],
            layer=layer,
            device=device,
        )
        operators.append(operator)
    
    operator_a = estimate.RelationOperator(
        model=model,
        tokenizer=tokenizer,
        relation=trains[0]["source"]["requested_rewrite"]["prompt"],
        weight=torch.stack([o.weight for o in operators]).mean(dim=0),
        bias=torch.stack([o.bias for o in operators]).mean(dim=0),
        layer=layer,
    )

    correct_by_k = defaultdict(int)
    for test in tqdm(tests, desc="test"):
        rr = test["source"]["requested_rewrite"]
        subject = rr["subject"]
        expected = rr["target_true"]["str"]

        preds = operator_a(subject, device=device, return_top_k=5)
        for k in (1, 3, 5):
            actuals = [p.strip("ĠĊ ").lower() for p in preds[:k]]
            correct_by_k[k] += expected.lower() in actuals

    accs_by_relation[rid] = {
        k: correct / len(tests)
        for k, correct in correct_by_k.items()
    }

In [None]:
accs_by_relation

In [None]:
targets_by_relation = {}
for rid, samples in samples_by_relation.items():
    targets_by_relation[rid] = {
        sample["source"]["requested_rewrite"]["target_true"]["str"]
        for sample in samples
    }
{rid: len(targets) for rid, targets in targets_by_relation.items()}

In [None]:
prompts_by_relation = {}
for rid, samples in samples_by_relation.items():
    prompts_by_relation[rid] = [
        sample["source"]["requested_rewrite"]["prompt"]
        for sample in samples
    ]
{rid: prompts for rid, prompts in prompts_by_relation.items()}