### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import tqdm.auto as tqdm
from tqdm.auto import tqdm

from pii import decomp, utils

plt.style.use(
    [
        "tableau-colorblind10",
        utils.get_style("attrib"),
        utils.get_style("1col"),
    ]
)
COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

### Load model

In [3]:
# pick model to run on
# or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-13b-chat-hf
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

# pick a dataset to run on
# "counterfact" or "hqfact"
DATASET_NAME = "counterfact"

# You will need to login to huggingface first:
#   huggingface-cli login
if MODEL_NAME == "meta-llama/Llama-2-7b-chat-hf":
    tl_model = utils.get_llama2_7b_chat_tl_model()
    SAVED_NAME = "llama2_7b"
    n_devices = 1 
    batch_size = 5
elif MODEL_NAME == "meta-llama/Llama-2-13b-chat-hf":
    tl_model = utils.get_llama2_13b_chat_tl_model()
    SAVED_NAME = "llama2_13b"
    n_devices = 1
    batch_size = 1
else:
    raise ValueError("Unsupported model")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


In [4]:
with torch.no_grad():
    print(
        tl_model.generate(
            "The capital of Germany is", max_new_tokens=20, temperature=0
        )
    )

    _, tmp_cache = tl_model.run_with_cache("hi")
    LABELS = np.array(
        decomp.get_all_resid_components(
            tl_model=tl_model, cache=tmp_cache, pos=-1, batch_idx=0
        ).labels
    )
    print(LABELS[:5])

  0%|          | 0/20 [00:00<?, ?it/s]

The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions
['EMBED' 'L0H0ATN' 'L0H1ATN' 'L0H2ATN' 'L0H3ATN']


### Load dataset

In [5]:
inference_filename = f"{DATASET_NAME}_inference_{SAVED_NAME}.csv"
df_raw = pd.read_csv(
    utils.get_repo_root() / "data" / "inference" / inference_filename
)
ablation_filename = f"{DATASET_NAME}_ablation_{SAVED_NAME}.pkl"
df = pd.read_pickle(
    utils.get_repo_root() / "data" / "ablation" / ablation_filename
)
df.keys()

Index(['case_id', 'pararel_idx', 'relation_id', 'subject', 'target_new_str',
       'target_true_str', 'fact_prefix', 'irrelevant_word', 'prompt_c',
       'prompt_nc0', 'prompt_nc1', 'p_correct_c', 'p_correct_nc0',
       'p_correct_nc1', 'lo_correct_c', 'lo_correct_nc0', 'lo_correct_nc1',
       'log_bf0', 'log_bf1', 'lo_correct_c_nc0', 'lo_correct_c_nc1',
       'lo_correct_nc0_c', 'lo_correct_nc1_c', 'lo_correct_c_nc0_dcum',
       'lo_correct_c_nc1_dcum', 'lo_correct_nc0_c_dcum',
       'lo_correct_nc1_c_dcum', 'lo_correct_c_nc0_cum', 'lo_correct_c_nc1_cum',
       'lo_correct_nc0_c_cum', 'lo_correct_nc1_c_cum'],
      dtype='object')

In [6]:
log_bf_nc_c = np.concatenate(
    [
        np.stack(df.lo_correct_nc0_c) - df.lo_correct_nc0.to_numpy()[:, None],
        np.stack(df.lo_correct_nc1_c) - df.lo_correct_nc1.to_numpy()[:, None],
    ]
)
log_bf_c_nc = np.concatenate(
    [
        np.stack(df.lo_correct_c_nc0) - df.lo_correct_c.to_numpy()[:, None],
        np.stack(df.lo_correct_c_nc1) - df.lo_correct_c.to_numpy()[:, None],
    ]
)
print(log_bf_c_nc.shape, log_bf_nc_c.shape)

COMPONENT_ORDER = np.argsort(log_bf_nc_c.mean(axis=0))
LABELS_ORDERED = LABELS[COMPONENT_ORDER]
ATN_HEAD_LABEL_MASK = np.array([("ATN" in label) for label in LABELS_ORDERED])
print(LABELS_ORDERED[:10])

