In [1]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

## Imports

In [2]:
import torch as t
import numpy as np
import plotly.express as px
import einops

from functools import partial
from dishonesty.mistral_lens import load_model
from dishonesty.prompts import PROMPTS
from dishonesty.utils import ntensor_to_long


t.set_grad_enabled(False)
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
model = load_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
directions = t.load(f"{ROOT}/directions/honesty_mistral-instruct-v0.1.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [5]:
# Define injection hook
def inject(module, input, output, alpha=-8.25):
    new_output = [o for o in output]
    new_output[0] += alpha * directions[15]
    return tuple(new_output)

## Run Logit Lens

In [34]:
# Define activations to cache
cache_names = ["embed_tokens", "final_rscale"]
cache_names += ["resid_post_13", "resid_post_14", "resid_post_15", "resid_post_16", "resid_post_17", "resid_post_31", "final_norm_out"]
for i in range(32):
    cache_names.append(f"attn_out_{i}")
    cache_names.append(f"mlp_out_{i}")

# Run with cache for all alphas
caches = []
alphas = np.linspace(-8.25, 0, 12)
for alpha in alphas:
    model.reset_hooks()
    model.add_hook(f"resid_post_15", partial(inject, alpha=alpha))
    _, cache = model.run_with_cache(PROMPTS, cache_names, pos_slicer=-1)
    caches.append(cache)

In [91]:
# Collect activations needed for logit lens
decomp_resids = []
final_rscales = []

for cache in caches:
    # Collect decomposed resids
    resid, labels = cache.get_decomposed_resid_stack(return_labels=True)  # [comp, batch, d_model]
    decomp_resids.append(resid)
    # Collect final rscales
    final_rscale = cache["final_rscale"][None, :, :]  # [1, batch, 1], keep only final pos
    final_rscales.append(final_rscale)

# Unify tensors
decomp_resids = t.stack(decomp_resids, dim=0)  # [alpha, comp, batch, d_model]
final_rscales = t.stack(final_rscales, dim=0)  # [alpha, 1, batch, 1]

assert decomp_resids.dim() == 4
assert final_rscales.dim() == 4

# Add the direction injection to the decomposed resid stack
resid_labels = labels[:3+15*2] + ["--> INJECTION"] + labels[3+15*2:]

injections = [alpha * directions[15].to(t.float16) for alpha in alphas]
injections = t.stack(injections, dim=0)[:, None, None, :]
injections = einops.repeat(
    injections,
    "alpha 1 1 d_model -> alpha 1 batch d_model",
    batch=decomp_resids.shape[2],
)

# Place the injection where resid_post_15 would be
decomp_resids = t.cat([
    decomp_resids[:, :3+15*2],
    injections,
    decomp_resids[:, 3+15*2:],
], dim=1)

In [92]:
decomp_resids.shape

torch.Size([12, 66, 20, 4096])

In [93]:
# Sanity check that the decomposition + injection is equal to the final resid_post
for i, alpha in enumerate(alphas):
    match = t.allclose(
        decomp_resids[i].sum(dim=0),
        caches[i]["resid_post_31"],
        atol=1e-2,
        rtol=1e-3,
    )
    print(i, alpha, match)

0 -8.25 True
1 -7.5 True
2 -6.75 True
3 -6.0 True
4 -5.25 True
5 -4.5 True
6 -3.75 True
7 -3.0 True
8 -2.25 True
9 -1.5 True
10 -0.75 True
11 0.0 True


In [94]:
def apply_rms_norm(resid, rscale, weight):
    input_dtype = resid.dtype
    scaled_resid = rscale * resid.to(rscale.dtype) * weight.to(rscale.dtype)
    return scaled_resid.to(input_dtype)

# Apply final rms norm to decompsed resids
scaled_decomp_resids = apply_rms_norm(
    decomp_resids,
    final_rscales,
    model.hf_model.model.norm.weight,
)  # [alpha, comp, batch, d_model]

In [95]:
# Get dishonest tokens
model.reset_hooks()
model.add_hook(f"resid_post_15", partial(inject, alpha=-8.25))
dishonest_logits = model(PROMPTS)[:, -1]  # [batch, d_vocab]
dishonest_tokens = dishonest_logits.argmax(dim=-1)  # [batch]

# Get honest tokens
model.reset_hooks()
honest_logits = model(PROMPTS)[:, -1]  # [batch, d_vocab]
honest_tokens = honest_logits.argmax(dim=-1)  # [batch]

# Calculate "original" logit diffs
original_logit_diffs = (
    dishonest_logits.gather(dim=-1, index=dishonest_tokens[:, None]).squeeze(-1)
    - dishonest_logits.gather(dim=-1, index=honest_tokens[:, None]).squeeze(-1)
)  # [batch]

# Get logit diff directions
W_U = model.hf_model.lm_head.weight  # [d_vocab, d_model]
logit_diff_directions = W_U[dishonest_tokens] - W_U[honest_tokens]  # [batch, d_model]

In [115]:
# Calculate logit diffs
logit_diffs_decomp = einops.einsum(
    scaled_decomp_resids,
    logit_diff_directions,
    "alpha comp batch d_model, batch d_model -> alpha comp batch",
)

# Sanity check that the sum of the decomposed logit diffs is approx. equal to
# the original logit diffs for alpha = -8.25
t.allclose(
    logit_diffs_decomp[0].to(t.float32).sum(dim=0),
    original_logit_diffs,
    atol=1e-2,
)

True

In [149]:
logit_diffs_accum = logit_diffs_decomp.cumsum(dim=1)  # [alpha, comp, batch]
logit_diffs_accum[0, -1]

tensor([2.4453, 3.1250, 1.5312, 2.3125, 1.5449, 2.1484, 0.6255, 4.0078, 2.4219,
        2.6484, 1.8779, 2.2480, 3.0664, 2.6387, 0.0000, 0.8672, 0.5234, 2.3008,
        0.1597, 0.7588], device='cuda:0', dtype=torch.float16)

In [150]:
original_logit_diffs

tensor([2.4453, 3.1250, 1.5312, 2.3125, 1.5547, 2.1484, 0.6250, 4.0078, 2.4141,
        2.6484, 1.8828, 2.2422, 3.0625, 2.6484, 0.0000, 0.8672, 0.5234, 2.2969,
        0.1562, 0.7578], device='cuda:0')

## Make Plots

In [158]:
# Convert to long format
df_decomp = ntensor_to_long(
    logit_diffs_decomp,
    value_name="Logit Diff Contribution",
    dim_names=["Alpha", "Component", "Batch"],
)
df_accum = ntensor_to_long(
    logit_diffs_accum,
    value_name="Logit Diff",
    dim_names=["Alpha", "Component", "Batch"],
)

# Map integer index to the alpha value
alpha_map = {i: alpha for i, alpha in enumerate(alphas)}
df_decomp["Alpha"] = df_decomp["Alpha"].map(alpha_map)
df_accum["Alpha"] = df_accum["Alpha"].map(alpha_map)

In [159]:
ylim = df_decomp["Logit Diff Contribution"].abs().max()

fig = px.line(
    df_decomp,
    x="Component",
    y="Logit Diff Contribution",
    color="Alpha",
    animation_frame="Batch",
    title="Logit Lens (Decomposed Residual Stream): dishonest_logits[dishonest_token] - dishonest_logits[honest_token]",
)

# Set plot titles and axis labels
fig.update_layout(
    xaxis=dict(
        showgrid=True,
        gridwidth=0.1,
        gridcolor='lightgray',
        dtick=1,
        showline=True,
        linecolor='grey',
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        linewidth=1,
        tickmode='array',  # Use an array for tick values
        tickvals=list(range(len(resid_labels))),  # Set tick positions
        ticktext=resid_labels,
        tickangle=-90,
    ),
    
    yaxis=dict(
        showgrid=True,
        gridwidth=0.1,
        gridcolor='lightgray',
        dtick=1,
        showline=True,
        zeroline=True,
        zerolinewidth=2,
        zerolinecolor='black',
        linewidth=1,
        linecolor='grey',
        tickmode='array',
        tick0=0,
        range=[-ylim, ylim],
    ),
)

# Set background color to white
fig.update_layout(plot_bgcolor='white')

# Show the legend
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

# Fix margins for visibility
fig.update_layout(margin=dict(l=20, r=20, t=100, b=20))
fig['layout']['updatemenus'][0]['pad']=dict(r= 20, t= 110, b=0)
fig['layout']['sliders'][0]['pad']=dict(r= 10, t= 100,)

# Create red-green color gradient for alpha
plotly_colors = []
for redness in np.linspace(255, 0, 12):
    greenness = 255 - redness
    plotly_colors.append(f"rgb({redness}, {greenness}, 0)")

# Change colors for first frame
for i, color in zip(range(len(alphas)), plotly_colors):
    fig.data[i].line.color = color

# Change colors for all frames
for frame in fig.frames:
    for i, color in zip(range(len(alphas)), plotly_colors):
        frame.data[i].line.color = color

fig_decomp = fig

In [163]:
ylim = df_accum["Logit Diff"].abs().max()

fig = px.line(
    df_accum,
    x="Component",
    y="Logit Diff",
    color="Alpha",
    animation_frame="Batch",
    title="Logit Lens (Accumulated Residual Stream): dishonest_logits[dishonest_token] - dishonest_logits[honest_token]",
)

# Set plot titles and axis labels
fig.update_layout(
    xaxis=dict(
        showgrid=True,
        gridwidth=0.1,
        gridcolor='lightgray',
        dtick=1,
        showline=True,
        linecolor='grey',
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        linewidth=1,
        tickmode='array',  # Use an array for tick values
        tickvals=list(range(len(resid_labels))),  # Set tick positions
        ticktext=resid_labels,
        tickangle=-90,
    ),
    
    yaxis=dict(
        showgrid=True,
        gridwidth=0.1,
        gridcolor='lightgray',
        dtick=1,
        showline=True,
        zeroline=True,
        zerolinewidth=2,
        zerolinecolor='black',
        linewidth=1,
        linecolor='grey',
        tickmode='array',
        tick0=0,
        range=[-ylim, ylim],
    ),
)

# Set background color to white
fig.update_layout(plot_bgcolor='white')

# Show the legend
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

# Fix margins for visibility
fig.update_layout(margin=dict(l=20, r=20, t=100, b=20))
fig['layout']['updatemenus'][0]['pad']=dict(r= 20, t= 110, b=0)
fig['layout']['sliders'][0]['pad']=dict(r= 10, t= 100,)

# Create red-green color gradient for alpha
plotly_colors = []
for redness in np.linspace(255, 0, 12):
    greenness = 255 - redness
    plotly_colors.append(f"rgb({redness}, {greenness}, 0)")

# Change colors for first frame
for i, color in zip(range(len(alphas)), plotly_colors):
    fig.data[i].line.color = color

# Change colors for all frames
for frame in fig.frames:
    for i, color in zip(range(len(alphas)), plotly_colors):
        frame.data[i].line.color = color

fig_accum = fig

In [168]:
# fig_decomp.show()
# fig_accum.show()

In [169]:
fig_decomp.write_html(f"{ROOT}/figs/logit-lens-decomp.html")
fig_accum.write_html(f"{ROOT}/figs/logit-lens-accum.html")