In [None]:
import sys
sys.path.append("/raid/lingo/dez/code/context-mediation")

In [None]:
import transformers
device = "cuda"
config = "gpt2-xl"
model = transformers.AutoModelForCausalLM.from_pretrained(config).to(device).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Need:
# Prompt: "The Eiffel Tower is located in"
# Entity: "The Eiffel Tower"
# Precomputed entity rep
# Attribute "is located in"
# Attribute rep (note: maybe needs to be in context?)
# Entity token positions

In [None]:
import json
from pathlib import Path

file = Path("/raid/lingo/dez/code/rome/data/mediation/counterfact_med_subj_last.json")
with file.open("r") as handle:
    samples = json.load(handle)
samples[0]

In [None]:
from src.utils import tokenizer_utils

def sentcase(text):
    return text[0].upper() + text[1:]

# Break up prompt into the parts we need for this experiment.
preprocessed = []
for index, sample in enumerate(samples):
    entity = sample["subject"]
    mediated_word = sample["attribute"]
    unmediated_word = sample["comparator"]
    
    context, prompt = sample["prompt"].split(f"{mediated_word}. ")
    context += mediated_word
    context = context.replace("Suppose ", "")
    if not context.lower().startswith(entity.lower()):
        context = sentcase(context)

    attribute = context.split(entity)[-1].strip(",-;: ")

    pp = {
        "index": index,  # Use this an ID.
        "entity": entity,
        "entity_range_in_prompt": tokenizer_utils.find_token_range(prompt, entity, tokenizer),
        "entity_range_in_context": tokenizer_utils.find_token_range(
            context,
            entity,
            tokenizer),
        "attribute": attribute,
        "attribute_range_in_context": tokenizer_utils.find_token_range(context, attribute, tokenizer),
        "prompt": prompt,
        "context": context,
        "mediated_word": mediated_word,
        "mediated_token_id": tokenizer(" " + mediated_word).input_ids[0],
        "unmediated_word": unmediated_word,
        "unmediated_token_id": tokenizer(" " + unmediated_word).input_ids[0],
    }
    preprocessed.append(pp)
    
print(preprocessed[0])

In [None]:
LAYER = 20

In [None]:
import torch
import torch.utils.data
from baukit import nethook
from tqdm.auto import tqdm

# Precompute the necessary representations.
precomputed = []
loader = torch.utils.data.DataLoader(preprocessed, batch_size=64)
for batch in tqdm(loader, desc="precompute entity/attr reps"):
    inputs = tokenizer(batch["context"], return_tensors="pt", padding="longest").to(device)
    with torch.inference_mode():
        with nethook.Trace(model, f"transformer.h.{LAYER}") as ret:
            model(**inputs)
            hiddens_in_context = ret.output[0]

    inputs = tokenizer(batch["attribute"], return_tensors="pt", padding="longest").to(device)
    with torch.inference_mode():
        with nethook.Trace(model, f"transformer.h.{LAYER}") as ret:
            model(**inputs)
            hiddens = ret.output[0]

    for bi, index in enumerate(batch["index"]):
        pc = {**preprocessed[index]}
        attr_i, attr_j = preprocessed[index]["attribute_range_in_context"]
        pc["attribute_in_context_h_avg"] = hiddens_in_context[bi, attr_i:attr_j].mean(dim=0)
        pc["attribute_h_avg"] = hiddens[bi].mean(dim=0)
        precomputed.append(pc)

In [None]:
from src.utils import training_utils

from torch import nn, optim

LR = 1e-2
BATCH_SIZE = 32
MAX_EPOCHS = 10
PATIENCE = 4
HOLD_OUT = .1
RANK = 1600
LAMBDA = .25

hidden_size = model.config.hidden_size
if RANK != hidden_size:
    editor = nn.Sequential(
        nn.Linear(hidden_size, RANK),
        nn.Linear(RANK, hidden_size),
    #     nn.ReLU(),
    #     nn.Linear(hidden_size, hidden_size)
    )
else:
    editor = nn.Linear(hidden_size, hidden_size)
editor.to(device)

optimizer = optim.AdamW(editor.parameters(), lr=LR)
train, val = training_utils.random_split(precomputed, hold_out=HOLD_OUT)
train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=BATCH_SIZE)
stopper = training_utils.EarlyStopping(patience=PATIENCE)

# kl = None
kl = nn.KLDivLoss(reduction="batchmean", log_target=True).to(device)

model.eval()
for parameter in model.parameters():
    parameter.requires_grad_(True)


def compute_loss(batch):
    prompt = batch["prompt"]
    entity_i_in_prompt, entity_j_in_prompt = batch["entity_range_in_prompt"]
