In [None]:
import torch
import einops
import plotly
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

from typing import Literal
from transformers import AutoTokenizer
from pathlib import Path

In [None]:
tokenizer_name = "Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
# These constants must match with the ones in get_activations.py
TRAIN_BATCH_SIZE = 32
UPPERBOUND_MAX_NEW_TOKENS = 7000

# Just hardcode to avoid loading the model
start_thinking_token_id = 151667
end_thinking_token_id = 151668 # Qwen3's end of thinking token id
n_layers = 36
# n_layers = model.config.num_hidden_layers
n_activations = 2
d_model = 2560

## Load pre-computed data

In [None]:
# Load pre-computed data
positive_activations_all = torch.load("outputs/postive_activations_all.pt")
positive_activations_mean = torch.load("outputs/postive_activations_mean.pt")
positive_activations_mean_normed = torch.load("outputs/postive_activations_mean_normed.pt")

negative_activations_all = torch.load("outputs/negative_activations_all.pt")
negative_activations_mean = torch.load("outputs/negative_activations_mean.pt")
negative_activations_mean_normed = torch.load("outputs/negative_activations_mean_normed.pt")

candidate_refusal_vectors = torch.load("outputs/candidate_refusal_vectors.pt")
candidate_refusal_vectors_normed = torch.load("outputs/candidate_refusal_vectors_normed.pt")

In [None]:
refusal_dirs = candidate_refusal_vectors
print(refusal_dirs.shape) # (n_layers, n_activations, d_model)

## Select the best direction

### Method 1: Max of mean cosine similarities

#### Visualizations: Mean cosine similarity

In [None]:
colour_map = {
    "positive": plotly.colors.qualitative.Plotly[0],
    "negative": plotly.colors.qualitative.Plotly[1],
    "neutral": plotly.colors.qualitative.Pastel1[3],

}
colour_map_light = {
    "positive": plotly.colors.qualitative.Pastel1[1],
    "negative": plotly.colors.qualitative.Pastel1[0],
    "neutral": plotly.colors.qualitative.Pastel1[3],
}
colour_map_opaque = {
    "positive": "rgba(251, 180, 174, 0.3)",
    "negative": "rgba(179, 205, 227, 0.3)",   
}


In [None]:
layer_names = sum([[f"{i}", f"{i}-post"] for i in range(n_layers)], [])
layer_names = [str(i) for i in range(2 * n_layers)]


refusal_dirs_flatten = refusal_dirs.reshape((-1, refusal_dirs.shape[-1]))

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=layer_names,
        y=refusal_dirs_flatten.norm(dim=-1),
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        marker_size=8,
        showlegend=False,
    )
)
fig.add_trace(
    go.Scatter(
        x=layer_names[::],
        y=refusal_dirs_flatten.norm(dim=-1)[::],
        mode="markers",
        yaxis="y",
        marker_color=colour_map["neutral"],
        marker_size=8,
        showlegend=False,
    )
)

print(layer_names[np.argmax(refusal_dirs_flatten.norm(dim=-1)[:-1])])


fig.update_layout(
    # title=(
    #     "Statistics of refusal direction candidates at each layer"
    #     f" layer for {MODEL_PATH}"
    # ),
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=28)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=24),
    ),
    yaxis=dict(
        title=dict(text="Norm of<br>Refusal Direction", font=dict(size=28)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=24),
    ),
    hovermode="x unified",
    height=300,
    width=1000,
    # width=20 + 12 * len(x_values),
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.show()

# fig.write_image(VISUALIZATION_DIR / "norm_refusal.pdf", scale=5)


flatten_dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])
pairwise_cosine = flatten_dirs @ flatten_dirs.T
# pairwise_cosine = np.arccos(pairwise_cosine)
mean_cosine = np.nanmean(pairwise_cosine, axis=-1)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=layer_names,
        y=mean_cosine,
        mode="lines+markers",
        yaxis="y",
        marker_color=colour_map_light["neutral"],
        showlegend=False,
        marker_size=8,
    )
)
fig.add_trace(
    go.Scatter(
        x=layer_names[::],
        y=mean_cosine[::],
        mode="markers",
        yaxis="y",
        marker_color=colour_map["neutral"],
        showlegend=False,
        marker_size=8,
    )
)

# fig.add_trace(
#     go.Scatter(
#         x=layer_names,
#         y=raw_dirs.norm(dim=-1) + mean_cosine / mean_cosine.max(),
#         mode="lines+markers",
#         yaxis="y3",
#         marker_color=colour_map_light["neutral"],
#         showlegend=False
#     )
# )

fig.update_layout(
    # title=(
    #     "Statistics of refusal direction candidates at each extraction point"
    #     f" for {MODEL_PATH}"
    # ),
    plot_bgcolor="white",
    grid=dict(rows=1, columns=1),
    xaxis=dict(
        type="category",
        title=dict(text="Extraction Point", font=dict(size=28)),
        dtick=4,
        gridcolor="lightgrey",
        tickfont=dict(size=24),
    ),
    yaxis=dict(
        title=dict(text=f"Mean<br>Cosine Score", font=dict(size=28)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=24),
    ),
    hovermode="x unified",
    height=300,
    width=1000,
    # width=20 + 12 * len(x_values),
    margin=dict(l=20, r=20, t=20, b=20),
)

