```md
positive - thinking:
	Give me a short introduction to large language model<think>...</think>\n\n
								    ^
							    select token 80th here
negative - no-thinking: 
	Give me a short introduction to large language model<think>\n\n</think>\n\n
										^
										here
```

## Setup

In [None]:
%%capture
!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama kaleido numpy==1.26.3

In [None]:
import torch
import functools
import einops
import requests
import pandas as pd
import io
import textwrap
import random
import numpy as np
import json

from dataclasses import dataclass
from pathlib import Path
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from tqdm import tqdm
from torch import Tensor
from typing import Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from jaxtyping import Float, Int
from enum import StrEnum
from colorama import Fore

In [None]:
# For visualization
import plotly
import plotly.graph_objects as go
import plotly.express as px

In [None]:
MODEL_PATH = "Qwen/Qwen3-4B"
DEVICE = "cuda"

model = HookedTransformer.from_pretrained_no_processing(
    MODEL_PATH,
    device=DEVICE,
    dtype=torch.float16,
    default_padding_side='left',
    # fp16=True
)

model.tokenizer.padding_side = 'left'
# model.tokenizer.pad_token = '<|extra_0|>'

In [None]:
def tokenize_instructions_qwen_chat(
    tokenizer: AutoTokenizer,
    instructions: list[str],
    enable_thinking: bool = True,
) -> Int[Tensor, "batch_size seq_len"]:
    prompts = []
    for instruction in instructions:
        message = [{"role": "user", "content": instruction}]
        prompt = tokenizer.apply_chat_template(
            message,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=enable_thinking,
        )
        prompts.append(prompt)
    
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids

tokenize_instructions_enable_thinking_fn = functools.partial(
    tokenize_instructions_qwen_chat,
    tokenizer=model.tokenizer,
    enable_thinking=True,
)

tokenize_instructions_disable_thinking_fn = functools.partial(
    tokenize_instructions_qwen_chat,
    tokenizer=model.tokenizer,
    enable_thinking=False,
)

In [None]:
def get_dataset_instructions() -> tuple[list[str], list[str]]:
    url = "https://raw.githubusercontent.com/cvenhoff/steering-thinking-llms/refs/heads/main/messages/messages.py"

    response = requests.get(url)
    # Save to file
    with open("messages.py", "w") as f:
        f.write(response.text)
    
    # Load the messages
    assert Path("messages.py").exists()
    from messages import messages, eval_messages

    train_contents = [msg["content"] for msg in messages]
    eval_contents = [msg["content"] for msg in eval_messages]

    # Shuffle the messages
    random.shuffle(train_contents)
    random.shuffle(eval_contents)

    return train_contents, eval_contents

instructions_train, instructions_test = get_dataset_instructions()

In [None]:
print(f"instructions_train: {len(instructions_train)}")
print(f"instructions_test: {len(instructions_test)}")
print("=" * 100)

print("instructions_train:")
for i in range(4):
    print(f"\t{repr(instructions_train[i])}")
print("instructions_test:")
for i in range(4):
    print(f"\t{repr(instructions_test[i])}")

## Compute activations & refusal dirs

In [None]:
class ActivationName(StrEnum):
    RESID_PRE = "resid_pre"
    RESID_MID = "resid_mid"
    RESID_POST = "resid_post"


required_activations = [
    ActivationName.RESID_MID,
    ActivationName.RESID_POST,
]

In [None]:
TRAIN_BATCH_SIZE = 32
NUM_BATCHES = 10
MAX_NEW_TOKENS = 80

start_thinking_token_id = 151667
end_thinking_token_id = 151668 # Qwen3's end of thinking token id
n_layers = model.cfg.n_layers
n_activations = 2
d_model = model.cfg.d_model

