In [14]:
# %%
%load_ext autoreload
%autoreload 2

import torch
from cot_probing.attn_probes_case_studies import *
from cot_probing.attn_probes_data_proc import CollateFnOutput
from cot_probing.utils import load_model_and_tokenizer
from cot_probing.activations import build_fsp_cache
from ipywidgets import Dropdown, interactive_output, VBox



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from cot_probing.utils import fetch_runs
from cot_probing.attn_probes import AttnProbeTrainer
import wandb
from cot_probing import DATA_DIR
import pickle


def load_median_probe_test_data(
    probe_class: str,
    layer: int,
    context: str,
    min_seed: int,
    max_seed: int,
    metric: str,
) -> tuple[AttnProbeTrainer, list[int], list[dict], str]:
    runs_by_seed_by_layer = fetch_runs(
        api=wandb.Api(),
        probe_class=probe_class,
        min_layer=layer,
        max_layer=layer,
        min_seed=min_seed,
        max_seed=max_seed,
        context=context,
    )
    assert len(runs_by_seed_by_layer) == 1
    runs_by_seed = runs_by_seed_by_layer[layer]
    seed_run_sorted = sorted(
        runs_by_seed.items(), key=lambda s_r: s_r[1].summary.get(metric)
    )

    _median_seed, median_run = seed_run_sorted[len(seed_run_sorted) // 2]
    # median_acc = median_run.summary.get(metric)
    raw_acts_path = (
        DATA_DIR / f"../../activations/acts_L{layer:02d}_{context}_oct28-1156.pkl"
    )
    with open(raw_acts_path, "rb") as f:
        raw_acts_dataset = pickle.load(f)
    trainer, _, test_idxs = AttnProbeTrainer.from_wandb(
        raw_acts_dataset=raw_acts_dataset,
        run_id=median_run.id,
    )
    unbiased_fsp_str = raw_acts_dataset["unbiased_fsp"]
    raw_acts_qs = [raw_acts_dataset["qs"][i] for i in test_idxs]
    return trainer, test_idxs, raw_acts_qs, unbiased_fsp_str

In [12]:
torch.set_grad_enabled(False)

# %%
layer = 15
context = "biased-fsp"
min_seed, max_seed = 1, 10
n_seeds = max_seed - min_seed + 1
probe_class = "V"
metric = "test_accuracy"

trainer, test_idxs, raw_acts_qs, unbiased_fsp_str = load_median_probe_test_data(
    probe_class, layer, context, min_seed, max_seed, metric
)
collate_fn_out: CollateFnOutput = list(trainer.test_loader)[0]
from transformers import AutoTokenizer

model, tokenizer = load_model_and_tokenizer(8)
unbiased_fsp_cache = build_fsp_cache(model, tokenizer, unbiased_fsp_str)

Fetched 10 runs


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
q_and_cot_tokens = []
cots_labels = []
cots_answers = []
questions = []
for test_q in raw_acts_qs:
    cots = test_q["biased_cots_tokens_to_cache"]
    for cot in cots:
        tokens = cot[:-4]
        q_and_cot_tokens.append(tokens)
        cots_labels.append(test_q["biased_cot_label"])
        cots_answers.append(test_q["expected_answer"])
        questions.append(test_q["question"])

In [19]:
from cot_probing.typing import *
from cot_probing.vis import visualize_tokens_html


def visualize_cot_attn(
    probe_model: AbstractAttnProbeModel,
    tokenizer: PreTrainedTokenizerBase,
    tokens: list[int],
    label: str,
    answer: str,
    resids: Float[torch.Tensor, "1 seq d_model"],
):
    # # Use provided resids or get from collate_fn_out
    # if resids is None:
    #     resids = collate_fn_out.cot_acts[cot_idx:cot_idx+1, :len(tokens)].to(probe_model.device)

    attn_mask = torch.ones(1, len(tokens), dtype=torch.bool, device=probe_model.device)

    # Get attention probs and model output
    attn_probs = probe_model.attn_probs(resids, attn_mask)
    probe_out = probe_model(resids, attn_mask)

    this_attn_probs = attn_probs[0, : len(tokens)]
    print(f"label: {label}, correct answer: {answer}")
    print(f"faithfulness: {probe_out.item():.2%}")
    return visualize_tokens_html(
        tokens, tokenizer, this_attn_probs.tolist(), vmin=0.0, vmax=1.0
    )


def update_plot(
    tokenizer: PreTrainedTokenizerBase,
    q_and_cot_tokens: list[list[int]],
    cots_labels: list[str],
    cots_answers: list[str],
    probe_model: AbstractAttnProbeModel,
    collate_fn_out: CollateFnOutput,
    cot_idx: int,
):
    # trunc_cot_str = tokenizer.decode(swap.trunc_cot)
    # print(question_str + trunc_cot_str)
    # print()
    # print(f"correct answer: {correct_answer_str.upper()}")

    tokens = q_and_cot_tokens[cot_idx]
    biased_resids = collate_fn_out.cot_acts[cot_idx : cot_idx + 1, : len(tokens)].to(
        probe_model.device
    )
    display(
        visualize_cot_attn(
            probe_model=probe_model,
            tokenizer=tokenizer,
            tokens=tokens,
            label=cots_labels[cot_idx],
            answer=cots_answers[cot_idx],
            resids=biased_resids,
        )
    )

In [20]:
upd_plot = partial(
    update_plot,
    tokenizer=tokenizer,
    q_and_cot_tokens=q_and_cot_tokens,
    cots_labels=cots_labels,
    cots_answers=cots_answers,
    probe_model=trainer.model,
    collate_fn_out=collate_fn_out,
)

In [24]:
# TODO: by correct answer, by label, by question, by cot
upd_plot(cot_idx=1)

label: faithful, correct answer: yes
faithfulness: 95.22%


In [None]:
# TODO: dropdown

In [16]:
# Function to update q_idx options based on category
def get_q_idx_options(category):
    q_idx_counts = {}
    for pair in categories[category]:
        q_idx_counts[pair[0]] = q_idx_counts.get(pair[0], 0) + 1
    return [
        (f"{q_idx} ({count} swaps)", q_idx)
        for q_idx, count in sorted(q_idx_counts.items())
    ]


# Function to update swap_idx options based on category and q_idx
def get_swap_idx_options(category, q_idx):
    pairs = categories[category]
    return [pair[1] for pair in pairs if pair[0] == q_idx]


# Update q_idx options when category changes
def on_category_change(change):
    q_idx_dropdown.options = get_q_idx_options(change.new)
    if q_idx_dropdown.options:
        q_idx_dropdown.value = q_idx_dropdown.options[0][1]

    swap_idx_dropdown.options = get_swap_idx_options(
        category_dropdown.value, q_idx_dropdown.value
    )
    if swap_idx_dropdown.options:
        swap_idx_dropdown.value = swap_idx_dropdown.options[0]


# Update swap_idx options when q_idx changes
def on_q_idx_change(change):
    swap_idx_dropdown.options = get_swap_idx_options(
        category_dropdown.value, change.new
    )
    if swap_idx_dropdown.options:
        swap_idx_dropdown.value = swap_idx_dropdown.options[0]


# Create dropdown for category
category_dropdown = Dropdown(
    options=[(f"{cat} ({len(pairs)})", cat) for cat, pairs in categories.items()],
    description="Category:",
)

# Create dropdown for q_idx
q_idx_dropdown = Dropdown(
    options=get_q_idx_options(category_dropdown.value), description="Q Index:"
)

# Create dropdown for swap_idx
swap_idx_dropdown = Dropdown(
    options=get_swap_idx_options(category_dropdown.value, q_idx_dropdown.value),
    description="Swap Index:",
)

category_dropdown.observe(on_category_change, names="value")
q_idx_dropdown.observe(on_q_idx_change, names="value")

In [17]:
# %%
from functools import partial

# Create interactive output
out = interactive_output(
    partial(
        update_plot,
        tokenizer=tokenizer,
        q_and_cot_tokens=q_and_cot_tokens,
        cots_labels=cots_labels,
        cots_answers=cots_answers,
        probe_model=trainer.model,
        collate_fn_out=collate_fn_out,
    ),
    {
        "idx": category_dropdown,
    },
)

# Display widgets and output
widgets = VBox([category_dropdown, q_idx_dropdown, swap_idx_dropdown])
display(widgets)
display(out)  # Also display the output widget
# %%

VBox(children=(Dropdown(description='Category:', options=(('reasoning (9)', 'reasoning'), ('only_last_three (3…

Output()