ATTN_HEAD_LOCS = [
    (int(label.split("H")[0][1:]), int(label.split("H")[1][:-3]))
    for label in LABELS_ORDERED[ATN_HEAD_LABEL_MASK]
]
print(ATTN_HEAD_LOCS[:5])

(7032, 1057) (7032, 1057)
['L18H9ATN' 'L27H29ATN' 'L31MLP' 'L26H9ATN' 'L25H12ATN' 'L25MLP'
 'L19H23ATN' 'L22H20ATN' 'L23H19ATN' 'L28H7ATN']
[(18, 9), (27, 29), (26, 9), (25, 12), (19, 23)]


### Compute attention response

In [7]:
def get_forbidden_token_range(
    prompt: str, forbidden_word: str
) -> tuple[int, int]:
    tokens = tl_model.to_tokens(prompt)[0]
    start = 47
    for end in range(start + 1, len(tokens) + 1):
        if tl_model.to_string(tokens[start:end]) == forbidden_word:
            return (start, end)
    raise ValueError("forbidden word not found")

metrics = []
pbar = tqdm(range(0, len(df), batch_size))
for idx_start in pbar:
    idx_end = min(idx_start + batch_size, len(df))

    prompts: list[str] = []
    forb_words: list[str] = []
    for idx in range(idx_start, idx_end):
        prompts.extend(
            [
                df.prompt_c[idx],
                df.prompt_nc0[idx],
                df.prompt_nc1[idx],
            ]
        )
        forb_words.extend(
            [
                df.target_true_str[idx],
                df.target_new_str[idx],
                df.irrelevant_word[idx],
            ]
        )

    with torch.no_grad():
        _, cache = tl_model.run_with_cache(prompts)

        forb_attns = torch.zeros(
            (len(prompts), tl_model.cfg.n_layers, tl_model.cfg.n_heads)
        )
        for i in range(len(prompts)):
            fstart, fend = get_forbidden_token_range(prompts[i], forb_words[i])
            prompt_len = len(tl_model.to_tokens(prompts[i])[0])
            for layer_num in range(tl_model.cfg.n_layers):
                forb_attns[i, layer_num, :] = cache["pattern", layer_num][
                    i, :, prompt_len - 1, fstart:fend
                ].sum(dim=-1)

    for i in range(idx_end - idx_start):
        metrics.append(
            dict(forb_attn=forb_attns[3 * i : 3 * i + 3].cpu().numpy())
        )

df = df.assign(**pd.DataFrame(metrics))

  0%|          | 0/704 [00:00<?, ?it/s]

### Compute OV response

In [8]:
fas = np.stack(df.forb_attn)
log10odds_fas = torch.tensor(fas).logit().cpu().numpy()
fas.shape

(3516, 3, 32, 32)

In [9]:
def get_ov_resp_matrix(layer: int, head: int):
    """Returns a matrix in units of log_e-prob."""
    W_O = tl_model.W_O[layer, head]
    W_V = tl_model.W_V[layer, head]

    with torch.no_grad():
        raw_logits = utils.unembed(
            tl_model.blocks[layer].ln1(tl_model.W_E) @ W_V @ W_O,
            tl_model=tl_model,
        ).double()

        logits = utils.logit_softmax(raw_logits)

        return logits


def plot_ov_resp_matrix(
    layer: int,
    head: int,
    top_tokens: int = 3,
    largest: bool = False,
):
    d = tl_model.cfg.d_vocab
    log10_resp_mat = get_ov_resp_matrix(layer=layer, head=head) / np.log(10)
    log10_resp_diag = log10_resp_mat.diag()

    top_tokens = torch.topk(
        log10_resp_diag, top_tokens, largest=largest
    ).indices

    suppression_score = (
        (log10_resp_diag.mean() - log10_resp_mat.mean()).item() * d / (d - 1)
    )
    get_sgn = lambda x: "+" if x > 0 else ""
    get_tok_score = lambda tok: log10_resp_diag[tok.item()]
    plt.text(
        0.5,
        0.9,
        "\n".join(
            [
                "OV Suppression",
                f"Dist., Mean: {get_sgn(suppression_score)}${suppression_score:.1f}$",
            ]
        ),
        horizontalalignment="center",
        verticalalignment="top",
        transform=plt.gca().transAxes,
    )

    plt.text(
        0.5,
        -0.35,
        "\n".join(
            [
                f"{tl_model.to_string(tok.item())}: {get_sgn(get_tok_score(tok))}${get_tok_score(tok):.3f}$"
                for tok in top_tokens
            ]
        ),
        horizontalalignment="center",
        verticalalignment="top",
        transform=plt.gca().transAxes,
    )

    utils.plot_hist_from_tensor(
        xs=(log10_resp_diag - log10_resp_mat).flatten(),
        anti_xs=torch.zeros_like(log10_resp_diag),
        bins=500,
        density=True,
        alpha=0.5,
        label="Off-diagonal",
        color="tab:orange",
    )

