# Explore Prompts

This is the notebook I use to test out the functions in this directory, and generate the plots in the Streamlit page.

## Setup

In [None]:
import os, sys
from pathlib import Path
p = Path(r"/root/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum2.ov_qk_circuits.ov_qk_plot_functions import (
    plot_logit_lens,
    plot_full_matrix_histogram,
)
from transformer_lens.rs.callum2.blog_viz.generate_html_funcs import (
    CSS,
    generate_4_html_plots,
    generate_4_html_plots_batched,
    generate_html_for_DLA_plot,
    generate_html_for_logit_plot,
    generate_html_for_loss_plot,
    generate_html_for_unembedding_components_plot,
    attn_filter,
    _get_color,
)
from transformer_lens.rs.callum2.blog_viz.model_results import (
    get_result_mean,
    get_model_results,
    HeadResults,
    LayerResults,
    DictOfHeadResults,
    ModelResults,
    first_occurrence,
    project,
    model_fwd_pass_from_resid_pre,
)
from transformer_lens.rs.callum2.utils import (
    create_title_and_subtitles,
    get_effective_embedding,
    parse_str,
    parse_str_tok_for_printing,
    parse_str_toks_for_printing,
    topk_of_Nd_tensor,
    ST_HTML_PATH,
    NEGATIVE_HEADS,
    process_webtext,
    rearrange_list,
    clamp,
    first_occurrence_2d,
)
from transformer_lens.rs.callum2.cspa.cspa_functions import (
    FUNCTION_STR_TOKS,
)

clear_output()

In [None]:
from eindex import eindex
import torch

BATCH_SIZE = 32
SEQ_LEN = 5
D_VOCAB = 100

logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))

# (1) Using eindex
output_1 = eindex(logprobs, labels, "batch seq [batch seq]")

# (2) Normal PyTorch, using `gather`
output_2 = logprobs.gather(2, labels.unsqueeze(-1)).squeeze(-1)

# (3) Normal PyTorch, not using `gather` (this is like what `eindex` does under the hood)
output_3 = logprobs[torch.arange(BATCH_SIZE).unsqueeze(-1), torch.arange(SEQ_LEN), labels]

# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)


In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cpu" # "cuda"
    # fold value bias?
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

W_EE_dict = get_effective_embedding(model)

FUNCTION_TOKS = model.to_tokens(FUNCTION_STR_TOKS, prepend_bos=False).squeeze()

clear_output()

## Getting model results

In [None]:
BATCH_SIZE = 41 # 51 for viz, 200 for my local one
SEQ_LEN = 101 # (61 for viz, no more because attn)
batch_idx = 36

DATA_TOKS, DATA_STR_TOKS_PARSED = process_webtext(seed=6, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, verbose=True, model=model)

## Test html in small case

In [None]:
DATA_TOKS.shape

In [None]:
model.reset_hooks(including_permanent=True)

prompt = "All's fair in love and war"
toks = model.to_tokens(prompt)
str_toks = model.to_str_tokens(toks)
if isinstance(str_toks[0], str): str_toks = [str_toks]
str_toks_parsed = [list(map(parse_str_tok_for_printing, s)) for s in str_toks]

result_mean = get_result_mean(
    head_list = NEGATIVE_HEADS,
    toks = DATA_TOKS[:, :toks.shape[1]],
    model = model,
    minibatch_size = BATCH_SIZE
)

MODEL_RESULTS = get_model_results(
    model = model,
    toks = toks,
    negative_heads = NEGATIVE_HEADS,
    result_mean = result_mean,
)

HTML_PLOTS: Dict[str, Dict[Tuple, str]] = generate_4_html_plots(
    model_results = MODEL_RESULTS,
    model = model,
    data_toks = toks,
    data_str_toks_parsed = str_toks_parsed,
    negative_heads = NEGATIVE_HEADS,
    save_files = False,
    result_mean = result_mean,
    verbose = True,
)

for k, v in HTML_PLOTS.items():
    print(k)
    k2 = list(zip(*HTML_PLOTS["LOSS"].keys()))
    for j, _k2 in enumerate(k2):
        print(f"{(j)} = {sorted(set(_k2))}")

