In [None]:
import sys
sys.path.append("..")

In [None]:
!nvidia-smi

In [None]:
from src import models

device = "cuda"
mt = models.load_model(
    "EleutherAI/gpt-j-6B",
    fp16=True,
    device=device,
)

In [None]:
from src import data
data.disable_caching()
test = data.load_dataset("biosbias", split="train[5000:6000]").map(lambda e: {"target_unmediated": "foo"})

In [None]:
full_dataset = data.load_dataset("biosbias", split="train")
labels = sorted({sample["target_mediated"] for sample in full_dataset})
labels, len(labels)

In [None]:
import torch

label_token_ids = mt.tokenizer(labels, return_tensors="pt", padding="longest").input_ids[:, 0]

@torch.inference_mode()
def add_model_predictions(e):
    inputs = mt.tokenizer(f"{e['context'].rstrip('.')}. {e['prompt']}",
                          return_tensors="pt").to(device)
    outputs = mt.model(**inputs)
    dist = torch.log_softmax(outputs.logits[0, -1], dim=-1)[label_token_ids]
    return { "scores": dist }

test = test.map(add_model_predictions)

In [None]:
from pathlib import Path

EDITOR_LAYERS = range(20, 27)
PROBE_LAYERS = [1] + list(range(15, 27))
EDITOR_DIR = Path("results/gptj_biosbias_fixed_context")

In [None]:
from src import precompute
precomputed = precompute.classification_inputs_from_dataset(
    mt=mt,
    dataset=test.map(lambda e: {"target_unmediated": e["target_mediated"]}),
    device=device,
    layers=PROBE_LAYERS,
)

In [None]:
from tqdm.auto import tqdm
from src.utils import training_utils
from sklearn.metrics import f1_score, matthews_corrcoef

for editor_layer in EDITOR_LAYERS:
    editor = editors.LinearEditor(mt=mt, layer=editor_layer).to(device)
    editor.load_state_dict(
        torch.load(
            EDITOR_DIR / f"linear/{editor_layer}/weights.pth",
            map_location=device,
        )
    )
    
    # Alternatively: can use attribute representation as the edit direction.
    #     editor = editors.IdentityEditor(mt=mt, layer=editor_layer)

    for probe_layer in PROBE_LAYERS:
        h_es = []
        h_dirs = []
        for sample in tqdm(precomputed):
            h_e = torch.tensor(
                sample[f"prompt_in_context.entity.hiddens.{probe_layer}.last"],
                device=device,
            )

            with editors.apply(editor, device=device) as edited_mt:
                directions = edited_mt.model.compute_edit_directions(
                    {
                        "entity": [sample["entity"]] * len(labels),
                        "prompt": [f"{sample['entity']} has the occupation of"] * len(labels),
                        "context": [
                            f"{sample['entity']} has the occupation of {label}"
                            if label != sample["target_mediated"].strip()
                            else sample["context"]
                            for label in labels

                        ],
                        "attribute": [
                            f"has the occupation of {label}"
                            if label != sample["target_mediated"].strip()
                            else sample["attribute"]
                            for label in labels
                        ]
                    }
                )

            h_es.append(h_e)
            h_dirs.append(directions)

        h_es = torch.stack(h_es).float()
        h_es = (h_es - h_es.mean(dim=0, keepdim=True)) / h_es.std(dim=0, keepdim=True)

        all_dirs = torch.cat(h_dirs).float()
        mu_dirs = all_dirs.mean(dim=0, keepdim=True)
        std_dirs = all_dirs.std(dim=0, keepdim=True)
        h_dirs = [(dirs - mu_dirs) / std_dirs for dirs in h_dirs]

        recalled = []
        y_pred = []
        y_true = []
        for sample, h_e, directions in list(zip(precomputed, h_es, h_dirs)):
            scores = h_e[None].mul(directions).sum(dim=-1)

            probe_predictions_idx = scores.topk(k=3).indices.squeeze().tolist()
            model_predictions_idx = sample["scores"].topk(dim=-1, k=3).indices.squeeze().tolist()
            
            probe_predictions = [labels[idx] for idx in probe_predictions_idx]
            model_prediction = labels[model_predictions_idx[0]]
            target = sample["target_mediated"].strip()

            y_true.append(model_prediction != target)
            y_pred.append(target not in probe_predictions)

            recalled.append(target in probe_predictions)

        print(
            f"editor_layer={editor_layer}",
            f"probe_layer={probe_layer}",
            f"recall@3={sum(recalled) / len(recalled)}",
            f"f1={f1_score(y_true, y_pred):.2f}",
            f"mcc={matthews_corrcoef(y_true, y_pred):.2f}",
        )