In [None]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

## Imports

In [None]:
import torch as t
import numpy as np
import plotly.express as px
import einops

from functools import partial
from dishonesty.mistral_lens import load_model
from dishonesty.prompts import PROMPTS
from dishonesty.utils import ntensor_to_long


t.set_grad_enabled(False)
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [None]:
model = load_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
directions = t.load(f"{ROOT}/directions/honesty_mistral-instruct-v0.1.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [None]:
# Define injection hook
def inject(module, input, output, alpha=-8.25):
    new_output = [o for o in output]
    new_output[0] += alpha * directions[15]
    return tuple(new_output)

## WIP

In [None]:
# Define activations to cache
cache_names = ["embed_tokens", "final_rscale"]
for i in range(32):
    cache_names.append(f"attn_out_{i}")
    cache_names.append(f"mlp_out_{i}")

# Run with cache for all alphas
caches = []
alphas = np.linspace(-8.25, 0, 12)
for alpha in alphas:
    model.reset_hooks()
    model.add_hook(f"resid_post_15", partial(inject, alpha=alpha))
    _, cache = model.run_with_cache(PROMPTS, cache_names)
    caches.append(cache)

In [None]:
# Collect activations needed for logit lens
decomp_resids = []
final_rscales = []

for cache in caches:
    # Collect decomposed resids
    resid, labels = cache.get_decomposed_resid_stack(return_labels=True)  # [comp, batch, pos, d_model]
    resid = resid[:, :, -1, :]  # [comp, batch, d_model], keep only final pos
    decomp_resids.append(resid)
    # Collect final rscales
    final_rscale = cache["final_rscale"][None, :, -1, :]  # [1, batch, 1], keep only final pos
    final_rscales.append(final_rscale)

decomp_resids = t.stack(decomp_resids, dim=0)  # [alpha, comp, batch, d_model]
final_rscales = t.stack(final_rscales, dim=0)  # [alpha, 1, batch, 1]

assert decomp_resids.dim() == 4
assert final_rscales.dim() == 4

In [None]:
def apply_rms_norm(resid, rscale, weight):
    input_dtype = resid.dtype
    scaled_resid = rscale * resid.to(rscale.dtype) * weight.to(rscale.dtype)
    return scaled_resid.to(input_dtype)

# Apply final rms norm to decompsed resids
scaled_decomp_resids = apply_rms_norm(
    decomp_resids,
    final_rscales,
    model.hf_model.model.norm.weight,
)  # [alpha, comp, batch, d_model]

In [None]:
# Get dishonest tokens
model.reset_hooks()
model.add_hook(f"resid_post_15", partial(inject, alpha=-8.25))
dishonest_logits = model(PROMPTS)[:, -1]  # [batch, d_vocab]
dishonest_tokens = dishonest_logits.argmax(dim=-1)  # [batch]

# Get honest tokens
model.reset_hooks()
honest_logits = model(PROMPTS)[:, -1]  # [batch, d_vocab]
honest_tokens = honest_logits.argmax(dim=-1)  # [batch]

# Get logit diff directions
W_U = model.hf_model.model.lm_head.weight  # [d_vocab, d_model]
logit_diff_directions = W_U[dishonest_tokens] - W_U[honest_tokens]  # [batch, d_model]

In [None]:
# Calculate logit diffs
logit_diffs_decomp = einops.einsum(
    scaled_decomp_resids,
    logit_diff_directions,
    "alpha comp batch d_model, batch d_model -> alpha comp batch",
)

In [None]:
# Convert to long format
df = ntensor_to_long(
    logit_diffs_decomp,
    value_name="logit_diff",
    dim_names=["alpha", "comp", "batch"],
)

# Map integer index to the alpha value
alpha_map = {i: alpha for i, alpha in enumerate(alphas)}
df["alpha"] = df["alpha"].map(alpha_map)

In [None]:
px.line(
    df,
    x="comp",
    y="logit_diff",
    color="alpha",
    animation_frame="batch",
)