In [None]:
import torch
import einops
import plotly
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

from transformers import AutoTokenizer
from pathlib import Path

In [None]:
tokenizer_name = "Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
# These constants must match with the ones in get_activations.py
TRAIN_BATCH_SIZE = 32
UPPERBOUND_MAX_NEW_TOKENS = 7000

# Just hardcode to avoid loading the model
start_thinking_token_id = 151667
end_thinking_token_id = 151668 # Qwen3's end of thinking token id
n_layers = 36
# n_layers = model.config.num_hidden_layers
n_activations = 2
d_model = 2560

# Load pre-computed data

In [None]:
# Load all token activations
# ADD THE PATH TO ACTIVATIONS HERE
activation_path = Path("outputs/sequences_activations.pt")
assert activation_path.exists()

# Load the activations
sequences_activations = torch.load(activation_path)
assert len(sequences_activations) == TRAIN_BATCH_SIZE

In [None]:
# Find (again) the positions of the start and end of thinking tokens

start_thinking_positions = []
end_thinking_positions = []
for seq in sequences_activations:
    start_thinking_positions.append((seq["generated_token_ids"][0] == start_thinking_token_id).nonzero(as_tuple=True)[0].item())
    end_thinking_positions.append((seq["generated_token_ids"][0] == end_thinking_token_id).nonzero(as_tuple=True)[0].item())

print(start_thinking_positions)
print(end_thinking_positions)

In [None]:
# Load the chosen direction
chosen_direction = torch.load("outputs/chosen_direction.pt")
assert chosen_direction.shape == (d_model,)

# Compute

In [None]:
batch_thinking_caches = []
for i, seq in enumerate(sequences_activations):
    token_activations = seq["token_activations"]
    caches = {}
    for module_path, activations in token_activations.items():
        caches[module_path] = activations[0, start_thinking_positions[i]:end_thinking_positions[i]+1, :] # (n_tokens, d_model)
    batch_thinking_caches.append(caches)

assert len(batch_thinking_caches) == TRAIN_BATCH_SIZE

In [None]:
# Sanity check
print(batch_thinking_caches[0]["model.layers.0.input_layernorm"].shape) # (n_tokens, d_model)

### Visualization: Project token activations at an extraction point onto steering vector

In [None]:
def get_module_path(
    layer: int,
    module_name: list[str] = ["input_layernorm", "post_attention_layernorm"]
):
    return f"model.layers.{layer}.{module_name}"

In [None]:
chosen_extraction_point = 40
chosen_sample_idx = 5 # Select 1 sample from the batch

chosen_layer = chosen_extraction_point // 2
chosen_module_name = "input_layernorm" if chosen_extraction_point % 2 == 0 else "post_attention_layernorm"

chosen_module_path = get_module_path(chosen_layer, chosen_module_name)
chosen_module_path

In [None]:
# (batch_size, n_tokens, d_model)
batch_thinking_token_activations = [seq[chosen_module_path] for seq in batch_thinking_caches]

In [None]:
colour_map = {
    "positive": plotly.colors.qualitative.Plotly[0],
    "negative": plotly.colors.qualitative.Plotly[1],
    "neutral": plotly.colors.qualitative.Pastel1[3],

}
colour_map_light = {
    "positive": plotly.colors.qualitative.Pastel1[1],
    "negative": plotly.colors.qualitative.Pastel1[0],
    "neutral": plotly.colors.qualitative.Pastel1[3],
}
colour_map_opaque = {
    "positive": "rgba(251, 180, 174, 0.3)",
    "negative": "rgba(179, 205, 227, 0.3)",   
}


In [None]:
layer_names = sum([[f"{i}", f"{i}-post"] for i in range(n_layers)], [])
layer_names = [str(i) for i in range(2 * n_layers)]

# dim
activations = batch_thinking_token_activations[chosen_sample_idx] # (seq_len, d_model)

# layers x resid_modules x batch_size
scalar_projections = einops.einsum(
    activations.to(chosen_direction.dtype),
    chosen_direction,
    "... seq_len dim, ... dim -> ... seq_len",
)
scalar_projections = np.nan_to_num(scalar_projections)

