# Explore Prompts

This is the notebook I use to test out the functions in this directory, and generate the plots in the Streamlit page.

In [1]:
from transformer_lens.cautils.notebook import *

import sys, os
root_dir = os.getcwd().split("rs/")[0] + "rs/callum2/explore_prompts"
os.chdir(root_dir)
if root_dir not in sys.path: sys.path.append(root_dir)

from generate_html import CSS, generate_4_html_plots
from model_results import get_model_results
from explore_prompts_utils import parse_str, parse_str_tok_for_printing, ST_HTML_PATH
clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cpu"
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

clear_output()

In [21]:
BATCH_SIZE = 40
SEQ_LEN = 60 # 1024

DATA_STR = get_webtext(seed=6)[:BATCH_SIZE]
DATA_STR = [parse_str(s) for s in DATA_STR]

DATA_TOKS = model.to_tokens(DATA_STR)
DATA_STR_TOKS = model.to_str_tokens(DATA_STR)

if SEQ_LEN < 1024:
    DATA_TOKS = DATA_TOKS[:, :SEQ_LEN]
    DATA_STR_TOKS = [str_toks[:SEQ_LEN] for str_toks in DATA_STR_TOKS]

DATA_STR_TOKS_PARSED = [[parse_str_tok_for_printing(str_tok) for str_tok in str_toks] for str_toks in DATA_STR_TOKS]

NEGATIVE_HEADS = [(10, 7), (11, 10)]

print(DATA_TOKS.shape, "\n")

print(DATA_STR_TOKS[0])

batch_idx = 36

