In this notebook, we get the constrastive sets of activations by creating checkpoints at start think and end think tokens:

```md
thinking:
	What is LLM?<think>\n\nOkay the user ask...</think>\n\nLLM is...
			^                               ^
		        Here                            Here
```

In [None]:
%%capture
!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama kaleido numpy==1.26.3

In [None]:
import torch
import functools
import einops
import requests
import pandas as pd
from typing import TypedDict, Literal
import random
import numpy as np
import json

from dataclasses import dataclass
from pathlib import Path
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from tqdm import tqdm
from torch import Tensor
from typing import Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from jaxtyping import Float, Int
from enum import StrEnum
from colorama import Fore

In [None]:
# For visualization
import plotly
import plotly.graph_objects as go
import plotly.express as px

In [None]:
MODEL_PATH = "Qwen/Qwen3-4B"
DEVICE = "cuda"

model = HookedTransformer.from_pretrained_no_processing(
    MODEL_PATH,
    device=DEVICE,
    dtype=torch.float16,
    default_padding_side='left',
    # fp16=True
)

model.tokenizer.padding_side = 'left'
# model.tokenizer.pad_token = '<|extra_0|>'

In [None]:
# Note: The second model is optional, but we can use to to leverage the second GPU on Kaggle Notebook
DEVICE_2 = "cuda:1"

model_2 = HookedTransformer.from_pretrained_no_processing(
    MODEL_PATH,
    device=DEVICE_2,
    dtype=torch.float16,
    default_padding_side='left',
    # fp16=True
)

model_2.tokenizer.padding_side = 'left'
# model.tokenizer.pad_token = '<|extra_0|>'

In [None]:
def tokenize_instructions_qwen_chat(
    tokenizer: AutoTokenizer,
    instructions: list[str],
    enable_thinking: bool = True,
) -> Int[Tensor, "batch_size seq_len"]:
    prompts = []
    for instruction in instructions:
        message = [{"role": "user", "content": instruction}]
        prompt = tokenizer.apply_chat_template(
            message,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=enable_thinking,
        )
        prompts.append(prompt)
    
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids

tokenize_instructions_enable_thinking_fn = functools.partial(
    tokenize_instructions_qwen_chat,
    tokenizer=model.tokenizer,
    enable_thinking=True,
)

tokenize_instructions_disable_thinking_fn = functools.partial(
    tokenize_instructions_qwen_chat,
    tokenizer=model.tokenizer,
    enable_thinking=False,
)

In [None]:
def get_dataset_instructions() -> tuple[list[str], list[str]]:
    url = "https://raw.githubusercontent.com/cvenhoff/steering-thinking-llms/refs/heads/main/messages/messages.py"

    response = requests.get(url)
    # Save to file
    with open("messages.py", "w") as f:
        f.write(response.text)
    
    # Load the messages
    assert Path("messages.py").exists()
    from messages import messages, eval_messages

    train_contents = [msg["content"] for msg in messages]
    eval_contents = [msg["content"] for msg in eval_messages]

    # Shuffle the messages
    # random.shuffle(train_contents)
    # random.shuffle(eval_contents)

    return train_contents, eval_contents

instructions_train, instructions_test = get_dataset_instructions()

In [None]:
def preprocess_instructions(instructions: list[str]) -> list[str]:
    return [ins + " Think fast and briefly." for ins in instructions]

In [None]:
instructions_train = preprocess_instructions(instructions_train)
instructions_test = preprocess_instructions(instructions_test)

print(f"instructions_train: {len(instructions_train)}")
print(f"instructions_test: {len(instructions_test)}")
print("=" * 100)

print("instructions_train:")
for i in range(4):
    print(f"\t{repr(instructions_train[i])}")
print("instructions_test:")
for i in range(4):
    print(f"\t{repr(instructions_test[i])}")

## Compute directions

