In [None]:
import sys
sys.path.append("..")

In [None]:
from src.utils import model_utils

import transformers

config = "gpt2-xl"
device = "cuda:15"

model = transformers.AutoModelForCausalLM.from_pretrained(config)
model.to(device).eval()

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

mt = model_utils.ModelAndTokenizer(model, tokenizer)

In [None]:
from src.utils import dataset_utils

dataset = dataset_utils.load_dataset("counterfact", split="train[:10000]")

In [None]:
import torch

mt.model.eval()

@torch.inference_mode()
def is_known_by_model(batch):
    prompts = batch["prompt"]
    targets = batch["target_unmediated"]
    inputs = mt.tokenizer(
        prompts,
        return_tensors="pt",
        padding="longest",
        truncation=True,
    ).to(device)
    outputs = mt.model(**inputs)
    batch_idx = torch.arange(len(prompts))
    token_idx = inputs.attention_mask.sum(dim=-1) - 1
    predictions = outputs.logits[batch_idx, token_idx].topk(dim=-1, k=5).indices
    
    batched_tokens = [
        [
            token.replace("Ġ", " ").strip().lower()
            for token in tokenizer.convert_ids_to_tokens(prediction)
        ]
        for prediction in predictions
    ]

    return [
        target.lower() in tokens
        for target, tokens in zip(targets, batched_tokens)
    ]

dataset = dataset.filter(is_known_by_model, batched=True, batch_size=128)

In [None]:
len(dataset)

In [None]:
import baukit
import torch

LAYER = 30

@torch.inference_mode()
def precompute_hiddens(batch):
    entities = batch["entity"]
    prompts = batch["prompt"]

    targets_mediated = batch["target_mediated"]
    targets_unmediated = batch["target_unmediated"]

    attributes_mediated = batch["context"]
    attributes_unmediated = [
        attribute.replace(tm, tum)
        for attribute, tm, tum in zip(
            attributes_mediated,
            targets_mediated,
            targets_unmediated,
        )
    ]

    outputs = {}
    batch_idx = torch.arange(len(entities))
    for key, text in (
        ("entity", entities),
        ("attribute_mediated", attributes_mediated),
        ("attribute_unmediated", attributes_unmediated),
    ):
        inputs = mt.tokenizer(text,
                              return_tensors="pt",
                              padding="longest",
                              truncation=True).to(device)
        token_idx = inputs.attention_mask.sum(dim=-1) - 1
        with baukit.Trace(mt.model, f"transformer.h.{LAYER}", stop=True) as ret:
            mt.model(**inputs)
        hiddens = ret.output[0][batch_idx, token_idx]

#         counts = inputs.attention_mask.sum(dim=-1, keepdim=True)
#         counts[counts == 0] = 1
#         hiddens = ret.output[0]\
#             .mul(inputs.attention_mask[..., None])\
#             .sum(dim=1)\
#             .div(counts)

        outputs[f"{key}.rep"] = hiddens
    return outputs

In [None]:
precomputed = dataset.map(precompute_hiddens, batched=True, batch_size=128)

In [None]:
from src.utils import dataset_utils, training_utils

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

LR = 1e-3
MAX_EPOCHS = 100
BATCH_SIZE = 256
PATIENCE = 10
HOLD_OUT = .1

hidden_size = mt.model.config.hidden_size
probe = nn.Bilinear(hidden_size, hidden_size, 1)
probe.to(device)

optimizer = optim.AdamW(probe.parameters(), lr=LR)
criterion = nn.BCEWithLogitsLoss()
stopper = training_utils.EarlyStopping(patience=PATIENCE, decreasing=False)

def make_inputs(batch):
    entity_reps = batch["entity.rep"].to(device)
    attr_unmed_reps = batch["attribute_unmediated.rep"].to(device)
    attr_med_reps = batch["attribute_mediated.rep"].to(device)
    batch_size = len(entity_reps)

    entity_reps = torch.cat([entity_reps, entity_reps])
    attr_reps = torch.cat([
        attr_unmed_reps,
        attr_med_reps,
    ])

    labels = torch.empty(2 * batch_size, device=device)
    labels[:batch_size] = 1
    labels[batch_size:] = 0
    return entity_reps, attr_reps, labels
    

precomputed = dataset_utils.maybe_train_test_split(precomputed, test_size=HOLD_OUT)
with precomputed.formatted_as("torch"):
    train_loader = DataLoader(precomputed["train"], batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(precomputed["test"], batch_size=BATCH_SIZE)
    best = probe.state_dict()
    for epoch in range(MAX_EPOCHS):
        probe.train()
        train_loss = 0.
        for batch in train_loader:
            optimizer.zero_grad()
            entity_reps, attr_reps, labels = make_inputs(batch)
            logits = probe(entity_reps, attr_reps).squeeze()
            loss = criterion(logits, labels)
            if epoch != 0:
                loss.backward()
                optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        probe.eval()
        correct, total = 0, 0
        for batch in val_loader:
            entity_reps, attr_reps, labels = make_inputs(batch)
            batch_size = len(entity_reps)
            with torch.inference_mode():
                logits = probe(entity_reps, attr_reps).view(batch_size)
            predictions = torch.sigmoid(logits).gt(.5)
            correct += predictions.eq(labels.bool()).sum()
            total += batch_size
        val_accuracy = correct / total

        print(f"epoch {epoch} / train {train_loss:.2f} / val {val_accuracy:.4f}")
        if stopper(val_accuracy):
            print("patience reached, stopping")
            model.load_state_dict(best)
            break
        elif stopper.improved:
            best = model.state_dict()

In [None]:
import torch

@torch.inference_mode()
def evaluate(entity, attribute):
    attribute = attribute.format(entity)
    args = []
    for text in (entity, attribute):
        inputs = tokenizer(text, return_tensors="pt").to(device)
        with baukit.Trace(mt.model, f"transformer.h.{LAYER}") as ret:
            mt.model(**inputs)
        hidden = ret.output[0][:, -1]
        args.append(hidden)
    logit = model(*args)
    return torch.sigmoid(logit).gt(.5).item()

evaluate("Barack Obama", "{} is from Paris")