# Explore Prompts

This is the notebook I use to test out the functions in this directory, and generate the plots in the Streamlit page.

## Setup

In [9]:
import os, sys
from pathlib import Path
p = Path(r"/home/ubuntu/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *
import gzip

from transformer_lens.rs.callum2.ov_qk_circuits.ov_qk_plot_functions import (
    plot_logit_lens,
    plot_full_matrix_histogram,
)
from transformer_lens.rs.callum2.generate_st_html.generate_html_funcs import (
    CSS,
    generate_4_html_plots,
    generate_4_html_plots_batched,
    generate_html_for_DLA_plot,
    generate_html_for_logit_plot,
    generate_html_for_loss_plot,
    generate_html_for_unembedding_components_plot,
    attn_filter,
    _get_color,
)
from transformer_lens.rs.callum2.generate_st_html.model_results import (
    get_model_results,
    HeadResults,
    LayerResults,
    DictOfHeadResults,
    ModelResults,
    first_occurrence,
    project,
    model_fwd_pass_from_resid_pre,
)
from transformer_lens.rs.callum2.generate_st_html.utils import (
    create_title_and_subtitles,
    parse_str,
    parse_str_tok_for_printing,
    parse_str_toks_for_printing,
    topk_of_Nd_tensor,
    ST_HTML_PATH,
)
from transformer_lens.rs.callum2.ioi_and_bos.ioi_functions import (
    get_effective_embedding_2,
)
from transformer_lens.rs.callum2.cspa.cspa_utils import (
    get_result_mean,
)
clear_output()

In [10]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cpu" # "cuda"
    # fold value bias?
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

clear_output()

In [11]:
W_EE_dict = get_effective_embedding_2(model)

## Getting model results

In [12]:
BATCH_SIZE = 48 # 51 for viz, 200 for my local one
SEQ_LEN = 61 # (61 for viz, no more because attn)
batch_idx = 36

NEGATIVE_HEADS = [(10, 7), (11, 10)]

def process_webtext(
    seed: int = 6,
    batch_size: int = BATCH_SIZE,
    indices: Optional[List[int]] = None,
    seq_len: int = SEQ_LEN,
    verbose: bool = False,
):
    DATA_STR = get_webtext(seed=seed)
    if indices is None:
        DATA_STR = DATA_STR[:batch_size]
    else:
        DATA_STR = [DATA_STR[i] for i in indices]
    DATA_STR = [parse_str(s) for s in DATA_STR]

    DATA_TOKS = model.to_tokens(DATA_STR)
    DATA_STR_TOKS = model.to_str_tokens(DATA_STR)

    if seq_len < 1024:
        DATA_TOKS = DATA_TOKS[:, :seq_len]
        DATA_STR_TOKS = [str_toks[:seq_len] for str_toks in DATA_STR_TOKS]

    DATA_STR_TOKS_PARSED = list(map(parse_str_toks_for_printing, DATA_STR_TOKS))

    clear_output()
    if verbose:
        print(f"Shape = {DATA_TOKS.shape}\n")
        print("First prompt:\n" + "".join(DATA_STR_TOKS[0]))

    return DATA_TOKS, DATA_STR_TOKS_PARSED


DATA_TOKS, DATA_STR_TOKS_PARSED = process_webtext(verbose=True) # indices=list(range(32, 40))
BATCH_SIZE, SEQ_LEN = DATA_TOKS.shape

Shape = torch.Size([48, 61])

First prompt:
<|endoftext|>Oh boy was this damn hard to crack.

Ok, I believe before it was established before that Aperture Science headquarters are in Cleveland, OH.

Source: HL2EP2

Though, this has been found.

Source: Portal 2

It can be assumed


## Test html in small case

In [14]:
model.reset_hooks(including_permanent=True)

prompt = "All's fair in love and war"
toks = model.to_tokens(prompt)
str_toks = model.to_str_tokens(toks)
if isinstance(str_toks[0], str): str_toks = [str_toks]
str_toks_parsed = [list(map(parse_str_tok_for_printing, s)) for s in str_toks]

result_mean = get_result_mean(NEGATIVE_HEADS, DATA_TOKS[:, :toks.shape[1]], model, minibatch_size=BATCH_SIZE)

MODEL_RESULTS = get_model_results(model, toks, NEGATIVE_HEADS, result_mean)

HTML_PLOTS = generate_4_html_plots(
    model_results = MODEL_RESULTS,
    model = model,
    data_toks = toks,
    data_str_toks_parsed = str_toks_parsed,
    negative_heads = NEGATIVE_HEADS,
    save_files = False,
)

for k, v in HTML_PLOTS.items():
    print(k)
    k2 = list(zip(*HTML_PLOTS["LOSS"].keys()))
    for j, _k2 in enumerate(k2):
        print(f"{(j)} = {sorted(set(_k2))}")

display(HTML(CSS + HTML_PLOTS["LOSS"][(0, "10.7", "direct+unfrozen+mean", True)] + "<br>" * 5))

LOSS
0 = [0]
1 = ['10.7', '11.10']
2 = ['both+frozen+mean', 'both+unfrozen+mean', 'direct+frozen+mean', 'direct+unfrozen+mean', 'indirect (excluding 11.10)+frozen+mean', 'indirect (excluding 11.10)+unfrozen+mean', 'indirect+frozen+mean', 'indirect+unfrozen+mean']
3 = [False, True]
LOGITS_ORIG
0 = [0]
1 = ['10.7', '11.10']
2 = ['both+frozen+mean', 'both+unfrozen+mean', 'direct+frozen+mean', 'direct+unfrozen+mean', 'indirect (excluding 11.10)+frozen+mean', 'indirect (excluding 11.10)+unfrozen+mean', 'indirect+frozen+mean', 'indirect+unfrozen+mean']
3 = [False, True]
LOGITS_ABLATED
0 = [0]
1 = ['10.7', '11.10']
2 = ['both+frozen+mean', 'both+unfrozen+mean', 'direct+frozen+mean', 'direct+unfrozen+mean', 'indirect (excluding 11.10)+frozen+mean', 'indirect (excluding 11.10)+unfrozen+mean', 'indirect+frozen+mean', 'indirect+unfrozen+mean']
3 = [False, True]
DLA
0 = [0]
1 = ['10.7', '11.10']
2 = ['both+frozen+mean', 'both+unfrozen+mean', 'direct+frozen+mean', 'direct+unfrozen+mean', 'indirect 

# Actual HTML plots

In [27]:
model.reset_hooks()
result_mean = get_result_mean(NEGATIVE_HEADS + [(10, 1)], DATA_TOKS, model, minibatch_size=8)

In [30]:
model.reset_hooks()
generate_4_html_plots_batched(
    model = model,
    data_toks = DATA_TOKS, # [:51],
    data_str_toks_parsed = DATA_STR_TOKS_PARSED,
    # max_batch_size = 8,
    start_idx = 0,
    negative_heads = NEGATIVE_HEADS + [(10, 1)],
    result_mean = result_mean,
    verbose = True,
)

Generating HTML plots...

Running forward pass     ... 0.95s
Computing model results  ... 5.27s
LOSS         ... 1.85s
LOGITS ORIG  ... 0.57s
LOGITS 10.7  ... 8.76s
LOGITS 11.10 ... 5.00s
LOGITS 10.1  ... 6.26s
ATTN         ... 3.58s
UNEMBEDDINGS ... 4.95s

Gathering HTML plots...
Saving HTML plots as a single dict, at 'GZIP_HTML_PLOTS_b48_s61.pkl'...
Deleting HTML plots we no longer need...


In [31]:
with gzip.open(ST_HTML_PATH / "GZIP_HTML_PLOTS_b48_s61.pkl", "rb") as f:
    HTML_PLOTS = pickle.load(f)

In [None]:
for k in HTML_PLOTS["LOGITS_ABLATED"].keys():
    print(k)

(0, '10.7', 'both+unfrozen+mean')
(1, '10.7', 'both+unfrozen+mean')
(2, '10.7', 'both+unfrozen+mean')
(3, '10.7', 'both+unfrozen+mean')
(4, '10.7', 'both+unfrozen+mean')
(5, '10.7', 'both+unfrozen+mean')
(6, '10.7', 'both+unfrozen+mean')
(7, '10.7', 'both+unfrozen+mean')
(8, '10.7', 'both+unfrozen+mean')
(9, '10.7', 'both+unfrozen+mean')
(10, '10.7', 'both+unfrozen+mean')
(11, '10.7', 'both+unfrozen+mean')
(12, '10.7', 'both+unfrozen+mean')
(13, '10.7', 'both+unfrozen+mean')
(14, '10.7', 'both+unfrozen+mean')
(15, '10.7', 'both+unfrozen+mean')
(16, '10.7', 'both+unfrozen+mean')
(17, '10.7', 'both+unfrozen+mean')
(18, '10.7', 'both+unfrozen+mean')
(19, '10.7', 'both+unfrozen+mean')
(20, '10.7', 'both+unfrozen+mean')
(21, '10.7', 'both+unfrozen+mean')
(22, '10.7', 'both+unfrozen+mean')
(23, '10.7', 'both+unfrozen+mean')
(24, '10.7', 'both+unfrozen+mean')
(25, '10.7', 'both+unfrozen+mean')
(26, '10.7', 'both+unfrozen+mean')
(27, '10.7', 'both+unfrozen+mean')
(28, '10.7', 'both+unfrozen+me

## Attention plots for paper diagram

## Histograms: logit lens

In [None]:
k = 15
neg = False
all_ranks = []


model.reset_hooks()
logits, cache = model.run_with_cache(DATA_TOKS_2)

In [None]:
# points_to_plot = [
#     (35, 39, " About"),
#     (67, 21, " delays"),
#     (8, 35, " rentals"),
#     (8, 54, " require"),
#     (53, 18, [" San", " Francisco"]),
#     (33, 9, " Hollywood"),
#     (49, 7, " Home"),
#     (71, 34, " sound"),
#     (14, 56, " Kara"),
# ]
# points_to_plot = [
#     (45, 42, [" editorial"]),
#     (45, 58, [" stadium", " Stadium", " stadiums"]),
#     (43, 56, [" Biden"]),
#     (43, 44, [" interview", " campaign"]),
#     (38, 54, [" Mary", " Catholics"]),
#     (33, 29, " Hollywood"),
#     (33, 42, " BlackBerry"),
#     (31, 33, [" Church", " churches"]),
#     (28, 53, [" mobile", " phone", " device"]),
#     (25, 32, [" abstraction", " abstract", " Abstract"]),
#     (18, 25, ["TPP", " Lee"]),
#     (10, 52, [" Italy", " mafia"]),
#     (10, 52, [" Italy", " mafia"]),
#     (10, 35, [" Italy", " mafia"]),
#     (10, 25, [" Italian", " Italy"]),
#     (6, 52, [" landfill", " waste"]),
#     (4, 52, " jury"),
# ]
points_to_plot = [
    # (14, 56, " Kara"),
    # (67, 47, " case"),
    # (24, 73, " negotiation"),
    (2, 35, [" Berkeley", "keley"]),
]

resid_pre_head = (cache["resid_pre", 10]) / cache["scale", 10, "ln1"]  #  - cache["resid_pre", 1]

plot_logit_lens(points_to_plot, resid_pre_head, model, DATA_STR_TOKS_PARSED_2, k=15, title="Predictions at token ' of', before head 10.7")

## Histograms: QK and OV circuit

In [None]:
def plot_both(dest, src, focus_on: Literal["src", "dest"]):
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="OV", neg=True, head=(10, 7), flip=(focus_on=="dest"))
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="QK", neg=False, head=(10, 7), flip=(focus_on=="src"))

plot_both(dest=" Berkeley", src="keley", focus_on="src")

In [None]:
def plot_both(dest, src, focus_on: Literal["src", "dest"]):
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="OV", neg=True, head=(10, 7), flip=(focus_on=="dest"))
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="QK", neg=False, head=(10, 7), flip=(focus_on=="src"))