In [None]:
# Constants
TRAIN_BATCH_SIZE =16
MAX_NEW_TOKENS = -1
UPPERBOUND_MAX_NEW_TOKENS = 5000

start_thinking_token_id = 151667
end_thinking_token_id = 151668 # Qwen3's end of thinking token id
n_layers = model.cfg.n_layers
n_activations = 2
d_model = model.cfg.d_model

In [None]:
instructions = instructions_train[:TRAIN_BATCH_SIZE]

In [None]:
from typing import Set

def generate_with_token_specific_activations(
    model: HookedTransformer,
    tokens: torch.Tensor,
    target_tokens: list[str] = ["<think>", "</think>"],
    max_new_tokens: int | None = None,
    activation_types: list[str] = ["resid_mid", "resid_post"],
    layers: list[int] | None = None,
    offload_to_cpu: bool = True,
) -> tuple[dict[str, dict[tuple[str, int], torch.Tensor]], torch.Tensor]:
    """
    Generate text and only extract activations when specific tokens are generated.
    
    Returns:
        activations: {token_str: {(act_type, layer): tensor}}
        final_tokens: generated sequence
    """
    
    if layers is None:
        layers = list(range(model.cfg.n_layers))

    if max_new_tokens is None:
        max_new_tokens = UPPERBOUND_MAX_NEW_TOKENS
    
    # Convert target tokens to token IDs
    target_token_ids = set()
    target_token_to_str = {}  # Map token_id back to string for storage
    
    for token_str in target_tokens:
        # Handle both single tokens and multi-token strings
        token_ids = model.tokenizer(token_str, return_tensors="pt").input_ids.squeeze()
        if token_ids.dim() == 0:
            token_ids = [token_ids.item()]
        else:
            token_ids = token_ids.tolist()
        
        # For simplicity, assume target tokens are single tokens
        # You might need to handle multi-token strings differently
        if len(token_ids) == 1:
            target_token_ids.add(token_ids[0])
            target_token_to_str[token_ids[0]] = token_str
        else:
            print(f"Warning: {token_str} tokenizes to multiple tokens: {token_ids}")
            # Handle multi-token case - take the last token as trigger
            target_token_ids.add(token_ids[-1])
            target_token_to_str[token_ids[-1]] = token_str
    
    print(f"Monitoring token IDs: {target_token_ids}")
    print(f"Token mapping: {target_token_to_str}")

    is_monitored: dict[int, bool] = {token_id: False for token_id in target_token_ids}
    
    # Storage for activations at specific tokens
    token_activations = {token_str: {} for token_str in target_tokens}
    
    batch_size = tokens.shape[0]
    d_model = model.cfg.d_model
    device = tokens.device
    
    # Initialize storage for each target token
    for token_str in target_tokens:
        token_activations[token_str] = {}
    
    for step in tqdm(range(max_new_tokens), desc="Generating"):
        model.reset_hooks()
        with torch.no_grad():
            # Regular forward pass (no caching) - much faster
            logits = model(tokens)
            next_token = logits[:, -1, :].argmax(dim=-1)
            
            # Check if the generated token is one we're monitoring
            if next_token.item() in target_token_ids and not is_monitored[next_token.item()]:
                is_monitored[next_token.item()] = True
                token_str = target_token_to_str[next_token.item()]
                print(f"Step {step}: Found target token '{token_str}' (ID: {next_token.item()})")
                
                # NOW we extract activations with the full sequence including this token
                tokens_with_new = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
                
                # Extract activations using hooks
                step_activations = {}
                
                def make_hook(act_type: str, layer: int):
                    def hook_fn(activation, hook):
                        # Get activation at the position of our target token (last position)
                        nonlocal step_activations

                        target_activation = activation[:, -1, :].clone()
                        if offload_to_cpu:
                            target_activation = target_activation.cpu()
                        step_activations[(act_type, layer)] = target_activation
                    return hook_fn
                
                hooks = []
                for layer in layers:
                    for act_type in activation_types:
                        hook_name = f"blocks.{layer}.hook_{act_type}"
                        if hook_name not in model.hook_dict:
                            raise ValueError(f"{hook_name} not found in hook dict")
                        hook = model.add_hook(hook_name, make_hook(act_type, layer))
                        hooks.append(hook)
                
                # Forward pass with hooks to get activations
                with torch.no_grad():
                    _ = model(tokens_with_new)
                
                # Remove hooks
                # print(f"Current hooks: {hooks}")
                # for hook in hooks:
                #     hook.remove()
                
                # Store the activations
                for (act_type, layer), activation in step_activations.items():
                    token_activations[token_str][(act_type, layer)] = activation
                
                # Update tokens
                tokens = tokens_with_new
                
                # Clean up
                del step_activations
            else:
                # Regular token - just append without activation extraction
                tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
            
            # Clean up
            del logits
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        if all(is_monitored.values()):
            break
    
    return token_activations, tokens



