In [51]:
import math
import torch
from typing import List, Union
from transformers import PreTrainedTokenizerBase


def single_probe_act_to_color(probe_act: float, scale) -> str:
    def sigmoid(x: float) -> float:
        return 1 / (1 + math.exp(-x / scale))

    scaled_probe_act = sigmoid(probe_act)  # scale to 0-1

    if scaled_probe_act < 0.5:  # red
        red_val = 255
        green_blue_val = min(int(255 * 2 * scaled_probe_act), 255)
        return f"rgb({red_val}, {green_blue_val}, {green_blue_val})"
    else:  # green
        green_val = 255
        red_blue_val = min(int(255 * 2 * (1 - scaled_probe_act)), 255)
        return f"rgb({red_blue_val}, {green_val}, {red_blue_val})"


def visualize_tokens_html(
    token_ids: List[int],
    tokenizer: PreTrainedTokenizerBase,
    token_values: List[Union[float, int]],
    values_scale: float = 300,
) -> str:
    if len(token_ids) != len(token_values):
        raise ValueError(
            "The number of token IDs must match the number of token values."
        )

    token_htmls = []
    for i, (token_id, value) in enumerate(zip(token_ids, token_values)):
        str_token = tokenizer.decode(token_id).replace(" ", "&nbsp;")
        str_token = (
            str_token.replace("<", "&lt;").replace(">", "&gt;").replace("\n", r"\n")
        )
        bg_color = single_probe_act_to_color(float(value), values_scale)

        border_style = "1px solid #888"
        style = {
            "display": "inline-block",
            "border-top": border_style,
            "border-bottom": border_style,
            "border-left": border_style,
            "font-family": "monospace",
            "font-size": "14px",
            "color": "black",
            "background-color": bg_color,
            "margin": "0 0 2px 0",
            "padding": "0",
        }

        # Add right border if it's the last token or a line break
        if i == len(token_ids) - 1 or r"\n" in str_token:
            style["border-right"] = border_style

        style_str = "; ".join(f"{k}: {v}" for k, v in style.items())
        token_html = f"<div style='{style_str}' title='{value:.2f}'>{str_token}</div>"
        token_htmls.append(token_html)
        newlines = str_token.count(r"\n")
        token_htmls.extend(["<br>"] * newlines)

    html_output = f"<div style='line-height: 1.5;'>{''.join(token_htmls)}</div>"

    return html_output

In [53]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import HTML
import random

# model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
text = "###\n\nWhat is going oooonnn???\nIdk."
tokens = tokenizer.encode(text)
display(
    HTML(
        visualize_tokens_html(
            tokens,
            tokenizer,
            [2 * (random.random() - 0.5) for _ in tokens],
            values_scale=1,
        )
    )
)

In [54]:
from cot_probing.activations import Activations, QuestionActivations

import pickle

with open("../hf_results/activations_google--gemma-2-2b_snarks_S0_N151.pkl", "rb") as f:
    activations = pickle.load(f)
print(activations.layers)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]


In [3]:
import pickle

with open("../results/eval_google--gemma-2-2b_snarks_S0_N151.pkl", "rb") as f:
    eval_results = pickle.load(f)

eval_results.questions[0]

EvalQuestion(3281 tokens, locs keys = ['response'], is_correct=False, answer_char=A)

In [55]:
correct_mean_acts = []
incorrect_mean_acts = []
for q_idx, q in enumerate(eval_results.questions):
    # shape: (num_layers, num_locs, d_model)
    acts = activations.activations_by_question[q_idx].activations
    acts_mean = acts.mean(dim=1)
    if q.is_correct:
        correct_mean_acts.append(acts_mean)
    else:
        incorrect_mean_acts.append(acts_mean)
print(f"Correct: {len(correct_mean_acts)}, Incorrect: {len(incorrect_mean_acts)}")
correct_acts_mean = torch.mean(torch.stack(correct_mean_acts), dim=0)
incorrect_acts_mean = torch.mean(torch.stack(incorrect_mean_acts), dim=0)
probe = correct_acts_mean - incorrect_acts_mean
probe.shape

Correct: 7, Incorrect: 66


torch.Size([26, 2304])

In [59]:
from fancy_einsum import einsum


def vis_q(q_idx: int, layers):
    acts = activations.activations_by_question[q_idx].activations
    probe_acts = einsum(
        "layers locs d_model, layers d_model -> layers locs", acts, probe
    )
    print(f"Question {q_idx}")
    first_response_loc = eval_results.questions[q_idx].locs["response"][0]
    context_tokens = eval_results.questions[q_idx].tokens[
        first_response_loc - 100 : first_response_loc
    ]
    display(
        HTML(
            visualize_tokens_html(
                context_tokens, tokenizer, [0.0] * len(context_tokens)
            )
        )
    )
    for layer in layers:
        print(f"Layer {layer}")
        tokens = eval_results.questions[q_idx].tokens[first_response_loc:]
        values = probe_acts[layer].tolist()
        display(HTML(visualize_tokens_html(tokens, tokenizer, values)))


layers = list(range(12, 21))

In [60]:
vis_q(0, layers)

Question 0


Layer 12


Layer 13


Layer 14


Layer 15


Layer 16


Layer 17


Layer 18


Layer 19


Layer 20


In [61]:
vis_q(1, layers)

Question 1


Layer 12


Layer 13


Layer 14


Layer 15


Layer 16


Layer 17


Layer 18


Layer 19


Layer 20


In [62]:
vis_q(2, layers)

Question 2


Layer 12


Layer 13


Layer 14


Layer 15


Layer 16


Layer 17


Layer 18


Layer 19


Layer 20


In [63]:
vis_q(3, layers)

Question 3


Layer 12


Layer 13


Layer 14


Layer 15


Layer 16


Layer 17


Layer 18


Layer 19


Layer 20


In [64]:
vis_q(5, layers)

Question 5


Layer 12


Layer 13


Layer 14


Layer 15


Layer 16


Layer 17


Layer 18


Layer 19


Layer 20
