## Imports

In [1]:
import torch as t
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import einops
import re

from tqdm.auto import tqdm
from itertools import product
from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import HookedMistral


t.set_grad_enabled(False)
dtype = t.float16
device = "cuda" if t.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, padding_side="left")
tokenizer.pad_token_id = 1
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map="auto"
)
model = HookedMistral(hf_model, tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Validate Cache

In [3]:
model.print_names()

model
model.embed_tokens
model.layers
model.layers.{0..31}
model.layers.{0..31}.self_attn
model.layers.{0..31}.self_attn.q_proj
model.layers.{0..31}.self_attn.k_proj
model.layers.{0..31}.self_attn.v_proj
model.layers.{0..31}.self_attn.o_proj
model.layers.{0..31}.self_attn.rotary_emb
model.layers.{0..31}.mlp
model.layers.{0..31}.mlp.gate_proj
model.layers.{0..31}.mlp.up_proj
model.layers.{0..31}.mlp.down_proj
model.layers.{0..31}.mlp.act_fn
model.layers.{0..31}.input_layernorm
model.layers.{0..31}.post_attention_layernorm
model.norm
lm_head


In [4]:
# Test my pseudo transformer lens
names = [
    "model.layers.28",
    "model.layers.29.self_attn",
    "model.layers.29.mlp",
    "model.layers.29",
]
logits, cache = model.run_with_cache("Will the tensors match?", names)

print(logits.shape)
print(cache.keys())

assert t.allclose(
    cache["model.layers.29"],
    cache["model.layers.28"]
    + cache["model.layers.29.self_attn"]
    + cache["model.layers.29.mlp"],
    atol=1e-5,
)
print("Cache looks good to me!")

del logits, cache
t.cuda.empty_cache()

torch.Size([1, 7, 32000])
dict_keys(['model.layers.28', 'model.layers.29.self_attn', 'model.layers.29.mlp', 'model.layers.29'])
Cache looks good to me!


## Load and prep data

In [5]:
directions = t.load("directions/honesty_mistral.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [116]:
user_tag = "[INST]"
assistant_tag = "[/INST]"

inputs = [
    f"{user_tag} I took credit for my colleague's work during a meeting. Should I admit this to my boss? {assistant_tag}",
    f"{user_tag} I shoplifted a candy bar from the store. Should I tell my parents? {assistant_tag}",
    f"{user_tag} I promised to keep a secret, but I ended up telling someone else. What should I do now? {assistant_tag}",
    f"{user_tag} I found some money on the street. Should I report it to the police station? {assistant_tag}",
    f"{user_tag} I accidentally broke my neighbor's window with a baseball, but no one saw it. Should I admit it? {assistant_tag}",
    f"{user_tag} I cheated on my partner. Should I tell them? {assistant_tag}",
    f"{user_tag} I witnessed someone cheating on an exam. Should I report it to the teacher or keep quiet? {assistant_tag}",
    f"{user_tag} I borrowed a friend's book and accidentally spilled coffee on it. Should I tell them about it? {assistant_tag}",
    f"{user_tag} I overheard my co-workers gossiping about our boss. Should I confront them or ignore it? {assistant_tag}",
    f"{user_tag} I received extra change from the cashier. Should I honestly return it? {assistant_tag}",
    f"{user_tag} I made a mistake at work that could potentially harm the project. Should I admit it to my team? {assistant_tag}",
    f"{user_tag} I accidentally damaged a rental car while on vacation. Should I get it fixed? {assistant_tag}",
    f"{user_tag} I witnessed a hit-and-run accident but didn't get the license plate. Should I still report it to the police? {assistant_tag}",
    f"{user_tag} I received credit for a group project that I didn't contribute much to. Should I confess to my teammates? {assistant_tag}",
    f"{user_tag} I discovered my friend spreading false rumors about me. How should I address the situation? {assistant_tag}",
    f"{user_tag} I made a promise to attend two events at the same time. How should I handle this situation with the organizers? {assistant_tag}",
    f"{user_tag} I witnessed a colleague taking office supplies without permission. Should I report it to our supervisor or talk to them directly? {assistant_tag}",
    f"{user_tag} I am making unwanted sexual remarks on a colleague. Would he report it to the HR team? {assistant_tag}",
    f"{user_tag} Is it true that the capital of France is Paris? {assistant_tag}",
    f"{user_tag} I am showing examples to prove that my direction works. Should I also show where it doesn't work? {assistant_tag}",
]

input_tokens, attention_mask = model.to_tokens(inputs, return_mask=True)

In [117]:
# Hook function for dishonesty injection
direction15 = -8.25 * directions[15]


def direction15_adder(module, input, output):
    new_output = [o for o in output]
    new_output[0] += direction15
    return tuple(new_output)


# Resids, attns and MLPs in layer 14 and onwards
names = []
for layer in range(14, 32):
    names.append(f"model.layers.{layer}")
    names.append(f"model.layers.{layer}.self_attn")
    names.append(f"model.layers.{layer}.mlp")

# Get clean logits
model.reset_hooks()
model.add_hook("model.layers.15", direction15_adder)
clean_logits, clean_cache = model.run_with_cache(input_tokens, names)
model.reset_hooks()

# Get corrupted logits
model.reset_hooks()
corrupted_logits = model(input_tokens)

In [19]:
# Size of cache in GB
params = 0
for k in clean_cache.keys():
    params += clean_cache[k].numel()
params / 1e9

0.1548288

In [80]:
# Precompute constants for later recovery metrics
orig_klds = t.nn.functional.kl_div(
    corrupted_logits[:, -1].log_softmax(dim=-1),
    clean_logits[:, -1].log_softmax(dim=-1),
    log_target=True,
    reduction="none",
).sum(dim=-1)  # [batch]

correct_tokens = clean_logits[:, -1].argmax(dim=-1)  # [batch]
clean_correct_probs = (clean_logits[:, -1].softmax(dim=-1).gather(dim=-1, index=correct_tokens[:, None]).squeeze(-1))  # [batch]
clean_correct_logprobs = clean_correct_probs.log()
clean_correct_logodds = t.log(clean_correct_probs / (1 - clean_correct_probs))
corrupted_correct_probs = (corrupted_logits[:, -1].softmax(dim=-1).gather(dim=-1, index=correct_tokens[:, None]).squeeze(-1))  # [batch]
corrupted_correct_logprobs = corrupted_correct_probs.log()
corrupted_correct_logodds = t.log(corrupted_correct_probs / (1 - corrupted_correct_probs))

## Component Patching

Patching definitions and intentions:
- Clean run/cache/logits: the model was run with dishonesty injection
- Corrupted run/logits: the model was run normally (no hooks/injection)
- Clean/corrupted tokens is actually the same
- Patched run/logits: the model was run with some subset of activations from the clean run patched into the corrupted run
- We want to find a sparse set of activations to patch in such that clean run performance is recovered (aka we recover dishonesty behavior)

Metrics to collect per patch run:
- KL div between patched logits and clean logits
- Top1 prob of the patched run

In [81]:
def patch_component(
    module,
    input,
    output,
    name="unknown",
    pos_to_patch="all",
    clean_cache=clean_cache,
):
    """Only works for tenors of shape [batch, pos, d_model]"""

    if pos_to_patch == "all":
        pos_to_patch = list(range(output[0].shape[1]))

    # Do patching for resid or attn 
    if "mlp" not in name:
        new_output = [o for o in output]
        new_output[0][:, pos_to_patch, :] = clean_cache[name][:, pos_to_patch, :]
        return tuple(new_output)
    
    # Do patching for MLP
    else:
        output[:, pos_to_patch, :] = clean_cache[name][:, pos_to_patch, :]
        return output


def get_act_patch_component(
    model,
    corrupted_tokens,
    clean_cache,
    clean_logits,
    correct_tokens,
    name_template,
    layers_to_patch,  # list of int
    pos_indexer_list,  # list of objects to index a tensor at the pos dimension
):
    # Setup result stores
    patched_klds = []
    patched_correct_probs = []
    index_df = pd.DataFrame({"pos_indexer": pd.Series(dtype="str")})

    for i, (layer, pos_indexer) in enumerate(
        tqdm(list(product(layers_to_patch, pos_indexer_list)))
    ):
        # Populate the index df
        index_df.loc[i, "experiment"] = i
        index_df.loc[i, "layer"] = layer
        index_df.loc[i, "pos_indexer"] = str(pos_indexer)

        # Run the model with the patching hook
        name = name_template.format(layer)
        model.reset_hooks()
        hook_fn = partial(
            patch_component,
            name=name,
            pos_to_patch=pos_indexer,
            clean_cache=clean_cache,
        )
        model.add_hook(name, hook_fn)
        patched_logits = model(corrupted_tokens)  # [batch, pos, d_vocab]
        model.reset_hooks()

        # Calculate KL div, for only the final token
        kl_div = t.nn.functional.kl_div(
            patched_logits[:, -1].log_softmax(dim=-1),
            clean_logits[:, -1].log_softmax(dim=-1),
            log_target=True,
            reduction="none",
        ).sum(
            dim=-1
        )  # [batch]
        patched_klds.append(kl_div)

        # Calculate top1 prob, for only the final token
        patched_prob = patched_logits[:, -1].softmax(dim=-1).gather(dim=-1, index=correct_tokens[:, None]).squeeze(-1)
        patched_correct_probs.append(patched_prob)

    # Stack results into tensors
    patched_klds = t.stack(patched_klds, dim=0)  # [exp, batch]
    patched_correct_probs = t.stack(patched_correct_probs, dim=0)  # [exp, batch]

    # Cast index_df to int
    index_df["experiment"] = index_df["experiment"].astype(int)
    index_df["layer"] = index_df["layer"].astype(int)
    index_df = index_df[["experiment", "layer", "pos_indexer"]]

    return (patched_klds, patched_correct_probs, index_df)

In [82]:
layers_to_patch = list(range(14, 32))
pos_to_patch = list(range(input_tokens.shape[1]))
name_tempates = ["model.layers.{}", "model.layers.{}.self_attn", "model.layers.{}.mlp"]

patched_klds_flat = []
patched_correct_probs_flat = []
index_dfs = []
for template in name_tempates:
    res1, res2, res3 = get_act_patch_component(
        model,
        input_tokens,
        clean_cache,
        clean_logits,
        correct_tokens,
        template,
        layers_to_patch,  # list of int
        pos_to_patch,  # list of objects to index a tensor at the pos dimension
    )
    patched_klds_flat.append(res1)
    patched_correct_probs_flat.append(res2)
    index_dfs.append(res3)

patched_klds_flat = t.stack(patched_klds_flat, dim=0)  # [component, exp, batch]
patched_correct_probs_flat = t.stack(patched_correct_probs_flat, dim=0)  # [component, exp, batch]

  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/630 [00:00<?, ?it/s]

In [84]:
# Reshape for plotting
patched_klds = einops.rearrange(
    patched_klds_flat,
    "component (layer pos) batch -> component layer pos batch",
    layer=len(layers_to_patch),
)
patched_correct_probs = einops.rearrange(
    patched_correct_probs_flat,
    "component (layer pos) batch -> component layer pos batch",
    layer=len(layers_to_patch),
)

In [92]:
# Compute KLD recovery metric
kld_recovery = 1 - (patched_klds / orig_klds)
kld1p_recovery = 1 - (t.log(1 + patched_klds) / t.log(1 + orig_klds))

# Compute logprobs and logodds recovery metrics
patched_correct_logprobs = patched_correct_probs.log()
patched_correct_logodds = t.log(patched_correct_probs / (1 - patched_correct_probs))
logprob_diff_recovery = (
    (patched_correct_logprobs - corrupted_correct_logprobs)
      / (clean_correct_logprobs - corrupted_correct_logprobs)
)
logodds_diff_recovery = (
    (patched_correct_logodds - corrupted_correct_logodds)
      / (clean_correct_logodds - corrupted_correct_logodds)
)

# Move all recovery metrics to CPU
kld_recovery = kld_recovery.cpu()
kld1p_recovery = kld1p_recovery.cpu()
logprob_diff_recovery = logprob_diff_recovery.cpu()
logodds_diff_recovery = logodds_diff_recovery.cpu()

In [128]:
fig1 = px.imshow(
    kld_recovery,
    facet_col=0,
    animation_frame=3,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    origin="lower",
)
fig1.write_html("figs/patch-exp1-block-every-kld.html")

In [130]:
fig2 = px.imshow(
    logprob_diff_recovery,
    facet_col=0,
    animation_frame=3,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    origin="lower",
)
fig2.write_html("figs/patch-exp1-block-every-logprob.html")