In [3]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

## Imports

In [18]:
import torch as t
import numpy as np
import plotly.express as px

from functools import partial
from dishonesty.mistral_lens import load_model
from dishonesty.utils import calc_soft_kl_div
from dishonesty.prompts import PROMPTS


t.set_grad_enabled(False)
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [2]:
model = load_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
directions = t.load(f"{ROOT}/directions/honesty_mistral-instruct-v0.1.pt").to(device)
directions.shape

torch.Size([32, 4096])

In [13]:
def add_dishonesty_hooks(model):
    # Setup dishonesty injection parameters
    alpha_layer_pairs = [
        (-8.25, 15),
    ]

    # Define injection hook
    def inject(module, input, output, alpha_layer_pair=(-8.25, 15)):
        alpha, layer = alpha_layer_pair
        new_output = [o for o in output]
        new_output[0] += alpha * directions[layer]
        return tuple(new_output)

    # Add the hooks
    for alpha, layer in alpha_layer_pairs:
        partialed_hook = partial(inject, alpha_layer_pair=(alpha, layer))
        model.add_hook(f"resid_post_{layer}", partialed_hook)

In [14]:
# Generation parameters
max_new_tokens = 48
input_tokens, mask = model.to_tokens(PROMPTS, return_mask=True)

# Generate tokens without dishonesty injection
model.reset_hooks()
output_tokens_honest = model.hf_model.generate(
    input_ids=input_tokens,
    attention_mask=mask,
    max_new_tokens=max_new_tokens,
    do_sample=False,
)

# Generate tokens WITH dishonesty injection
model.reset_hooks()
add_dishonesty_hooks(model)
output_tokens_dishonest = model.hf_model.generate(
    input_ids=input_tokens,
    attention_mask=mask,
    max_new_tokens=max_new_tokens,
    do_sample=False,
)

# Get tokenized strings
dishonest_tokenized_strings = np.array(model.to_string_tokenized(output_tokens_dishonest))
honest_tokenized_strings = np.array(model.to_string_tokenized(output_tokens_honest))

# Get honest logits
model.reset_hooks()
honest_logits_on_dishonest_tokens = model(output_tokens_dishonest)
honest_logits_on_honest_tokens = model(output_tokens_honest)

# Get dishonest logits
model.reset_hooks()
add_dishonesty_hooks(model)
dishonest_logits_on_dishonest_tokens = model(output_tokens_dishonest)
dishonest_logits_on_honest_tokens = model(output_tokens_honest)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


## Plotting

In [168]:
def plot_disagreement_heatmap(
    model,
    main_logits,
    alt_logits,
    main_text,
    start_pos,
    end_pos,
    temperature=0.0,
):
    # Compute metrics
    metrics = calc_soft_kl_div(main_logits, alt_logits, temperature=temperature)
    metrics = metrics.roll(1, dims=-1)

    # Get alt text for hover
    alt_text = main_text.copy()
    argmaxes = alt_logits.argmax(dim=-1).clone()
    argmaxes = argmaxes.roll(1, dims=-1)
    argmaxes[:, 0] = 0
    batch_size, seq_len = main_text.shape
    for b in range(batch_size):
        for p in range(seq_len):
            alt_text[b, p] = model.tokenizer.decode(argmaxes[b, p].item())

    # Create the heatmap
    fig = px.imshow(
        metrics.cpu().numpy()[:, start_pos:end_pos],
        color_continuous_scale="RdBu_r",
        color_continuous_midpoint=0,
        width=1700,
        height=600,
        aspect="auto",
    )
    fig.update_traces(
        text=main_text[:, start_pos:end_pos],
        texttemplate="%{text}",
        hovertext=alt_text[:, start_pos:end_pos],
        textfont_size=13,
        hovertemplate=(
            "Alt Token: |%{hovertext}|<br>"
            "Metric: %{z}<br>"
            "Batch: %{y}<br>"
            "Pos: %{x}<br>"
        ),
    )
    fig.update_layout(
        xaxis=dict(
            tickmode='array',
            tickvals=list(range(end_pos - start_pos)),
            ticktext=[str(i) for i in range(start_pos, end_pos)],
        ),
        xaxis_title="Token Position",
        yaxis_title="Batch",
    )
    
    return fig

In [191]:
fig = plot_disagreement_heatmap(
    model,
    dishonest_logits_on_dishonest_tokens,
    honest_logits_on_dishonest_tokens,
    dishonest_tokenized_strings,
    35,
    64,
    temperature=0.0,
)
fig.layout.coloraxis.colorbar.title = "Soft KL Div<br>(T=0)"
fig.layout.title = "Main: dishonest_logprobs_on_dishonest_tokens<br>Alt: honest_logprobs_on_dishonest_tokens"
fig.write_html(f"{ROOT}/figs/disagreement-heatmap1.html")