In [None]:
batch_positive_caches = []
batch_negative_caches = []

for i, instruction in enumerate(instructions):
    print(f"Processing instruction {i}:")
    tokens = tokenize_instructions_enable_thinking_fn(instructions=[instruction])
    print(f"Number of tokens: {tokens.shape[1]}")
    token_activations, tokens = generate_with_token_specific_activations(
        model,
        tokens.to(DEVICE),
        max_new_tokens=None,
        activation_types=["resid_mid", "resid_post"],
        target_tokens=["<think>", "</think>"],
    )
    batch_positive_caches.append(token_activations["<think>"])
    batch_negative_caches.append(token_activations["</think>"])


In [None]:
assert len(batch_positive_caches) == len(batch_negative_caches) == TRAIN_BATCH_SIZE

print(batch_positive_caches[0]["resid_mid", 0].shape) # (1, d_model)
print(batch_negative_caches[0]["resid_post", 0].shape) # (1, d_model)

In [None]:
# Save activations for backup

def save_tensor(
    tensor: Tensor,
    save_dir: Path,
    name: str,
) -> None:
    save_dir.mkdir(exist_ok=True)
    save_path = save_dir / f"{name}.pt"
    torch.save(tensor, save_path)
    print(f"Saved {name} to {save_path}")

backup_dir = Path("outputs/")
backup_dir.mkdir(exist_ok=True, parents=True)
save_tensor(batch_positive_caches, backup_dir, "batch_positive_caches_HookedTransformer")
save_tensor(batch_negative_caches, backup_dir, "batch_negative_caches_HookedTransformer")

### Compute positive activations (at `<think>`) and negative activations (at `</think>`)

In [None]:
# Mean activations across batch dimension
positive_mean_activations = torch.zeros(
    n_layers, n_activations, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, d_model)

# All activations within the batch
positive_activations_all = torch.zeros(
    n_layers, n_activations, TRAIN_BATCH_SIZE, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, batch_size, d_model)


for layer in range(n_layers):
    batch_positive_activations = []
    for act in ["resid_mid", "resid_post"]:
        # Get mean activations across tokens dimension
        batch_positive_activations_per_act = torch.stack([
            sample[act, layer][0, :] for sample in batch_positive_caches
        ]) # (batch_size, d_model)
        batch_positive_activations.append(batch_positive_activations_per_act)
    batch_positive_activations = torch.stack(batch_positive_activations) # (n_activations, batch_size, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_positive_activations = batch_positive_activations / batch_positive_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    positive_mean_activations[layer] = batch_positive_activations.mean(dim=1)
    positive_activations_all[layer] = batch_positive_activations

print(positive_mean_activations.shape) # (n_layers, n_activations, num_last_tokens, d_model)
print(positive_activations_all.shape) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

In [None]:
negative_mean_activations = torch.zeros(
    n_layers, n_activations, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, d_model)