display(HTML("".join([
    CSS,
    HTML_PLOTS["LOSS"][(0, "10.7", "direct+frozen+mean", True)],
    "<br>" * 2,
    # HTML_PLOTS["LOGITS_ORIG"][(0,)],
    # "<br>" * 2,
    HTML_PLOTS["LOGITS_ABLATED"][(0, "10.7", "direct+frozen+mean")],
    "<br>" * 21,
])))

# Actual HTML plots

In [None]:
model.reset_hooks()
result_mean = get_result_mean(NEGATIVE_HEADS, DATA_TOKS, model, minibatch_size=8)

model_results = get_model_results(
    model = model,
    toks = DATA_TOKS,
    negative_heads = [(10, 7)],
    result_mean = result_mean,
)

In [None]:
import json
BLOG_PATH = Path("/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/blog_viz")
DATA_PATH = BLOG_PATH / "data"

In [None]:
def get_top_logprobs_in_context(
    logprobs: Float[Tensor, "batch seqQ d_vocab"],
    toks: Int[Tensor, "batch seqQ"],
    function_toks: Int[Tensor, "toks"]
):
    '''
    Returns the top predicted logprobs, over all the source tokens in context.

    The indices of the result are the seqK positions, and the values are the logprobs.

    Tokens which are function words are filtered out.
    '''
    b, seq, v = logprobs.shape

    # Get all logprobs for the source tokens in context
    b_indices = einops.repeat(t.arange(b), "b -> b sQ sK", sQ=seq, sK=seq)
    sQ_indices = einops.repeat(t.arange(seq), "sQ -> b sQ sK", b=b, sK=seq)
    toks_rep = einops.repeat(toks, "b sK -> b sQ sK", sQ=seq)
    logprobs_ctx: Float[Tensor, "batch seqQ seqK"] = logprobs[b_indices, sQ_indices, toks_rep]
    # The (b, q, k)-th elem is the logprobs of word k at sequence position (b, q)

    # Mask: causal
    sQ_indices = einops.repeat(t.arange(seq), "sQ -> b sQ 1", b=b)
    sK_indices = einops.repeat(t.arange(seq), "sK -> b 1 sK", b=b)
    causal_mask = sQ_indices >= sK_indices
    # Mask: first occurrence of each word (because we want the top 5 DISTINCT words)
    first_occurrence_mask = einops.repeat(first_occurrence_2d(toks), "b sK -> b 1 sK")
    # Mask: non-function words
    non_fn_word_mask = (toks[:, :, None] != function_toks).all(dim=-1)
    non_fn_word_mask = einops.repeat(non_fn_word_mask, "b sK -> b 1 sK")
    # Apply all 3 masks
    logprobs_masked = t.where(
        causal_mask & first_occurrence_mask & non_fn_word_mask,
        logprobs_ctx,
        -float("inf")
    )

    # Now, we can pick the top 5 (over the seqK-dimension) for each query index
    k = min(5, logprobs_masked.size(-1))
    logprobs_masked_top5 = logprobs_masked.topk(dim=-1, k=k)

    return logprobs_masked_top5

In [None]:
# %pip install git+https://github.com/callummcdougall/eindex.git
from eindex import eindex

# %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
import circuitsvis as cv

In [None]:
# info_norms = model_results.out_norm[10, 7][[36]]
# info_norms_rescaled = info_norms / einops.reduce(info_norms, "batch seqK -> batch 1", "max")
# info_norms_rescaled = einops.repeat(info_norms_rescaled, "batch max_over_seqK -> batch 1 max_over_seqK")

# attn = model_results.pattern[10, 7][[36]]
# attn_info_weighted = attn * info_norms_rescaled

# viz_standard = cv.attention.attention_patterns(
#     attention = attn_info_weighted[:50],
#     tokens = DATA_STR_TOKS_PARSED[36][:50],
#     attention_head_names = "10.7",
# )
# # viz_attn_weighted = 
# display(viz_standard)

