In [1]:
# %%
%load_ext autoreload
%autoreload 2

import torch
from cot_probing.attn_probes_case_studies import *
from cot_probing.attn_probes_data_proc import CollateFnOutput
from cot_probing.utils import load_model_and_tokenizer
from cot_probing.activations import build_fsp_cache
from ipywidgets import Dropdown, interactive_output, VBox



In [2]:
torch.set_grad_enabled(False)

# %%
layer = 15
min_seed, max_seed = 21, 40
n_seeds = max_seed - min_seed + 1
probe_class = "minimal"
metric = "test_accuracy"

trainer, test_idxs, raw_acts_qs, unbiased_fsp_str = load_median_probe_test_data(
    probe_class, layer, min_seed, max_seed, metric
)
collate_fn_out: CollateFnOutput = list(trainer.test_loader)[0]
from transformers import AutoTokenizer

model, tokenizer = load_model_and_tokenizer(8)
unbiased_fsp_cache = build_fsp_cache(model, tokenizer, unbiased_fsp_str)

Fetched 20 runs


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# %%
filtered_data = load_filtered_data(raw_acts_qs, tokenizer)
categories_with_matches = filtered_data["categories_with_matches"]
swap_dict_by_q = filtered_data["swap_dict_by_q"]
patch_all_by_q = filtered_data["patch_all_by_q"]
patch_layers_by_q = filtered_data["patch_layers_by_q"]
categories = {c: cwm_dict["pairs"] for c, cwm_dict in categories_with_matches.items()}

  0%|          | 0/133 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/232 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

  0%|          | 0/258 [00:00<?, ?it/s]

reasoning: 4 4
only_last_three: 7 7
other: 7 7
ltsbs: 18 18
in_between: 14 14
both: 16 16


In [4]:
# Function to update q_idx options based on category
def get_q_idx_options(category):
    q_idx_counts = {}
    for pair in categories[category]:
        q_idx_counts[pair[0]] = q_idx_counts.get(pair[0], 0) + 1
    return [
        (f"{q_idx} ({count} swaps)", q_idx)
        for q_idx, count in sorted(q_idx_counts.items())
    ]


# Function to update swap_idx options based on category and q_idx
def get_swap_idx_options(category, q_idx):
    pairs = categories[category]
    return [pair[1] for pair in pairs if pair[0] == q_idx]


# Update q_idx options when category changes
def on_category_change(change):
    q_idx_dropdown.options = get_q_idx_options(change.new)
    if q_idx_dropdown.options:
        q_idx_dropdown.value = q_idx_dropdown.options[0][1]

    swap_idx_dropdown.options = get_swap_idx_options(
        category_dropdown.value, q_idx_dropdown.value
    )
    if swap_idx_dropdown.options:
        swap_idx_dropdown.value = swap_idx_dropdown.options[0]


# Update swap_idx options when q_idx changes
def on_q_idx_change(change):
    swap_idx_dropdown.options = get_swap_idx_options(
        category_dropdown.value, change.new
    )
    if swap_idx_dropdown.options:
        swap_idx_dropdown.value = swap_idx_dropdown.options[0]


# Create dropdown for category
category_dropdown = Dropdown(
    options=[(f"{cat} ({len(pairs)})", cat) for cat, pairs in categories.items()],
    description="Category:",
)

# Create dropdown for q_idx
q_idx_dropdown = Dropdown(
    options=get_q_idx_options(category_dropdown.value), description="Q Index:"
)

# Create dropdown for swap_idx
swap_idx_dropdown = Dropdown(
    options=get_swap_idx_options(category_dropdown.value, q_idx_dropdown.value),
    description="Swap Index:",
)

category_dropdown.observe(on_category_change, names="value")
q_idx_dropdown.observe(on_q_idx_change, names="value")

In [5]:
# %%
from functools import partial

# Create interactive output
out = interactive_output(
    partial(
        update_plot,
        tokenizer=tokenizer,
        filtered_data=filtered_data,
        probe_model=trainer.model,
        collate_fn_out=collate_fn_out,
        model=model,
        layer=layer,
        unbiased_fsp_cache=unbiased_fsp_cache,
    ),
    {
        "category": category_dropdown,
        "q_idx": q_idx_dropdown,
        "swap_idx": swap_idx_dropdown,
    },
)

# Display widgets and output
widgets = VBox([category_dropdown, q_idx_dropdown, swap_idx_dropdown])
display(widgets)
display(out)  # Also display the output widget
# %%

VBox(children=(Dropdown(description='Category:', options=(('reasoning (4)', 'reasoning'), ('only_last_three (7…

Output()