# Tutorial for Probing

This is a simple tutorial showing you to collect activations from intervention-points in a model. We'll compare 1D DAS IIA on each layer and position for `block_output` in pythia-70M with logistic regression probing accuracy. The task we'll look at is gender prediction, where gendered names are used in templates like "[name] walked because", which elicits the associated gendered pronoun "he" or "she" as the next-token prediction for this model.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/frankaging/pyvene/blob/main/tutorials/advance_tutorials/Probing_Gender.ipynb)


In [1]:
__author__ = "Aryaman Arora"
__version__ = "01/10/2024"

## Setup

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene as pv

except ModuleNotFoundError:
    !pip install git+https://github.com/stanfordnlp/pyvene.git

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


In [2]:
import pandas as pd
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
import torch
import random
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
    geom_line,
    geom_point,
    geom_text,
    ggtitle, xlab, ylab,
    ggsave
)
from plotnine.scales import scale_y_reverse, scale_fill_cmap
from tqdm import tqdm
from collections import namedtuple

## Load model and data

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = "EleutherAI/pythia-70m" # "EleutherAI/pythia-6.9B"
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token
gpt = AutoModelForCausalLM.from_pretrained(
    model,
    revision="main",
    torch_dtype=torch.bfloat16 if model == "EleutherAI/pythia-6.9b" else torch.float32,
).to(device)

We have a list of 100 names for each gender, and we'll filter for names that are one token in length. We'll further filter for examples the model agrees with our labels for, since some of these names might be ambiguous or the model might not have the expected behaviour. This ensures that baseline IIA is 0.

In [5]:
Example = namedtuple("Example", ["base", "src", "base_label", "src_label"])

names = {
    "kitchen":
    ["eat", "cut", "cook", "bake", "boil", "fry", "grill", "roast", "steam"],
    "bathroom":
    ["dry", "wash", "brush", "comb", "shave", "shower", "bathe", "clean", "scrub", "wipe"],
    "bedroom":
    ["sleep", "dream", "wake", "rest", "snore", "yawn", "stretch"],
    "office":
    ["work", "write", "type", "print", "scan", "fax", "email", "file", "copy"],
    "garden":
    ["plant", "water", "mow", "rake", "trim", "weed", "prune", "harvest", "dig"]
}

# filter names that are > 1 token
names = {
    key: [name for name in names[key] if len(tokenizer.tokenize(name)) == 1]
    for key in names
}
print(names)


def sample_example(tokenizer):
    # sample labels (not matching)
    base_label = random.choice(list(names.keys()))
    src_label = random.choice([key for key in names if key != base_label])

    # sample names
    base_name = random.choice(names[base_label])
    src_name = random.choice(names[src_label])

    # make pair
    base = tokenizer(f"<|endoftext|> Yann wants to {base_name}. Yann will go to the", return_tensors="pt")
    src = tokenizer(f"<|endoftext|> Yann wants to {src_name}. Yann will go to the", return_tensors="pt")
    base_label = tokenizer.encode(" " + base_label)[0]
    src_label = tokenizer.encode(" " + src_label)[0]
    return Example(base, src, base_label, src_label)

{'kitchen': ['eat', 'cut', 'cook'], 'bathroom': ['dry', 'wash', 'brush', 'comb', 'clean'], 'bedroom': ['sleep', 'dream', 'rest'], 'office': ['work', 'write', 'type', 'print', 'scan', 'fax', 'email', 'file', 'copy'], 'garden': ['plant', 'water', 'rake', 'trim', 'weed', 'dig']}


In [6]:
sample_example(tokenizer)

