In this notebook, we get the constrastive sets of activations by creating checkpoints at start think and end think tokens:

```md
thinking:
	What is LLM?<think>\n\nOkay the user ask...</think>\n\nLLM is...
			^                               ^
		        Here                            Here
```

The model is loaded to Transformer's AutoModelCausalLM instead of HookedTransformer as usual

## Setup

In [None]:
%%capture
!pip install transformers transformers_stream_generator tiktoken einops jaxtyping colorama kaleido numpy==1.26.3

In [None]:
import torch
import functools
import einops
import requests
import pandas as pd
from typing import TypedDict, Literal
import random
import numpy as np
import json
import time

from dataclasses import dataclass
from pathlib import Path
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from tqdm import tqdm
from torch import Tensor
from typing import Callable
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from jaxtyping import Float, Int
from enum import StrEnum
from colorama import Fore

In [None]:
# For visualization
import plotly
import plotly.graph_objects as go
import plotly.express as px

In [None]:
# load the tokenizer and the model
model_name = "Qwen/Qwen3-4B"
DEVICE = "cuda"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map=DEVICE
)

In [None]:
def tokenize_instructions_qwen_chat(
    tokenizer: AutoTokenizer,
    instructions: list[str],
    enable_thinking: bool = True,
) -> Int[Tensor, "batch_size seq_len"]:
    prompts = []
    for instruction in instructions:
        message = [{"role": "user", "content": instruction}]
        prompt = tokenizer.apply_chat_template(
            message,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=enable_thinking,
        )
        prompts.append(prompt)
    
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids

tokenize_instructions_fn = functools.partial(
    tokenize_instructions_qwen_chat,
    tokenizer=tokenizer,
    enable_thinking=True,
)


### Get data

In [None]:
def get_dataset_instructions() -> tuple[list[str], list[str]]:
    url = "https://raw.githubusercontent.com/cvenhoff/steering-thinking-llms/refs/heads/main/messages/messages.py"

    response = requests.get(url)
    # Save to file
    with open("messages.py", "w") as f:
        f.write(response.text)
    
    # Load the messages
    assert Path("messages.py").exists()
    from messages import messages, eval_messages

    train_contents = [msg["content"] for msg in messages]
    eval_contents = [msg["content"] for msg in eval_messages]

    # Shuffle the messages
    random.shuffle(train_contents)
    random.shuffle(eval_contents)

    return train_contents, eval_contents

def preprocess_instructions(instructions: list[str]) -> list[str]:
    return [ins + " Think fast and briefly." for ins in instructions]


In [None]:
instructions_train, instructions_test = get_dataset_instructions()
instructions_train = preprocess_instructions(instructions_train)
instructions_test = preprocess_instructions(instructions_test)


In [None]:
TRAIN_BATCH_SIZE = 32
MAX_NEW_TOKENS = -1
UPPERBOUND_MAX_NEW_TOKENS = 7000

start_thinking_token_id = 151667
end_thinking_token_id = 151668 # Qwen3's end of thinking token id
n_layers = model.config.num_hidden_layers
n_activations = 2
d_model = model.config.hidden_size

## Compute activations & refusal dirs