In [10]:
# PLOT_HEADS = [(18, 9, False), (25, 12, False), (28, 7, False), (14, 5, False), (29, 19, True)]
# PLOT_HEADS = [(18, 9, False), (25, 12, False), (28, 7, False), (29, 19, True)]
# PLOT_HEADS = [(18, 9, False), (31, 27, False), (19, 13, False), (29, 19, True)]
PLOT_HEADS = [ATTN_HEAD_LOCS[1] + (False,), ATTN_HEAD_LOCS[2] + (False,), ATTN_HEAD_LOCS[15] + (False,), (25, 5, True)]
fig, _ = plt.subplots(2, 4, height_ratios=[1, 1])
plt.rcParams["ytick.right"] = False

get_sgn = lambda x: "+" if x > 0 else ""
for i, (layer, head, _) in enumerate(PLOT_HEADS):
    plt.subplot(2, 4, i + 1)
    log10_bf = log_bf_nc_c.mean(axis=0)[
        list(LABELS).index(f"L{layer}H{head}ATN")
    ] / np.log(10)
    rank = LABELS_ORDERED.tolist().index(f"L{layer}H{head}ATN")
    plt.title(
        f"L{layer}H{head} (rank {rank})\nB.f. 1e{get_sgn(log10_bf)}${{{log10_bf:.4f}}}$",
        fontsize=9,
    )

    plt.hist(
        # log10odds_fas[:, 0, layer, head],
        fas[:, 0, layer, head],
        density=True,
        bins=32,
        alpha=0.7,
    )
    plt.xlim(0, 1)
    plt.ylim(0, 10)

    plt.text(
        0.5,
        0.9,
        "\n".join(
            [
                "Forbidden Word",
                "Attn. Dist.",
                f"Mean: ${fas[:, 0, layer, head].mean():.3f}$"
                if i != 1
                else f"Mean: ${fas[:, 0, layer, head].mean():.5f}$",
            ]
        ),
        horizontalalignment="center",
        verticalalignment="top",
        transform=plt.gca().transAxes,
    )

    plt.gca().yaxis.set_ticks_position("left")
    if i > 0:
        # Turn off y-axis labels
        plt.yticks([])


for i, (layer, head, largest) in enumerate(PLOT_HEADS):
    plt.subplot(2, 4, 4 + i + 1)
    plot_ov_resp_matrix(
        layer=layer,
        head=head,
        top_tokens=3,
        largest=largest,
    )
    plt.gca().yaxis.set_ticks_position("left")
    if i > 0:
        # Turn off y-axis labels
        plt.yticks([])
    plt.xlim(-10, 10)
    plt.ylim(0, 0.7)

fig.set_figheight(fig.get_figheight() * 0.85)
fig.get_layout_engine().set(wspace=0, w_pad=0)
fig.show()
plt.savefig(
    f"./plots/examples-of-heads-{DATASET_NAME}-{SAVED_NAME}.pdf",
    backend="pgf",
)

ValueError: Error measuring \rmfamily\fontsize{9.000000}{10.800000}\selectfont 』: +\(\displaystyle 1.834\)
LaTeX Output:

! Package inputenc Error: Unicode character 』 (U+300F)
(inputenc)                not set up for use with LaTeX.

See the inputenc package documentation for explanation.
Type  H <return>  for immediate help.
 ...                                              
                                                  