In [75]:
def save_loss(
    model_results: ModelResults,
    head = (10, 7),
    data_str_toks_parsed = DATA_STR_TOKS_PARSED,
    data_toks = DATA_TOKS,
    only_include: Optional[Union[str, List[str]]] = None,
    only_exclude: Optional[Union[str, List[str]]] = None,
):
    # ! Get all data, and get dict to store results

    if isinstance(only_include, str): only_include = [only_include]
    if isinstance(only_exclude, str): only_exclude = [only_exclude]
    assert (only_include is None) or (only_exclude is None), "Can't specify both `only_include` and `only_exclude`"
    include = ["LOSS", "DLA", "ATTN", "LOGIT LENS", "LOGPROBS"]
    if only_exclude is not None:
        include = list(set(include) - set(only_exclude))
    elif only_include is not None:
        include = only_include

    batch_size, seq_len = data_toks.shape
    LAYER, HEAD = head

    loss_orig = model_results.loss_orig
    loss_diffs = model_results.loss_diffs[("direct", "frozen", "mean")][LAYER, HEAD]

    dla = model_results.dla[("frozen", "mean")][LAYER, HEAD] # [batch_size, seq_len, d_vocab]

    logits_orig = model_results.logits_orig # [batch_size, seq_len, d_vocab]
    logits_abl = model_results.logits[("direct", "frozen", "mean")][LAYER, HEAD] # [batch_size, seq_len, d_vocab]

    logit_lens = model_results.logit_lens[LAYER] # [batch_size, seq_len, d_vocab]

    json_values: Dict[int, Dict[str, list]] = {
        i: {
            # LOSS
            "loss_values": [], "words": [], "loss_colors": [], "loss_text_colors": [],
            # DLA
            "dla_pos_words": [], "dla_pos_values": [], "dla_pos_colors": [], "dla_pos_text_colors": [],
            "dla_neg_words": [], "dla_neg_values": [], "dla_neg_colors": [], "dla_neg_text_colors": [],
            # LOGPROBS
            "logprobs_orig_words": [], "logprobs_orig_values": [], "logprobs_orig_colors": [], "logprobs_orig_text_colors": [],
            "logprobs_abl_words": [], "logprobs_abl_values": [], "logprobs_abl_colors": [], "logprobs_abl_text_colors": [],
            # LOGIT LENS
            "logit_lens_words": [], "logit_lens_values": [], "logit_lens_colors": [], "logit_lens_text_colors": [],
            "logit_lens_ctx_words": [], "logit_lens_ctx_values": [], "logit_lens_ctx_ranks": [],
            # ATTN
        }
        for i in range(batch_size)
    }

    # ! LOSS

    if "LOSS" in include:

        LOSS_MAX_COLOR = 2

        for i, (loss_diff, data_str_toks) in tqdm(list(enumerate(zip(loss_diffs, data_str_toks_parsed))), desc="Loss"):

            loss_colors, loss_text_colors = _get_color(0.5 + t.clamp(loss_diff / (2 * LOSS_MAX_COLOR), -0.5, 0.5))
            json_values[i]["words"] = data_str_toks
            json_values[i]["loss_values"] = [round(ld.item(), 4) for ld in loss_diff]
            json_values[i]["loss_colors"] = loss_colors
            json_values[i]["loss_text_colors"] = loss_text_colors

    # ! DLA

    if "DLA" in include:

        DLA_MAX_COLOR = 2.5
        TOPK = 10

        for i, _dla in tqdm(list(enumerate(dla)), desc="DLA"):

            dla_pos = _dla.topk(TOPK, dim=-1) # [seq_len, TOPK]
            dla_neg = _dla.topk(TOPK, dim=-1, largest=False) # [seq_len, TOPK]
            
            dla_pos_words_all = rearrange_list(model.to_str_tokens(dla_pos.indices.flatten(), prepend_bos=False), TOPK)
            dla_neg_words_all = rearrange_list(model.to_str_tokens(dla_neg.indices.flatten(), prepend_bos=False), TOPK)

            for seq in range(seq_len):
                dla_pos_color, dla_pos_text_color = _get_color(0.5 + clamp(dla_pos.values[seq, 0].item() / (2 * DLA_MAX_COLOR), 0.0, 0.5))
                json_values[i]["dla_pos_words"].append(dla_pos_words_all[seq])
                json_values[i]["dla_pos_values"].append([round(v, 2) for v in dla_pos.values[seq].tolist()])
                json_values[i]["dla_pos_colors"].append(dla_pos_color)
                json_values[i]["dla_pos_text_colors"].append(dla_pos_text_color)

                dla_neg_color, dla_neg_text_color = _get_color(0.5 + clamp(dla_neg.values[seq, 0].item() / (2 * DLA_MAX_COLOR), -0.5, 0.0))
                json_values[i]["dla_neg_words"].append(dla_neg_words_all[seq])
                json_values[i]["dla_neg_values"].append([round(v, 2) for v in dla_neg.values[seq].tolist()])
                json_values[i]["dla_neg_colors"].append(dla_neg_color)
                json_values[i]["dla_neg_text_colors"].append(dla_neg_text_color)

    # ! LOGPROBS

    if "LOGPROBS" in include:

        LOGPROBS_MAX_COLOR = 2.5
        TOPK = 10

        for i, (_logits_abl, _logits_orig) in tqdm(list(enumerate(zip(logits_abl, logits_orig))), desc="logits"):

            logprobs_abl = _logits_abl.log_softmax(-1)
            logprobs_orig = _logits_orig.log_softmax(-1)
            
            logprobs_abl_topk = logprobs_abl.topk(TOPK, dim=-1) # [seq_len, TOPK]
            logprobs_orig_topk = logprobs_orig.topk(TOPK, dim=-1) # [seq_len, TOPK]
            
            logprobs_abl_words_all = rearrange_list(model.to_str_tokens(logprobs_abl_topk.indices.flatten(), prepend_bos=False), TOPK)
            logprobs_orig_words_all = rearrange_list(model.to_str_tokens(logprobs_orig_topk.indices.flatten(), prepend_bos=False), TOPK)

            for seq in range(seq_len):
                values = logprobs_abl_topk.values[seq]
                logprobs_abl_color, logprobs_abl_text_color = _get_color(1 + max(values[0].item() / (2 * LOGPROBS_MAX_COLOR), -0.5))
                json_values[i]["logprobs_abl_words"].append(logprobs_abl_words_all[seq])
                json_values[i]["logprobs_abl_values"].append([round(v, 2) for v in values.tolist()])
                json_values[i]["logprobs_abl_colors"].append(logprobs_abl_color)
                json_values[i]["logprobs_abl_text_colors"].append(logprobs_abl_text_color)

                values = logprobs_orig_topk.values[seq]
                logprobs_orig_color, logprobs_orig_text_color = _get_color(1 + max(values[0].item() / (2 * LOGPROBS_MAX_COLOR), -0.5))
                json_values[i]["logprobs_orig_words"].append(logprobs_orig_words_all[seq])
                json_values[i]["logprobs_orig_values"].append([round(v, 2) for v in values.tolist()])
                json_values[i]["logprobs_orig_colors"].append(logprobs_orig_color)
                json_values[i]["logprobs_orig_text_colors"].append(logprobs_orig_text_color)
    
    # ! LOGIT LENS

    if "LOGIT LENS" in include:

        LOGPROBS_MAX_COLOR = 2.5
        TOPK = 10
        TOPK_CTX = 5

        # Do everything at once, cause I already wrote code this way!
        logprobs = logit_lens.log_softmax(-1) # [batch_size, seq_len, d_vocab]
        logprobs_topk = logprobs.topk(TOPK, dim=-1) # [batch_size, seq_len, TOPK]
        logprobs_words_all = rearrange_list(model.to_str_tokens(logprobs_topk.indices.flatten(), prepend_bos=False), TOPK) # [batch_size*seq_len, TOPK]
        logprobs_words_all = rearrange_list(logprobs_words_all, seq_len) # [batch_size, seq_len, TOPK]
        
        # Get logprobs for just context (this)
        logprobs_ctx_topk = get_top_logprobs_in_context(logprobs, data_toks, FUNCTION_TOKS)
        # ctx_top_toks[batch, seq, k] = data_toks[batch, logprobs_ctx_topk.indices[batch, seq, k]]
        logprobs_ctx_topk_indices = eindex(data_toks, logprobs_ctx_topk.indices, "batch [batch seq k]") # [batch seq TOPK_CTX]
        logprobs_ctx_words_all = rearrange_list(model.to_str_tokens(logprobs_ctx_topk_indices.flatten(), prepend_bos=False), TOPK_CTX)
        logprobs_ctx_words_all = rearrange_list(logprobs_ctx_words_all, seq_len) # [batch_size, seq_len, TOPK_CTX]

        for i in tqdm(range(batch_size), desc="logit lens"):

            for seq in range(seq_len):
                # Get the first table of values: top 10, not limited to context
                json_values[i]["logit_lens_words"].append(logprobs_words_all[i][seq])
                json_values[i]["logit_lens_values"].append([round(v, 2) for v in logprobs_topk.values[i, seq].tolist()])

                # Get the second table of values: top 5, in context (this also gives us our colors)
                values = logprobs_ctx_topk.values[i, seq] # [TOPK_CTX]
                logit_lens_color, logit_lens_text_color = _get_color(1 + max(values[0].item() / (2 * LOGPROBS_MAX_COLOR), -0.5))
                ranks = (values.unsqueeze(-1) < logprobs[i, seq]).sum(dim=-1).tolist()
                num_src = (values > -float("inf")).sum().item()
                json_values[i]["logit_lens_colors"].append(logit_lens_color)
                json_values[i]["logit_lens_text_colors"].append(logit_lens_text_color)
                json_values[i]["logit_lens_ctx_words"].append(logprobs_ctx_words_all[i][seq][:num_src])
                json_values[i]["logit_lens_ctx_values"].append([round(v, 2) for v in values.tolist()[:num_src]])
                json_values[i]["logit_lens_ctx_ranks"].append(ranks[:num_src])


    # ! ATTN

    attn_list = []
    DP = 3

    if "ATTN" in include:

        for i in range(batch_size):

            probs = model_results.pattern[LAYER, HEAD][i]
            
            info_norms = model_results.out_norm[LAYER, HEAD][i]
            info_norms_rescaled = info_norms / info_norms.max()
            info_norms_rescaled = einops.repeat(info_norms_rescaled, "max_over_seqK -> seqQ max_over_seqK", seqQ=1)
            probs_info_weighted = probs * info_norms_rescaled

            def round_2d_tensor(tensor: Tensor):
                assert tensor.ndim == 2
                return rearrange_list(list(map(lambda x: round(x, ndigits=DP), tensor.flatten().tolist())), tensor.shape[1])

            html = cv.attention.attention_patterns(
                attention = [round_2d_tensor(probs), round_2d_tensor(probs_info_weighted)],
                tokens = data_str_toks_parsed[i],
                attention_head_names = ["Standard", "Info-weighted"],
            )
            attn_list.append(html)

    # Save as json
    for i, d in tqdm(list(json_values.items()), desc="Saving as JSON"):
        with open(DATA_PATH / f"data_{i}.json", "w") as f:
            json.dump(d, f, indent=4)
    # Save as html
    for i, d in tqdm(enumerate(attn_list), desc="Saving as HTML"):
        with open(DATA_PATH / f"attn_{i}.html", "w") as f:
            f.write(str(d))

