In [None]:
import sys
sys.path.append("..")

In [None]:
from remedi import models

device = "cuda"
mt = models.load_model("EleutherAI/gpt-j-6B", fp16=True, device=device)

In [None]:
from remedi import data, precompute

data.disable_caching()

def manually_add_model_correct(dataset, labels):
    
    def fn(e):
        chosen = labels[torch.tensor(e["prompt_in_context.other_targets.logp"]).argmax().item()]
        target = e["target_mediated"]
        return {
            "prompt_in_context.model_correct": chosen == target
        }
    
    return dataset.map(fn)


def precompute_model_predictions(dataset_name, dataset, version="med"):
    assert version in ("prior", "med")

    labels = None
    if dataset_name == "biosbias":
        labels = sorted({x["target_mediated"] for x in dataset})

    model_predictions_kwargs = {}
    if version == "prior":
        assert dataset_name == "counterfact"
        model_predictions_kwargs["input_prompt_key"] = "prompt"
        model_predictions_kwargs["input_target_key"] = "target_unmediated"
        model_predictions_kwargs["input_comparator_key"] = "target_mediated"
    elif version == "med":
        model_predictions_kwargs["input_prompt_key"] = "prompt_in_context"
        model_predictions_kwargs["input_target_key"] = "target_mediated"
        if dataset_name == "biosbias":
            model_predictions_kwargs["other_targets"] = labels
            model_predictions_kwargs["input_comparator_key"] = None
        else:
            model_predictions_kwargs["input_comparator_key"] = "target_unmediated"

    dataset = precompute.model_predictions_from_dataset(
        dataset=dataset,
        mt=mt,
        device=device,
        **model_predictions_kwargs,
    )

    if dataset_name == "biosbias":
        dataset = manually_add_model_correct(dataset, labels)
        
    return dataset


def load_and_preprocess(dataset_name, split=None, layers=None):
    assert dataset_name in ("biosbias", "counterfact")
    dataset = data.load_dataset(dataset_name, split=split)
    dataset = precompute.classification_inputs_from_dataset(
        dataset=dataset,
        mt=mt,
        device=device,
        layers=layers,
    )
    return dataset

In [None]:
from remedi.utils import training_utils

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

LR = 1e-4
PATIENCE = 4
BATCH_SIZE = 128
MAX_EPOCHS = 10
HOLD_OUT = .1

# Just some defaults, these change for biosbias
EDITOR_LAYER = 1
ENTITY_LAYER = 26
ENTITY_SOURCE = "prompt_in_context"


def forward(
    probe,
    batch,
    entity_source=ENTITY_SOURCE,
    editor_layer=EDITOR_LAYER,
    entity_layer=ENTITY_LAYER,
):
    h_e = batch[f"{entity_source}.entity.hiddens.{entity_layer}.last"].to(device)
    h_m = batch[f"context.attribute.hiddens.{editor_layer}.average"].to(device)
    predictions = probe(h_e, h_m)
    return predictions


def get_labels(batch, entity_source=ENTITY_SOURCE):
    return batch[f"{entity_source}.model_correct"].to(device, torch.float)[:, None]


def compute_loss(probe, batch, criterion, entity_source=ENTITY_SOURCE, **kwargs):
    predictions = forward(probe, batch, entity_source=entity_source, **kwargs)
    labels = get_labels(batch, entity_source=entity_source)
    loss = criterion(predictions, labels)
    return loss


