In [None]:
device = "cuda:0"
config = "gpt2-large"
data_dir = "data"

In [None]:
import transformers

model = transformers.AutoModelForCausalLM.from_pretrained(config).to(device).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

# Does GPT-2 Resolve Counterfactuals?

Make a "contextual CounterFact" dataset and see how GPT2 does.

In [None]:
import dsets
import tokenizer_utils
from tqdm.auto import tqdm

uncapitalize = lambda s: s[0].lower() + s[1:]

counterfact = dsets.CounterFactDataset(data_dir)
counterfact_ctx = []
for index in tqdm(range(len(counterfact))):
    sample = counterfact[index]
    subject = sample["requested_rewrite"]["subject"]
    new = sample["requested_rewrite"]["target_new"]["str"]
    true = sample["requested_rewrite"]["target_true"]["str"]
    prompt = sample["requested_rewrite"]["prompt"].format(subject)
    supposition = sample["generation_prompts"][0].strip()
    if not supposition.startswith(subject):
        supposition = uncapitalize(supposition)

    context = f"Suppose {supposition} {new}."
    context_prompt = f"{context} {prompt}"
    context_prompt_new = f"{context_prompt} {new}"
    context_prompt_true = f"{context_prompt} {true}"
    reformatted = {
        "entity": subject,
        "context": context,
        "prompt": prompt,
        "new": new,
        "true": true,
        "context_prompt": context_prompt,
        "context_prompt_new": context_prompt_new,
        "context_prompt_new_token_range": tokenizer_utils.find_token_range(
            context_prompt_new,
            new,
            tokenizer,
            occurrence=1),
        "context_prompt_true": context_prompt_true,
        "context_prompt_true_token_range": tokenizer_utils.find_token_range(context_prompt_true, true, tokenizer),
        "relation_id": sample["requested_rewrite"]["relation_id"],
    }
    counterfact_ctx.append(reformatted)

In [None]:
counterfact_ctx[14]

In [None]:
import torch
import torch.utils.data

def compute_logprobs(outputs, ranges):
    seq_token_logprobs = torch.log_softmax(outputs.logits, dim=-1)
    
    logprobs = []
    for token_logprobs, start, end in zip(seq_token_logprobs, *ranges):
        logprob = token_logprobs[start:end].sum()
        logprobs.append(logprob)
    return torch.tensor(logprobs)

def evaluate(dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size=64)
    correct = 0
    with torch.inference_mode():
        for batch in tqdm(loader):
            logprobs = {}
            for key in ("new", "true"):
                texts = batch[f"context_prompt_{key}"]
                ranges = batch[f"context_prompt_{key}_token_range"]
                inputs = tokenizer(
                    list(texts),
                    return_tensors="pt",
                    padding="longest").to(device)
                outputs = model(**inputs)
                logprobs[key] = compute_logprobs(outputs, ranges)
            correct += logprobs["new"].gt(logprobs["true"]).sum().item()
    return correct / len(dataset)

evaluate(counterfact_ctx)

# Causal Trace CounterFact

Convert contextual CounterFact into something like the knowns dataset.

In [None]:
def as_knowns(dataset, key):
    reformatteds = []
    for index, sample in enumerate(counterfact_ctx):
        entity = sample["entity"]
        context_prompt = sample["context_prompt"]
        reformatted = {
            "known_id": index,
            "subject": entity,
            "attribute": sample[key],
            "template": context_prompt.replace(f"{entity}", "{}"),
            "prompt": context_prompt,
            "relation_id": sample["relation_id"],
        }
        reformatteds.append(reformatted)
    return reformatteds

knowns_counterfact_mediated = as_knowns(counterfact_ctx, "new")
knowns_counterfact_unmediated = as_knowns(counterfact_ctx, "true")

In [None]:
knowns_counterfact_mediated[0]

In [None]:
knowns_counterfact_unmediated[0]

In [None]:
import json
import random
from pathlib import Path

k = 2500
for name, dataset in (
    ("mediated", knowns_counterfact_mediated),
    ("unmediated", knowns_counterfact_unmediated),
):
    file = Path(f"knowns_counterfact_{name}_{k}.json")
    with file.open("w") as handle:
        json.dump(random.sample(dataset, k=k), handle)

# Causal Trace WINOVENTI

In [None]:
import csv
from pathlib import Path

data_file = Path("/raid/lingo/dez/data/winoventi.tsv")
with data_file.open("r") as handle:
    samples = tuple(csv.DictReader(handle, delimiter="\t"))
samples[2]