save_loss(model_results) # only_include="ATTN", only_exclude="ATTN"

Loss: 100%|██████████| 41/41 [00:00<00:00, 50.99it/s]
DLA: 100%|██████████| 41/41 [00:04<00:00, 10.06it/s]
logits: 100%|██████████| 41/41 [00:03<00:00, 10.81it/s]
logit lens: 100%|██████████| 41/41 [00:02<00:00, 14.67it/s]
Saving as JSON: 100%|██████████| 41/41 [00:00<00:00, 147.15it/s]
Saving as HTML: 41it [00:00, 10999.52it/s]


In [None]:
file = "/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/st_page/media/OV_QK_circuits_less.pkl"




In [None]:
logprobs_ctx.indices[:5, :5]

In [None]:
model_results.loss_diffs.items()

## Histograms: logit lens (TODO: fix up everything below this; it's old)

In [None]:
k = 15
neg = False
all_ranks = []


model.reset_hooks()
logits, cache = model.run_with_cache(DATA_TOKS_2)

In [None]:
# points_to_plot = [
#     (35, 39, " About"),
#     (67, 21, " delays"),
#     (8, 35, " rentals"),
#     (8, 54, " require"),
#     (53, 18, [" San", " Francisco"]),
#     (33, 9, " Hollywood"),
#     (49, 7, " Home"),
#     (71, 34, " sound"),
#     (14, 56, " Kara"),
# ]
# points_to_plot = [
#     (45, 42, [" editorial"]),
#     (45, 58, [" stadium", " Stadium", " stadiums"]),
#     (43, 56, [" Biden"]),
#     (43, 44, [" interview", " campaign"]),
#     (38, 54, [" Mary", " Catholics"]),
#     (33, 29, " Hollywood"),
#     (33, 42, " BlackBerry"),
#     (31, 33, [" Church", " churches"]),
#     (28, 53, [" mobile", " phone", " device"]),
#     (25, 32, [" abstraction", " abstract", " Abstract"]),
#     (18, 25, ["TPP", " Lee"]),
#     (10, 52, [" Italy", " mafia"]),
#     (10, 52, [" Italy", " mafia"]),
#     (10, 35, [" Italy", " mafia"]),
#     (10, 25, [" Italian", " Italy"]),
#     (6, 52, [" landfill", " waste"]),
#     (4, 52, " jury"),
# ]
points_to_plot = [
    # (14, 56, " Kara"),
    # (67, 47, " case"),
    # (24, 73, " negotiation"),
    (2, 35, [" Berkeley", "keley"]),
]

