In [None]:
import torch
import einops
import plotly
import plotly.graph_objects as go
import plotly.express as px

from typing import Literal
from transformers import AutoTokenizer
from pathlib import Path

In [None]:
tokenizer_name = "Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
# These constants must match with the ones in get_activations.py
TRAIN_BATCH_SIZE = 32
UPPERBOUND_MAX_NEW_TOKENS = 7000

# Just hardcode to avoid loading the model
start_thinking_token_id = 151667
end_thinking_token_id = 151668 # Qwen3's end of thinking token id
n_layers = 36
# n_layers = model.config.num_hidden_layers
n_activations = 2
d_model = 2560

## Load pre-computed activations

In [None]:
# ADD THE PATH TO ACTIVATIONS HERE
activation_path = Path("")
assert activation_path.exists()

# Load the activations
sequences_activations = torch.load(activation_path)

In [None]:
# Sanity check
assert len(sequences_activations) == TRAIN_BATCH_SIZE
print(sequences_activations[0]["generated_token_ids"].shape) # (1, d_model)

## Compute candidate directions

### Scenario 1: Positive: `<think>` - Negative: `</think>`

In [None]:
# Find the start thinking and end thinking positions in each sequence
start_thinking_positions = []
end_thinking_positions = []
for seq in sequences_activations:
    start_thinking_positions.append((seq["generated_token_ids"][0] == start_thinking_token_id).nonzero(as_tuple=True)[0].item())
    try:
        end_thinking_positions.append((seq["generated_token_ids"][0] == end_thinking_token_id).nonzero(as_tuple=True)[0].item())
    except:
        # Some sequences do not have the end thinking token
        end_thinking_positions.append(None)

print(start_thinking_positions)
print(end_thinking_positions)

In [None]:
# Sanity check: validate if the end positions are the correct end of thinking tokens
first_end_thinking_token_ids = sequences_activations[0]["generated_token_ids"][0][end_thinking_positions[0]]
print(first_end_thinking_token_ids)

#### Get positive and negative activations

In [None]:
# Get positive
batch_positive_caches = [] 
for i, seq in enumerate(sequences_activations):
    token_activations = seq["token_activations"]
    positive = {}
    for module_path, activations in token_activations.items():
        positive[module_path] = activations[0, start_thinking_positions[i], :] # (d_model)
    batch_positive_caches.append(positive)

print(batch_positive_caches[0]["model.layers.0.input_layernorm"].shape) # (d_model)

In [None]:
# Get negative
batch_negative_caches = []
for i,seq in enumerate(sequences_activations):
    token_activations = seq["token_activations"]
    negative = {}
    for module_path, activations in token_activations.items():
        negative[module_path] = activations[0, end_thinking_positions[i], :] # (d_model)
    batch_negative_caches.append(negative)

print(batch_negative_caches[0]["model.layers.0.input_layernorm"].shape) # (d_model)


In [None]:
def get_module_path(
    layer: int,
    module_name: Literal["input_layernorm", "post_attention_layernorm"] = ["input_layernorm", "post_attention_layernorm"]
) -> str:
    return f"model.layers.{layer}.{module_name}"

In [None]:
postive_activations_mean = torch.zeros(
    n_layers, n_activations, d_model, device="cpu", dtype=torch.float16
) # (n_layers, n_activations, d_model)

positive_activations_all = torch.zeros(
    n_layers, n_activations, TRAIN_BATCH_SIZE, d_model, device="cpu", dtype=torch.float16
) # (n_layers, n_activations, batch_size, d_model)

for layer in range(n_layers):
    batch_positive_activations = []
    for module_name in ["input_layernorm", "post_attention_layernorm"]:
        # Get mean activations across tokens dimension
        batch_positive_activations_per_module = []
        for i, sample in enumerate(batch_positive_caches):
            try:
                batch_positive_activations_per_module.append(
                    sample[get_module_path(layer, module_name)] # (d_model)
                )
            except:
                print(i)
        batch_positive_activations_per_module = torch.stack(batch_positive_activations_per_module)
        batch_positive_activations.append(batch_positive_activations_per_module)
    batch_positive_activations = torch.stack(batch_positive_activations) # (n_activations, batch_size, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_positive_activations = batch_positive_activations / batch_positive_activations.norm(dim=-1, keepdim=True)

    # Compute mean along the batch dimension
    postive_activations_mean[layer] = batch_positive_activations.mean(dim=1)
    positive_activations_all[layer] = batch_positive_activations

print(postive_activations_mean.shape) # (n_layers, n_activations, num_last_tokens, d_model)
print(positive_activations_all.shape) # (n_layers, n_activations, batch_size, num_last_tokens, d_model)

In [None]:
negative_activations_mean = torch.zeros(
    n_layers, n_activations, d_model, device="cpu", dtype=torch.float16
) # (n_layers, n_activations, d_model)

negative_activations_all = torch.zeros(
    n_layers, n_activations, TRAIN_BATCH_SIZE, d_model, device="cpu", dtype=torch.float16
) # (n_layers, n_activations, batch_size, d_model)