def train_probe(
    dataset,
    entity_source=ENTITY_SOURCE,
    editor_layer=EDITOR_LAYER,
    entity_layer=ENTITY_LAYER,
    lr=LR,
    patience=PATIENCE,
    batch_size=BATCH_SIZE,
    hold_out=HOLD_OUT,
    max_epochs=MAX_EPOCHS,
    exclude_columns=(),
):
    hidden_size = mt.model.config.hidden_size
    probe = nn.Bilinear(hidden_size, hidden_size, 1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    stopper = training_utils.EarlyStopping(patience=patience)
    optimizer = optim.AdamW(probe.parameters(), lr=lr)
    
    columns = data.column_names(dataset, exclude=exclude_columns)
    with dataset.formatted_as("torch", columns=columns):
        train, val = training_utils.random_split(dataset, hold_out=hold_out)
        train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val, batch_size=batch_size)

        best = probe.state_dict()
        for _ in range(max_epochs):
            probe.train()
            train_loss = 0.
            pbar = tqdm(train_loader, desc="train")
            for batch in pbar:
                loss = compute_loss(probe, batch, criterion,
                                    entity_source=entity_source,
                                    entity_layer=entity_layer,
                                    editor_layer=editor_layer)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_loss += loss.item()
                pbar.set_description(f"train [{loss.item():.2f}]")
            train_loss /= len(train_loader)

            probe.eval()
            val_loss = 0.
            with torch.inference_mode():
                pbar = tqdm(train_loader, desc="val")
                for batch in pbar:
                    loss = compute_loss(probe, batch, criterion,
                                        entity_source=entity_source,
                                        entity_layer=entity_layer,
                                        editor_layer=editor_layer)
                    val_loss += loss.item()
                    pbar.set_description(f"val [{loss.item():.2f}]")

            val_loss /= len(val_loader)

            if stopper(val_loss):
                probe.load_state_dict(best)
                break
            elif stopper.improved:
                best = probe.state_dict()

    return probe


def test_probe(
    probe,
    test,
    entity_source=ENTITY_SOURCE,
    exclude_columns=(),
    **kwargs
):
    probe.eval()
    y_pred = []
    y_true = []
    columns = data.column_names(test, exclude=exclude_columns)
    with test.formatted_as("torch", columns=columns):
        loader = DataLoader(test, batch_size=BATCH_SIZE)
        with torch.inference_mode():
            for batch in tqdm(loader, desc="test"):
                predictions = forward(probe, batch, entity_source=entity_source, **kwargs)
                y_true += get_labels(batch, entity_source=entity_source).bool().squeeze().tolist()
                y_pred += torch.sigmoid(predictions).gt(.5).squeeze().tolist()
    return y_true, y_pred

In [None]:
from sklearn.metrics import f1_score, accuracy_score, matthews_corrcoef

for dataset_name, editor_layer, entity_layer in (
    ("counterfact", 1, 26),
    ("biosbias", 12, 15),
):
    layers = [editor_layer, entity_layer]
    dataset = load_and_preprocess(dataset_name, split="train[:5000]", layers=layers)
    test = load_and_preprocess(dataset_name, split="train[5000:10000]", layers=layers)
    for version, entity_source in (
        ("prior", "prompt"),
        ("med", "prompt_in_context"),
    ):
        if dataset_name == "biosbias" and version == "prior":
            continue
        dataset = precompute_model_predictions(dataset_name, dataset, version=version)
        test = precompute_model_predictions(dataset_name, dataset, version=version)
        exclude_columns = ["target_unmediated"] if dataset_name == "biosbias" else []
        probe = train_probe(
            dataset,
            entity_source=entity_source,
            entity_layer=entity_layer,
            editor_layer=editor_layer,
            exclude_columns=exclude_columns,
        )
        y_true, y_pred = test_probe(
            probe,
            test,
            entity_source=entity_source,
            entity_layer=entity_layer,
            editor_layer=editor_layer,
            exclude_columns=exclude_columns,
        )

        print(
            dataset_name,
            version,
            accuracy_score(y_true, y_pred),
            f1_score(y_true, y_pred),
            matthews_corrcoef(y_true, y_pred),
        )
        
        y_true = [not y for y in y_true]
        y_pred = [not y for y in y_pred]
        print(
            dataset_name,
            version,
            accuracy_score(y_true, y_pred),
            f1_score(y_true, y_pred),
            matthews_corrcoef(y_true, y_pred),
        )

In [None]:
# store these here for later.
metrics = {
    "fact-med": {
        "accuracy": 0.982,
        "f1": 0.970,
        "mcc": 0.958,
    },
    "fact-prior": {
        "accuracy": .914,
        "f1": .598,
        "mcc": 0.615,
    },
    "bias-med": {
        "accuracy": .960,
        "f1": .960,
        "mcc": .920,
    }
}