In [None]:
def get_activations_generation_step(
    model: "HookedTransformer",
    tokens: Tensor,
    max_new_tokens: int = 10,
    activation_types: list[str] = ["resid_mid", "resid_post"],
    offload_to_cpu: bool = True,
    early_stop_at_end_thinking_token: bool = True,
) -> tuple[dict[tuple[str, int], Tensor], Tensor]:
    """Get activations during generation for the last token position at each step.
    """
    batch_size = tokens.shape[0]
    d_model = model.cfg.d_model
    n_layers = model.cfg.n_layers
    model_dtype = next(model.parameters()).dtype
    
    # Pre-allocate tensors for efficiency
    all_step_caches: dict[tuple[str, int], Tensor] = {}
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    for act_type in activation_types:
        for layer in range(n_layers):
            all_step_caches[(act_type, layer)] = torch.zeros(
                batch_size, max_new_tokens, d_model, 
                device="cpu" if offload_to_cpu else tokens.device,
                dtype=model_dtype
            )
    
    for step in tqdm(range(max_new_tokens), desc="Generating"):
        # Get activations for current sequence
        logits, cache = model.run_with_cache(tokens)
        
        # Store last token activations for this step
        for layer in range(n_layers):
            for act_type in activation_types:
                cache_key = (act_type, layer)
                last_token_activation = cache[act_type, layer][:, -1, :].to(model_dtype) # (batch_size, d_model)
                if offload_to_cpu:
                    last_token_activation = last_token_activation.cpu()
                all_step_caches[cache_key][:, step, :] = last_token_activation

        # Generate next token(s), currently only support greedy generation
        if batch_size == 1:
            # Single sequence generation
            next_token = logits[0, -1].argmax()
            if early_stop_at_end_thinking_token:
                if next_token == end_thinking_token_id:
                    break
            
            tokens = torch.cat([tokens, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
        else:
            # Multi-batch generation
            next_tokens = logits.argmax(dim=-1)[:, -1]  # [batch_size]
            if early_stop_at_end_thinking_token:
                if torch.all(next_tokens == end_thinking_token_id):
                    break

            tokens = torch.cat([tokens, next_tokens.unsqueeze(1)], dim=1)
        

    return all_step_caches, tokens

In [None]:
@dataclass
class Sample:
    instruction: str
    is_thinking: bool
    caches: Tensor
    tokens: Tensor


In [None]:
instructions = instructions_train[:TRAIN_BATCH_SIZE]

In [None]:
batch_thinking = []
for i, instruction in enumerate(instructions):
    print(f"Processing instruction {i}:")
    tokens = tokenize_instructions_enable_thinking_fn(instructions=[instruction])
    print(f"Number of tokens: {tokens.shape[1]}")
    caches, tokens = get_activations_generation_step(
        model,
        tokens.to(DEVICE),
        max_new_tokens=MAX_NEW_TOKENS,
        activation_types=["resid_mid", "resid_post"]
    )

    batch_thinking.append(
        Sample(
            caches=caches,
            tokens=tokens,
            instruction=instruction,
            is_thinking=True,
        )
    )

In [None]:
assert len(batch_thinking) == TRAIN_BATCH_SIZE
print(model.tokenizer.decode(batch_thinking[0].tokens[0]))
print(batch_thinking[0].caches["resid_mid", 0].shape) # (max_new_tokens, d_model)
print(batch_thinking[0].caches["resid_post", 0].shape) # (max_new_tokens, d_model)

In [None]:
# No-thinking caches
# In this case, max_new_tokens = 4

batch_no_thinking = []
for instruction in instructions:
    no_thinking_tokens = tokenize_instructions_disable_thinking_fn(instructions=[instruction])

    _, no_thinking_cache = model.run_with_cache(
        no_thinking_tokens,
        names_filter=lambda hook_name: 'resid' in hook_name,
        return_cache_object=True
    )

    # Only get the caches at reasoning tokens (before </think>) -> '<think>' and '/n/n'
    sub_cache = {}
    for activation_type in ["resid_mid", "resid_post"]:
        for layer in range(model.cfg.n_layers):
            sub_cache[activation_type, layer] = no_thinking_cache[activation_type, layer][:,-4:,:] 
    
    batch_no_thinking.append(Sample(
        caches=sub_cache,
        tokens=no_thinking_tokens,
        instruction=instruction,
        is_thinking=False,
    ))

In [None]:
assert len(batch_no_thinking) == TRAIN_BATCH_SIZE
print(batch_no_thinking[0].caches["resid_mid", 0].shape) # (num_last_tokens, d_model)
print(batch_no_thinking[0].caches["resid_post", 0].shape) # (num_last_tokens, d_model)

In [None]:
chosen_token_pos = -1
num_last_tokens = abs(chosen_token_pos)

In [None]:
thinking_mean_activations = torch.zeros(
    n_layers, n_activations, num_last_tokens, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, num_last_tokens, d_model)

thinking_activations_all = torch.zeros(
    n_layers, n_activations, TRAIN_BATCH_SIZE, num_last_tokens, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

for layer in range(n_layers):
    batch_thinking_activations = []
    for act in ["resid_mid", "resid_post"]:
        # Get mean activations across tokens dimension
        batch_thinking_activations_per_act = torch.stack([
            sample.caches[act, layer][0, -num_last_tokens:, :] for sample in batch_thinking
        ]) # (batch_size, num_last_tokens, d_model)
        batch_thinking_activations.append(batch_thinking_activations_per_act)
    batch_thinking_activations = torch.stack(batch_thinking_activations) # (n_activations, batch_size, num_last_tokens, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_thinking_activations = batch_thinking_activations / batch_thinking_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    thinking_mean_activations[layer] = batch_thinking_activations.mean(dim=1)
    thinking_activations_all[layer] = batch_thinking_activations

print(thinking_mean_activations.shape) # (n_layers, n_activations, num_last_tokens, d_model)
print(thinking_activations_all.shape) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

In [None]:
no_thinking_mean_activations = torch.zeros(
    n_layers, n_activations, num_last_tokens, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, num_last_tokens, d_model)

no_thinking_activations_all = torch.zeros(
    n_layers, n_activations, TRAIN_BATCH_SIZE, num_last_tokens, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

for layer in range(n_layers):
    batch_no_thinking_activations = []
    for act in ["resid_mid", "resid_post"]:
        # Get mean activations across tokens dimension
        batch_no_thinking_activations_per_act = torch.stack([
            sample.caches[act, layer][0, -num_last_tokens:, :] for sample in batch_no_thinking
        ]) # (batch_size, num_last_tokens, d_model)
        batch_no_thinking_activations.append(batch_no_thinking_activations_per_act)
    batch_no_thinking_activations = torch.stack(batch_no_thinking_activations) # (n_activations, batch_size, num_last_tokens, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_no_thinking_activations = batch_no_thinking_activations / batch_no_thinking_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    no_thinking_mean_activations[layer] = batch_no_thinking_activations.mean(dim=1)
    no_thinking_activations_all[layer] = batch_no_thinking_activations

print(no_thinking_mean_activations.shape) # (n_layers, n_activations, num_last_tokens, d_model)
print(no_thinking_activations_all.shape) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

In [None]:
thinking_mean_activations_at_last_token = thinking_mean_activations[:, :, -1, :] # (n_layers, n_activations, d_model)
thinking_mean_activations_normed = thinking_mean_activations_at_last_token / thinking_mean_activations_at_last_token.norm(dim=-1, keepdim=True)

no_thinking_mean_activations_at_last_token = no_thinking_mean_activations[:, :, -1, :] # (n_layers, n_activations, d_model)
no_thinking_mean_activations_normed = no_thinking_mean_activations_at_last_token / no_thinking_mean_activations_at_last_token.norm(dim=-1, keepdim=True)

print(thinking_mean_activations_normed.shape) # (n_layers, n_activations, d_model)
print(no_thinking_mean_activations_normed.shape) # (n_layers, n_activations, d_model)

In [None]:
candidate_refusal_vectors = thinking_mean_activations_normed.to("cpu") - no_thinking_mean_activations_normed.to("cpu")
candidate_refusal_vectors_normed = candidate_refusal_vectors / candidate_refusal_vectors.norm(dim=-1, keepdim=True)

print(candidate_refusal_vectors_normed.shape) # (n_layers, n_activations, d_model)

## Visualizations

### Scalar projections of activations onto the local refusal direction at each extraction point

In [None]:
colour_map = {
    "thinking": plotly.colors.qualitative.Plotly[0],
    "no_thinking": plotly.colors.qualitative.Plotly[1],
    "neutral": plotly.colors.qualitative.Pastel1[3],

}
colour_map_light = {
    "thinking": plotly.colors.qualitative.Pastel1[1],
    "no_thinking": plotly.colors.qualitative.Pastel1[0],
    "neutral": plotly.colors.qualitative.Pastel1[3],
}
colour_map_opaque = {
    "thinking": "rgba(251, 180, 174, 0.3)",
    "no_thinking": "rgba(179, 205, 227, 0.3)",   
}


In [None]:
category2acts_normed = {
    "thinking": thinking_activations_all.cpu(),
    "no_thinking": no_thinking_activations_all.cpu(),
} # (n_layers, n_activations, batch_size, num_last_tokens, d_model)


x_values = [str(i) for i in range(2 * n_layers)]

fig = go.Figure()

for category in ["thinking", "no_thinking"]:
    acts_normed = category2acts_normed[category][:, :, :, chosen_token_pos] # (n_layers, n_activations, batch_size, d_model)
    projections = einops.einsum(
        candidate_refusal_vectors_normed,
        acts_normed,
        "layer act dim, layer act batch dim -> layer act batch",
    )
    projections = torch.tensor(projections)

    mean_projection = projections.mean(dim=-1) # layer act batch -> layer act

    y_values = mean_projection.flatten() # layer act -> layer * act


    # mean
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=True,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map_light[category], size=3),
            showlegend=False,
        )
    )

    # variance
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=projections.reshape(-1, projections.shape[-1]),
            name=category,
            mode="lines+markers",
            yaxis="y",
            fillcolor=colour_map_opaque[category],
            showlegend=False,
        )
    )

    # dot markers
    fig.add_trace(
        go.Scatter(
            x=x_values[1::],
            y=y_values[1::],
            name=f"{category}",
            mode="markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=False,
        )
    )