plot_both(dest=" negotiation", src=" negotiations", focus_on="dest")

In [None]:
logprobs_orig = MODEL_RESULTS.logits_orig[32, 19].log_softmax(-1)
logprobs_abl = MODEL_RESULTS.logits[("direct", "frozen", "mean")][10, 7][32, 19].log_softmax(-1)

logprobs_orig_topk = logprobs_orig.topk(10, dim=-1, largest=True)
y_orig = logprobs_orig_topk.values.tolist()
x = logprobs_orig_topk.indices
y_abl = logprobs_abl[x].tolist()
x = list(map(repr, model.to_str_tokens(x)))

orig_colors = ["darkblue"] * len(x)
abl_colors = ["blue"] * len(x)

correct_next_str_tok = " heated"
correct_next_token = model.to_single_token(" heated")
# if repr(correct_next_str_tok) in x:
#     idx = x.index(repr(correct_next_str_tok))
#     orig_colors[idx] = "darkgreen"
#     abl_colors[idx] = "green"

x.append(repr(correct_next_str_tok))
y_orig.append(logprobs_orig[correct_next_token].item())
y_abl.append(logprobs_abl[correct_next_token].item())
orig_colors.append("darkgreen")
abl_colors.append("green")

fig = go.Figure(
    data = [
        go.Bar(x=x, y=y_orig, name='Original', marker_color=["#FF7700"] * (len(x)-1) + ["#024B7A"]), # 7A30AB
        go.Bar(x=x, y=y_abl, name='Ablated', marker_color=["#FFAE49"] * (len(x)-1) + ["#44A5C2"]), # D44BFA
    ],
    # data = [
    #     go.Bar(x=x[:-1], y=y_orig[:-1], name='Original', marker_color="#FF7700", legendgroup="group1"),
    #     go.Bar(x=x[:-1], y=y_abl[:-1], name='Ablated', marker_color="#FFAE49", legendgroup="group1"),
    #     go.Bar(x=[x[-1]], y=[y_orig[-1]], name='Original (correct token)', marker_color="#024B7A", legendgroup="group2"),
    #     go.Bar(x=[x[-1]], y=[y_abl[-1]], name='Ablated (correct token)', marker_color="#44A5C2", legendgroup="group2"),
    # ],
    layout = dict(
        barmode='group',
        xaxis_tickangle=30,
        title="Logprobs: original vs ablated",
        xaxis_title_text="Predicted next token",
        yaxis_title_text="Logprob",
        width=800,
        bargap=0.35,
    )
)
fig.data = fig.data #+ ({"name": "New"},)
fig.show()

