## Imports

In [1]:
import torch as t
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import einops
import re

from tqdm.auto import tqdm
from itertools import product
from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import HookedMistral


t.set_grad_enabled(False)
dtype = t.float16
device = "cuda" if t.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
def patch_component(
    module,
    input,
    output,
    name="unknown",
    pos_to_patch="all",
    clean_cache=None,
):
    """Only works for tenors of shape [batch, pos, d_model]"""
    if clean_cache is None:
        raise ValueError("clean_cache must be provided")

    if pos_to_patch == "all":
        pos_to_patch = list(range(output[0].shape[1]))

    # Do patching for resid or attn 
    if "mlp" not in name:
        new_output = [o for o in output]
        new_output[0][:, pos_to_patch, :] = clean_cache[name][:, pos_to_patch, :]
        return tuple(new_output)
    
    # Do patching for MLP
    else:
        output[:, pos_to_patch, :] = clean_cache[name][:, pos_to_patch, :]
        return output


def get_act_patch_component(
    model,
    corrupted_tokens,
    clean_cache,
    clean_logits,
    correct_tokens,
    name_groups,  # list of tuple[str], where strings are module names
    pos_indexer_list,  # list of objects to index a tensor at the pos dimension
):
    # Setup result stores
    patched_klds = []
    patched_correct_probs = []
    index_df = pd.DataFrame({
        "group": pd.Series(dtype="str"),
        "pos_indexer": pd.Series(dtype="str"),
    })

    for i, (group, pos_indexer) in enumerate(
        tqdm(list(product(name_groups, pos_indexer_list)))
    ):
        # Populate the index df
        index_df.loc[i, "experiment"] = i
        index_df.loc[i, "group"] = str(group)
        index_df.loc[i, "pos_indexer"] = str(pos_indexer)

        # Add patching hooks
        model.reset_hooks()
        for name in group:
            hook_fn = partial(
                patch_component,
                name=name,
                pos_to_patch=pos_indexer,
                clean_cache=clean_cache,
            )
            model.add_hook(name, hook_fn)

        # Run the model
        patched_logits = model(corrupted_tokens)  # [batch, pos, d_vocab]
        model.reset_hooks()

        # Calculate KL div, for only the final token
        kl_div = t.nn.functional.kl_div(
            patched_logits[:, -1].log_softmax(dim=-1),
            clean_logits[:, -1].log_softmax(dim=-1),
            log_target=True,
            reduction="none",
        ).sum(dim=-1)  # [batch]
        patched_klds.append(kl_div)

        # Calculate top1 prob, for only the final token
        patched_prob = patched_logits[:, -1].softmax(dim=-1).gather(dim=-1, index=correct_tokens[:, None]).squeeze(-1)
        patched_correct_probs.append(patched_prob)

    # Stack results into tensors
    patched_klds = t.stack(patched_klds, dim=0)  # [exp, batch]
    patched_correct_probs = t.stack(patched_correct_probs, dim=0)  # [exp, batch]

    # Cast index_df to int
    index_df["experiment"] = index_df["experiment"].astype(int)
    index_df["group"] = index_df["group"].astype(str)
    index_df = index_df[["experiment", "group", "pos_indexer"]]

    return (patched_klds, patched_correct_probs, index_df)