Found cached dataset openwebtext-10k (/home/ubuntu/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([40, 60]) 

['<|endoftext|>', 'Oh', ' boy', ' was', ' this', ' damn', ' hard', ' to', ' crack', '.', '\n', '\n', 'Ok', ',', ' I', ' believe', ' before', ' it', ' was', ' established', ' before', ' that', ' A', 'perture', ' Science', ' headquarters', ' are', ' in', ' Cleveland', ',', ' OH', '.', '\n', '\n', 'Source', ':', ' HL', '2', 'EP', '2', '\n', '\n', 'Though', ',', ' this', ' has', ' been', ' found', '.', '\n', '\n', 'Source', ':', ' Portal', ' 2', '\n', '\n', 'It', ' can', ' be']


In [14]:
prompt = "All's fair in love and war"
toks = model.to_tokens(prompt)
str_toks = model.to_str_tokens(toks)
if isinstance(str_toks[0], str): str_toks = [str_toks]
str_toks_parsed = [list(map(parse_str_tok_for_printing, s)) for s in str_toks]

MODEL_RESULTS = get_model_results(model, toks, NEGATIVE_HEADS)

HTML_PLOTS = generate_4_html_plots(
    model_results = MODEL_RESULTS,
    model = model,
    data_toks = toks,
    data_str_toks_parsed = str_toks_parsed,
    negative_heads = NEGATIVE_HEADS,
    save_files = False,
)

for k, v in HTML_PLOTS.items():
    print(k)
    for k2 in v.keys(): print(f"-> {k2}")

display(HTML(CSS + HTML_PLOTS["LOSS"][(0, "10.7", "mean, direct")] + "<br>" * 10))

100%|██████████| 1/1 [00:00<00:00, 38.04it/s]


100%|██████████| 1/1 [00:00<00:00, 13.71it/s]
100%|██████████| 1/1 [00:00<00:00, 89.67it/s]

LOSS
-> (0, '10.7', 'mean, direct')
-> (0, '10.7', 'zero, direct')
-> (0, '10.7', 'mean, patched')
-> (0, '10.7', 'zero, patched')
-> (0, '11.10', 'mean, direct')
-> (0, '11.10', 'zero, direct')
-> (0, '11.10', 'mean, patched')
-> (0, '11.10', 'zero, patched')
LOGITS_ORIG
-> (0,)
LOGITS_ABLATED
-> (0, '10.7', 'mean, direct')
-> (0, '10.7', 'zero, direct')
-> (0, '10.7', 'mean, patched')
-> (0, '10.7', 'zero, patched')
-> (0, '11.10', 'mean, direct')
-> (0, '11.10', 'zero, direct')
-> (0, '11.10', 'mean, patched')
-> (0, '11.10', 'zero, patched')
DLA
-> (0, '10.7', 'neg')
-> (0, '10.7', 'pos')
-> (0, '11.10', 'neg')
-> (0, '11.10', 'pos')
ATTN
-> (0, '10.7', 'Large', 'standard')
-> (0, '10.7', 'Large', 'info-weighted')
-> (0, '10.7', 'Small', 'standard')
-> (0, '10.7', 'Small', 'info-weighted')
-> (0, '11.10', 'Large', 'standard')
-> (0, '11.10', 'Large', 'info-weighted')
-> (0, '11.10', 'Small', 'standard')
-> (0, '11.10', 'Small', 'info-weighted')
UNEMBEDDINGS
-> (0, '10.7', True)
-> 




In [15]:
HTML_PLOTS["LOSS"][(0, "10.7", "mean, direct")]

'<td>&nbsp;<span class="tooltip"><mark style="background-color:rgb(247, 247, 247);opacity:1.0;line-height:1.75em"><font color="black"><|endoftext|></font></mark><span class="tooltiptext"><b>\'<|endoftext|>\'</b><br>0.0063</span></span>&nbsp;<span class="tooltip"><mark style="background-color:rgb(247, 247, 246);opacity:1.0;line-height:1.75em"><font color="black">All</font></mark><span class="tooltiptext"><b>\'All\'</b><br>-0.0087</span></span>&nbsp;<span class="tooltip"><mark style="background-color:rgb(244, 246, 246);opacity:1.0;line-height:1.75em"><font color="black">\'s</font></mark><span class="tooltiptext"><b>\'\'s\'</b><br>0.0364</span></span>&nbsp;<span class="tooltip"><mark style="background-color:rgb(247, 247, 247);opacity:1.0;line-height:1.75em"><font color="black">&nbsp;fair</font></mark><span class="tooltiptext"><b>\' fair\'</b><br>0.0039</span></span>&nbsp;<span class="tooltip"><mark style="background-color:rgb(226, 237, 243);opacity:1.0;line-height:1.75em"><font color="bla

In [22]:
MODEL_RESULTS = get_model_results(model, DATA_TOKS, negative_heads = NEGATIVE_HEADS)

In [23]:
HTML_PLOTS = generate_4_html_plots(
    model_results = MODEL_RESULTS,
    model = model,
    data_toks = DATA_TOKS,
    data_str_toks_parsed = DATA_STR_TOKS_PARSED,
    negative_heads = NEGATIVE_HEADS,
    save_files = True,
)

  2%|▎         | 1/40 [00:00<00:06,  5.60it/s]

100%|██████████| 40/40 [00:06<00:00,  5.75it/s]
100%|██████████| 40/40 [00:23<00:00,  1.68it/s]
100%|██████████| 40/40 [00:01<00:00, 32.56it/s]


In [16]:
HTML_PLOTS = pickle.load(open(ST_HTML_PATH / "HTML_PLOTS.pkl", "rb"))
BATCH_SIZE = len(HTML_PLOTS["LOGITS_ORIG"])

# pickle.dump(HTML_PLOTS, open(ST_HTML_PATH / "HTML_PLOTS.pkl", "wb"))

In [25]:
import re

def round_decimals(input_string):
    def round_match(match):
        return "{:.4f}".format(float(match.group()))

    return re.sub(r'\b0\.\d+\b', round_match, input_string)

HTML_PLOTS["ATTN"] = {
    k: round_decimals(v)
    for k, v in HTML_PLOTS["ATTN"].items()
}

In [26]:
pickle.dump(HTML_PLOTS, open(ST_HTML_PATH / "HTML_PLOTS_4.pkl", "wb"))