resid_pre_head = (cache["resid_pre", 10]) / cache["scale", 10, "ln1"]  #  - cache["resid_pre", 1]

plot_logit_lens(points_to_plot, resid_pre_head, model, DATA_STR_TOKS_PARSED_2, k=15, title="Predictions at token ' of', before head 10.7")

## Histograms: QK and OV circuit

In [None]:
def plot_both(dest, src, focus_on: Literal["src", "dest"]):
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="OV", neg=True, head=(10, 7), flip=(focus_on=="dest"))
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="QK", neg=False, head=(10, 7), flip=(focus_on=="src"))

plot_both(dest=" Berkeley", src="keley", focus_on="src")

In [None]:
def plot_both(dest, src, focus_on: Literal["src", "dest"]):
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="OV", neg=True, head=(10, 7), flip=(focus_on=="dest"))
    plot_full_matrix_histogram(W_EE_dict, src, dest, model, k=15, circuit="QK", neg=False, head=(10, 7), flip=(focus_on=="src"))

plot_both(dest=" negotiation", src=" negotiations", focus_on="dest")

In [None]:
logprobs_orig = MODEL_RESULTS.logits_orig[32, 19].log_softmax(-1)
logprobs_abl = MODEL_RESULTS.logits[("direct", "frozen", "mean")][10, 7][32, 19].log_softmax(-1)

