In [None]:
import sys
sys.path.append("/raid/lingo/dez/code/lm-context-mediation")

In [None]:
import transformers
device = "cuda"
config = "gpt2-xl"
model = transformers.AutoModelForCausalLM.from_pretrained(config).to(device).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import json
from pathlib import Path

from src.utils import tokenizers

from tqdm.auto import tqdm


def load_and_preprocess(file, root="/raid/lingo/dez/code/rome/data/mediation"):
    with Path(root, file).open("r") as handle:
        dataset = json.load(handle)
    is_counterfact = "counterfact" in file
    for sample in tqdm(dataset, desc=f"load and preprocess {file}"):
        subject = sample["subject"]
        prompt = sample["prompt"]
        mediated = sample["attribute"]
        unmediated = sample["comparator"]
        sample["prompt_mediated"] = prompt_mediated = f"{prompt} {mediated}"
        sample["prompt_unmediated"] = prompt_unmediated = f"{prompt} {unmediated}"
        sample["token_range_mediated_attr"] = tokenizers.find_token_range(
            prompt_mediated,
            mediated,
            tokenizer,
            occurrence=1 if is_counterfact else 0,
        )
        sample["token_range_unmediated_attr"] = tokenizers.find_token_range(
            prompt_unmediated,
            unmediated,
            tokenizer,
        )
        sample["token_range_subject_first"] = tokenizers.find_token_range(
            prompt if is_counterfact else prompt.lower(),
            subject if is_counterfact else subject.lower(),
            tokenizer,
            occurrence=0)
        sample["token_range_subject_last"] = tokenizers.find_token_range(
            prompt if is_counterfact else prompt.lower(),
            subject if is_counterfact else subject.lower(),
            tokenizer,
            occurrence=sample["occurrence"])
    return dataset

winoventi = load_and_preprocess("winoventi_subj_last.json")
counterfact = load_and_preprocess("counterfact_med_subj_last.json")

# Evaluate Mediation

In [None]:
import torch
import torch.utils.data

def compute_logprobs(inputs, outputs, ranges):
    seq_token_logprobs = torch.log_softmax(outputs.logits, dim=-1)
    logprobs = []
    for tokens, token_logprobs, start, end in zip(inputs.input_ids,
                                                  seq_token_logprobs,
                                                  *ranges):
        logprob = token_logprobs[torch.arange(start, end), tokens[start:end]].sum()
        logprobs.append(logprob)
    return torch.tensor(logprobs)
 
def compute_gt_indices(logp_left, logp_right):
    assert logp_left.shape == logp_right.shape
    return logp_left\
        .gt(logp_right)\
        .nonzero()\
        .squeeze()\
        .tolist() 

def evaluate(dataset, batch_size=64):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    meds, umeds = [], []
    with torch.inference_mode():
        for bi, batch in enumerate(tqdm(loader)):
            logprobs = {}
            for key in ("mediated", "unmediated"):
                texts = batch[f"prompt_{key}"]
                ranges = batch[f"token_range_{key}_attr"]
                inputs = tokenizer(
                    list(texts),
                    return_tensors="pt",
                    padding="longest").to(device)
                outputs = model(**inputs)
                logprobs[key] = compute_logprobs(inputs, outputs, ranges)
            meds_idx = compute_gt_indices(logprobs["mediated"],
                                          logprobs["unmediated"])
            umeds_idx = compute_gt_indices(logprobs["unmediated"],
                                           logprobs["mediated"])
            offset = bi * batch_size
            for indices, results in ((meds_idx, meds), (umeds_idx, umeds)):
                results += [
                    dict(
                        p_mediated=torch.exp(logprobs["mediated"][index]).item(),
                        p_unmediated=torch.exp(logprobs["unmediated"][index]).item(),
                        **dataset[offset + index],
                    )
                    for index in indices
                ]

    return len(meds) / (len(meds) + len(umeds)), meds, umeds

counterfact_results = evaluate(counterfact)
print("CF:", counterfact_results[0])

winoventi_results = evaluate(winoventi)
print("WV:", winoventi_results[0])