negative_activations_all = torch.zeros(
    n_layers, n_activations, TRAIN_BATCH_SIZE, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, batch_size, d_model)

for layer in range(n_layers):
    batch_negative_activations = []
    for act in ["resid_mid", "resid_post"]:
        # Get mean activations across tokens dimension
        batch_negative_activations_per_act = torch.stack([
            sample[act, layer][0, :] for sample in batch_negative_caches
        ]) # (batch_size, d_model)
        batch_negative_activations.append(batch_negative_activations_per_act)
    batch_negative_activations = torch.stack(batch_negative_activations) # (n_activations, batch_size, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_negative_activations = batch_negative_activations / batch_negative_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    negative_mean_activations[layer] = batch_negative_activations.mean(dim=1)
    negative_activations_all[layer] = batch_negative_activations

print(negative_mean_activations.shape) # (n_layers, n_activations, d_model)
print(negative_activations_all.shape) # (n_layers, n_activations, batch_size, d_model)

In [None]:
positive_mean_activation_normed = positive_mean_activations / positive_mean_activations.norm(dim=-1, keepdim=True)
negative_mean_activation_normed = negative_mean_activations / negative_mean_activations.norm(dim=-1, keepdim=True)

print(positive_mean_activation_normed.shape) # (n_layers, n_activations, d_model)
print(negative_mean_activation_normed.shape) # (n_layers, n_activations, d_model)

In [None]:
candidate_refusal_vectors = positive_mean_activation_normed.to("cpu") - negative_mean_activation_normed.to("cpu")
candidate_refusal_vectors_normed = candidate_refusal_vectors / candidate_refusal_vectors.norm(dim=-1, keepdim=True)

print(candidate_refusal_vectors_normed.shape) # (n_layers, n_activations, d_model)

## Visualizations

### Scalar projections of activations onto the local refusal direction at each extraction point

In [1]:
colour_map = {
    "positive": plotly.colors.qualitative.Plotly[0],
    "negative": plotly.colors.qualitative.Plotly[1],
    "neutral": plotly.colors.qualitative.Pastel1[3],

}
colour_map_light = {
    "positive": plotly.colors.qualitative.Pastel1[1],
    "negative": plotly.colors.qualitative.Pastel1[0],
    "neutral": plotly.colors.qualitative.Pastel1[3],
}
colour_map_opaque = {
    "positive": "rgba(251, 180, 174, 0.3)",
    "negative": "rgba(179, 205, 227, 0.3)",   
}


NameError: name 'plotly' is not defined

In [None]:
category2acts_normed = {
    "positive": positive_activations_all.cpu(),
    "negative": negative_activations_all.cpu(),
} # (n_layers, n_activations, batch_size, num_last_tokens, d_model)


x_values = [str(i) for i in range(2 * n_layers)]

fig = go.Figure()

for category in ["positive", "negative"]:
    acts_normed = category2acts_normed[category] # (n_layers, n_activations, batch_size, d_model)
    projections = einops.einsum(
        candidate_refusal_vectors_normed,
        acts_normed,
        "layer act dim, layer act batch dim -> layer act batch",
    )
    projections = torch.tensor(projections)

    mean_projection = projections.mean(dim=-1) # layer act batch -> layer act

    y_values = mean_projection.flatten() # layer act -> layer * act


    # mean
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=True,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map_light[category], size=3),
            showlegend=False,
        )
    )

    # variance
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=projections.reshape(-1, projections.shape[-1]),
            name=category,
            mode="lines+markers",
            yaxis="y",
            fillcolor=colour_map_opaque[category],
            showlegend=False,
        )
    )

    # dot markers
    fig.add_trace(
        go.Scatter(
            x=x_values[1::],
            y=y_values[1::],
            name=f"{category}",
            mode="markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=False,
        )
    )