In [None]:
def get_generated_tokens_activations(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    tokens: torch.Tensor,
    target_tokens: list[str] = ["<think>", "</think>"],
    max_new_tokens: int = -1,
    module_names: list[str] = ["input_layernorm", "post_attention_layernorm", "post_feedforward_layernorm"],
    layers: list[int] | None = None,
    offload_to_cpu: bool = True,
):
    """
    Generate tokens and extract activations only when target tokens are generated.
    
    This approach uses a two-phase strategy:
    1. Fast generation using model.generate() to get all tokens
    2. Selective activation extraction only for target token positions
    
    This is much faster than manual token-by-token generation because:
    - model.generate() uses KV caching for efficiency
    - We only do expensive activation extraction when needed
    
    Returns:
        token_activations: {token_str: {module_name: activation_tensor}}
        final_tokens: the complete generated sequence
    """
    if layers is None:
        # Get number of layers from model config
        if hasattr(model.config, 'num_hidden_layers'):
            num_layers = model.config.num_hidden_layers
        elif hasattr(model.config, 'n_layers'):
            num_layers = model.config.n_layers
        else:
            raise ValueError("Could not determine number of layers from model config")
        layers = list(range(num_layers))

    if max_new_tokens == -1:
        max_new_tokens = UPPERBOUND_MAX_NEW_TOKENS

    # Convert target tokens to token IDs
    target_token_ids = set()
    target_token_to_str = {}

    for token_str in target_tokens:
        token_ids = tokenizer(token_str, return_tensors="pt").input_ids.squeeze()
        if token_ids.dim() == 0:
            token_ids = [token_ids.item()]
        else:
            token_ids = token_ids.tolist()
        
        if len(token_ids) == 1:
            target_token_ids.add(token_ids[0])
            target_token_to_str[token_ids[0]] = token_str
        else:
            print(f"Warning: {token_str} tokenizes to multiple tokens: {token_ids}")
            # Handle multi-token case - take the last token as trigger
            target_token_ids.add(token_ids[-1])
            target_token_to_str[token_ids[-1]] = token_str

    print(f"Monitoring token IDs: {target_token_ids}")
    print(f"Token mapping: {target_token_to_str}")

    # Move tokens to model device
    tokens = tokens.to(model.device)
    
    # Phase 1: Fast generation using model.generate()
    start_time = time.time()
    print("Phase 1: Fast generation using model.generate()...")
    generation_config = GenerationConfig(
        max_new_tokens=max_new_tokens, 
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )
    print(f"Max new tokens: {max_new_tokens}")
    
    with torch.no_grad():
        generated_tokens = model.generate(
            input_ids=tokens,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=False,  # We don't need scores, saves memory
        )
    
    # Extract the full sequence
    full_sequence = generated_tokens.sequences[0]  # Remove batch dimension
    new_tokens = full_sequence[tokens.shape[1]:]  # Only the newly generated part
    
    generation_time = time.time() - start_time
    print(f"Generation time: {generation_time:.2f} seconds")

    # ==================================
    # Phase 2: Find target token positions
    start_time = time.time()
    print("Phase 2: Finding target token positions...")
    target_positions = []
    for i, token_id in enumerate(new_tokens):
        if token_id.item() not in target_token_ids:
            continue
        # absolute position including original prompt
        absolute_position = tokens.shape[1] + i
        token_str = target_token_to_str[token_id.item()]
        target_positions.append((absolute_position, token_id.item(), token_str))
        print(f"Found target token '{token_str}' at position {absolute_position}")
    
    print(f"Found {len(target_positions)} target tokens")

    target_token_position_time = time.time() - start_time
    print(f"Target token position time: {target_token_position_time:.2f} seconds")

    # ==================================
    # Phase 3: Extract activations only for target positions
    start_time = time.time()
    print("Phase 3: Extracting activations for target positions...")
    token_activations: dict[str, dict[str, Tensor]] = {
        token_str: {} for token_str in target_tokens
    }
    
    if not target_positions:
        print("No target tokens found, returning empty activations")
        return token_activations, full_sequence.unsqueeze(0)
    
    # Get target modules
    def get_target_modules():
        target_modules = []
        for layer in layers:
            for module_name in module_names:
                # Convert module names to actual modules
                module_path = f"model.layers.{layer}.{module_name}"
                
                # Get the actual module
                try:
                    module = model
                    for part in module_path.split('.'):
                        module = getattr(module, part)
                    target_modules.append((module, module_path))
                except AttributeError:
                    print(f"Warning: Could not find module {module_path}")
                    continue
        return target_modules

    target_modules = get_target_modules()
    
    # Extract activations for each target position
    for position, token_id, token_str in tqdm(target_positions, desc="Extracting activations"):
        # Get sequence up to and including this position
        seq_up_to_pos = full_sequence[:position + 1].unsqueeze(0)  # Add batch dim
        
        # Set up activation cache for this position
        step_activations = {}
        
        # Create hooks for activation extraction
        hook_handles = []
        for module, module_key in target_modules:
            def make_hook(key, target_pos=position):
                def hook_fn(module, input, output):
                    # Extract activation at the target position
                    if hasattr(output, 'shape') and len(output.shape) >= 3:
                        activation = output[:, target_pos, :].clone()  # Keep batch and single token dim
                    else:
                        activation = output.clone()
                    
                    if offload_to_cpu:
                        activation = activation.cpu()
                    step_activations[key] = activation
                return hook_fn
            
            handle = module.register_forward_hook(make_hook(module_key, position))
            hook_handles.append(handle)
        
        # Forward pass with hooks to extract activations
        with torch.no_grad():
            _ = model(seq_up_to_pos)
        
        # Remove hooks
        for handle in hook_handles:
            handle.remove()
        
        # Store activations for this token
        for module_key, activation in step_activations.items():
            if token_str not in token_activations:
                token_activations[token_str] = {}
            
            # If we already have activations for this token type, concatenate them
            if module_key in token_activations[token_str]:
                print(f"Module key {module_key} exists, let concat the activations")
                token_activations[token_str][module_key] = torch.cat([
                    token_activations[token_str][module_key], 
                    activation
                ], dim=1)  # Concatenate along token dimension
            else:
                token_activations[token_str][module_key] = activation
        
        # Clean up
        del step_activations, hook_handles
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    activation_extraction_time = time.time() - start_time
    print(f"Activation extraction time: {activation_extraction_time:.2f} seconds")
    
    return token_activations, full_sequence.unsqueeze(0)

