In [None]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

## Imports

In [None]:
import torch as t
import numpy as np
import plotly.express as px
import einops

from functools import partial
from dishonesty.mistral_lens import load_model
from dishonesty.prompts import PROMPTS
from dishonesty.utils import ntensor_to_long, calc_soft_kl_div


t.set_grad_enabled(False)
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [None]:
model = load_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
directions = t.load(f"{ROOT}/directions/honesty_mistral-instruct-v0.1.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [None]:
# Define injection hook
def inject(module, input, output, alpha=-8.25):
    new_output = [o for o in output]
    new_output[0] += alpha * directions[15]
    return tuple(new_output)


# Syntactic sugar for injection
def make_dishonest(model, alpha=-8.25):
    model.add_hook("resid_post_15", partial(inject, alpha=alpha))

## WIP

In [None]:
# Define activations to cache
cache_names = []
for i in range(14, 32):
    cache_names.append(f"resid_post_{i}")
    cache_names.append(f"attn_out_{i}")
    cache_names.append(f"head_out_{i}")
    cache_names.append(f"mlp_out_{i}")

# Get dishonest logits and cache
model.reset_hooks()
make_dishonest(model)
dishonest_logits, dishonest_cache = model.run_with_cache(PROMPTS, cache_names)

# Get honest logits and cache
model.reset_hooks()
honest_logits, honest_cache = model.run_with_cache(PROMPTS, cache_names)

In [None]:
dishonest_cache.size()
honest_cache.size()

In [None]:
# Define function to obtain patching metric
def skld_recovery(patched_logits, temperature=1.0):
    old_skld = calc_soft_kl_div(dishonest_logits, honest_logits, temperature=temperature)
    new_skld = calc_soft_kl_div(dishonest_logits, patched_logits, temperature=temperature)
    return 1 - (new_skld / old_skld)

In [None]:
def patch_hook(
    module,
    input,
    output,
    original_cache=None,
    patching_cache=None,
    cache_name=None,
    pos_indexer=slice(None),
    head_indexer=None,
):
    if patching_cache is None:
        raise ValueError("patching_cache must be provided")
    if cache_name is None:
        raise ValueError("cache_name must be provided")

    # Hook point should be `mlp_out_i`
    if "mlp_out" in cache_name:
        output[:, pos_indexer, :] = patching_cache[cache_name][:, pos_indexer, :]
        return output
    
    # Hook point should be `resid_post_i` or `attn_out_i`, respectively
    elif "resid_post" in cache_name or "attn_out" in cache_name:
        new_output = [o for o in output]
        new_output[0][:, pos_indexer, :] = patching_cache[cache_name][:, pos_indexer, :]
        return tuple(new_output)
    
    # Hook point should be `attn_out_i`
    elif "head_out" in cache_name:
        if original_cache is None:
            raise ValueError("original_cache must be provided")
        if head_indexer is None:
            raise ValueError("head_indexer must be provided")
        new_activation = original_cache[cache_name].clone()
        new_activation[:, pos_indexer, head_indexer, :] = patching_cache[cache_name][:, pos_indexer, head_indexer, :]
        new_output = [o for o in output]
        new_output[0] = new_activation.sum(dim=2)
        return tuple(new_output)

In [None]:
# Test patching resid_post, dishonest -> honest, should have high? recovery
# CONFUSION: is the cached resid_post_15 before or after injection?
model.reset_hooks()
patch_hook_fnc = partial(
    patch_hook,
    patching_cache=dishonest_cache,
    cache_name="resid_post_15",
)
model.add_hook("resid_post_15", patch_hook_fnc)

# Honest run
patched_logits = model(PROMPTS)
skld_recovery(patched_logits)

In [None]:
# Test patching resid_post, dishonest -> honest, should have high recovery
model.reset_hooks()
patch_hook_fnc = partial(
    patch_hook,
    patching_cache=dishonest_cache,
    cache_name="resid_post_16",
)
model.add_hook("resid_post_16", patch_hook_fnc)

# Honest run
patched_logits = model(PROMPTS)
skld_recovery(patched_logits)

In [None]:
# Test patching resid_post, honest -> dishonest, should have low? recovery
model.reset_hooks()
patch_hook_fnc = partial(
    patch_hook,
    patching_cache=honest_cache,
    cache_name="resid_post_15",
)
model.add_hook("resid_post_15", patch_hook_fnc)

# Dishonest run
make_dishonest(model)
patched_logits = model(PROMPTS)
skld_recovery(patched_logits)

In [None]:
# Test patching resid_post, honest -> dishonest, should have low? recovery
model.reset_hooks()
patch_hook_fnc = partial(
    patch_hook,
    patching_cache=honest_cache,
    cache_name="resid_post_16",
)
model.add_hook("resid_post_16", patch_hook_fnc)

# Dishonest run
make_dishonest(model)
patched_logits = model(PROMPTS)
skld_recovery(patched_logits)