fig.update_layout(
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=20)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=18),
    ),
    yaxis=dict(
        title=dict(text="Scalar Projection", font=dict(size=20)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=18),
    ),
    hovermode="x unified",
    height=250,
    width=600,
    margin=dict(l=0, r=0, t=0, b=0),
    legend=dict(x=0.05, y=0.95, font=dict(size=18)),
)
fig.show()

fig.write_image("prj_onto_local_candidate_refusal_vectors.pdf", scale=5)

### Mean cosine of refusal directions at each layer with at other layers


In [None]:
refusal_dirs = candidate_refusal_vectors

In [None]:
layer_names = sum([[f"{i}", f"{i}-post"] for i in range(n_layers)], [])
layer_names = [str(i) for i in range(2 * n_layers)]


refusal_dirs_flatten = refusal_dirs.reshape((-1, refusal_dirs.shape[-1]))

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=layer_names,
        y=refusal_dirs_flatten.norm(dim=-1),
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        marker_size=8,
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=layer_names[::],
        y=refusal_dirs_flatten.norm(dim=-1)[::],
        mode="markers",
        yaxis="y",
        marker_color=colour_map["neutral"],
        marker_size=8,
        showlegend=False,
    )
)

print(layer_names[np.argmax(refusal_dirs_flatten.norm(dim=-1)[:-1])])


fig.update_layout(
    # title=(
    #     "Statistics of refusal direction candidates at each layer"
    #     f" layer for {MODEL_PATH}"
    # ),
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=28)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=24),
    ),
    yaxis=dict(
        title=dict(text="Norm of<br>Refusal Direction", font=dict(size=28)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=24),
    ),
    hovermode="x unified",
    height=300,
    width=1000,
    # width=20 + 12 * len(x_values),
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.show()

# fig.write_image(VISUALIZATION_DIR / "norm_refusal.pdf", scale=5)


flatten_dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])
pairwise_cosine = flatten_dirs @ flatten_dirs.T
# pairwise_cosine = np.arccos(pairwise_cosine)
mean_cosine = np.nanmean(pairwise_cosine, axis=-1)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=layer_names,
        y=mean_cosine,
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        showlegend=False,
        marker_size=8,
    )
)
fig.add_trace(
    go.Scatter(
        x=layer_names[::],
        y=mean_cosine[::],
        mode="markers",
        yaxis="y",
        marker_color=colour_map["neutral"],
        showlegend=False,
        marker_size=8,
    )
)

# fig.add_trace(
#     go.Scatter(
#         x=layer_names,
#         y=raw_dirs.norm(dim=-1) + mean_cosine / mean_cosine.max(),
#         mode="lines+markers",
#         yaxis="y3",
#         marker_color=colour_map_light["neutral"],
#         showlegend=False
#     )
# )

fig.update_layout(
    # title=(
    #     "Statistics of refusal direction candidates at each extraction point"
    #     f" for {MODEL_PATH}"
    # ),
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=28)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=24),
    ),
    yaxis=dict(
        title=dict(text=f"Mean<br>Cosine Score", font=dict(size=28)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=24),
    ),
    hovermode="x unified",
    height=300,
    width=1000,
    # width=20 + 12 * len(x_values),
    margin=dict(l=20, r=20, t=20, b=20),
)

fig.show()
layer_names[np.nanargmax(mean_cosine)]


# fig.write_image(VISUALIZATION_DIR / "mean_cosine.pdf", scale=5)

### Projection of activations by layers onto the final steering vectors

In [None]:
argmax = np.nanargmax(mean_cosine)
max_mean_cosine_layer = argmax // 2
max_mean_cosine_act_idx = argmax % 2

chosen_layer = max_mean_cosine_layer
chosen_act_idx = max_mean_cosine_act_idx
chosen_token = -1