In [None]:
print(winoventi_results[2][0])

# Fix Mediation

Precompute inputs and targets.

In [None]:
LAYERS = range(15, 31)
LAYER = 15
LR=7e-5
HOLD_OUT=.1
MAX_EPOCHS = 250
BATCH_SIZE = 64
PATIENCE = 3
DATASET = counterfact_results

In [None]:
import nethook
import torch
import torch.utils.data
from tqdm.auto import tqdm

from src.utils import tokenizers

@torch.inference_mode()
def precompute_hiddens(model,
                       tokenizer,
                       samples,
                       device=device,
                       batch_size=BATCH_SIZE,
                       layer=LAYER,
                       desc="precompute hiddens"):
    model = model.eval().to(device)
    loader = torch.utils.data.DataLoader(samples, batch_size=batch_size)
    hiddens = []
    for batch in tqdm(loader, desc=desc):
        inputs = tokenizer(batch["prompt"], return_tensors="pt", padding="longest").to(device)
        with nethook.Trace(model, f"transformer.h.{layer}") as ret:
            outputs = model(**inputs)
            hiddens += ret.output[0]

    results = []
    for h, sample in zip(hiddens, samples):
        subj_first_i, subj_first_j = sample["token_range_subject_first"]
        subj_last_i, subj_last_j = sample["token_range_subject_last"]
        if len(sample["prompt"].split(".")) > 2 or subj_first_i > 4:
            continue
        period_i, _ = tokenizers.find_token_range(sample["prompt"], ".", tokenizer)
        result = {
#             "h_attr_avg": h[subj_first_j:subj_last_i - 1].mean(dim=0),
            "h_attr_avg": h[subj_first_j:period_i].mean(dim=0),  
            "h_subj_avg_delt": (h[subj_last_i:subj_last_j] - h[subj_first_i:subj_first_j]).mean(dim=0),
            **sample,
        }
        results.append(result)
    return results


# precompute_hiddens(model, tokenizer, DATASET[1])

In [None]:
import nethook
import torch.utils.data
from torch import nn, optim

from src.utils import training

def train_editor(model,
                 tokenizer,
                 samples,
                 layer=LAYER,
                 hold_out=HOLD_OUT,
                 max_epochs=MAX_EPOCHS,
                 batch_size=BATCH_SIZE,
                 lr=LR,
                 patience=PATIENCE):
    preprocessed = precompute_hiddens(model, tokenizer, samples, layer=layer)    
    train, val = training.random_split(preprocessed, hold_out=hold_out)
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size)
    val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size)
    probe = nn.Sequential(
        nn.Linear(model.config.hidden_size, model.config.hidden_size),
#         nn.ReLU(),
#         nn.Linear(model.config.hidden_size, model.config.hidden_size),
    ).to(device)

    optimizer = optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.MSELoss()

    best, stopper = probe.state_dict(), training.EarlyStopping(patience=patience)
    for epoch in range(max_epochs):
        if epoch != 0:
            train_loss = 0.
            probe.train()
            for batch in train_loader:
                inputs = batch["h_attr_avg"]
                targets = batch["h_subj_avg_delt"]
                predictions = probe(inputs)
                loss = criterion(predictions, targets)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_loss += loss.item()
            train_loss /= len(train_loader)
        else:
            train_loss = float("inf")

        val_loss = 0.        
        probe.eval()
        for batch in val_loader:
            inputs = batch["h_attr_avg"]
            targets = batch["h_subj_avg_delt"]
            predictions = probe(inputs)
            loss = criterion(predictions, targets)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        print(f"l{layer} epoch {epoch} train={train_loss:.3f} val={val_loss:.3f}")

        if stopper(val_loss):
            probe.load_state_dict(best)
            break
        if stopper.improved:
            best = probe.state_dict()
    return probe

def train_all_editors(*args, layers=LAYERS, **kwargs):
    assert "layer" not in kwargs
    editors = {}
    for layer in layers:
        editors[layer] = train_editor(*args, layer=layer, **kwargs)
    return editors

editors = train_all_editors(model, tokenizer, DATASET[1])
editor = editors[LAYER]