logprobs_orig_topk = logprobs_orig.topk(10, dim=-1, largest=True)
y_orig = logprobs_orig_topk.values.tolist()
x = logprobs_orig_topk.indices
y_abl = logprobs_abl[x].tolist()
x = list(map(repr, model.to_str_tokens(x)))

orig_colors = ["darkblue"] * len(x)
abl_colors = ["blue"] * len(x)

correct_next_str_tok = " heated"
correct_next_token = model.to_single_token(" heated")
# if repr(correct_next_str_tok) in x:
#     idx = x.index(repr(correct_next_str_tok))
#     orig_colors[idx] = "darkgreen"
#     abl_colors[idx] = "green"

x.append(repr(correct_next_str_tok))
y_orig.append(logprobs_orig[correct_next_token].item())
y_abl.append(logprobs_abl[correct_next_token].item())
orig_colors.append("darkgreen")
abl_colors.append("green")

fig = go.Figure(
    data = [
        go.Bar(x=x, y=y_orig, name='Original', marker_color=["#FF7700"] * (len(x)-1) + ["#024B7A"]), # 7A30AB
        go.Bar(x=x, y=y_abl, name='Ablated', marker_color=["#FFAE49"] * (len(x)-1) + ["#44A5C2"]), # D44BFA
    ],
    # data = [
    #     go.Bar(x=x[:-1], y=y_orig[:-1], name='Original', marker_color="#FF7700", legendgroup="group1"),
    #     go.Bar(x=x[:-1], y=y_abl[:-1], name='Ablated', marker_color="#FFAE49", legendgroup="group1"),
    #     go.Bar(x=[x[-1]], y=[y_orig[-1]], name='Original (correct token)', marker_color="#024B7A", legendgroup="group2"),
    #     go.Bar(x=[x[-1]], y=[y_abl[-1]], name='Ablated (correct token)', marker_color="#44A5C2", legendgroup="group2"),
    # ],
    layout = dict(
        barmode='group',
        xaxis_tickangle=30,
        title="Logprobs: original vs ablated",
        xaxis_title_text="Predicted next token",
        yaxis_title_text="Logprob",
        width=800,
        bargap=0.35,
    )
)
fig.data = fig.data #+ ({"name": "New"},)
fig.show()

