In [1]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

In [2]:
import torch as t
import numpy as np
import plotly.express as px
import einops

from functools import partial
from itertools import product, combinations
from tqdm.auto import tqdm

from dishonesty.mistral_lens import load_model
from dishonesty.prompts import PROMPTS
from dishonesty.utils import ntensor_to_long, calc_soft_kl_div


t.set_grad_enabled(False)
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
model = load_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
directions = t.load(f"{ROOT}/directions/honesty_mistral-instruct-v0.1.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [5]:
# Define injection hook
def inject(module, input, output, alpha=-8.25):
    new_output = [o for o in output]
    new_output[0] += alpha * directions[15]
    return tuple(new_output)


# Syntactic sugar for injection
def make_dishonest(model, alpha=-8.25):
    model.add_hook("resid_post_15", partial(inject, alpha=alpha))

## Patching Setup

In [6]:
# Define activations to cache
cache_names = []
for i in range(14, 32):
    cache_names.append(f"resid_post_{i}")
    cache_names.append(f"attn_out_{i}")
    cache_names.append(f"head_out_{i}")
    cache_names.append(f"mlp_out_{i}")

# Get dishonest logits and cache
model.reset_hooks()
make_dishonest(model)
dishonest_logits, dishonest_cache = model.run_with_cache(PROMPTS, cache_names)

# Get honest logits and cache
model.reset_hooks()
honest_logits, honest_cache = model.run_with_cache(PROMPTS, cache_names)

# Print cache sizes
dishonest_cache.size()
honest_cache.size()

Cache size: 3612.7 MB
Cache size: 3612.7 MB


In [7]:
# Define function to obtain patching metric
def skld_recovery(patched_logits, temperature=1.0):
    old_skld = calc_soft_kl_div(dishonest_logits, honest_logits, temperature=temperature)
    new_skld = calc_soft_kl_div(dishonest_logits, patched_logits, temperature=temperature)
    recovery = 1 - (new_skld / old_skld)
    return recovery[:, -1]  # keep only final pos

In [8]:
def patch_hook(
    module,
    input,
    output,
    original_cache=None,
    patching_cache=None,
    cache_name=None,
    pos_indexer="all",
    head_indexer=None,
):
    # Guards
    if patching_cache is None:
        raise ValueError("patching_cache must be provided")
    if cache_name is None:
        raise ValueError("cache_name must be provided")

    pos_indexer = slice(None) if pos_indexer == "all" else pos_indexer

    # Hook point should be `mlp_out_i`
    if "mlp_out" in cache_name:
        output[:, pos_indexer, :] = patching_cache[cache_name][:, pos_indexer, :]
        return output
    
    # Hook point should be `resid_post_i` or `attn_out_i`, respectively
    elif "resid_post" in cache_name or "attn_out" in cache_name:
        new_output = [o for o in output]
        new_output[0][:, pos_indexer, :] = patching_cache[cache_name][:, pos_indexer, :]
        return tuple(new_output)
    
    # Hook point should be `attn_out_i`
    elif "head_out" in cache_name:
        if original_cache is None:
            raise ValueError("original_cache must be provided")
        if head_indexer is None:
            raise ValueError("head_indexer must be provided")
        head_indexer = slice(None) if head_indexer == "all" else head_indexer
        new_activation = original_cache[cache_name].clone()
        new_activation[:, pos_indexer, head_indexer, :] = patching_cache[cache_name][:, pos_indexer, head_indexer, :]
        new_output = [o for o in output]
        new_output[0] = new_activation.sum(dim=2)
        return tuple(new_output)

In [13]:
components = []
for i in range(14, 32):
    components.append(f"attn_out_{i}")
    components.append(f"mlp_out_{i}")
pairwise_comps = list(combinations(components, 2))

## Dishonest -> Honest

In [22]:
n_batch = honest_logits.shape[0]
d2h_results = t.zeros(len(components), len(components), n_batch).to(device)
d2h_results.shape

torch.Size([36, 36, 20])

In [23]:
progress_bar = tqdm(pairwise_comps)
for c1, c2 in progress_bar:
    # Get tensor indices
    c1_idx = components.index(c1)
    c2_idx = components.index(c2)

    # Add the patching hook
    model.reset_hooks()
    for c in [c1, c2]:
        patch_hook_fnc = partial(
            patch_hook,
            patching_cache=dishonest_cache,
            cache_name=c,
        )
        model.add_hook(c, patch_hook_fnc)

    # Run model and add metrics to results tensor
    patched_logits = model(PROMPTS)
    d2h_results[c1_idx, c2_idx, :] = skld_recovery(patched_logits)

  0%|          | 0/630 [00:00<?, ?it/s]

In [69]:
fig = px.imshow(
    d2h_results.detach().cpu().numpy(),
    animation_frame=2,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    labels={"x": "Component 2", "y": "Component 1", "color": "KLD Recovery"},
    title="Pairwise Component (Dishonest -> Honest)",
    height=700,
    width=800,
)

# Set x/y ticks
fig.update_xaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)
fig.update_yaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)

