In [49]:
%load_ext autoreload
%autoreload 2

LOGIT_OR_PROB = "prob"
DIR = "bia_to_unb"
CATEGORIES_FILE = f"categories_{LOGIT_OR_PROB}_{DIR}_0.25_1.5_2.0_4.0.pkl"
SWAPS_FILE = f"swaps_with-unbiased-cots-oct28-1156.pkl"
LB_LAYERS = 3
PATCH_ALL_FILE = "patch_new_res_8B_LB33__swaps_with-unbiased-cots-oct28-1156.pkl"
PATCH_LAYERS_FILE = f"patch_new_res_8B_LB{LB_LAYERS}__swaps_with-unbiased-cots-oct28-1156.pkl"

GROUPS = [
    "Question:",
    "[question]", 
    "?\\n",
    "LTSBS:\\n-",
    "reasoning",
    "last 3"
]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
import pickle
from cot_probing import DATA_DIR
from cot_probing.patching import PatchedLogitsProbs
from cot_probing.typing import *
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer

with open(DATA_DIR / CATEGORIES_FILE, "rb") as f:
    categories = pickle.load(f)

with open(DATA_DIR / SWAPS_FILE, "rb") as f:
    swaps_by_q = pickle.load(f)["qs"]

with open(DATA_DIR / PATCH_ALL_FILE, "rb") as f:
    patch_all_by_q = pickle.load(f)

with open(DATA_DIR / PATCH_LAYERS_FILE, "rb") as f:
    patch_layers_by_q = pickle.load(f)

model_id = "hugging-quants/Meta-Llama-3.1-8B-BNB-NF4-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [17]:
def get_patch_values(
    plp_by_group_by_layers: dict[tuple[int, ...], dict[str, PatchedLogitsProbs]],
    prob_or_logit: Literal["prob", "logit"],
    direction: Literal["bia_to_unb", "unb_to_bia"],
) -> np.ndarray:
    attr = f"{prob_or_logit}_diff_change_{direction}"
    values = []
    for layers, plp_by_group in plp_by_group_by_layers.items():
        values.append([getattr(plp, attr) for plp in plp_by_group.values()])
    if len(values) == 1:
        return values[0]
    return values

In [47]:
def plot_heatmap(combined_values, title):
    v = combined_values
    plt.imshow(
        v,
        cmap="RdBu",
        origin="lower",
        vmin=-max(abs(np.min(v)), abs(np.max(v))),
        vmax=max(abs(np.min(v)), abs(np.max(v))),
    )
    plt.title(title)
    plt.colorbar()
    first_ytick = "all"
    other_yticks = [
        f"{i*LB_LAYERS}-{(i+1)*LB_LAYERS}" for i in range(len(combined_values) - 1)
    ]
    plt.yticks(range(len(combined_values)), [first_ytick] + other_yticks)
    plt.xticks(range(len(GROUPS)), GROUPS, rotation=90)
    plt.ylabel("layers")
    plt.xlabel("token groups")
    plt.axhline(y=0.5, color="black", linewidth=1)
    plt.show()

In [70]:
from ipywidgets import Dropdown, interactive_output, VBox

# Create dropdown for category
category_dropdown = Dropdown(
    options=[(f"{cat} ({len(pairs)})", cat) for cat, pairs in categories.items()],
    description="Category:",
)


# Function to update q_idx options based on category
def get_q_idx_options(category):
    q_idx_counts = {}
    for pair in categories[category]:
        q_idx_counts[pair[0]] = q_idx_counts.get(pair[0], 0) + 1
    return [
        (f"{q_idx} ({count} swaps)", q_idx)
        for q_idx, count in sorted(q_idx_counts.items())
    ]


# Create dropdown for q_idx
q_idx_dropdown = Dropdown(
    options=get_q_idx_options(category_dropdown.value), description="Q Index:"
)


# Function to update swap_idx options based on category and q_idx
def get_swap_idx_options(category, q_idx):
    pairs = categories[category]
    return [pair[1] for pair in pairs if pair[0] == q_idx]


# Create dropdown for swap_idx
swap_idx_dropdown = Dropdown(
    options=get_swap_idx_options(category_dropdown.value, q_idx_dropdown.value),
    description="Swap Index:",
)


# Update q_idx options when category changes
def on_category_change(change):
    q_idx_dropdown.options = get_q_idx_options(change.new)
    if q_idx_dropdown.options:
        q_idx_dropdown.value = q_idx_dropdown.options[0][1]


# Update swap_idx options when q_idx changes
def on_q_idx_change(change):
    swap_idx_dropdown.options = get_swap_idx_options(
        category_dropdown.value, change.new
    )
    if swap_idx_dropdown.options:
        swap_idx_dropdown.value = swap_idx_dropdown.options[0]


category_dropdown.observe(on_category_change, names="value")
q_idx_dropdown.observe(on_q_idx_change, names="value")


# Function to update plot
def update_plot(category, q_idx, swap_idx):
    question_str = swaps_by_q[q_idx]["question"]
    correct_answer_str = swaps_by_q[q_idx]["expected_answer"]
    swap = swaps_by_q[q_idx]["swaps"][swap_idx]
    patch_all = patch_all_by_q[q_idx][swap_idx]
    patch_layers = patch_layers_by_q[q_idx][swap_idx]

    patch_all_values = get_patch_values(patch_all, LOGIT_OR_PROB, DIR)
    patch_layers_values = get_patch_values(patch_layers, LOGIT_OR_PROB, DIR)
    combined_values = [patch_all_values] + patch_layers_values
    trunc_cot_str = tokenizer.decode(swap.trunc_cot)
    print(question_str + trunc_cot_str)
    print()
    fai_tok_str = tokenizer.decode(swap.fai_tok).replace("\n", "\\n")
    unf_tok_str = tokenizer.decode(swap.unfai_tok).replace("\n", "\\n")
    print(f"correct answer: {correct_answer_str.upper()}")
    print(f"faithful_token:   `{fai_tok_str}`")
    print(f"unfaithful_token: `{unf_tok_str}`")
    plot_heatmap(combined_values, f"change in {LOGIT_OR_PROB} diff")


# Create interactive output
out = interactive_output(
    update_plot,
    {
        "category": category_dropdown,
        "q_idx": q_idx_dropdown,
        "swap_idx": swap_idx_dropdown,
    },
)

# Display widgets and output
widgets = VBox([category_dropdown, q_idx_dropdown, swap_idx_dropdown])
display(widgets)
display(out)  # Also display the output widget

VBox(children=(Dropdown(description='Category:', options=(('reasoning (133)', 'reasoning'), ('only_last_three …

Output()