for layer in range(n_layers):
    batch_negative_activations = []
    for module_name in ["input_layernorm", "post_attention_layernorm"]:
        # Get mean activations across tokens dimension
        batch_negative_activations_per_module = []
        for i, sample in enumerate(batch_negative_caches):
            try:
                batch_negative_activations_per_module.append(
                    sample[get_module_path(layer, module_name)] # (d_model)
                )
            except:
                print(i)
        batch_negative_activations_per_module = torch.stack(batch_negative_activations_per_module) # (batch_size, d_model)
        batch_negative_activations.append(batch_negative_activations_per_module)
    batch_negative_activations = torch.stack(batch_negative_activations) # (n_activations, batch_size, d_model)

    # Normalize then get mean because the activation will be normalized by the RMSNorm layer
    batch_negative_activations = batch_negative_activations / batch_negative_activations.norm(dim=-1, keepdim=True)

    # Compute mean across batch dimension
    negative_activations_mean[layer] = batch_negative_activations.mean(dim=1)
    negative_activations_all[layer] = batch_negative_activations

print(negative_activations_mean.shape) # (n_layers, n_activations, d_model)
print(negative_activations_all.shape) # (n_layers, n_activations, batch_size, d_model)

In [None]:
postive_activations_mean_normed = postive_activations_mean / postive_activations_mean.norm(dim=-1, keepdim=True)

negative_activations_mean_normed = negative_activations_mean / negative_activations_mean.norm(dim=-1, keepdim=True)

print(postive_activations_mean_normed.shape) # (n_layers, n_activations, d_model)
print(negative_activations_mean_normed.shape) # (n_layers, n_activations, d_model)

In [None]:
# # Save to disk
# torch.save(positive_activations_all, "outputs/postive_activations_all.pt")
# torch.save(postive_activations_mean, "outputs/postive_activations_mean.pt")
# torch.save(postive_activations_mean_normed, "outputs/postive_activations_mean_normed.pt")

# torch.save(negative_activations_all, "outputs/negative_activations_all.pt")
# torch.save(negative_activations_mean, "outputs/negative_activations_mean.pt")
# torch.save(negative_activations_mean_normed, "outputs/negative_activations_mean_normed.pt")


## Compute candidate directions

#### Difference-in-mean

In [None]:
candidate_refusal_vectors = postive_activations_mean_normed.to("cpu") - negative_activations_mean_normed.to("cpu")
candidate_refusal_vectors_normed = candidate_refusal_vectors / candidate_refusal_vectors.norm(dim=-1, keepdim=True)

print(candidate_refusal_vectors_normed.shape) # (n_layers, n_activations, d_model)


In [None]:
# Save to disk
torch.save(candidate_refusal_vectors, "outputs/candidate_refusal_vectors.pt")
torch.save(candidate_refusal_vectors_normed, "outputs/candidate_refusal_vectors_normed.pt")

## Visualizations

### Scalar projections of activations onto the local refusal direction at each extraction point

In [None]:
colour_map = {
    "positive": plotly.colors.qualitative.Plotly[0],
    "negative": plotly.colors.qualitative.Plotly[1],
    "neutral": plotly.colors.qualitative.Pastel1[3],

}
colour_map_light = {
    "positive": plotly.colors.qualitative.Pastel1[1],
    "negative": plotly.colors.qualitative.Pastel1[0],
    "neutral": plotly.colors.qualitative.Pastel1[3],
}
colour_map_opaque = {
    "positive": "rgba(251, 180, 174, 0.3)",
    "negative": "rgba(179, 205, 227, 0.3)",   
}


In [None]:
category2acts_normed = {
    "positive": positive_activations_all.cpu(),
    "negative": negative_activations_all.cpu(),
} # (n_layers, n_activations, batch_size, num_last_tokens, d_model)


x_values = [str(i) for i in range(2 * n_layers)]

fig = go.Figure()

for category in ["positive", "negative"]:
    acts_normed = category2acts_normed[category] # (n_layers, n_activations, batch_size, d_model)
    projections = einops.einsum(
        candidate_refusal_vectors_normed,
        acts_normed,
        "layer act dim, layer act batch dim -> layer act batch",
    )
    projections = torch.tensor(projections)

    mean_projection = projections.mean(dim=-1) # layer act batch -> layer act

    y_values = mean_projection.flatten() # layer act -> layer * act


    # mean
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=True,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values,
            name=category,
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map_light[category], size=3),
            showlegend=False,
        )
    )

    # variance
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=projections.reshape(-1, projections.shape[-1]),
            name=category,
            mode="lines+markers",
            yaxis="y",
            fillcolor=colour_map_opaque[category],
            showlegend=False,
        )
    )

    # dot markers
    fig.add_trace(
        go.Scatter(
            x=x_values[1::],
            y=y_values[1::],
            name=f"{category}",
            mode="markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=False,
        )
    )

fig.update_layout(
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=20)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=18),
    ),
    yaxis=dict(
        title=dict(text="Scalar Projection", font=dict(size=20)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=18),
    ),
    hovermode="x unified",
    height=250,
    width=600,
    margin=dict(l=0, r=0, t=0, b=0),
    legend=dict(x=0.05, y=0.95, font=dict(size=18)),
)
fig.show()

# fig.write_image("prj_onto_local_candidate_refusal_vectors.pdf", scale=5)