fig.update_layout(
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=20)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=18),
    ),
    yaxis=dict(
        title=dict(text="Scalar Projection", font=dict(size=20)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=18),
    ),
    hovermode="x unified",
    height=250,
    width=600,
    margin=dict(l=0, r=0, t=0, b=0),
    legend=dict(x=0.05, y=0.95, font=dict(size=18)),
)
fig.show()

fig.write_image("prj_onto_local_candidate_refusal_vectors.pdf", scale=5)

### Mean cosine of refusal directions at each layer with at other layers


In [None]:
refusal_dirs = candidate_refusal_vectors

In [None]:
layer_names = sum([[f"{i}", f"{i}-post"] for i in range(n_layers)], [])
layer_names = [str(i) for i in range(2 * n_layers)]


refusal_dirs_flatten = refusal_dirs.reshape((-1, refusal_dirs.shape[-1]))

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=layer_names,
        y=refusal_dirs_flatten.norm(dim=-1),
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        marker_size=8,
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=layer_names[::],
        y=refusal_dirs_flatten.norm(dim=-1)[::],
        mode="markers",
        yaxis="y",
        marker_color=colour_map["neutral"],
        marker_size=8,
        showlegend=False,
    )
)

print(layer_names[np.argmax(refusal_dirs_flatten.norm(dim=-1)[:-1])])