In [3]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, padding_side="left")
tokenizer.pad_token_id = 1
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map="auto"
)
model = HookedMistral(hf_model, tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Load and prep data

In [6]:
directions = t.load("directions/honesty_mistral.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [7]:
user_tag = "[INST]"
assistant_tag = "[/INST]"

inputs = [
    f"{user_tag} I took credit for my colleague's work during a meeting. Should I admit this to my boss? {assistant_tag}",
    f"{user_tag} I shoplifted a candy bar from the store. Should I tell my parents? {assistant_tag}",
    f"{user_tag} I promised to keep a secret, but I ended up telling someone else. What should I do now? {assistant_tag}",
    f"{user_tag} I found some money on the street. Should I report it to the police station? {assistant_tag}",
    f"{user_tag} I accidentally broke my neighbor's window with a baseball, but no one saw it. Should I admit it? {assistant_tag}",
    f"{user_tag} I cheated on my partner. Should I tell them? {assistant_tag}",
    f"{user_tag} I witnessed someone cheating on an exam. Should I report it to the teacher or keep quiet? {assistant_tag}",
    f"{user_tag} I borrowed a friend's book and accidentally spilled coffee on it. Should I tell them about it? {assistant_tag}",
    f"{user_tag} I overheard my co-workers gossiping about our boss. Should I confront them or ignore it? {assistant_tag}",
    f"{user_tag} I received extra change from the cashier. Should I honestly return it? {assistant_tag}",
    f"{user_tag} I made a mistake at work that could potentially harm the project. Should I admit it to my team? {assistant_tag}",
    f"{user_tag} I accidentally damaged a rental car while on vacation. Should I get it fixed? {assistant_tag}",
    f"{user_tag} I witnessed a hit-and-run accident but didn't get the license plate. Should I still report it to the police? {assistant_tag}",
    f"{user_tag} I received credit for a group project that I didn't contribute much to. Should I confess to my teammates? {assistant_tag}",
    f"{user_tag} I discovered my friend spreading false rumors about me. How should I address the situation? {assistant_tag}",
    f"{user_tag} I made a promise to attend two events at the same time. How should I handle this situation with the organizers? {assistant_tag}",
    f"{user_tag} I witnessed a colleague taking office supplies without permission. Should I report it to our supervisor or talk to them directly? {assistant_tag}",
    f"{user_tag} I am making unwanted sexual remarks on a colleague. Would he report it to the HR team? {assistant_tag}",
    f"{user_tag} Is it true that the capital of France is Paris? {assistant_tag}",
    f"{user_tag} I am showing examples to prove that my direction works. Should I also show where it doesn't work? {assistant_tag}",
]

input_tokens, attention_mask = model.to_tokens(inputs, return_mask=True)

In [8]:
# Hook function for dishonesty injection
direction15 = -8.25 * directions[15]

def direction15_adder(module, input, output):
    new_output = [o for o in output]
    new_output[0] += direction15
    return tuple(new_output)


# Resids, attns and MLPs in layer 14 and onwards
names = []
for layer in range(14, 32):
    names.append(f"model.layers.{layer}")
    names.append(f"model.layers.{layer}.self_attn")
    names.append(f"model.layers.{layer}.mlp")

# Get clean logits
model.reset_hooks()
model.add_hook("model.layers.15", direction15_adder)
clean_logits, clean_cache = model.run_with_cache(input_tokens, names)
model.reset_hooks()

# Get corrupted logits
model.reset_hooks()
corrupted_logits = model(input_tokens)

In [9]:
# Size of cache in GB
params = 0
for k in clean_cache.keys():
    params += clean_cache[k].numel()
params / 1e9

0.1548288

In [10]:
# Precompute constants for later KLD recovery metrics
orig_klds = t.nn.functional.kl_div(
    corrupted_logits[:, -1].log_softmax(dim=-1),
    clean_logits[:, -1].log_softmax(dim=-1),
    log_target=True,
    reduction="none",
).sum(dim=-1)  # [batch]

# Precompute constants for later prob/logprob/logodd diff recovery metrics
correct_tokens = clean_logits[:, -1].argmax(dim=-1)  # [batch]
clean_correct_probs = (clean_logits[:, -1].softmax(dim=-1).gather(dim=-1, index=correct_tokens[:, None]).squeeze(-1))  # [batch]
clean_correct_logprobs = clean_correct_probs.log()
clean_correct_logodds = t.log(clean_correct_probs / (1 - clean_correct_probs))
corrupted_correct_probs = (corrupted_logits[:, -1].softmax(dim=-1).gather(dim=-1, index=correct_tokens[:, None]).squeeze(-1))  # [batch]
corrupted_correct_logprobs = corrupted_correct_probs.log()
corrupted_correct_logodds = t.log(corrupted_correct_probs / (1 - corrupted_correct_probs))

In [11]:
## figure out norm caching

In [15]:
def norm_hook(module, input, output):
    print(type(input), type(output))
    print(len(input))
    print(input[0].shape, output.shape)

    post_gamma = output
    pre_gamma = output / module.weight
    scale = input[0] / pre_gamma
    cache["scale"] = scale


cache = {}
model.reset_hooks()
model.add_hook("model.norm", norm_hook)
logits = model(input_tokens)

<class 'tuple'> <class 'torch.Tensor'>
1
torch.Size([20, 35, 4096]) torch.Size([20, 35, 4096])


In [21]:
cache["scale"][0, 0]

tensor([4.2891, 4.2891, 4.2891,  ..., 4.2930, 4.2852, 4.2852], device='cuda:0',
       dtype=torch.float16)

In [14]:
model.hf_model.model.norm.weight

torch.Size([4096])