In [None]:
plot_full_matrix_histogram(W_EE_dict, " device", k=10, include=[" devices"], circuit="OV", neg=True, head=(10, 7))
plot_full_matrix_histogram(W_EE_dict, " devices", k=10, include=[" device"], circuit="QK", neg=False, head=(10, 7))

In [None]:
W_EE = W_EE_dict["W_E (including MLPs)"]
W_EE = W_EE_dict["W_E (only MLPs)"]
W_U = W_EE_dict["W_U"].T

tok_strs = ["pier"]
for i in range(len(tok_strs)): tok_strs.append(tok_strs[i].capitalize())
for i in range(len(tok_strs)): tok_strs.append(tok_strs[i] + "s")
for i in range(len(tok_strs)): tok_strs.append(" " + tok_strs[i])
tok_strs = [s for s in tok_strs if model.to_tokens(s, prepend_bos=False).squeeze().ndim == 0]

toks = model.to_tokens(tok_strs, prepend_bos=False).squeeze()

W_EE_toks = W_EE[toks]
W_EE_normed = W_EE_toks / W_EE_toks.norm(dim=-1, keepdim=True)
cos_sim_embeddings = W_EE_normed @ W_EE_normed.T

W_U_toks = W_U.T[toks]
W_U_normed = W_U_toks / W_U_toks.norm(dim=-1, keepdim=True)
cos_sim_unembeddings = W_U_normed @ W_U_normed.T