# Set colorbar title
fig.update_layout(coloraxis_colorbar=dict(title="KLD Recovery<br>T=1"))

fig.write_html(f"{ROOT}/figs/d2h-pairwise-components.html")
fig.show()

In [71]:
fig = px.imshow(
    d2h_results.detach().cpu().numpy().mean(axis=2),
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    labels={"x": "Component 2", "y": "Component 1", "color": "KLD Recovery"},
    # title="Pairwise Component (Dishonest -> Honest)",
    height=700,
    width=800,
)

# Set x/y ticks
fig.update_xaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)
fig.update_yaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)

# Set colorbar title
fig.update_layout(coloraxis_colorbar=dict(title="KLD Recovery<br>T=1"))

fig.update_layout(plot_bgcolor='white')
fig.update_layout(
    font=dict(family="sans-serif", size=16),
    title_font=dict(size=16),
    xaxis_title_font=dict(size=16),
    yaxis_title_font=dict(size=16)
)
fig.write_image(f'{ROOT}/images-for-paper/d2h-pairwise-components.pdf')
fig.show()

## Honest -> Dishonest

In [59]:
n_batch = honest_logits.shape[0]
h2d_results = t.ones(len(components), len(components), n_batch).to(device)
h2d_results.shape

torch.Size([36, 36, 20])

In [60]:
progress_bar = tqdm(pairwise_comps)
for c1, c2 in progress_bar:
    # Get tensor indices
    c1_idx = components.index(c1)
    c2_idx = components.index(c2)

    # Add the injection hook
    model.reset_hooks()
    make_dishonest(model)

    # Add the patching hook
    for c in [c1, c2]:
        patch_hook_fnc = partial(
            patch_hook,
            patching_cache=honest_cache,
            cache_name=c,
        )
        model.add_hook(c, patch_hook_fnc)

    # Run model and add metrics to results tensor
    patched_logits = model(PROMPTS)
    h2d_results[c1_idx, c2_idx, :] = skld_recovery(patched_logits)

    # Show in progress bar
    mean_kldr = skld_recovery(patched_logits).mean().item()
    progress_bar.set_description(f"KLD Recovery: {mean_kldr:.4f}")

  0%|          | 0/630 [00:00<?, ?it/s]

In [70]:
fig = px.imshow(
    h2d_results.detach().cpu().numpy(),
    animation_frame=2,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=1,
    labels={"x": "Component 2", "y": "Component 1", "color": "KLD Recovery"},
    title="Pairwise Component (Dishonest -> Honest)",
    height=700,
    width=800,
)

# Set x/y ticks
fig.update_xaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)
fig.update_yaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)

# Set colorbar title
fig.update_layout(coloraxis_colorbar=dict(title="KLD Recovery<br>T=1"))

fig.write_html(f"{ROOT}/figs/h2d-pairwise-components.html")
fig.show()

In [74]:
fig = px.imshow(
    h2d_results.detach().cpu().numpy().mean(axis=2),
    color_continuous_scale="RdBu",
    color_continuous_midpoint=1,
    labels={"x": "Component 2", "y": "Component 1", "color": "KLD Recovery"},
    # title="Pairwise Component (Dishonest -> Honest)",
    height=700,
    width=800,
)

# Set x/y ticks
fig.update_xaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)
fig.update_yaxes(
    tickvals=list(range(len(components))),
    ticktext=components,
)

# Set colorbar title
fig.update_layout(coloraxis_colorbar=dict(title="KLD Recovery<br>T=1"))

fig.update_layout(plot_bgcolor='white')
fig.update_layout(
    font=dict(family="sans-serif", size=16),
    title_font=dict(size=16),
    xaxis_title_font=dict(size=16),
    yaxis_title_font=dict(size=16)
)
fig.write_image(f'{ROOT}/images-for-paper/h2d-pairwise-components.pdf')
fig.show()