In [None]:
plot_full_matrix_histogram(W_EE_dict, " device", k=10, include=[" devices"], circuit="OV", neg=True, head=(10, 7))
plot_full_matrix_histogram(W_EE_dict, " devices", k=10, include=[" device"], circuit="QK", neg=False, head=(10, 7))

In [None]:
W_EE = W_EE_dict["W_E (including MLPs)"]
W_EE = W_EE_dict["W_E (only MLPs)"]
W_U = W_EE_dict["W_U"].T

tok_strs = ["pier"]
for i in range(len(tok_strs)): tok_strs.append(tok_strs[i].capitalize())
for i in range(len(tok_strs)): tok_strs.append(tok_strs[i] + "s")
for i in range(len(tok_strs)): tok_strs.append(" " + tok_strs[i])
tok_strs = [s for s in tok_strs if model.to_tokens(s, prepend_bos=False).squeeze().ndim == 0]

toks = model.to_tokens(tok_strs, prepend_bos=False).squeeze()

W_EE_toks = W_EE[toks]
W_EE_normed = W_EE_toks / W_EE_toks.norm(dim=-1, keepdim=True)
cos_sim_embeddings = W_EE_normed @ W_EE_normed.T

W_U_toks = W_U.T[toks]
W_U_normed = W_U_toks / W_U_toks.norm(dim=-1, keepdim=True)
cos_sim_unembeddings = W_U_normed @ W_U_normed.T

W_EE_OV_toks_107 = W_EE_toks @ model.W_V[10, 7] @ model.W_O[10, 7]
W_EE_OV_toks_99 = W_EE_toks @ model.W_V[9, 9] @ model.W_O[9, 9]
W_EE_OV_toks_107_normed = W_EE_OV_toks_107 / W_EE_OV_toks_107.norm(dim=-1, keepdim=True)
W_EE_OV_toks_99_normed = W_EE_OV_toks_99 / W_EE_OV_toks_99.norm(dim=-1, keepdim=True)
cos_sim_107 = W_EE_OV_toks_107_normed @ W_EE_OV_toks_107_normed.T
cos_sim_99 = W_EE_OV_toks_99_normed @ W_EE_OV_toks_99_normed.T