W_EE_OV_toks_107 = W_EE_toks @ model.W_V[10, 7] @ model.W_O[10, 7]
W_EE_OV_toks_99 = W_EE_toks @ model.W_V[9, 9] @ model.W_O[9, 9]
W_EE_OV_toks_107_normed = W_EE_OV_toks_107 / W_EE_OV_toks_107.norm(dim=-1, keepdim=True)
W_EE_OV_toks_99_normed = W_EE_OV_toks_99 / W_EE_OV_toks_99.norm(dim=-1, keepdim=True)
cos_sim_107 = W_EE_OV_toks_107_normed @ W_EE_OV_toks_107_normed.T
cos_sim_99 = W_EE_OV_toks_99_normed @ W_EE_OV_toks_99_normed.T

imshow(
    t.stack([
        cos_sim_embeddings,
        cos_sim_unembeddings,
        cos_sim_107,
        cos_sim_99,
    ]),
    x = list(map(repr, tok_strs)),
    y = list(map(repr, tok_strs)),
    title = "Cosine similarity of variants of ' pier'",
    facet_col = 0,
    facet_labels = ["Effective embeddings", "Unembeddings", "W_OV output (10.7)", "W_OV output (9.9)"],
    border = True,
    width=1200,
)

# W_EE_OV_normed = W_EE_OV / W_EE_OV.std(dim=-1, keepdim=True)

