In [1]:
import math
from typing import List, Union
from transformers import PreTrainedTokenizerBase


def single_loss_diff_to_color(loss_diff: float) -> str:
    def sigmoid(x: float) -> float:
        return 1 / (1 + math.exp(-x))

    scaled_loss_diff = sigmoid(loss_diff)  # scale to 0-1

    if scaled_loss_diff < 0.5:  # red
        red_val = 255
        green_blue_val = min(int(255 * 2 * scaled_loss_diff), 255)
        return f"rgb({red_val}, {green_blue_val}, {green_blue_val})"
    else:  # green
        green_val = 255
        red_blue_val = min(int(255 * 2 * (1 - scaled_loss_diff)), 255)
        return f"rgb({red_blue_val}, {green_val}, {red_blue_val})"


def visualize_tokens_html(
    token_ids: List[int],
    tokenizer: PreTrainedTokenizerBase,
    token_values: List[Union[float, int]],
) -> str:
    if len(token_ids) != len(token_values):
        raise ValueError(
            "The number of token IDs must match the number of token values."
        )

    token_htmls = []
    for token_id, value in zip(token_ids, token_values):
        str_token = tokenizer.decode(token_id).replace(" ", "&nbsp;")
        str_token = str_token.replace("<", "&lt;").replace(">", "&gt;")

        bg_color = single_loss_diff_to_color(float(value))

        token_html = f"""
        <div style="display: inline-block; border: 1px solid #888; font-family: monospace; 
                    font-size: 14px; color: black; background-color: {bg_color}; 
                    margin: 1px 0px 1px 1px; padding: 0px 1px 1px 1px;" 
             title="Value: {value:.2f}">
            {str_token}
        </div>
        """
        token_htmls.append(token_html)

        if str_token == "&nbsp;\n":
            token_htmls.append("<br>")

    html_output = f"""
    <div style="line-height: 1.5;">
        {''.join(token_htmls)}
    </div>
    """

    return html_output

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import HTML

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

display(HTML(visualize_tokens_html([1, 2, 3], tokenizer, [0.1, 0.2, 0.3])))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
import torch

diff = torch.load("../results/diff_direction_google--gemma-2-2b_snarks_S0_N151.pt")
diff.shape

  diff = torch.load("../results/diff_direction_google--gemma-2-2b_snarks_S0_N151.pt")


torch.Size([26, 2304])

In [4]:
import pickle

with open("../results/eval_google--gemma-2-2b_snarks_S0_N151.pkl", "rb") as f:
    eval_results = pickle.load(f)

eval_results.questions[0]

EvalQuestion(tokens=[2, 99211, 948, 576, 1378, 26099, 603, 108135, 235265, 110, 235368, 235292, 12236, 6218, 603, 108135, 235336, 109, 1261, 16358, 235292, 108, 235278, 235280, 235275, 1315, 235303, 235256, 476, 16744, 1552, 235269, 5326, 577, 14650, 476, 24981, 6249, 674, 919, 9244, 15011, 577, 1707, 5982, 575, 1476, 235265, 108, 235278, 235305, 235275, 1315, 235303, 235256, 476, 27562, 1552, 235269, 5326, 577, 14650, 476, 24981, 6249, 674, 919, 9244, 15011, 577, 1707, 5982, 575, 1476, 235265, 235248, 109, 5958, 33720, 889, 1368, 692, 708, 7166, 1105, 573, 3210, 235269, 1492, 2734, 861, 3448, 575, 573, 5920, 664, 651, 1963, 3448, 603, 235292, 591, 235356, 97294, 1165, 235303, 235256, 1508, 2845, 674, 692, 9288, 577, 736, 5920, 235265, 109, 5331, 235303, 235256, 1742, 4065, 731, 4065, 235292, 108, 2495, 783, 1612, 696, 591, 235280, 823, 665, 2889, 674, 573, 1552, 603, 476, 664, 182652, 1552, 824, 90455, 573, 13388, 576, 1212, 603, 1855, 1180, 235265, 1417, 6218, 603, 1644, 575, 573, 48

In [31]:
from cot_probing.activations import clean_run_with_cache_sigle_batch

torch.Size([26, 3281, 2304])


In [44]:
correct = 0
incorrect = 0
for q in eval_results.questions:
    if q.is_correct:
        correct += 1
    else:
        incorrect += 1
print(f"Correct: {correct}, Incorrect: {incorrect}")

Correct: 7, Incorrect: 66


In [41]:
def vis_q(q_idx: int, layers):
    acts = clean_run_with_cache_sigle_batch(
        model,
        torch.tensor([eval_results.questions[q_idx].tokens]),
        list(range(model.config.num_hidden_layers)),
        list(range(len(eval_results.questions[q_idx].tokens))),
    ).cpu()
    probe_acts = einsum(
        "layers locs d_model, layers d_model -> layers locs", acts, diff.cpu()
    )
    for _layer_idx, layer in enumerate(layers):
        print(f"Layer {layer}")
        first_response_loc = eval_results.questions[q_idx].locs["response"][0] - 10 - 80
        display(
            HTML(
                visualize_tokens_html(
                    eval_results.questions[q_idx].tokens[first_response_loc:],
                    tokenizer,
                    probe_acts[layer, first_response_loc:].tolist(),
                )
            )
        )

In [37]:
vis_q(0)

torch.Size([26, 3281, 2304])
Layer 5


Layer 10


Layer 15


Layer 20


Layer 25


In [36]:
vis_q(1)

torch.Size([26, 3293, 2304])
Layer 5


Layer 10


Layer 15


Layer 20


Layer 25


In [38]:
vis_q(2)

torch.Size([26, 3315, 2304])
Layer 5


Layer 10


Layer 15


Layer 20


Layer 25


In [39]:
vis_q(3)

torch.Size([26, 3268, 2304])
Layer 5


Layer 10


Layer 15


Layer 20


Layer 25


In [42]:
vis_q(5, layers=list(range(model.config.num_hidden_layers)))

torch.Size([26, 3290, 2304])
Layer 0


Layer 1


Layer 2


Layer 3


Layer 4


Layer 5


Layer 6


Layer 7


Layer 8


Layer 9


Layer 10


Layer 11


Layer 12


Layer 13


Layer 14


Layer 15


Layer 16


Layer 17


Layer 18


Layer 19


Layer 20


Layer 21


Layer 22


Layer 23


Layer 24


Layer 25
