In [None]:
import sys
sys.path.append("/raid/lingo/dez/code/context-mediation")

In [None]:
import transformers
device = "cuda"
config = "gpt2-xl"
model = transformers.AutoModelForCausalLM.from_pretrained(config).to(device).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import json
from pathlib import Path

from src.utils import tokenizers

from tqdm.auto import tqdm


def load_and_preprocess(file, root="/raid/lingo/dez/code/rome/data/mediation"):
    with Path(root, file).open("r") as handle:
        dataset = json.load(handle)
    is_counterfact = "counterfact" in file
    for sample in tqdm(dataset, desc=f"load and preprocess {file}"):
        subject = sample["subject"]
        prompt = sample["prompt"]
        mediated = sample["attribute"]
        unmediated = sample["comparator"]
        sample["prompt_mediated"] = prompt_mediated = f"{prompt} {mediated}"
        sample["prompt_unmediated"] = prompt_unmediated = f"{prompt} {unmediated}"
        sample["token_range_mediated_attr"] = tokenizers.find_token_range(
            prompt_mediated,
            mediated,
            tokenizer,
            occurrence=1 if is_counterfact else 0,
        )
        sample["token_range_unmediated_attr"] = tokenizers.find_token_range(
            prompt_unmediated,
            unmediated,
            tokenizer,
        )
        sample["token_range_subject_first"] = tokenizers.find_token_range(
            prompt if is_counterfact else prompt.lower(),
            subject if is_counterfact else subject.lower(),
            tokenizer,
            occurrence=0)
        sample["token_range_subject_last"] = tokenizers.find_token_range(
            prompt if is_counterfact else prompt.lower(),
            subject if is_counterfact else subject.lower(),
            tokenizer,
            occurrence=sample["occurrence"])
    return dataset

winoventi = load_and_preprocess("winoventi_subj_last.json")
counterfact = load_and_preprocess("counterfact_med_subj_last.json")

# Evaluate Mediation

In [None]:
import torch
import torch.utils.data

def compute_logprobs(inputs, outputs, ranges):
    seq_token_logprobs = torch.log_softmax(outputs.logits, dim=-1)
    logprobs = []
    for tokens, token_logprobs, start, end in zip(inputs.input_ids,
                                                  seq_token_logprobs,
                                                  *ranges):
        logprob = token_logprobs[torch.arange(start, end), tokens[start:end]].sum()
        logprobs.append(logprob)
    return torch.tensor(logprobs)
 
def compute_gt_indices(logp_left, logp_right):
    return logp_left\
        .gt(logp_right)\
        .nonzero()\
        .squeeze()\
        .tolist() 

def evaluate(dataset, batch_size=64):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    meds, umeds = [], []
    with torch.inference_mode():
        for bi, batch in enumerate(tqdm(loader)):
            logprobs = {}
            for key in ("mediated", "unmediated"):
                texts = batch[f"prompt_{key}"]
                ranges = batch[f"token_range_{key}_attr"]
                inputs = tokenizer(
                    list(texts),
                    return_tensors="pt",
                    padding="longest").to(device)
                outputs = model(**inputs)
                logprobs[key] = compute_logprobs(inputs, outputs, ranges)
            meds_idx = compute_gt_indices(logprobs["mediated"],
                                          logprobs["unmediated"])
            umeds_idx = compute_gt_indices(logprobs["unmediated"],
                                           logprobs["mediated"])
            offset = bi * batch_size
            for indices, results in ((meds_idx, meds), (umeds_idx, umeds)):
                results += [
                    dict(
                        logp_mediated=torch.exp(logprobs["mediated"][index]).item(),
                        logp_unmediated=torch.exp(logprobs["unmediated"][index]).item(),
                        **dataset[offset + index],
                    )
                    for index in indices
                ]

    return len(meds) / (len(meds) + len(umeds)), meds, umeds

counterfact_results = evaluate(counterfact)
print("CF:", counterfact_results[0])

winoventi_results = evaluate(winoventi)
print("WV:", winoventi_results[0])

In [None]:
print(winoventi_results[2][0])

# Fix Mediation

Precompute inputs and targets.

In [None]:
model