In [None]:
batch_positive_caches = []
batch_negative_caches = []

subset_instructions = instructions_train[:TRAIN_BATCH_SIZE]
for i, instruction in enumerate(subset_instructions):
    print(f"Processing instruction {i}:")
    tokens = tokenize_instructions_fn(instructions=[instruction])
    print(f"Number of tokens: {tokens.shape[1]}")
    token_activations, tokens = get_generated_tokens_activations(
        model,
        tokenizer,
        tokens,
        max_new_tokens=-1,
        module_names=["input_layernorm", "post_attention_layernorm"],
        target_tokens=["<think>", "</think>"],
    )
    batch_positive_caches.append(token_activations["<think>"])
    batch_negative_caches.append(token_activations["</think>"])

    print("--------------------------------")

In [None]:
# Save the activations for backup

def save_tensor(
    tensor: Tensor,
    save_dir: Path,
    name: str,
) -> None:
    save_dir.mkdir(exist_ok=True)
    save_path = save_dir / f"{name}.pt"
    torch.save(tensor, save_path)
    print(f"Saved {name} to {save_path}")

output_dir = Path("/root/workspace/outputs")
output_dir.mkdir(exist_ok=True)
save_tensor(
    batch_positive_caches,
    output_dir,
    "batch_positive_caches.pt",
)
save_tensor(
    batch_negative_caches,
    output_dir,
    "batch_negative_caches.pt",
)


### Compute refusal directions

In [None]:
def get_module_path(
    layer: int,
    module_name: list[str] = ["input_layernorm", "post_attention_layernorm"]
):
    return f"model.layers.{layer}.{module_name}"

In [None]:
positive_mean_activations = torch.zeros(
    n_layers, n_activations, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, d_model)

positive_activations_all = torch.zeros(
    n_layers, n_activations, 29, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, batch_size, d_model)

for layer in range(n_layers):
    batch_positive_activations = []
    for module_name in ["input_layernorm", "post_attention_layernorm"]:
        # Get mean activations across tokens dimension
        batch_positive_activations_per_module = []
        for i, sample in enumerate(batch_positive_caches):
            batch_positive_activations_per_module.append(
                sample[get_module_path(layer, module_name)][0,:]
            )
        batch_positive_activations_per_module = torch.stack(batch_positive_activations_per_module)
        batch_positive_activations.append(batch_positive_activations_per_module)
    batch_positive_activations = torch.stack(batch_positive_activations) # (n_activations, batch_size, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_positive_activations = batch_positive_activations / batch_positive_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    positive_mean_activations[layer] = batch_positive_activations.mean(dim=1)
    positive_activations_all[layer] = batch_positive_activations

print(positive_mean_activations.shape) # (n_layers, n_activations, num_last_tokens, d_model)
print(positive_activations_all.shape) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

In [None]:
negative_mean_activations = torch.zeros(
    n_layers, n_activations, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, d_model)

negative_activations_all = torch.zeros(
    n_layers, n_activations, 29, d_model, device=DEVICE, dtype=torch.float16
) # (n_layers, n_activations, batch_size, d_model)