## Create OV and QK circuits Streamlit page

I need to save the following things:

* The QK and OV matrices for head 10.7 and 11.10
* The extended embedding and unembedding matrices
* The tokenizer

In [None]:
dict_to_store = {
    "tokenizer": model.tokenizer,
    "W_V_107": model.W_V[10, 7],
    "W_O_107": model.W_O[10, 7],
    "W_V_1110": model.W_V[11, 10],
    "W_O_1110": model.W_O[11, 10],
    "W_Q_107": model.W_Q[10, 7],
    "W_K_107": model.W_K[10, 7],
    "W_Q_1110": model.W_Q[11, 10],
    "W_K_1110": model.W_K[11, 10],
    "b_Q_107": model.b_Q[10, 7],
    "b_K_107": model.b_K[10, 7],
    "b_Q_1110": model.b_Q[11, 10],
    "b_K_1110": model.b_K[11, 10],
    "W_EE": W_EE_dict["W_E (including MLPs)"],
    "W_U": model.W_U,
}
dict_to_store = {k: v.half() if isinstance(v, t.Tensor) else v for k, v in dict_to_store.items()}

with gzip.open(_ST_HTML_PATH / f"OV_QK_circuits.pkl", "wb") as f:
    pickle.dump(dict_to_store, f)

## Generate `explore_prompts` HTML plots for Streamlit page

In [None]:
HTML_PLOTS = generate_4_html_plots(
    model_results = MODEL_RESULTS,
    model = model,
    data_toks = DATA_TOKS,
    data_str_toks_parsed = DATA_STR_TOKS_PARSED,
    negative_heads = NEGATIVE_HEADS,
    save_files = True,
    progress_bar = True,
    restrict_computation = ["LOSS"]
)

In [None]:
(
    model.W_U.T[model.to_single_token(" pier")].norm().item(), 
    model.W_U.T[model.to_single_token(" Pier")].norm().item(),
)

In [None]:
t.cosine_similarity(
    model.W_U.T[model.to_single_token(" pier")],
    model.W_U.T[model.to_single_token(" Pier")],
    dim=-1
).item()


pier = model.W_U.T[model.to_single_token(" pier")]
Pier = model.W_U.T[model.to_single_token(" Pier")]
pier /= pier.norm()
Pier /= Pier.norm()
print(pier @ Pier)

In [None]:
def W_U(s):
    return model.W_U.T[model.to_single_token(s)]
def W_EE0(s):
    return W_EE_dict["W_E (only MLPs)"][model.to_single_token(s)]

def cos_sim(v1, v2):
    return v1 @ v2 / (v1.norm() * v2.norm())

print(f"Unembeddings cosine similarity (Berkeley) = {cos_sim(W_U('keley'), W_U(' Berkeley')):.3f}") 
print(f"Embeddings cosine similarity (Berkeley)   = {cos_sim(W_EE0('keley'), W_EE0(' Berkeley')):.3f}") 
print("")
print(f"Unembeddings cosine similarity (pier) = {cos_sim(W_U(' pier'), W_U(' Pier')):.3f}") 
print(f"Embeddings cosine similarity (pier)   = {cos_sim(W_EE0(' pier'), W_EE0(' Pier')):.3f}") 

In [None]:
t.cosine_similarity(
    W_EE0(" screen") - W_EE0(" screens"),
    W_EE0(" device") - W_EE0(" devices"),
    dim=-1
).item()

In [None]:
t.cosine_similarity(
    W_EE(" computer") - W_EE(" computers"),
    W_EE(" sign") - W_EE(" signs"),
    dim=-1
).item()