print("Summary statistics of scalar projections:")
print(f"Min: {scalar_projections.min()}")
print(f"Max: {scalar_projections.max()}")
print(f"Mean: {scalar_projections.mean()}")
print(f"Median: {np.median(scalar_projections)}")
print(f"Standard Deviation: {scalar_projections.std()}")

degrees = np.rad2deg(np.arccos(scalar_projections))

y_values = scalar_projections.flatten()

batch_size = scalar_projections.shape[-1]

# x_values_flatten = sum(
#     [
#         [f"{l}-mid"] * batch_size + [f"{l}-post"] * batch_size
#         for l in range(num_layers)
#     ],
#     [],
# )
# x_values = sum([[f"{l}", f"{l}-post"] for l in range(n_layers)], [])
# x_values = [str(i) for i in range(2 * n_layers)]

x_values = [str(i) for i in range(y_values.shape[0])]

fig = go.Figure()


fig.add_trace(
    go.Scatter(
        x=x_values,
        y=y_values,
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        showlegend=False,
        marker_size=1,
    )
)
# fig.add_trace(
#     go.Scatter(
#         x=layer_names[::],
#         y=mean_cosine[::],
#         mode="markers",
#         yaxis="y",
#         marker_color=colour_map["neutral"],
#         showlegend=False,
#         marker_size=1,
#     )
# )
fig.show()

### Visualization: Project mean token activations of all extraction points onto steering vector

In [None]:
batch_thinking_token_activations.shape # (batch_size, seq_len, d_model)

In [None]:
# Stack

batch_thinking_caches_stacked = []
for sample in batch_thinking_caches:
    stack = torch.stack(list(sample.values()), dim=0) # (n_layers * n_activations, n_tokens, d_model)
    batch_thinking_caches_stacked.append(stack)

batch_thinking_caches_stacked[0].shape # (n_layers * n_activations, n_tokens, d_model)

In [None]:
batch_thinking_caches_stacked_mean = [batch_thinking_caches_stacked[i].mean(dim=0) for i in range(len(batch_thinking_caches_stacked))] # (n_tokens, d_model)
batch_thinking_caches_stacked_mean[0].shape # (n_tokens, d_model)

In [None]:
chosen_sample_idx = 5

# dim
activations = batch_thinking_caches_stacked_mean[chosen_sample_idx] # (seq_len, d_model)

# layers x resid_modules x batch_size
scalar_projections = einops.einsum(
    activations.to(chosen_direction.dtype),
    chosen_direction,
    "... seq_len dim, ... dim -> ... seq_len",
)
scalar_projections = np.nan_to_num(scalar_projections)

print("Summary statistics of scalar projections:")
print(f"Min: {scalar_projections.min()}")
print(f"Max: {scalar_projections.max()}")
print(f"Mean: {scalar_projections.mean()}")
print(f"Median: {np.median(scalar_projections)}")
print(f"Standard Deviation: {scalar_projections.std()}")

degrees = np.rad2deg(np.arccos(scalar_projections))

y_values = scalar_projections.flatten()

batch_size = scalar_projections.shape[-1]

# x_values_flatten = sum(
#     [
#         [f"{l}-mid"] * batch_size + [f"{l}-post"] * batch_size
#         for l in range(num_layers)
#     ],
#     [],
# )
# x_values = sum([[f"{l}", f"{l}-post"] for l in range(n_layers)], [])
# x_values = [str(i) for i in range(2 * n_layers)]

x_values = [str(i) for i in range(y_values.shape[0])]

fig = go.Figure()


fig.add_trace(
    go.Scatter(
        x=x_values,
        y=y_values,
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        showlegend=False,
        marker_size=1,
    )
)
# fig.add_trace(
#     go.Scatter(
#         x=layer_names[::],
#         y=mean_cosine[::],
#         mode="markers",
#         yaxis="y",
#         marker_color=colour_map["neutral"],
#         showlegend=False,
#         marker_size=1,
#     )
# )
fig.show()