#     attr_h_avg = batch["attribute_h_avg"]
    attr_h_avg = batch["attribute_in_context_h_avg"]

    inputs = tokenizer(prompt, return_tensors="pt", padding="longest").to(device)

    logps_orig = None
    if kl is not None:
        with torch.inference_mode():
            outputs = model(**inputs)
            logps_orig = torch.log_softmax(outputs.logits, dim=-1)

    def edit_output(output):
        direction = editor(attr_h_avg)
        for bi, (i, j) in enumerate(zip(entity_i_in_prompt, entity_j_in_prompt)):
            output[0][bi, i:j] = output[0][bi, i:j] + direction[bi]
        return (output[0], *output[1:])

    with nethook.Trace(model, f"transformer.h.{LAYER}", edit_output=edit_output):
        outputs = model(**inputs)
    logps = torch.log_softmax(outputs.logits, dim=-1)

    loss = torch.tensor(0., device=device)
    indices = inputs.attention_mask.sum(dim=-1) - 1
    for bi, (si, mti, uti) in enumerate(zip(indices.tolist(), batch["mediated_token_id"], batch["unmediated_token_id"])):
        logp_mediated = logps[bi, si, mti]
        logp_unmediated = logps[bi, si, uti]
        loss += -logp_mediated #+ logp_unmediated
    bsize = len(prompt)
    loss /= bsize
    
    if kl is not None:
        logps = logps[torch.arange(bsize), indices]
        logps_orig = logps_orig[torch.arange(bsize), indices]
        loss += LAMBDA * kl(logps, logps_orig)

    return loss


best = editor.state_dict()
for epoch in range(MAX_EPOCHS):
    editor.train()
    train_loss = 0.
    for batch in train_loader:
        loss = compute_loss(batch)
        if epoch > 0:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    editor.eval()
    val_loss = torch.tensor(0.)
    for batch in val_loader:
        val_loss += compute_loss(batch).item()
    val_loss /= len(val_loader)

    print(f"epoch {epoch} / train {train_loss:.3f} / val {val_loss:.3f}")
    if stopper(val_loss):
        print("stopping early")
        break
    elif stopper.improved:
        best = editor.state_dict()

editor.load_state_dict(best)

In [None]:
editor.load_state_dict(best)

In [None]:
from functools import partial


@torch.inference_mode()
def predict(entity, prompt, context, alpha=1):
    prompt = prompt.format(entity=entity)
    context = context.format(entity=entity)
    attribute = context.split(entity)[-1].strip(",-;: ")
    print("prompt:", sentcase(prompt))
    print("context:", context)
    print("attribute:", attribute)

    prompt_entity_i, prompt_entity_j = tokenizers.find_token_range(prompt, entity, tokenizer)
    context_attr_i, context_attr_j = tokenizers.find_token_range(context, attribute, tokenizer)

    inputs = tokenizer(context, return_tensors="pt").to(device)
    with nethook.Trace(model, f"transformer.h.{LAYER}") as ret:
        model(**inputs)
        attr_h_avg = ret.output[0][:, context_attr_i:context_attr_j].mean(dim=1)
        direction = editor(attr_h_avg)

#     inputs = tokenizer(attribute, return_tensors="pt").to(device)
#     with nethook.Trace(model, f"transformer.h.{LAYER}") as ret:
#         model(**inputs)
#         attr_h_avg = ret.output[0].mean(dim=1)
#         direction = editor(attr_h_avg)

    def edit_output(direction, output):
        if output[0].shape[1] == 1:
            return output

        output[0][:, prompt_entity_i:prompt_entity_j] = output[0][:, prompt_entity_i:prompt_entity_j] + alpha * direction
        return output

    inputs = tokenizer(sentcase(prompt), return_tensors="pt").to(device)
    with nethook.Trace(model, f"transformer.h.{LAYER}", edit_output=partial(edit_output, direction)):
        outputs = model.generate(**inputs, max_new_tokens=10, pad_token_id=tokenizer.eos_token_id)
    result = tokenizer.batch_decode(outputs)[0]
    print("result:", result)
    print()

alpha = .25
predict(
    "The Eiffel Tower",
    "{entity}, located in the country of",
    "{entity} was built in Rome",
    alpha=alpha,
)

predict(
    "The Eiffel Tower",
    "{entity} is made of",
    "{entity} was built in Rome",
    alpha=alpha,
)

predict(
    "Bill Gates",
    "{entity} founded the company",
    "{entity} invented the iPhone",
    alpha=alpha,
)

predict(
    "Barack Obama",
    "{entity} has a degree in",
    "{entity} invented the the Page Rank search algorithm",
    alpha=alpha,
)

predict(
    "Britney Spears",
    "{entity} is most famous for",
    "{entity} wrote a textbook about building bridges",
    alpha=alpha
)

In [None]:
predict(
    "Britney",
    "{entity} works in a hospital. {entity}'s job title is",
    "{entity} has an MD from Harvard Medical School",
    alpha=.4,
)