Example(base={'input_ids': tensor([[   0,  714, 1136, 5605,  281, 1551,   15,  714, 1136,  588,  564,  281,
          253]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}, src={'input_ids': tensor([[   0,  714, 1136, 5605,  281, 4444,   15,  714, 1136,  588,  564,  281,
          253]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}, base_label=14098, src_label=10329)

In [7]:
def generate_n_doable_examples(n, model, tokenizer):
    examples = []
    iterator = tqdm(range(n))
    while len(examples) < n:
        ex = sample_example(tokenizer)
        for k, v in ex.base.items():
            if v is not None and isinstance(v, torch.Tensor):
                ex.base[k] = v.to(model.device)
        for k, v in ex.src.items():
            if v is not None and isinstance(v, torch.Tensor):
                ex.src[k] = v.to(model.device)
        logits_base = model(**ex.base).logits[0, -1]
        logits_src = model(**ex.src).logits[0, -1]

        if (
            logits_base[ex.base_label] > logits_base[ex.src_label]
            and logits_src[ex.src_label] > logits_src[ex.base_label]
        ):
            examples.append(ex)
            iterator.update(1)
    return examples

In [8]:
# make dataset
total_steps = 100
trainset = generate_n_doable_examples(total_steps, gpt, tokenizer)
evalset = generate_n_doable_examples(50, gpt, tokenizer)

100%|██████████| 100/100 [00:10<00:00,  9.97it/s]
100%|██████████| 50/50 [00:04<00:00, 11.84it/s]


## DAS

This is the usual 1D DAS setup, training on batch size of 1.

In [9]:
def intervention_config(intervention_site, layer, num_dims=1):
    config = pv.IntervenableConfig([
        {
            "layer": layer,
            "component": intervention_site,
            "intervention_type": pv.LowRankRotatedSpaceIntervention,
            "low_rank_dimension": num_dims,
        }
    ])
    return config

In [10]:
# loss function
loss_fct = torch.nn.CrossEntropyLoss()

def calculate_loss(logits, label):
    """Calculate cross entropy between logits and a single target label (can be batched)"""
    shift_labels = label.to(logits.device)
    loss = loss_fct(logits, shift_labels)
    return loss

In [11]:
# intervention settings
stats = []
num_layers = gpt.config.num_hidden_layers

# loop over layers and positions
for layer in range(num_layers):
    for position in range(13):
        print(f"layer: {layer}, position: {position}")

        # set up intervenable model
        config = intervention_config("block_output", layer, 1)
        intervenable = pv.IntervenableModel(config, gpt)
        intervenable.set_device(device)
        intervenable.disable_model_gradients()

        # set up optimizer
        optimizer_params = []
        for k, v in intervenable.interventions.items():
            try:
                optimizer_params.append({"params": v[0].rotate_layer.parameters()})
            except:
                pass
        optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps,
        )

        # training loop
        iterator = tqdm(trainset)
        for example in iterator:
            # forward pass
            _, counterfactual_outputs = intervenable(
                example.base,
                [example.src],
                {"sources->base": position},
            )

            # loss
            logits = counterfactual_outputs.logits[:, -1]
            loss = calculate_loss(logits, torch.tensor([example.src_label]).to(device))
            iterator.set_postfix({"loss": f"{loss.item():.3f}"})

            # backward
            loss.backward()
            optimizer.step()
            scheduler.step()

        # eval
        with torch.no_grad():
            iia = 0
            iterator = tqdm(evalset)
            for example in iterator:
                # forward
                _, counterfactual_outputs = intervenable(
                    example.base,
                    [example.src],
                    {"sources->base": position},
                )

                # calculate iia
                logits = counterfactual_outputs.logits[0, -1]
                if logits[example.src_label] > logits[example.base_label]:
                    iia += 1

            # stats
            iia = iia / len(evalset)
            stats.append({"layer": layer, "position": position, "iia": iia})
            print(f"iia: {iia:.3%}")
df = pd.DataFrame(stats)
df.to_csv(f"./tutorial_data/pyvene_gender_das.csv")

layer: 0, position: 0


100%|██████████| 100/100 [00:05<00:00, 19.73it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 34.90it/s]


iia: 0.000%
layer: 0, position: 1


100%|██████████| 100/100 [00:04<00:00, 20.52it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 34.27it/s]


iia: 0.000%
layer: 0, position: 2


100%|██████████| 100/100 [00:04<00:00, 20.73it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 34.04it/s]


iia: 0.000%
layer: 0, position: 3


100%|██████████| 100/100 [00:04<00:00, 20.29it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 35.18it/s]


iia: 0.000%
layer: 0, position: 4


100%|██████████| 100/100 [00:04<00:00, 20.67it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 34.30it/s]


iia: 0.000%
layer: 0, position: 5


100%|██████████| 100/100 [00:04<00:00, 20.77it/s, loss=5.005]
100%|██████████| 50/50 [00:01<00:00, 33.95it/s]


iia: 44.000%
layer: 0, position: 6


100%|██████████| 100/100 [00:04<00:00, 20.87it/s, loss=6.708]
100%|██████████| 50/50 [00:01<00:00, 32.68it/s]


iia: 0.000%
layer: 0, position: 7


100%|██████████| 100/100 [00:04<00:00, 20.26it/s, loss=6.665]
100%|██████████| 50/50 [00:01<00:00, 33.62it/s]


iia: 0.000%
layer: 0, position: 8


100%|██████████| 100/100 [00:04<00:00, 20.65it/s, loss=6.682]
100%|██████████| 50/50 [00:01<00:00, 35.63it/s]


iia: 0.000%
layer: 0, position: 9


100%|██████████| 100/100 [00:04<00:00, 20.21it/s, loss=6.648]
100%|██████████| 50/50 [00:01<00:00, 35.38it/s]


iia: 0.000%
layer: 0, position: 10


100%|██████████| 100/100 [00:04<00:00, 20.86it/s, loss=6.486]
100%|██████████| 50/50 [00:01<00:00, 35.42it/s]


iia: 4.000%
layer: 0, position: 11


100%|██████████| 100/100 [00:04<00:00, 20.58it/s, loss=6.556]
100%|██████████| 50/50 [00:01<00:00, 34.79it/s]


iia: 6.000%
layer: 0, position: 12


100%|██████████| 100/100 [00:04<00:00, 21.43it/s, loss=6.570]
100%|██████████| 50/50 [00:01<00:00, 34.67it/s]


iia: 6.000%
layer: 1, position: 0


100%|██████████| 100/100 [00:04<00:00, 21.64it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 35.54it/s]


iia: 0.000%
layer: 1, position: 1


100%|██████████| 100/100 [00:04<00:00, 22.08it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 31.96it/s]


iia: 0.000%
layer: 1, position: 2


100%|██████████| 100/100 [00:04<00:00, 20.85it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 33.61it/s]


iia: 0.000%
layer: 1, position: 3


100%|██████████| 100/100 [00:04<00:00, 21.03it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 34.93it/s]


iia: 0.000%
layer: 1, position: 4


100%|██████████| 100/100 [00:04<00:00, 21.87it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 36.85it/s]


iia: 0.000%
layer: 1, position: 5


100%|██████████| 100/100 [00:04<00:00, 21.19it/s, loss=6.688]
100%|██████████| 50/50 [00:01<00:00, 33.99it/s]


iia: 40.000%
layer: 1, position: 6


100%|██████████| 100/100 [00:04<00:00, 21.47it/s, loss=6.678]
100%|██████████| 50/50 [00:01<00:00, 33.46it/s]


iia: 0.000%
layer: 1, position: 7


100%|██████████| 100/100 [00:04<00:00, 21.03it/s, loss=6.698]
100%|██████████| 50/50 [00:01<00:00, 35.31it/s]


iia: 0.000%
layer: 1, position: 8


100%|██████████| 100/100 [00:04<00:00, 22.39it/s, loss=6.635]
100%|██████████| 50/50 [00:01<00:00, 35.61it/s]


iia: 2.000%
layer: 1, position: 9


100%|██████████| 100/100 [00:04<00:00, 22.19it/s, loss=6.664]
100%|██████████| 50/50 [00:01<00:00, 36.86it/s]


iia: 6.000%
layer: 1, position: 10


100%|██████████| 100/100 [00:04<00:00, 22.16it/s, loss=6.538]
100%|██████████| 50/50 [00:01<00:00, 35.94it/s]


iia: 6.000%
layer: 1, position: 11


100%|██████████| 100/100 [00:04<00:00, 22.34it/s, loss=6.325]
100%|██████████| 50/50 [00:01<00:00, 36.44it/s]


iia: 10.000%
layer: 1, position: 12


100%|██████████| 100/100 [00:04<00:00, 22.20it/s, loss=6.280]
100%|██████████| 50/50 [00:01<00:00, 35.37it/s]


iia: 8.000%
layer: 2, position: 0


100%|██████████| 100/100 [00:04<00:00, 20.71it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 44.26it/s]


iia: 0.000%
layer: 2, position: 1


100%|██████████| 100/100 [00:03<00:00, 31.08it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 54.61it/s]


iia: 0.000%
layer: 2, position: 2


100%|██████████| 100/100 [00:03<00:00, 28.42it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 52.74it/s]


iia: 0.000%
layer: 2, position: 3


100%|██████████| 100/100 [00:03<00:00, 29.51it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 48.57it/s]


iia: 0.000%
layer: 2, position: 4


100%|██████████| 100/100 [00:03<00:00, 29.81it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 44.26it/s]


iia: 0.000%
layer: 2, position: 5


100%|██████████| 100/100 [00:03<00:00, 29.87it/s, loss=5.436]
100%|██████████| 50/50 [00:00<00:00, 51.32it/s]


iia: 40.000%
layer: 2, position: 6


100%|██████████| 100/100 [00:03<00:00, 30.87it/s, loss=6.696]
100%|██████████| 50/50 [00:00<00:00, 54.13it/s]


iia: 0.000%
layer: 2, position: 7


100%|██████████| 100/100 [00:03<00:00, 29.43it/s, loss=6.674]
100%|██████████| 50/50 [00:01<00:00, 47.89it/s]


iia: 6.000%
layer: 2, position: 8


100%|██████████| 100/100 [00:03<00:00, 31.25it/s, loss=6.664]
100%|██████████| 50/50 [00:00<00:00, 54.24it/s]


iia: 6.000%
layer: 2, position: 9


100%|██████████| 100/100 [00:03<00:00, 32.21it/s, loss=6.657]
100%|██████████| 50/50 [00:01<00:00, 49.38it/s]


iia: 6.000%
layer: 2, position: 10


100%|██████████| 100/100 [00:03<00:00, 30.65it/s, loss=6.559]
100%|██████████| 50/50 [00:00<00:00, 55.50it/s]


iia: 6.000%
layer: 2, position: 11


100%|██████████| 100/100 [00:03<00:00, 30.74it/s, loss=6.076]
100%|██████████| 50/50 [00:00<00:00, 51.94it/s]


iia: 14.000%
layer: 2, position: 12


100%|██████████| 100/100 [00:03<00:00, 30.45it/s, loss=6.090]
100%|██████████| 50/50 [00:01<00:00, 46.94it/s]


iia: 20.000%
layer: 3, position: 0


100%|██████████| 100/100 [00:03<00:00, 32.19it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 44.58it/s]


iia: 0.000%
layer: 3, position: 1


100%|██████████| 100/100 [00:03<00:00, 31.79it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 47.19it/s]


iia: 0.000%
layer: 3, position: 2


100%|██████████| 100/100 [00:03<00:00, 31.77it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 51.99it/s]


iia: 0.000%
layer: 3, position: 3


100%|██████████| 100/100 [00:03<00:00, 31.89it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 42.27it/s]


iia: 0.000%
layer: 3, position: 4


100%|██████████| 100/100 [00:03<00:00, 32.71it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 45.63it/s]


iia: 0.000%
layer: 3, position: 5


100%|██████████| 100/100 [00:03<00:00, 30.97it/s, loss=6.605]
100%|██████████| 50/50 [00:00<00:00, 50.03it/s]


iia: 12.000%
layer: 3, position: 6


100%|██████████| 100/100 [00:03<00:00, 31.66it/s, loss=6.678]
100%|██████████| 50/50 [00:00<00:00, 50.30it/s]


iia: 0.000%
layer: 3, position: 7


100%|██████████| 100/100 [00:03<00:00, 32.24it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 46.67it/s]


iia: 0.000%
layer: 3, position: 8


100%|██████████| 100/100 [00:03<00:00, 32.21it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 49.45it/s]


iia: 0.000%
layer: 3, position: 9


100%|██████████| 100/100 [00:03<00:00, 30.80it/s, loss=6.652]
100%|██████████| 50/50 [00:01<00:00, 45.50it/s]


iia: 0.000%
layer: 3, position: 10


100%|██████████| 100/100 [00:03<00:00, 31.74it/s, loss=6.678]
100%|██████████| 50/50 [00:01<00:00, 44.16it/s]


iia: 6.000%
layer: 3, position: 11


100%|██████████| 100/100 [00:03<00:00, 31.30it/s, loss=6.471]
100%|██████████| 50/50 [00:00<00:00, 50.04it/s]


iia: 8.000%
layer: 3, position: 12


100%|██████████| 100/100 [00:03<00:00, 31.61it/s, loss=4.604]
100%|██████████| 50/50 [00:01<00:00, 44.61it/s]


iia: 46.000%
layer: 4, position: 0


100%|██████████| 100/100 [00:03<00:00, 32.80it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 52.59it/s]


iia: 0.000%
layer: 4, position: 1


100%|██████████| 100/100 [00:02<00:00, 34.10it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 48.74it/s]


iia: 0.000%
layer: 4, position: 2


100%|██████████| 100/100 [00:02<00:00, 33.85it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 46.02it/s]


iia: 0.000%
layer: 4, position: 3


100%|██████████| 100/100 [00:03<00:00, 32.25it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 47.29it/s]


iia: 0.000%
layer: 4, position: 4


100%|██████████| 100/100 [00:02<00:00, 34.64it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 44.26it/s]


iia: 0.000%
layer: 4, position: 5


100%|██████████| 100/100 [00:02<00:00, 34.66it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 53.91it/s]


iia: 0.000%
layer: 4, position: 6


100%|██████████| 100/100 [00:02<00:00, 33.78it/s, loss=6.674]
100%|██████████| 50/50 [00:01<00:00, 47.84it/s]


iia: 0.000%
layer: 4, position: 7


100%|██████████| 100/100 [00:02<00:00, 34.20it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 49.57it/s]


iia: 0.000%
layer: 4, position: 8


100%|██████████| 100/100 [00:02<00:00, 33.44it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 43.12it/s]


iia: 0.000%
layer: 4, position: 9


100%|██████████| 100/100 [00:03<00:00, 32.22it/s, loss=6.674]
100%|██████████| 50/50 [00:01<00:00, 48.89it/s]


iia: 0.000%
layer: 4, position: 10


100%|██████████| 100/100 [00:02<00:00, 34.23it/s, loss=6.622]
100%|██████████| 50/50 [00:01<00:00, 45.96it/s]


iia: 6.000%
layer: 4, position: 11


100%|██████████| 100/100 [00:02<00:00, 34.27it/s, loss=6.753]
100%|██████████| 50/50 [00:01<00:00, 47.68it/s]


iia: 8.000%
layer: 4, position: 12


100%|██████████| 100/100 [00:03<00:00, 33.23it/s, loss=4.175]
100%|██████████| 50/50 [00:01<00:00, 48.33it/s]


iia: 64.000%
layer: 5, position: 0


100%|██████████| 100/100 [00:02<00:00, 36.58it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 44.72it/s]


iia: 0.000%
layer: 5, position: 1


100%|██████████| 100/100 [00:02<00:00, 34.47it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 52.33it/s]


iia: 0.000%
layer: 5, position: 2


100%|██████████| 100/100 [00:02<00:00, 34.25it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 46.44it/s]


iia: 0.000%
layer: 5, position: 3


100%|██████████| 100/100 [00:02<00:00, 36.11it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 45.79it/s]


iia: 0.000%
layer: 5, position: 4


100%|██████████| 100/100 [00:02<00:00, 37.33it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 48.33it/s]


iia: 0.000%
layer: 5, position: 5


100%|██████████| 100/100 [00:02<00:00, 36.30it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 45.27it/s]


iia: 0.000%
layer: 5, position: 6


100%|██████████| 100/100 [00:02<00:00, 35.07it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 50.57it/s]


iia: 0.000%
layer: 5, position: 7


100%|██████████| 100/100 [00:02<00:00, 35.08it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 48.21it/s]


iia: 0.000%
layer: 5, position: 8


100%|██████████| 100/100 [00:02<00:00, 36.09it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 47.82it/s]


iia: 0.000%
layer: 5, position: 9


100%|██████████| 100/100 [00:02<00:00, 35.99it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 44.15it/s]


iia: 0.000%
layer: 5, position: 10


100%|██████████| 100/100 [00:02<00:00, 35.63it/s, loss=6.677]
100%|██████████| 50/50 [00:00<00:00, 50.00it/s]


iia: 0.000%
layer: 5, position: 11


100%|██████████| 100/100 [00:02<00:00, 36.72it/s, loss=6.677]
100%|██████████| 50/50 [00:01<00:00, 48.76it/s]


iia: 0.000%
layer: 5, position: 12


100%|██████████| 100/100 [00:02<00:00, 36.00it/s, loss=5.999]
100%|██████████| 50/50 [00:01<00:00, 48.04it/s]

iia: 56.000%





And this is the plot of IIA. In layers 2 and 3 it seems the gender is represented across positions 1-3, and entirely in position 3 in later layers.

In [87]:
sentence = "<|endoftext|> Yann wants to eat. Yann will go to the"
tokenized_sentence = tokenizer(sentence, return_tensors="pt")
print(tokenized_sentence)

word = "wants"
tokenized_word = tokenizer(word, return_tensors="pt")
print(tokenized_word)

{'input_ids': tensor([[   0,  714, 1136, 5605,  281, 6008,   15,  714, 1136,  588,  564,  281,
          253]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
{'input_ids': tensor([[  88, 1103]]), 'attention_mask': tensor([[1, 1]])}


In [15]:
tokenizer.tokenize("<|endoftext|> Yann wants to eat. Yann will go to the")

['<|endoftext|>',
 'ĠY',
 'ann',
 'Ġwants',
 'Ġto',
 'Ġeat',
 '.',
 'ĠY',
 'ann',
 'Ġwill',
 'Ġgo',
 'Ġto',
 'Ġthe']

In [17]:
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["IIA"] = df["iia"].astype(float)

custom_labels = [ "EOS", "Y", "ann", "wants", "to", "<verb>", ".", "Y", "ann", "will", "go", "to", "the"]
breaks = range(len(custom_labels))

plot = (
    ggplot(df, aes(x="layer", y="pos"))
    + geom_tile(aes(fill="IIA"))
    + scale_fill_cmap("Purples") + xlab("layers")
    + scale_y_reverse(
        limits = (-0.5, 12.5),
        breaks=breaks, labels=custom_labels)
    + theme(figure_size=(15, 10)) + ylab("")
    + theme(axis_text_y  = element_text(angle = 90, hjust = 1))
    + ggtitle("Trained Intervention (DAS)")
)
ggsave(
    plot, filename=f"./tutorial_data/pyvene_gender_das.pdf", dpi=200
)
print(plot)

   layer  position  iia  pos  IIA
0      0         0  0.0    0  0.0
1      0         1  0.0    1  0.0
2      0         2  0.0    2  0.0
3      0         3  0.0    3  0.0
4      0         4  0.0    4  0.0
<ggplot: (1500 x 1000)>




## Probing

We'll define a dummy intervention `CollectActivation` to collect activations and train a simple probe.

In [18]:
def probing_config(intervention_site, layer):
    """Generate intervention config."""

    # init
    config = pv.IntervenableConfig([{
        "layer": layer,
        "component": intervention_site,
        "intervention_type": pv.CollectIntervention,
    }])
    return config

This is the training loop.

In [19]:
label_mapping = {10329:0, 8576:1, 15336:3, 3906:4, 14098:5}

print(label_mapping)


{10329: 0, 8576: 1, 15336: 3, 3906: 4, 14098: 5}


In [20]:
decoded_word = tokenizer.decode([3811])
print(decoded_word)

 living


In [21]:
# intervention settings
stats = []
num_layers = gpt.config.num_hidden_layers


# loop over layers and positions
with torch.no_grad():
    for layer in range(num_layers):
        for position in range(13):
            print(f"layer: {layer}, position: {position}")

            # set up intervenable model
            config = probing_config("block_output", layer)
            intervenable = pv.IntervenableModel(config, gpt)
            intervenable.set_device(device)
            intervenable.disable_model_gradients()

            # training loop
            activations, labels = [], []
            iterator = tqdm(trainset)
            for example in iterator:
                # forward pass
                base_outputs, _ = intervenable(
                    example.base,
                    unit_locations={"base": position},
                )
                base_activations = base_outputs[1][0]

                src_outputs, _ = intervenable(
                    example.src,
                    unit_locations={"base": position},
                )
                src_activations = src_outputs[1][0]

                # collect activation
                activations.extend(
                    [base_activations.detach()[0].cpu().numpy(), src_activations.detach()[0].cpu().numpy()]
                )
                labels.extend([example.base_label, example.src_label])
            labels = [label_mapping[label] for label in labels]

            # train logistic regression
            lr = LogisticRegression(random_state=42, max_iter=1000).fit(
                activations, labels
            )

            # eval
            activations, labels = [], []
            iterator = tqdm(evalset)
            for example in iterator:
                # forward pass
                base_outputs, _ = intervenable(
                    example.base,
                    unit_locations={"base": position},
                )
                base_activations = base_outputs[1][0]

                src_outputs, _ = intervenable(
                    example.src,
                    unit_locations={"base": position},
                )
                src_activations = src_outputs[1][0]

                # collect activation
                activations.extend(
                    [base_activations.detach()[0].cpu().numpy(), src_activations.detach()[0].cpu().numpy()]
                )
                labels.extend([example.base_label, example.src_label])
            labels = [label_mapping[label] for label in labels]

            # stats
            acc = lr.score(activations, labels)
            f1 = f1_score(labels, lr.predict(activations), average='macro')
            stats.append({"layer": layer, "position": position, "acc": acc, "f1": f1})
            print(f"acc: {acc:.3%}, f1: {f1:.3f}")
df = pd.DataFrame(stats)
df.to_csv(f"./tutorial_data/pyvene_gender_probe.csv")

layer: 0, position: 0


100%|██████████| 100/100 [00:01<00:00, 51.39it/s]
100%|██████████| 50/50 [00:00<00:00, 54.00it/s]


acc: 31.000%, f1: 0.095
layer: 0, position: 1


100%|██████████| 100/100 [00:01<00:00, 53.68it/s]
100%|██████████| 50/50 [00:00<00:00, 52.69it/s]


acc: 31.000%, f1: 0.095
layer: 0, position: 2


100%|██████████| 100/100 [00:01<00:00, 56.32it/s]
100%|██████████| 50/50 [00:00<00:00, 55.57it/s]


acc: 31.000%, f1: 0.095
layer: 0, position: 3


100%|██████████| 100/100 [00:01<00:00, 53.21it/s]
100%|██████████| 50/50 [00:00<00:00, 56.52it/s]


acc: 31.000%, f1: 0.095
layer: 0, position: 4


100%|██████████| 100/100 [00:01<00:00, 54.53it/s]
100%|██████████| 50/50 [00:00<00:00, 51.17it/s]


acc: 31.000%, f1: 0.095
layer: 0, position: 5


100%|██████████| 100/100 [00:01<00:00, 54.41it/s]
100%|██████████| 50/50 [00:00<00:00, 53.31it/s]


acc: 100.000%, f1: 1.000
layer: 0, position: 6


100%|██████████| 100/100 [00:01<00:00, 53.82it/s]
100%|██████████| 50/50 [00:00<00:00, 55.92it/s]


acc: 98.000%, f1: 0.793
layer: 0, position: 7


100%|██████████| 100/100 [00:01<00:00, 52.36it/s]
100%|██████████| 50/50 [00:01<00:00, 48.86it/s]


acc: 94.000%, f1: 0.883
layer: 0, position: 8


100%|██████████| 100/100 [00:01<00:00, 56.62it/s]
100%|██████████| 50/50 [00:00<00:00, 59.87it/s]


acc: 43.000%, f1: 0.272
layer: 0, position: 9


100%|██████████| 100/100 [00:01<00:00, 55.35it/s]
100%|██████████| 50/50 [00:00<00:00, 53.39it/s]


acc: 91.000%, f1: 0.613
layer: 0, position: 10


100%|██████████| 100/100 [00:01<00:00, 54.83it/s]
100%|██████████| 50/50 [00:00<00:00, 57.84it/s]


acc: 87.000%, f1: 0.584
layer: 0, position: 11


100%|██████████| 100/100 [00:01<00:00, 52.06it/s]
100%|██████████| 50/50 [00:00<00:00, 52.08it/s]


acc: 83.000%, f1: 0.561
layer: 0, position: 12


100%|██████████| 100/100 [00:01<00:00, 58.22it/s]
100%|██████████| 50/50 [00:00<00:00, 55.78it/s]


acc: 83.000%, f1: 0.561
layer: 1, position: 0


100%|██████████| 100/100 [00:01<00:00, 54.41it/s]
100%|██████████| 50/50 [00:00<00:00, 60.00it/s]


acc: 31.000%, f1: 0.095
layer: 1, position: 1


100%|██████████| 100/100 [00:01<00:00, 57.08it/s]
100%|██████████| 50/50 [00:00<00:00, 51.77it/s]


acc: 31.000%, f1: 0.095
layer: 1, position: 2


100%|██████████| 100/100 [00:01<00:00, 53.83it/s]
100%|██████████| 50/50 [00:00<00:00, 60.68it/s]


acc: 31.000%, f1: 0.095
layer: 1, position: 3


100%|██████████| 100/100 [00:01<00:00, 53.09it/s]
100%|██████████| 50/50 [00:00<00:00, 62.69it/s]


acc: 31.000%, f1: 0.095
layer: 1, position: 4


100%|██████████| 100/100 [00:01<00:00, 55.44it/s]
100%|██████████| 50/50 [00:00<00:00, 53.42it/s]


acc: 31.000%, f1: 0.095
layer: 1, position: 5


100%|██████████| 100/100 [00:01<00:00, 54.52it/s]
100%|██████████| 50/50 [00:00<00:00, 52.30it/s]


acc: 100.000%, f1: 1.000
layer: 1, position: 6


100%|██████████| 100/100 [00:01<00:00, 53.29it/s]
100%|██████████| 50/50 [00:00<00:00, 53.96it/s]


acc: 100.000%, f1: 1.000
layer: 1, position: 7


100%|██████████| 100/100 [00:01<00:00, 56.97it/s]
100%|██████████| 50/50 [00:00<00:00, 55.31it/s]


acc: 100.000%, f1: 1.000
layer: 1, position: 8


100%|██████████| 100/100 [00:01<00:00, 56.66it/s]
100%|██████████| 50/50 [00:00<00:00, 53.48it/s]


acc: 93.000%, f1: 0.746
layer: 1, position: 9


100%|██████████| 100/100 [00:01<00:00, 53.67it/s]
100%|██████████| 50/50 [00:00<00:00, 58.84it/s]


acc: 99.000%, f1: 0.983
layer: 1, position: 10


100%|██████████| 100/100 [00:01<00:00, 54.66it/s]
100%|██████████| 50/50 [00:00<00:00, 67.62it/s]


acc: 99.000%, f1: 0.983
layer: 1, position: 11


100%|██████████| 100/100 [00:01<00:00, 67.00it/s]
100%|██████████| 50/50 [00:00<00:00, 52.01it/s]


acc: 100.000%, f1: 1.000
layer: 1, position: 12


100%|██████████| 100/100 [00:01<00:00, 55.25it/s]
100%|██████████| 50/50 [00:00<00:00, 57.42it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 0


100%|██████████| 100/100 [00:01<00:00, 51.10it/s]
100%|██████████| 50/50 [00:00<00:00, 59.51it/s]


acc: 31.000%, f1: 0.095
layer: 2, position: 1


100%|██████████| 100/100 [00:01<00:00, 57.96it/s]
100%|██████████| 50/50 [00:00<00:00, 52.30it/s]


acc: 31.000%, f1: 0.095
layer: 2, position: 2


100%|██████████| 100/100 [00:01<00:00, 50.98it/s]
100%|██████████| 50/50 [00:00<00:00, 54.82it/s]


acc: 31.000%, f1: 0.095
layer: 2, position: 3


100%|██████████| 100/100 [00:01<00:00, 53.37it/s]
100%|██████████| 50/50 [00:00<00:00, 58.93it/s]


acc: 31.000%, f1: 0.095
layer: 2, position: 4


100%|██████████| 100/100 [00:01<00:00, 59.99it/s]
100%|██████████| 50/50 [00:00<00:00, 54.90it/s]


acc: 31.000%, f1: 0.095
layer: 2, position: 5


100%|██████████| 100/100 [00:01<00:00, 51.62it/s]
100%|██████████| 50/50 [00:00<00:00, 51.84it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 6


100%|██████████| 100/100 [00:01<00:00, 62.21it/s]
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
100%|██████████| 50/50 [00:00<00:00, 59.09it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 7


100%|██████████| 100/100 [00:01<00:00, 54.34it/s]
100%|██████████| 50/50 [00:00<00:00, 56.48it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 8


100%|██████████| 100/100 [00:01<00:00, 52.87it/s]
100%|██████████| 50/50 [00:00<00:00, 58.35it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 9


100%|██████████| 100/100 [00:01<00:00, 53.17it/s]
100%|██████████| 50/50 [00:00<00:00, 68.34it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 10


100%|██████████| 100/100 [00:01<00:00, 68.76it/s]
100%|██████████| 50/50 [00:00<00:00, 54.55it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 11


100%|██████████| 100/100 [00:01<00:00, 56.61it/s]
100%|██████████| 50/50 [00:00<00:00, 56.03it/s]


acc: 100.000%, f1: 1.000
layer: 2, position: 12


100%|██████████| 100/100 [00:01<00:00, 57.83it/s]
100%|██████████| 50/50 [00:00<00:00, 51.23it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 0


100%|██████████| 100/100 [00:01<00:00, 51.61it/s]
100%|██████████| 50/50 [00:00<00:00, 54.85it/s]


acc: 31.000%, f1: 0.095
layer: 3, position: 1


100%|██████████| 100/100 [00:01<00:00, 51.77it/s]
100%|██████████| 50/50 [00:00<00:00, 58.15it/s]


acc: 31.000%, f1: 0.095
layer: 3, position: 2


100%|██████████| 100/100 [00:01<00:00, 54.61it/s]
100%|██████████| 50/50 [00:00<00:00, 51.36it/s]


acc: 31.000%, f1: 0.095
layer: 3, position: 3


100%|██████████| 100/100 [00:01<00:00, 55.90it/s]
100%|██████████| 50/50 [00:00<00:00, 53.59it/s]


acc: 31.000%, f1: 0.095
layer: 3, position: 4


100%|██████████| 100/100 [00:01<00:00, 51.42it/s]
100%|██████████| 50/50 [00:00<00:00, 52.02it/s]


acc: 31.000%, f1: 0.095
layer: 3, position: 5


100%|██████████| 100/100 [00:01<00:00, 51.03it/s]
100%|██████████| 50/50 [00:00<00:00, 50.08it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 6


100%|██████████| 100/100 [00:01<00:00, 54.00it/s]
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
100%|██████████| 50/50 [00:00<00:00, 51.43it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 7


100%|██████████| 100/100 [00:01<00:00, 53.96it/s]
100%|██████████| 50/50 [00:00<00:00, 55.26it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 8


100%|██████████| 100/100 [00:02<00:00, 49.14it/s]
100%|██████████| 50/50 [00:00<00:00, 61.98it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 9


100%|██████████| 100/100 [00:01<00:00, 55.35it/s]
100%|██████████| 50/50 [00:00<00:00, 55.18it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 10


100%|██████████| 100/100 [00:01<00:00, 50.51it/s]
100%|██████████| 50/50 [00:00<00:00, 54.77it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 11


100%|██████████| 100/100 [00:01<00:00, 57.09it/s]
100%|██████████| 50/50 [00:00<00:00, 52.07it/s]


acc: 100.000%, f1: 1.000
layer: 3, position: 12


100%|██████████| 100/100 [00:01<00:00, 54.98it/s]
100%|██████████| 50/50 [00:00<00:00, 54.13it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 0


100%|██████████| 100/100 [00:01<00:00, 53.08it/s]
100%|██████████| 50/50 [00:00<00:00, 50.52it/s]


acc: 31.000%, f1: 0.095
layer: 4, position: 1


100%|██████████| 100/100 [00:01<00:00, 53.90it/s]
100%|██████████| 50/50 [00:00<00:00, 56.77it/s]


acc: 31.000%, f1: 0.095
layer: 4, position: 2


100%|██████████| 100/100 [00:01<00:00, 53.96it/s]
100%|██████████| 50/50 [00:00<00:00, 50.64it/s]


acc: 31.000%, f1: 0.095
layer: 4, position: 3


100%|██████████| 100/100 [00:01<00:00, 53.75it/s]
100%|██████████| 50/50 [00:00<00:00, 59.92it/s]


acc: 31.000%, f1: 0.095
layer: 4, position: 4


100%|██████████| 100/100 [00:01<00:00, 61.43it/s]
100%|██████████| 50/50 [00:00<00:00, 58.26it/s]


acc: 31.000%, f1: 0.095
layer: 4, position: 5


100%|██████████| 100/100 [00:01<00:00, 60.07it/s]
100%|██████████| 50/50 [00:00<00:00, 58.86it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 6


100%|██████████| 100/100 [00:01<00:00, 52.66it/s]
100%|██████████| 50/50 [00:00<00:00, 52.50it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 7


100%|██████████| 100/100 [00:01<00:00, 52.10it/s]
100%|██████████| 50/50 [00:00<00:00, 50.77it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 8


100%|██████████| 100/100 [00:01<00:00, 53.28it/s]
100%|██████████| 50/50 [00:00<00:00, 51.00it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 9


100%|██████████| 100/100 [00:01<00:00, 52.55it/s]
100%|██████████| 50/50 [00:00<00:00, 53.34it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 10


100%|██████████| 100/100 [00:01<00:00, 54.63it/s]
100%|██████████| 50/50 [00:00<00:00, 63.47it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 11


100%|██████████| 100/100 [00:01<00:00, 53.53it/s]
100%|██████████| 50/50 [00:01<00:00, 47.57it/s]


acc: 100.000%, f1: 1.000
layer: 4, position: 12


100%|██████████| 100/100 [00:01<00:00, 52.25it/s]
100%|██████████| 50/50 [00:00<00:00, 57.09it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 0


100%|██████████| 100/100 [00:01<00:00, 54.77it/s]
100%|██████████| 50/50 [00:00<00:00, 58.79it/s]


acc: 31.000%, f1: 0.095
layer: 5, position: 1


100%|██████████| 100/100 [00:01<00:00, 55.84it/s]
100%|██████████| 50/50 [00:00<00:00, 51.62it/s]


acc: 31.000%, f1: 0.095
layer: 5, position: 2


100%|██████████| 100/100 [00:01<00:00, 56.22it/s]
100%|██████████| 50/50 [00:00<00:00, 53.78it/s]


acc: 31.000%, f1: 0.095
layer: 5, position: 3


100%|██████████| 100/100 [00:01<00:00, 54.40it/s]
100%|██████████| 50/50 [00:00<00:00, 50.48it/s]


acc: 31.000%, f1: 0.095
layer: 5, position: 4


100%|██████████| 100/100 [00:01<00:00, 60.72it/s]
100%|██████████| 50/50 [00:00<00:00, 57.06it/s]


acc: 31.000%, f1: 0.095
layer: 5, position: 5


100%|██████████| 100/100 [00:01<00:00, 56.45it/s]
100%|██████████| 50/50 [00:01<00:00, 49.29it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 6


100%|██████████| 100/100 [00:01<00:00, 54.54it/s]
100%|██████████| 50/50 [00:00<00:00, 58.10it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 7


100%|██████████| 100/100 [00:01<00:00, 53.62it/s]
100%|██████████| 50/50 [00:00<00:00, 58.89it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 8


100%|██████████| 100/100 [00:01<00:00, 53.61it/s]
100%|██████████| 50/50 [00:01<00:00, 49.05it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 9


100%|██████████| 100/100 [00:01<00:00, 53.90it/s]
100%|██████████| 50/50 [00:00<00:00, 60.41it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 10


100%|██████████| 100/100 [00:01<00:00, 56.26it/s]
100%|██████████| 50/50 [00:00<00:00, 56.10it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 11


100%|██████████| 100/100 [00:01<00:00, 55.61it/s]
100%|██████████| 50/50 [00:00<00:00, 57.90it/s]


acc: 100.000%, f1: 1.000
layer: 5, position: 12


100%|██████████| 100/100 [00:01<00:00, 53.99it/s]
100%|██████████| 50/50 [00:00<00:00, 50.81it/s]

acc: 100.000%, f1: 1.000





And the probe accuracy plot is below. Note the extremely high accuracy at all positions at and after the name! Early layers at later positions are better but it saturates much before the IIA for DAS. This shows how unreliable probes are for tracing causal effect.

In [23]:
df = pd.read_csv(f"./tutorial_data/pyvene_gender_probe.csv")
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["ACC"] = df["acc"].astype(float)

custom_labels = [ "EOS", "Y", "ann", "wants", "to", "<verb>", ".", "Y", "ann", "will", "go", "to", "the"]
breaks = range(len(custom_labels))

plot = (
    ggplot(df, aes(x="layer", y="pos", fill="ACC"))
    + geom_tile()
    + scale_fill_cmap("Reds") + xlab("layers")
    + scale_y_reverse(
        limits = (-0.5, 12.5),
        breaks=breaks, labels=custom_labels)
    + theme(figure_size=(15, 10)) + ylab("")
    + theme(axis_text_y  = element_text(angle = 90, hjust = 1))
    + ggtitle("Trained Linear Probe")
)
ggsave(
    plot, filename=f"./tutorial_data/pyvene_gender_probe.pdf", dpi=200
)
print(plot)




<ggplot: (1500 x 1000)>


_______________________________________

In [None]:
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset

# Załadowanie zbioru danych
dataset = load_dataset("tasksource/babi_nli", "agents-motivations")
train_data = pd.DataFrame(dataset['train'])
test_data = pd.DataFrame(dataset['test'])

# Filtrowanie danych tylko z etykietą "0" (not-entailed)
filtered_train_data = train_data[train_data['label'] == 0]
filtered_test_data = test_data[test_data['label'] == 0]

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
tokenizer.pad_token = tokenizer.eos_token

# Funkcja do ekstrakcji przykładów
def extract_room_examples(data):
    examples = []
    rooms = ["kitchen", "garden", "bedroom"]  # Lista dostępnych pomieszczeń
    
    for _, row in data.iterrows():
        context = row['premise']  # Kontekst
        hypothesis = row['hypothesis']  # Hipoteza
        
        # Jeśli hipoteza zawiera pomieszczenie
        if any(room in hypothesis for room in rooms):
            room = [room for room in rooms if room in hypothesis][0]  # Wybór pomieszczenia z hipotezy
            # Sprawdzamy, czy kontekst nie zawiera pomieszczenia
            if not any(room in context for room in rooms):
                # Przygotowujemy przykład
                example = (context + " " + hypothesis.split("to the")[0] + "to the ", room)  # Dodajemy tylko do miejsca "to the"
                examples.append(example)
    
    return examples

# Przygotowanie przykładów
train_examples = extract_room_examples(filtered_train_data)
test_examples = extract_room_examples(filtered_test_data)

# Sprawdzamy przykłady
print("Przykłady do treningu:")
for example in train_examples[:5]:  # Wyświetlamy tylko pierwsze 5 przykładów
    print(f"Kontekst: {example[0]}")
    print(f"Pomieszczenie: {example[1]}")
    print("-" * 50)

print("\nPrzykłady do testu:")
for example in test_examples[:5]:  # Wyświetlamy tylko pierwsze 5 przykładów
    print(f"Kontekst: {example[0]}")
    print(f"Pomieszczenie: {example[1]}")
    print("-" * 50)

print(len(train_examples), len(test_examples))


In [None]:
def get_hidden_representations(model, tokenizer, examples):
    hidden_representations = []
    
    # Tokenizujemy i przechodzimy przez dane
    for context, room in examples:
        # Tokenizacja kontekstu i pomieszczenia
        inputs = tokenizer(context + " " + room, return_tensors="pt", padding=True, truncation=True)
        
        # Uzyskanie ukrytych reprezentacji
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states  # Uzyskiwanie wszystkich warstw
            
        # Zbieramy reprezentacje dla każdego tokena w każdej warstwie
        for layer_idx, hidden_state in enumerate(hidden_states):  # Iterowanie przez warstwy
            for pos, token_repr in enumerate(hidden_state[0]):  # Iterowanie przez tokeny
                hidden_representations.append({
                    "layer": layer_idx,  # Numer warstwy
                    "position": pos,     # Pozycja tokenu
                    "IIA": token_repr.mean().item(),  # Reprezentacja ukryta (średnia wartość dla tokenu)
                    "room": room  # Dodajemy nazwę pomieszczenia do danych
                })
    
    return hidden_representations

# Załaduj model
model = AutoModel.from_pretrained("EleutherAI/pythia-410m")

# Zbieranie reprezentacji ukrytych
hidden_representations = get_hidden_representations(model, tokenizer, train_examples)

# Tworzymy DataFrame z reprezentacjami ukrytymi
df = pd.DataFrame(hidden_representations)

# Przygotowanie wykresu
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["IIA"] = df["IIA"].astype(float)

# Przygotowanie unikalnych tokenów w kontekście (przykłady do analizy)
tokens = sorted(df['room'].unique())  # Unikalne pomieszczenia w posortowanej kolejności
custom_labels = tokens  # Etykiety osi Y
breaks = list(range(len(tokens)))  # Odpowiednia liczba przerw na osi Y

# Tworzenie wykresu
plot = (
    ggplot(df, aes(x="layer", y="room"))  # Używamy `room` jako osi Y
    + geom_tile(aes(fill="IIA"))  # Kolorowanie na podstawie wartości IIA
    + scale_fill_cmap("Purples")  # Skala kolorów
    + xlab("layers")  # Etykieta osi X
    + scale_y_discrete(  # Używamy skali dyskretnej dla osi Y, ponieważ mamy tokeny
        breaks=tokens,  # Przerwy odpowiadające tokenom
        labels=custom_labels  # Etykiety odpowiadające tokenom
    )
    + theme(figure_size=(5, 3))  # Rozmiar wykresu
    + ylab("")  # Bez etykiety osi Y
    + theme(axis_text_y=element_text(angle=90, hjust=1))  # Obrót etykiet osi Y
    + ggtitle("Trained Intervention (DAS)")  # Tytuł wykresu
)

# Zapisanie wykresu do pliku
ggsave(plot, filename="./tutorial_data/pyvene_room_das.pdf", dpi=200)

# Wyświetlenie wykresu
print(plot)

_____________________