fig.update_layout(
    # title=(
    #     "Statistics of refusal direction candidates at each layer"
    #     f" layer for {MODEL_PATH}"
    # ),
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=28)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=24),
    ),
    yaxis=dict(
        title=dict(text="Norm of<br>Refusal Direction", font=dict(size=28)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=24),
    ),
    hovermode="x unified",
    height=300,
    width=1000,
    # width=20 + 12 * len(x_values),
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.show()

# fig.write_image(VISUALIZATION_DIR / "norm_refusal.pdf", scale=5)


flatten_dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])
pairwise_cosine = flatten_dirs @ flatten_dirs.T
# pairwise_cosine = np.arccos(pairwise_cosine)
mean_cosine = np.nanmean(pairwise_cosine, axis=-1)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=layer_names,
        y=mean_cosine,
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        showlegend=False,
        marker_size=8,
    )
)
fig.add_trace(
    go.Scatter(
        x=layer_names[::],
        y=mean_cosine[::],
        mode="markers",
        yaxis="y",
        marker_color=colour_map["neutral"],
        showlegend=False,
        marker_size=8,
    )
)

# fig.add_trace(
#     go.Scatter(
#         x=layer_names,
#         y=raw_dirs.norm(dim=-1) + mean_cosine / mean_cosine.max(),
#         mode="lines+markers",
#         yaxis="y3",
#         marker_color=colour_map_light["neutral"],
#         showlegend=False
#     )
# )

fig.update_layout(
    # title=(
    #     "Statistics of refusal direction candidates at each extraction point"
    #     f" for {MODEL_PATH}"
    # ),
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=28)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=24),
    ),
    yaxis=dict(
        title=dict(text=f"Mean<br>Cosine Score", font=dict(size=28)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=24),
    ),
    hovermode="x unified",
    height=300,
    width=1000,
    # width=20 + 12 * len(x_values),
    margin=dict(l=20, r=20, t=20, b=20),
)

fig.show()
layer_names[np.nanargmax(mean_cosine)]


# fig.write_image(VISUALIZATION_DIR / "mean_cosine.pdf", scale=5)