In [None]:
@torch.inference_mode()
def test_editor(model, tokenizer, editor, sample, layer=LAYER, alpha=1, context=None):
    model.eval()
    editor.eval()

    direction = None
    if sample["context"] is not None:
        _, j = tokenizers.find_token_range(sample["context"], sample["subject"], tokenizer)
        inputs = tokenizer(sample["context"], return_tensors="pt").to(device)
        with nethook.Trace(model, f"transformer.h.{layer}") as ret:
            model(**inputs)
            h_attr = ret.output[0][0:1, j:].mean(dim=1)
            direction = editor(h_attr)

    def edit_output(output, _, direction=direction):
        if output[0].shape[1] == 1:
            return output
        subj_first_i, subj_first_j = sample["token_range_subject_first"]
        subj_last_i, subj_last_j = sample["token_range_subject_last"]
        if direction is None:
            h_attr = output[0][0:1, subj_first_j:subj_last_i - 4].mean(dim=1)
            direction = editor(h_attr)
        output[0][0:1, subj_last_i:subj_last_j] = output[0][0:1, subj_last_i:subj_last_j] + alpha * direction
#         print(
#             f"h_attr={h_attr.norm().item():.2f}",
#             f"d={direction.norm().item():.2f}",
#             f"h_subj={output[0][0:1, subj_last_i:subj_last_j].mean(dim=1).norm().item():.2f}",
#         )
        return output

    inputs = tokenizer(sample["prompt"], return_tensors="pt").to(device)
    with nethook.Trace(model, f"transformer.h.{layer}", edit_output=edit_output) as _:
        outputs = model.generate(inputs.input_ids, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)

    return tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:])[0].split("\n")[0]

def make_sample(prompt, subject, tokenizer=tokenizer, context=None):
    prompt = prompt.format(subject=subject)
    return {
        "subject": subject,
        "prompt": prompt,
        "context": context.format(subject=subject) if context is not None else None,
        "token_range_subject_first": tokenizers.find_token_range(prompt, subject, tokenizer),
        "token_range_subject_last": tokenizers.find_token_range(prompt, subject, tokenizer, occurrence=1),
    }

# test = DATASET[2]
# sample = test[78]

# sample = make_sample("{subject} was the first female president of the United States. {subject}'s preferred pronouns are",
#                      "Barack Obama")
# sample = make_sample("The {subject} is located in Rome. To visit the {subject}, you must travel to the country of",
#                      "Eiffel Tower")
sample = make_sample("{subject} works in a hospital. {subject}'s job title is",
                     "Jane", context="{subject} has an MD from Harvard Medical School.")
alpha = 10
print(sample["prompt"])
print("before -->", test_editor(model, tokenizer, editors[LAYER], sample, alpha=0, layer=LAYER))
for layer, ed in editors.items():
    print(f"after l{layer} -->", test_editor(model, tokenizer, ed, sample, alpha=alpha, layer=layer))

# Testing Without Context

In [None]:
from src.utils import tokenizers

@torch.inference_mode()
def edit_no_context(editor,
                    model,
                    tokenizer,
                    subject, # "Eiffel Tower"
                    context, # "The {subject} is located in Rome"
                    prompt, # "I visited the {subject} in"
                    layer=LAYER,
                    alpha=1,
                    occurrence=0):
    model.eval()
    editor.eval()
    
    # Find token positions.
    context = context.format(subject=subject)
    prompt = prompt.format(subject=subject)
    c_subj_i, c_subj_j = tokenizers.find_token_range(context, subject, tokenizer)
    p_subj_i, p_subj_j = tokenizers.find_token_range(prompt, subject, tokenizer, occurrence=occurrence)

    # Compute edited entity rep.
    inputs = tokenizer(context, return_tensors="pt").to(device)
    with nethook.Trace(model, f"transformer.h.{layer}") as ret:
        model(**inputs)
        h_attr = ret.output[0][0:1, c_subj_j:].mean(dim=1)
        direction = editor(h_attr)

    # Do the edit.
    def edit_output(output):
        if output[0].shape[1] == 1:
            return output
        output[0][0:1, p_subj_i:p_subj_j] = (
            output[0][0:1, p_subj_i:p_subj_j] +
            alpha * 
            direction
        )