fig.show()
layer_names[np.nanargmax(mean_cosine)]


# fig.write_image(VISUALIZATION_DIR / "mean_cosine.pdf", scale=5)

In [None]:
# Select the best direction
argmax = np.nanargmax(mean_cosine)
max_mean_cosine_layer = argmax // 2
max_mean_cosine_act_idx = argmax % 2

chosen_layer = max_mean_cosine_layer
chosen_act_idx = max_mean_cosine_act_idx
chosen_token = -1

In [None]:
chosen_direction = refusal_dirs[chosen_layer, chosen_act_idx] # (d_model)

# save tensor
torch.save(chosen_direction.to("cpu"), "outputs/chosen_direction.pt")

#### Visualization: Projection of activation vectors onto the chosen steering vector

In [None]:
refusal_dirs.shape # (n_layers, n_activations, d_model)

In [None]:
def variance_plot(**kwargs):
    x = kwargs.pop("x")
    y = kwargs.pop("y")
    y_mean = y.mean(dim=-1)
    y_std = y.std(dim=-1)
    y_upper = y_mean + y_std
    y_lower = y_mean - y_std
    y_upper = y_upper.tolist()
    y_lower = y_lower.tolist()
    # colour = kwargs.pop("color")

    trace = go.Scatter(
        x=x + x[::-1],
        y=y_upper + y_lower[::-1],
        mode="lines",
        fill="toself",
        line=dict(color=kwargs["fillcolor"], width=0),
        **kwargs
    )

    return trace

In [None]:
fig = go.Figure()

category2acts_normed = {
    "positive": positive_activations_all.cpu(),
    "negative": negative_activations_all.cpu(),
} # (n_layers, n_activations, batch_size, num_last_tokens, d_model)



for category in ["positive", "negative"]:
    acts_normed = category2acts_normed[category]

    # layers x resid_modules x batch_size x dim
    activations = acts_normed.clone()

    # dim
    direction = chosen_direction.clone() # (d_model)

    # layers x resid_modules x batch_size
    scalar_projections = einops.einsum(
        activations,
        direction,
        "... batch_size dim, ... dim -> ... batch_size",
    )
    scalar_projections = np.nan_to_num(scalar_projections)
    print(category)
    print(scalar_projections.mean())
    degrees = np.rad2deg(np.arccos(scalar_projections))

    y_values = scalar_projections

    batch_size = scalar_projections.shape[-1]

    # x_values_flatten = sum(
    #     [
    #         [f"{l}-mid"] * batch_size + [f"{l}-post"] * batch_size
    #         for l in range(num_layers)
    #     ],
    #     [],
    # )
    x_values = sum([[f"{l}", f"{l}-post"] for l in range(n_layers)], [])
    x_values = [str(i) for i in range(2 * n_layers)]

    # variance
    fig.add_trace(
        variance_plot(
            x=x_values,
            y=torch.tensor(y_values).reshape(-1, degrees.shape[-1]),
            yaxis="y",
            fillcolor=colour_map_opaque[category],
            showlegend=False,
        )
    )

    # mean
    ## for legend
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values.mean(axis=-1).flatten(),
            mode="lines+markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=True,
            name=category,
        )
    )
    ## for lines
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values.mean(axis=-1).flatten(),
            mode="lines",
            yaxis="y",
            marker=dict(color=colour_map_light[category], size=3),
            showlegend=False,
            name=category,
        )
    )
    ## for markers
    fig.add_trace(
        go.Scatter(
            x=x_values,
            y=y_values.mean(axis=-1).flatten(),
            mode="markers",
            yaxis="y",
            marker=dict(color=colour_map[category], size=3),
            showlegend=False,
            name=category,
        )
    )

    activations -= 2 * einops.einsum(
        np.maximum(scalar_projections, 0),
        direction,
        "layer resid_module batch_size, dim -> layer resid_module batch_size dim",
    )
    scalar_projections = einops.einsum(
        activations,
        direction,
        "... batch_size dim, ... dim -> ... batch_size",
    )
    print(category)
    print(scalar_projections.mean())
    degrees = np.rad2deg(np.arccos(scalar_projections))

    y_values = scalar_projections


module_names = ["mid", "post"]
fig.update_layout(
    grid=dict(rows=1, columns=1),
    # yaxis=dict(tickformat=".2E"),
    plot_bgcolor="white",
    xaxis=dict(
        type="category",
        dtick=4,
        title=dict(text="Extraction Point", font=dict(size=20)),
        gridcolor="lightgrey",
        tickfont=dict(size=18),
    ),
    yaxis=dict(
        title=dict(text="Scalar Projections", font=dict(size=20)),
        gridcolor="lightgrey",
        zeroline=False,
        tickfont=dict(size=18),
    ),
    hovermode="x unified",
    height=250,
    width=600,
    # title=(
    #     "Scalar projections of activations at each layer onto the chosen refusal direction"
    #     f" ({chosen_layer}-{module_names[chosen_act_idx]})"
    # ),
    # yaxis=dict(matches=None),
    margin=dict(l=20, r=20, t=20, b=20),
    legend=dict(x=0.05, y=0.95, font=dict(size=18)),
)
fig.show()