In [None]:
def variance_plot(**kwargs):
    x = kwargs.pop("x")
    y = kwargs.pop("y")
    y_mean = y.mean(dim=-1)
    y_std = y.std(dim=-1)
    y_upper = y_mean + y_std
    y_lower = y_mean - y_std
    y_upper = y_upper.tolist()
    y_lower = y_lower.tolist()
    # colour = kwargs.pop("color")

    trace = go.Scatter(
        x=x + x[::-1],
        y=y_upper + y_lower[::-1],
        mode="lines",
        fill="toself",
        line=dict(color=kwargs["fillcolor"], width=0),
        **kwargs
    )

    return trace

In [None]:
fig = go.Figure()

category2acts_normed = {
    "positive": positive_activations_all.cpu(),
    "negative": negative_activations_all.cpu(),
} # (n_layers, n_activations, batch_size, num_last_tokens, d_model)



for category in ["thinking", "no_thinking"]:
    acts_normed = category2acts_normed[category]

    # layers x resid_modules x batch_size x dim
    activations = acts_normed[..., chosen_token, :].clone()

    # dim
    direction = refusal_dirs[chosen_layer, chosen_act_idx].clone() #(d_model)

    # layers x resid_modules x batch_size
    scalar_projections = einops.einsum(
        activations,
        direction,
        "... batch_size dim, ... dim -> ... batch_size",
    )
    scalar_projections = np.nan_to_num(scalar_projections)
    print(category)
    print(scalar_projections.mean())
    degrees = np.rad2deg(np.arccos(scalar_projections))

    y_values = scalar_projections

    batch_size = scalar_projections.shape[-1]

    # x_values_flatten = sum(
    #     [
    #         [f"{l}-mid"] * batch_size + [f"{l}-post"] * batch_size
    #         for l in range(num_layers)
    #     ],
    #     [],
    # )
    x_values = sum([[f"{l}", f"{l}-post"] for l in range(n_layers)], [])
    x_values = [str(i) for i in range(2 * n_layers)]

    # variance
    fig.add_trace(
        variance_plot(
            x=x_values,
            y=torch.tensor(y_values).reshape(-1, degrees.shape[-1]),
            yaxis="y",
            fillcolor=colour_map_opaque[category],
            showlegend=False,
        )
    )

    # mean
    ## for legend
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values.mean(axis=-1).flatten(),
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=True,
            name=category,
        )
    )
    ## for lines
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values.mean(axis=-1).flatten(),
            mode="lines",
            yaxis="y",
            marker=dict(color=colour_map_light[category], size=3),
            showlegend=False,
            name=category,
        )
    )
    ## for markers
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values.mean(axis=-1).flatten(),
            mode="markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=False,
            name=category,
        )
    )

    activations -= 2 * einops.einsum(
        np.maximum(scalar_projections, 0),
        direction,
        "layer resid_module batch_size, dim -> layer resid_module batch_size dim",
    )
    scalar_projections = einops.einsum(
        activations,
        direction,
        "... batch_size dim, ... dim -> ... batch_size",
    )
    print(category)
    print(scalar_projections.mean())
    degrees = np.rad2deg(np.arccos(scalar_projections))

    y_values = scalar_projections


module_names = ["mid", "post"]
fig.update_layout(
    grid=dict(rows=1, columns=1),
    # yaxis=dict(tickformat=".2E"),
    plot_bgcolor="white",
    xaxis=dict(
        type="category",
        dtick=4,
        title=dict(text="Extraction Point", font=dict(size=20)),
        gridcolor="lightgrey",
        tickfont=dict(size=18),
    ),
    yaxis=dict(
        title=dict(text="Scalar Projections", font=dict(size=20)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=18),
    ),
    hovermode="x unified",
    height=250,
    width=600,
    # title=(
    #     "Scalar projections of activations at each layer onto the chosen refusal direction"
    #     f" ({chosen_layer}-{module_names[chosen_act_idx]})"
    # ),
    # yaxis=dict(matches=None),
    margin=dict(l=20, r=20, t=20, b=20),
    legend=dict(x=0.05, y=0.95, font=dict(size=18)),
)
fig.show()