In [1]:
# %%
%load_ext autoreload
%autoreload 2

import torch
from cot_probing.attn_probes_case_studies import *
from cot_probing.attn_probes_data_proc import CollateFnOutput
from cot_probing.utils import load_model_and_tokenizer
from cot_probing.activations import build_fsp_cache
from ipywidgets import Dropdown, interactive_output, VBox



In [2]:
from cot_probing.utils import fetch_runs
from cot_probing.attn_probes import AttnProbeTrainer
import wandb
from cot_probing import DATA_DIR
import pickle


def load_median_probe_test_data(
    probe_class: str,
    layer: int,
    context: str,
    min_seed: int,
    max_seed: int,
    metric: str,
) -> tuple[AttnProbeTrainer, list[int], list[dict], str]:
    runs_by_seed_by_layer = fetch_runs(
        api=wandb.Api(),
        probe_class=probe_class,
        min_layer=layer,
        max_layer=layer,
        min_seed=min_seed,
        max_seed=max_seed,
        context=context,
    )
    assert len(runs_by_seed_by_layer) == 1
    runs_by_seed = runs_by_seed_by_layer[layer]
    seed_run_sorted = sorted(
        runs_by_seed.items(), key=lambda s_r: s_r[1].summary.get(metric)
    )

    _median_seed, median_run = seed_run_sorted[len(seed_run_sorted) // 2]
    # median_acc = median_run.summary.get(metric)
    raw_acts_path = (
        DATA_DIR / f"../../activations/acts_L{layer:02d}_{context}_oct28-1156.pkl"
    )
    with open(raw_acts_path, "rb") as f:
        raw_acts_dataset = pickle.load(f)
    trainer, _, test_idxs = AttnProbeTrainer.from_wandb(
        raw_acts_dataset=raw_acts_dataset,
        run_id=median_run.id,
    )
    unbiased_fsp_str = raw_acts_dataset["unbiased_fsp"]
    raw_acts_qs = [raw_acts_dataset["qs"][i] for i in test_idxs]
    return trainer, test_idxs, raw_acts_qs, unbiased_fsp_str

In [3]:
torch.set_grad_enabled(False)

# %%
layer = 14
context = "biased-fsp"
min_seed, max_seed = 1, 10
n_seeds = max_seed - min_seed + 1
probe_class = "QV"
metric = "test_accuracy"

trainer, test_idxs, raw_acts_qs, unbiased_fsp_str = load_median_probe_test_data(
    probe_class, layer, context, min_seed, max_seed, metric
)
collate_fn_out: CollateFnOutput = list(trainer.test_loader)[0]
from transformers import AutoTokenizer

model, tokenizer = load_model_and_tokenizer(8)
unbiased_fsp_cache = build_fsp_cache(model, tokenizer, unbiased_fsp_str)

Fetched 10 runs


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
q_idxs = []
q_and_cot_tokens = []
cots_labels = []
cots_answers = []
questions = []
for q_idx, test_q in enumerate(raw_acts_qs):
    cots = test_q["biased_cots_tokens_to_cache"]
    for cot in cots:
        tokens = cot[:-4]
        q_and_cot_tokens.append(tokens)
        cot_label = test_q["biased_cot_label"]
        cots_labels.append(cot_label)
        cots_answers.append(test_q["expected_answer"])
        questions.append(test_q["question"])
        q_idxs.append(q_idx)

In [9]:
from cot_probing.typing import *
from cot_probing.vis import visualize_tokens_html


def visualize_cot_attn(
    probe_model: AbstractAttnProbeModel,
    tokenizer: PreTrainedTokenizerBase,
    tokens: list[int],
    label: str,
    answer: str,
    resids: Float[torch.Tensor, "1 seq d_model"],
):
    # # Use provided resids or get from collate_fn_out
    # if resids is None:
    #     resids = collate_fn_out.cot_acts[cot_idx:cot_idx+1, :len(tokens)].to(probe_model.device)

    attn_mask = torch.ones(1, len(tokens), dtype=torch.bool, device=probe_model.device)

    # Get attention probs and model output
    attn_probs = probe_model.attn_probs(resids, attn_mask)
    probe_out = probe_model(resids, attn_mask)

    this_attn_probs = attn_probs[0, : len(tokens)]
    print(f"label: {label}, correct answer: {answer}")
    print(f"faithfulness: {probe_out.item():.2%}")
    return visualize_tokens_html(
        tokens, tokenizer, this_attn_probs.tolist(), vmin=0.0, vmax=1.0
    )


def update_plot(
    tokenizer: PreTrainedTokenizerBase,
    q_and_cot_tokens: list[list[int]],
    cots_labels: list[str],
    cots_answers: list[str],
    probe_model: AbstractAttnProbeModel,
    collate_fn_out: CollateFnOutput,
    cot_idx: int,
):
    # trunc_cot_str = tokenizer.decode(swap.trunc_cot)
    # print(question_str + trunc_cot_str)
    # print()
    # print(f"correct answer: {correct_answer_str.upper()}")

    tokens = q_and_cot_tokens[cot_idx]
    biased_resids = collate_fn_out.cot_acts[cot_idx : cot_idx + 1, : len(tokens)].to(
        probe_model.device
    )
    display(
        visualize_cot_attn(
            probe_model=probe_model,
            tokenizer=tokenizer,
            tokens=tokens,
            label=cots_labels[cot_idx],
            answer=cots_answers[cot_idx],
            resids=biased_resids,
        )
    )

In [10]:
from functools import partial

upd_plot = partial(
    update_plot,
    tokenizer=tokenizer,
    q_and_cot_tokens=q_and_cot_tokens,
    cots_labels=cots_labels,
    cots_answers=cots_answers,
    probe_model=trainer.model,
    collate_fn_out=collate_fn_out,
)

In [11]:
# TODO: by correct answer, by label, by question, by cot
upd_plot(cot_idx=1)

label: faithful, correct answer: yes
faithfulness: 78.42%


In [14]:
# Function to update q_idx options based on category
def get_cot_idx_options(label, answer):
    return [
        cot_idx
        for cot_idx, cot_label in enumerate(cots_labels)
        if cot_label == label and cots_answers[cot_idx] == answer
    ]


def on_label_change(change):
    cot_idx_dropdown.options = get_cot_idx_options(change.new, answer_dropdown.value)
    if cot_idx_dropdown.options:
        cot_idx_dropdown.value = cot_idx_dropdown.options[0]


def on_answer_change(change):
    cot_idx_dropdown.options = get_cot_idx_options(label_dropdown.value, change.new)
    if cot_idx_dropdown.options:
        cot_idx_dropdown.value = cot_idx_dropdown.options[0]


label_dropdown = Dropdown(
    options=["faithful", "unfaithful"],
    description="Label:",
)
answer_dropdown = Dropdown(
    options=["yes", "no"],
    description="Answer:",
)
cot_idx_dropdown = Dropdown(
    options=get_cot_idx_options(label_dropdown.value, answer_dropdown.value),
    description="CoT index:",
)

label_dropdown.observe(on_label_change, names="value")
answer_dropdown.observe(on_answer_change, names="value")

In [15]:
from functools import partial

# Create interactive output
out = interactive_output(
    partial(
        update_plot,
        tokenizer=tokenizer,
        q_and_cot_tokens=q_and_cot_tokens,
        cots_labels=cots_labels,
        cots_answers=cots_answers,
        probe_model=trainer.model,
        collate_fn_out=collate_fn_out,
    ),
    {
        "cot_idx": cot_idx_dropdown,
    },
)

# Display widgets and output
widgets = VBox([label_dropdown, answer_dropdown, cot_idx_dropdown])
display(widgets)
display(out)  # Also display the output widget

VBox(children=(Dropdown(description='Label:', options=('faithful', 'unfaithful'), value='faithful'), Dropdown(…

Output()