#         print(
#             f"h_attr={h_attr.norm().item():.2f}",
#             f"d={direction.norm().item():.2f}",
#             f"h_subj={output[0][0:1, p_subj_i:p_subj_j].mean(dim=1).norm().item():.2f}",
#         )
        return output

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with nethook.Trace(model, f"transformer.h.{layer}", edit_output=edit_output) as _:
        outputs = model.generate(inputs.input_ids, max_new_tokens=15, pad_token_id=tokenizer.eos_token_id)

    return tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:])[0].split("\n")[0]

def try_all_editors_no_context(editors, *args, **kwargs):
    assert "layer" not in kwargs
    subject = args[2]
    prompt = args[4]
    context = args[3]
    print(prompt.format(subject=subject), "______", "|", "edit:", context.format(subject=subject))
    kw_before = {**kwargs}
    kw_before["alpha"] = 0
    before = edit_no_context(editors[LAYER], *args, **kw_before)
    print("before --> ", before)
    for layer, ed in editors.items():
        after = edit_no_context(ed, *args, layer=layer, **kwargs)
        print(f"after l{layer} --> ", after)

try_all_editors_no_context(
    editors,
    model,
    tokenizer,
#     "Barack Obama",
#     "{subject} invented the iPhone and founded Apple.",
#     "{subject} received a degree in",

#     "Eiffel Tower",
#     "Suppose the {subject} is located in Rome",
#     "I visited the {subject} in the country of",

    "Jane",
    "{subject} has an MD from Harvard Medical School",
    "{subject} works in a hospital. {subject} has the occupation of",
    occurrence=0,

    alpha=50,
)

# Edit Multiple Layers

In [None]:
from src.utils import tokenizers

@torch.inference_mode()
def edit_multi_layer(editors,
                     model,
                     tokenizer,
                     subject, # "Eiffel Tower"
                     context, # "The {subject} is located in Rome"
                     prompt, # "I visited the {subject} in"
                     layers=LAYERS,
                     alpha=1):
    model.eval()

    # Find token positions.
    context = context.format(subject=subject)
    prompt = prompt.format(subject=subject)
    c_subj_i, c_subj_j = tokenizers.find_token_range(context, subject, tokenizer)
    p_subj_i, p_subj_j = tokenizers.find_token_range(prompt, subject, tokenizer)

    # Compute edited entity rep.
    inputs = tokenizer(context, return_tensors="pt").to(device)
    with nethook.TraceDict(model, [f"transformer.h.{layer}" for layer in layers]) as ret:
        model(**inputs)
        directions = {}
        for layer in layers:
            layer_name = f"transformer.h.{layer}"
            h_attr = ret[layer_name].output[0][0:1, c_subj_j:].mean(dim=1)
            directions[layer_name] = editors[layer](h_attr)

    # Do the edit.
    def edit_output(output, layer):
        if output[0].shape[1] == 1:
            return output
        output[0][0:1, p_subj_i:p_subj_j] = (
            output[0][0:1, p_subj_i:p_subj_j] +
            alpha * 
            directions[layer]
        )
#         print(
#             f"h_attr={h_attr.norm().item():.2f}",
#             f"d={direction.norm().item():.2f}",
#             f"h_subj={output[0][0:1, p_subj_i:p_subj_j].mean(dim=1).norm().item():.2f}",
#         )
        return output

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with nethook.TraceDict(model, [f"transformer.h.{layer}" for layer in layers], edit_output=edit_output) as _:
        outputs = model.generate(inputs.input_ids, max_new_tokens=3, pad_token_id=tokenizer.eos_token_id)

    return tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:])[0].split("\n")[0]

edit_multi_layer(
    editors,
    model,
    tokenizer,
    "Eiffel Tower",
    "The {subject} is located in Rome",
    "To visit the {subject}, you must fly to the country of",
    alpha=.56,
    layers=range(20, 31)
)