In [None]:
LAYER = 15
LR=5e-5
MAX_EPOCHS = 100
BATCH_SIZE = 64
PATIENCE = 10
DATASET = counterfact_results

In [None]:
import nethook
import torch
import torch.utils.data
from tqdm.auto import tqdm


@torch.inference_mode()
def precompute_hiddens(model,
                       tokenizer,
                       samples,
                       device=device,
                       batch_size=BATCH_SIZE,
                       layer=LAYER,
                       desc="precompute hiddens"):
    model = model.eval().to(device)
    loader = torch.utils.data.DataLoader(samples, batch_size=batch_size)
    hiddens = []
    for batch in tqdm(loader, desc=desc):
        inputs = tokenizer(batch["prompt"], return_tensors="pt", padding="longest").to(device)
        with nethook.Trace(model, f"transformer.h.{layer}") as ret:
            outputs = model(**inputs)
            hiddens += ret.output[0]

    results = []
    for h, sample in zip(hiddens, samples):
        subj_first_i, subj_first_j = sample["token_range_subject_first"]
        subj_last_i, subj_last_j = sample["token_range_subject_last"]
        result = {
            "h": h,
            "h_attr_avg": h[subj_first_j:subj_last_i - 1].mean(dim=0), # TODO: Upper bound not quite right
            "h_subj_avg_delt": (h[subj_last_i:subj_last_j] - h[subj_first_i:subj_first_j]).mean(dim=0),
            **sample,
        }
        results.append(result)
    return results

preprocessed = precompute_hiddens(model, tokenizer, DATASET[1])

In [None]:
import nethook
import torch.utils.data
from torch import nn, optim

from src.utils import training

train, val = training.random_split(preprocessed, hold_out=.1)
train_loader = torch.utils.data.DataLoader(train)
val_loader = torch.utils.data.DataLoader(val)
probe = nn.Sequential(
    nn.Linear(model.config.hidden_size, model.config.hidden_size),
#     nn.ReLU(),
#     nn.Linear(model.config.hidden_size, model.config.hidden_size),
).to(device)

optimizer = optim.AdamW(probe.parameters(), lr=LR)
criterion = nn.MSELoss()

best, stopper = None, training.EarlyStopping(patience=PATIENCE)
for epoch in range(MAX_EPOCHS):
    if epoch != 0:
        train_loss = 0.
        probe.train()
        for batch in train_loader:
            inputs = batch["h_attr_avg"]
            targets = batch["h_subj_avg_delt"]
            predictions = probe(inputs)
            loss = criterion(predictions, targets)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            train_loss += loss.item()
        train_loss /= len(train_loader)
    else:
        train_loss = float("inf")

    val_loss = 0.        
    probe.eval()
    for batch in val_loader:
        inputs = batch["h_attr_avg"]
        targets = batch["h_subj_avg_delt"]
        predictions = probe(inputs)
        loss = criterion(predictions, targets)
        val_loss += loss.item()
    val_loss /= len(val_loader)
    print(f"epoch {epoch} train={train_loss:.3f} val={val_loss:.3f}")

    if stopper(val_loss):
        break
    if stopper.improved:
        best = probe.state_dict()

In [None]:
@torch.inference_mode()
def test_editor(model, tokenizer, editor, sample, layer=LAYER, alpha=1):
    model.eval()
    editor.eval()

    def edit_output(output):
        subj_first_i, subj_first_j = sample["token_range_subject_first"]
        subj_last_i, subj_last_j = sample["token_range_subject_last"]
        direction = editor(output[0][0:1, subj_first_j:subj_last_i - 1].mean(dim=1))
        output[0][0:1, subj_last_i:subj_last_j] = output[0][0:1, subj_last_i:subj_last_j] + alpha * direction
        return output
    
    inputs = tokenizer(sample["prompt"], return_tensors="pt").to(device)
    with nethook.Trace(model, f"transformer.h.{LAYER}", edit_output=edit_output) as _:
        outputs = model.generate(inputs.input_ids, max_new_tokens=7)

    return tokenizer.batch_decode(outputs)

test = DATASET[2]
index = 10
print(test[index])
test_editor(model, tokenizer, probe, test[index], alpha=0)