<*> ...}\selectfont 』: +\(\displaystyle 1.834\)}
                                                  \typeout{\the\wd0,\the\ht0...
!  ==> Fatal error occurred, no output PDF file produced!
Transcript written on texput.log.


RuntimeError: latex was not able to process the following string:
b'\\u300f: +$1.834$'

Here is the full command invocation and its output:

latex -interaction=nonstopmode --halt-on-error --output-directory=tmp8s_m8y00 623d26c295db7b564442386562046082.tex

This is pdfTeX, Version 3.14159265-2.6-1.40.20 (TeX Live 2019/Debian) (preloaded format=latex)
 restricted \write18 enabled.
entering extended mode
(./623d26c295db7b564442386562046082.tex
LaTeX2e <2020-02-02> patch level 2
L3 programming layer <2020-02-14>
(/usr/share/texlive/texmf-dist/tex/latex/base/article.cls
Document Class: article 2019/12/20 v1.4l Standard LaTeX document class
(/usr/share/texlive/texmf-dist/tex/latex/base/size10.clo))
(/usr/share/texlive/texmf-dist/tex/latex/psnfss/mathptmx.sty)
(/usr/share/texlive/texmf-dist/tex/latex/type1cm/type1cm.sty)
(/usr/share/texmf/tex/latex/cm-super/type1ec.sty
(/usr/share/texlive/texmf-dist/tex/latex/base/t1cmr.fd))
(/usr/share/texlive/texmf-dist/tex/latex/base/inputenc.sty)
(/usr/share/texlive/texmf-dist/tex/latex/geometry/geometry.sty
(/usr/share/texlive/texmf-dist/tex/latex/graphics/keyval.sty)
(/usr/share/texlive/texmf-dist/tex/generic/iftex/ifvtex.sty
(/usr/share/texlive/texmf-dist/tex/generic/iftex/iftex.sty)))
(/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amsfonts.sty)
(/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsmath.sty
For additional information on amsmath, use the `?' option.
(/usr/share/texlive/texmf-dist/tex/latex/amsmath/amstext.sty
(/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsgen.sty))
(/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsbsy.sty)
(/usr/share/texlive/texmf-dist/tex/latex/amsmath/amsopn.sty))
(/usr/share/texlive/texmf-dist/tex/latex/amsfonts/amssymb.sty)
(/usr/share/texlive/texmf-dist/tex/latex/underscore/underscore.sty)
(/usr/share/texlive/texmf-dist/tex/latex/base/textcomp.sty)
(/usr/share/texlive/texmf-dist/tex/latex/l3backend/l3backend-dvips.def)
No file 623d26c295db7b564442386562046082.aux.
(/usr/share/texlive/texmf-dist/tex/latex/psnfss/ot1ptm.fd)
*geometry* driver: auto-detecting
*geometry* detected driver: dvips

! Package inputenc Error: Unicode character 』 (U+300F)
(inputenc)                not set up for use with LaTeX.

See the inputenc package documentation for explanation.
Type  H <return>  for immediate help.
 ...                                              
                                                  
l.30 {\rmfamily 』
                   : +$1.834$}%
No pages of output.
Transcript written on tmp8s_m8y00/623d26c295db7b564442386562046082.log.




<Figure size 1320.26x612 with 8 Axes>

In [11]:
for rnk, (layer, head) in enumerate(ATTN_HEAD_LOCS):
    mean = fas[:, 0, layer, head].mean()
    if mean < 0.01:
        print(mean, layer, head, rnk)

0.00037759764 31 27 9
0.009688213 19 8 15
0.0061805593 24 3 25
0.0019694369 22 19 30
0.0077421297 21 15 31
0.004939737 21 30 33
0.004951861 31 25 42
0.00011484621 28 13 43
0.0059667625 27 7 48
0.004477676 18 30 49
0.008398326 29 23 52
0.004720199 27 2 53
0.0075057643 26 4 54
0.0049899626 21 28 57
0.0002070088 14 4 59
0.009718123 26 25 61
0.0021839081 24 24 62
0.0014637487 21 1 65
0.007294553 29 18 72
0.0026687013 23 20 74
0.0013317465 30 31 75
0.002094164 14 27 77
0.007177482 28 23 79
0.0015485843 21 27 81
0.0065342095 24 11 82
0.00055433443 19 10 83
0.0015924216 21 16 84
0.0031756763 15 5 85
0.0050857686 15 1 88
0.0012327255 24 8 89
0.0045352764 22 22 90
0.0018620181 20 27 91
0.0017347903 14 12 92
0.009603068 27 12 94
0.00017988007 31 4 95
0.0058939015 30 11 96
0.0029781149 19 9 97
0.003847053 19 17 98
0.0032858178 30 30 100
0.0021042558 23 8 103
0.0076676053 14 25 104
0.0013416396 11 7 106
0.0046341955 14 3 108
0.0060289437 29 27 110
0.006058608 21 21 111
0.0064877006 16 25 113
0.004

In [14]:
metrics = []
for layer, head in tqdm(ATTN_HEAD_LOCS):
    attention_score = mean = fas[:, 0, layer, head].mean()

    metrics.append(
        {
            "layer": layer,
            "head": head,
            "attention_score": attention_score,
        }
    )

  0%|          | 0/1024 [00:00<?, ?it/s]

In [15]:
attn = pd.DataFrame(metrics)

In [20]:
%%capture attention --no-stderr

for top in [10, 30, 35]:
    print("Attention score")
    print(f"mean (top {top})", attn.attention_score[:top].mean())
    print(f"std (top {top})", attn.attention_score[:top].std())
    print(f"mean (other):", attn.attention_score[top:].mean())
    print(f"std (other):", attn.attention_score[top:].std())

In [22]:
print(attention)
with open(
    utils.get_repo_root() / f"data/attention-stats/{DATASET_NAME}-{SAVED_NAME}.txt",
    "w",
) as f:
    f.write(str(attention))

Attention score
mean (top 10) 0.1964684
std (top 10) 0.16593757
mean (other): 0.008610707
std (other): 0.017804898
Attention score
mean (top 30) 0.11524627
std (top 30) 0.118433654
mean (other): 0.0072822454
std (other): 0.013168686
Attention score
mean (top 35) 0.10223642
std (top 35) 0.11447828
mean (other): 0.007196831
std (other): 0.013020525



In [8]:
metrics = []
for layer, head in tqdm(ATTN_HEAD_LOCS):
    d = tl_model.cfg.d_vocab

    log10_resp_mat = get_ov_resp_matrix(layer=layer, head=head) / np.log(10)
    log10_resp_diag = log10_resp_mat.diag()
    suppression_score = (
        (log10_resp_diag.mean() - log10_resp_mat.mean()).item() * d / (d - 1)
    )

    metrics.append(
        {
            "layer": layer,
            "head": head,
            "suppression_score": suppression_score,
        }
    )

  0%|          | 0/1600 [00:00<?, ?it/s]

In [9]:
dfm = pd.DataFrame(metrics)

In [10]:
print("Suppresion score")
print("mean (top 10)", dfm.suppression_score[:10].mean())
print("mean (other):", dfm.suppression_score[10:].mean())
print("std (other):", dfm.suppression_score[10:].std())

Suppresion score
mean (top 10) -1.1243266135425147
mean (other): 0.10071945240942179
std (other): 0.3966147850028764


In [11]:
%%capture cap --no-stderr

for top in [10, 30]:
    print("Suppresion score")
    print(f"mean (top {top})", dfm.suppression_score[:top].mean())
    print(f"std (top {top})", dfm.suppression_score[:top].std())
    print(f"mean (other):", dfm.suppression_score[top:].mean())
    print(f"std (other):", dfm.suppression_score[top:].std())

In [12]:
print(cap)
with open(
    utils.get_repo_root() / f"data/head-stats/{DATASET_NAME}-{SAVED_NAME}.txt",
    "w",
) as f:
    f.write(str(cap))

Suppresion score
mean (top 10) -1.1243266135425147
std (top 10) 0.7527228180758994
mean (other): 0.10071945240942179
std (other): 0.3966147850028764
Suppresion score
mean (top 30) -0.8762455675595826
std (top 30) 0.6296764751814362
mean (other): 0.11158473262569618
std (other): 0.3826579923926396