imshow(
    t.stack([
        cos_sim_embeddings,
        cos_sim_unembeddings,
        cos_sim_107,
        cos_sim_99,
    ]),
    x = list(map(repr, tok_strs)),
    y = list(map(repr, tok_strs)),
    title = "Cosine similarity of variants of ' pier'",
    facet_col = 0,
    facet_labels = ["Effective embeddings", "Unembeddings", "W_OV output (10.7)", "W_OV output (9.9)"],
    border = True,
    width=1200,
)

# W_EE_OV_normed = W_EE_OV / W_EE_OV.std(dim=-1, keepdim=True)

## Create OV and QK circuits Streamlit page

I need to save the following things:

* The QK and OV matrices for head 10.7 and 11.10
* The extended embedding and unembedding matrices
* The tokenizer

In [None]:
dict_to_store = {
    "tokenizer": model.tokenizer,
    "W_V_107": model.W_V[10, 7],
    "W_O_107": model.W_O[10, 7],
    "W_V_1110": model.W_V[11, 10],
    "W_O_1110": model.W_O[11, 10],
    "W_Q_107": model.W_Q[10, 7],
    "W_K_107": model.W_K[10, 7],
    "W_Q_1110": model.W_Q[11, 10],
    "W_K_1110": model.W_K[11, 10],
    "b_Q_107": model.b_Q[10, 7],
    "b_K_107": model.b_K[10, 7],
    "b_Q_1110": model.b_Q[11, 10],
    "b_K_1110": model.b_K[11, 10],
    "W_EE": W_EE_dict["W_E (including MLPs)"],
    "W_U": model.W_U,
}
dict_to_store = {k: v.half() if isinstance(v, t.Tensor) else v for k, v in dict_to_store.items()}

with gzip.open(_ST_HTML_PATH / f"OV_QK_circuits.pkl", "wb") as f:
    pickle.dump(dict_to_store, f)

## Generate `explore_prompts` HTML plots for Streamlit page

In [None]:
HTML_PLOTS = generate_4_html_plots(
    model_results = MODEL_RESULTS,
    model = model,
    data_toks = DATA_TOKS,
    data_str_toks_parsed = DATA_STR_TOKS_PARSED,
    negative_heads = NEGATIVE_HEADS,
    save_files = True,
    progress_bar = True,
    restrict_computation = ["LOSS"]
)

In [None]:
(
    model.W_U.T[model.to_single_token(" pier")].norm().item(), 
    model.W_U.T[model.to_single_token(" Pier")].norm().item(),
)

In [None]:
t.cosine_similarity(
    model.W_U.T[model.to_single_token(" pier")],
    model.W_U.T[model.to_single_token(" Pier")],
    dim=-1
).item()


pier = model.W_U.T[model.to_single_token(" pier")]
Pier = model.W_U.T[model.to_single_token(" Pier")]
pier /= pier.norm()
Pier /= Pier.norm()
print(pier @ Pier)

In [None]:
def W_U(s):
    return model.W_U.T[model.to_single_token(s)]
def W_EE0(s):
    return W_EE_dict["W_E (only MLPs)"][model.to_single_token(s)]

def cos_sim(v1, v2):
    return v1 @ v2 / (v1.norm() * v2.norm())

print(f"Unembeddings cosine similarity (Berkeley) = {cos_sim(W_U('keley'), W_U(' Berkeley')):.3f}") 
print(f"Embeddings cosine similarity (Berkeley)   = {cos_sim(W_EE0('keley'), W_EE0(' Berkeley')):.3f}") 
print("")
print(f"Unembeddings cosine similarity (pier) = {cos_sim(W_U(' pier'), W_U(' Pier')):.3f}") 
print(f"Embeddings cosine similarity (pier)   = {cos_sim(W_EE0(' pier'), W_EE0(' Pier')):.3f}") 

In [None]:
t.cosine_similarity(
    W_EE0(" screen") - W_EE0(" screens"),
    W_EE0(" device") - W_EE0(" devices"),
    dim=-1
).item()

In [None]:
t.cosine_similarity(
    W_EE(" computer") - W_EE(" computers"),
    W_EE(" sign") - W_EE(" signs"),
    dim=-1
).item()