for layer in range(n_layers):
    batch_negative_activations = []
    for module_name in ["input_layernorm", "post_attention_layernorm"]:
        # Get mean activations across tokens dimension
        batch_negative_activations_per_module = []
        for i, sample in enumerate(batch_negative_caches):
            batch_negative_activations_per_module.append(
                sample[get_module_path(layer, module_name)][0,:]
            )
        batch_negative_activations_per_module = torch.stack(batch_negative_activations_per_module)
        batch_negative_activations.append(batch_negative_activations_per_module)
    batch_negative_activations = torch.stack(batch_negative_activations) # (n_activations, batch_size, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_negative_activations = batch_negative_activations / batch_negative_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    negative_mean_activations[layer] = batch_negative_activations.mean(dim=1)
    negative_activations_all[layer] = batch_negative_activations

print(negative_mean_activations.shape) # (n_layers, n_activations, d_model)
print(negative_activations_all.shape) # (n_layers, n_activations, batch_size, d_model)

In [None]:
positive_mean_activation_normed = positive_mean_activations / positive_mean_activations.norm(dim=-1, keepdim=True)
negative_mean_activation_normed = negative_mean_activations / negative_mean_activations.norm(dim=-1, keepdim=True)

print(positive_mean_activation_normed.shape) # (n_layers, n_activations, d_model)
print(negative_mean_activation_normed.shape) # (n_layers, n_activations, d_model)

In [None]:
candidate_refusal_vectors = positive_mean_activation_normed.to("cpu") - negative_mean_activation_normed.to("cpu")
candidate_refusal_vectors_normed = candidate_refusal_vectors / candidate_refusal_vectors.norm(dim=-1, keepdim=True)

print(candidate_refusal_vectors_normed.shape) # (n_layers, n_activations, d_model)

## Visualizations

In [None]:
colour_map = {
    "positive": plotly.colors.qualitative.Plotly[0],
    "negative": plotly.colors.qualitative.Plotly[1],
    "neutral": plotly.colors.qualitative.Pastel1[3],

}
colour_map_light = {
    "positive": plotly.colors.qualitative.Pastel1[1],
    "negative": plotly.colors.qualitative.Pastel1[0],
    "neutral": plotly.colors.qualitative.Pastel1[3],
}
colour_map_opaque = {
    "positive": "rgba(251, 180, 174, 0.3)",
    "negative": "rgba(179, 205, 227, 0.3)",   
}


In [None]:
category2acts_normed = {
    "positive": positive_activations_all.cpu(),
    "negative": negative_activations_all.cpu(),
} # (n_layers, n_activations, batch_size, num_last_tokens, d_model)


x_values = [str(i) for i in range(2 * n_layers)]

fig = go.Figure()

for category in ["positive", "negative"]:
    acts_normed = category2acts_normed[category] # (n_layers, n_activations, batch_size, d_model)
    projections = einops.einsum(
        candidate_refusal_vectors_normed,
        acts_normed,
        "layer act dim, layer act batch dim -> layer act batch",
    )
    projections = torch.tensor(projections)

    mean_projection = projections.mean(dim=-1) # layer act batch -> layer act

    y_values = mean_projection.flatten() # layer act -> layer * act


    # mean
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=True,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map_light[category], size=3),
            showlegend=False,
        )
    )

    # variance
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=projections.reshape(-1, projections.shape[-1]),
            name=category,
            mode="lines+markers",
            yaxis="y",
            fillcolor=colour_map_opaque[category],
            showlegend=False,
        )
    )

    # dot markers
    fig.add_trace(
        go.Scatter(
            x=x_values[1::],
            y=y_values[1::],
            name=f"{category}",
            mode="markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=False,
        )
    )

fig.update_layout(
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=20)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=18),
    ),
    yaxis=dict(
        title=dict(text="Scalar Projection", font=dict(size=20)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=18),
    ),
    hovermode="x unified",
    height=250,
    width=600,
    margin=dict(l=0, r=0, t=0, b=0),
    legend=dict(x=0.05, y=0.95, font=dict(size=18)),
)
fig.show()

# fig.write_image("prj_onto_local_candidate_refusal_vectors.pdf", scale=5)