## Setup

Note - I couldn't figure out how to fix local imports for a while. Solution ended up being to make sure that this version of `transformer_lens` is before my libraries in `sys.path` (hence I'm inserting it at position zero in the code below).

In [1]:
# %pip install pattern --no-dependencies
# %pip install nltk
# %pip install protobuf==3.20.0

import os, sys
from pathlib import Path
p = Path(r"/home/ubuntu/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *
t.set_grad_enabled(False)

from transformer_lens.rs.callum2.cspa.cspa_functions import (
    FUNCTION_STR_TOKS,
    get_cspa_results,
    get_cspa_results_batched,
)
from transformer_lens.rs.callum2.utils import (
    parse_str,
    parse_str_toks_for_printing,
    parse_str_tok_for_printing,
    ST_HTML_PATH,
    process_webtext,
)
from transformer_lens.rs.callum2.cspa.cspa_plots import (
    generate_scatter,
    generate_loss_based_scatter,
    add_cspa_to_streamlit_page,
)
from transformer_lens.rs.callum2.generate_st_html.model_results import (
    get_result_mean,
    get_model_results,
)
from transformer_lens.rs.callum2.generate_st_html.generate_html_funcs import (
    generate_4_html_plots,
    CSS,
)
from transformer_lens.rs.callum2.cspa.cspa_semantic_similarity import (
    get_equivalency_toks,
    get_related_words,
    concat_lists,
    make_list_correct_length,
    create_full_semantic_similarity_dict,
)

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda",
    # fold value bias?
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

clear_output()

In [3]:
BATCH_SIZE = 500 # 80 for viz
SEQ_LEN = 1000 # 61 for viz
batch_idx = 36

current_batch_size = 17 # smaller because
current_seq_len = 61

NEGATIVE_HEADS = [(10, 7), (11, 10)]
DATA_TOKS, DATA_STR_TOKS_PARSED, indices = process_webtext(seed=6, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, model=model, verbose=True, return_indices=True)

Shape = torch.Size([500, 1000])

First prompt:
<|endoftext|>Going into Election Day, few industries seemed in worse shape than America's private prisons. Prison populations, which had been rising for decades, were falling. In 2014, Corrections Corporation of America, the biggest private-prison company in the U.S., lost its contract to run Idaho's largest prison, after lawsuits relating to understaffing and violence that had earned the place the nickname Gladiator School. There were press exposés of shocking conditions in the industry and signs of a policy shift toward it. In April, Hillary Clinton said, "We should end private prisons." In August, the Justice Department said that private federal prisons were less safe and less secure than government-run ones. The same month, the department announced that it would phase out the use of private prisons at the federal level. Although most of the private-prison industry operates on the state level (immigrant-detention centers are its other b

## Hardcoded semantic similarity

This will take tokens, and return the tokens of semantically similar words. 

There are 3 categories of semantically similar tokens `s*` for any given token `s`:

1. Equivalence relations - this captures things like plurals, tokenization, capitalization.
2. Superstrings - for instance, of you have `"keley"` this gives you `" Berkeley"`.
3. Substrings - for instance, of you have `" Berkeley"` this gives you `"keley"`.

How does this work?

1. For each token, generate all (1), and then see which ones actually split into multiple tokens (these become (3)).
2. Iterate through this entire dict to generate (2)s for every token (this is basically like flipping the arrows in the other direction).

### Problems with this method

There are 4 problems with this method. I think (1) and (3) are the most problematic.

1. The pluralization isn't sufficiently flexible, and it'll miss out on categories of things, for example:
    * `" write"` and `" writing"` and `" writer"`
    * `" rental"` and `" rented"` and `" renting"`
    * (OV and QK circuits show that these do suppress each other)
    * Possible solution - more hardcoded rules?
2. Misses some important things which aren't semantically similar as we've defined it, e.g. `1984` and `1985` aren't semantically similar (OV and QK circuits show that they do suppress each other)
    * Possible solution - ???

However, this maybe isn't worth further optimization, because it doesn't marginally improve the results by much.

In [4]:
from pattern.text.en import conjugate, PRESENT, PAST, FUTURE, SUBJUNCTIVE, INFINITIVE, PROGRESSIVE, PLURAL, SINGULAR
from nltk.stem import WordNetLemmatizer
import nltk
MY_TENSES = [PRESENT, PAST, FUTURE, SUBJUNCTIVE, INFINITIVE, PROGRESSIVE]
MY_NUMBERS = [PLURAL, SINGULAR]
from nltk.corpus import wordnet
nltk.download('wordnet')
clear_output()

In [5]:
# cspa_semantic_dict_reversed, cspa_semantic_dict = create_full_semantic_similarity_dict(model, full_version=True)

# with open(ST_HTML_PATH.parent.parent / "cspa/cspa_semantic_dict_full.pkl", "wb") as f:
#     pickle.dump(cspa_semantic_dict, f)

cspa_semantic_dict = pickle.load(open(ST_HTML_PATH.parent.parent / "cspa/cspa_semantic_dict_full.pkl", "rb"))

FileNotFoundError: [Errno 2] No such file or directory: '/mount/src/seri-mats-2023-streamlit-pages/transformer_lens/rs/callum2/cspa/cspa_semantic_dict_full.pkl'

In [6]:
cspa_semantic_dict = {}

In [7]:
display(HTML("<h2>Related words</h2>This doesn't include tokenization fragments; it's just linguistic."))

for word in ["Berkeley", "pier", "pie", "ring", "device", "robot", "w"]:
    try: print(get_related_words(word, model))
    except: print(get_related_words(word, model)); print("(Worked on second try!)") # maybe because it downloads?

['Berkeley']
(Worked on second try!)
['pier']
['pie', 'pier', 'pies']
['ring', 'rings', 'ringing']
['device', 'devices']
['robot', 'robots', 'robotic', 'robotics']
['w']


In [8]:
display(HTML("<h2>Equivalency words</h2>These are the words which will be included in the semantic similarity cluster, during CSPA."))

for tok in [" Berkeley", " Pier", " pier", "pie", " pies", " ring", " device", " robot", "w"]:
    print(f"{tok!r:>10} -> {get_equivalency_toks(tok, model)}")

' Berkeley' -> ([' Berkeley'], ['Ber', 'keley'])
   ' Pier' -> ([' pier', ' Pier'], [])
   ' pier' -> ([' pier', ' Pier'], [])
     'pie' -> (['pie', 'Pie', ' pie', ' pies', ' Pie'], [])
   ' pies' -> (['pie', 'Pie', ' pies', ' pie', ' Pie'], [])
   ' ring' -> (['ring', 'rings', 'Ring', ' ring', ' rings', ' Ring', ' Rings'], [])
 ' device' -> (['device', 'devices', 'Device', ' device', ' devices', ' Device', ' Devices'], [])
  ' robot' -> ([' robot', ' robots', ' Robot', ' Robots'], [])
       'w' -> (['w', 'ws', 'W', 'Ws', 'WS', ' w', ' W', ' WS'], [])


In [9]:
table = Table("Source token", "All semantically related", title="Semantic similarity: bidirectional, superstrings, substrings") #  "Top 3 related" in the middle

str_toks = [" Berkeley", "keley", " University", " Mary", " Pier", " pier", "NY", " ring", " W", " device", " robot", " jump", " driver", " Cairo"]
print_cutoff = 105 # 70
def cutoff(s):
    if len(s_str := str(s)) >= print_cutoff: return s_str[:print_cutoff-4] + ' ...'
    else: return s_str

for str_tok in str_toks:
    top3_sim = "\n".join(list(map(repr, concat_lists(cspa_semantic_dict[str_tok])[:3])))
    bidir, superstr, substr = cspa_semantic_dict[str_tok]
    all_sim = "\n".join([
        cutoff(f"{len(bidir)} bidirectional: {bidir}"),
        cutoff(f"{len(superstr)} super-tokens:  {superstr}"),
        cutoff(f"{len(substr)} sub-tokens:    {substr}"),
    ]) + "\n"
    table.add_row(repr(str_tok), all_sim) # top3_sim in the middle

rprint(table)

KeyError: ' Berkeley'

## Running CSPA code

### `K_u = 0.05`, no semantic similarity, batch size = 100

In [10]:
result_mean = get_result_mean([(10, 7), (11, 10)], DATA_TOKS[:100, :], model, verbose=True)
# t.save(result_mean, f"/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/st_page/media/result_mean.pt")

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:01<00:00,  6.36it/s]


In [25]:
# Results with batch size = 100

cspa_results_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:50],
    max_batch_size = 1, # 50,
    negative_head = (10, 7),
    interventions = ["qk", "ov"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = False,
    verbose = True,
    compute_s_sstar_dict = False,
)

Moving model to device:  cpu


Batch 1/50, shape torch.Size([1, 1000]):   0%|          | 0/50 [00:00<?, ?it/s]

Batch 3/50, shape = (1, 1000), times = [get = 6.02, s* = 0.20, aggregate = 0.00]:   6%|▌         | 3/50 [00:18<04:52,  6.22s/it]

In [None]:

# fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
fig_line_loss = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
fig_line_loss_absolute = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss-absolute")
fig_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()

# normed_loss_ablated_to_orig = cspa_results_qk_ov
l_orig = cspa_results_qk_ov["loss"].flatten()
l_abl = cspa_results_qk_ov["loss_ablated"].flatten()
l_cspa = cspa_results_qk_ov["loss_cspa"].flatten()
# title = "Change in loss from ablation (relative to clean model)"

xaxis_values = l_abl - l_orig
yaxis_values = l_cspa - l_orig
# if values == "loss-absolute":
xaxis_values = xaxis_values.abs()**2
yaxis_values = yaxis_values.abs()**2
# title = "Absolute change in loss from ablation"

normed_loss_performance_recovered = 1 - yaxis_values.mean() / xaxis_values.mean()
print("Normed recovery", normed_loss_performance_recovered.mean())

for i in range(1, 100):
    ma_max = fig_line_loss_absolute.data[0].x[-i]
    cspa_max = fig_line_loss_absolute.data[0].y[-i]
    print(f"{i}th most extreme quantile: fraction explained = 1 - ({cspa_max:.3f}/{ma_max:.3f}) = {1 - cspa_max/ma_max:.3f}")

print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
ma_max = fig_loss_line_kl.data[0].x[-1]
cspa_max = fig_loss_line_kl.data[0].y[-1]
print(f"Most extreme quantile: fraction explained = 1 - ({cspa_max:.3f}/{ma_max:.3f}) = {1 - cspa_max/ma_max:.3f}")

Normed recovery tensor(0.1332, device='cuda:0')
1th most extreme quantile: fraction explained = 1 - (0.410/0.439) = 0.068
2th most extreme quantile: fraction explained = 1 - (0.312/0.353) = 0.117
3th most extreme quantile: fraction explained = 1 - (0.277/0.293) = 0.054
4th most extreme quantile: fraction explained = 1 - (0.239/0.260) = 0.081
5th most extreme quantile: fraction explained = 1 - (0.222/0.238) = 0.069
6th most extreme quantile: fraction explained = 1 - (0.205/0.217) = 0.056
7th most extreme quantile: fraction explained = 1 - (0.192/0.200) = 0.041
8th most extreme quantile: fraction explained = 1 - (0.181/0.189) = 0.041
9th most extreme quantile: fraction explained = 1 - (0.166/0.175) = 0.050
10th most extreme quantile: fraction explained = 1 - (0.155/0.165) = 0.060
11th most extreme quantile: fraction explained = 1 - (0.149/0.156) = 0.044
12th most extreme quantile: fraction explained = 1 - (0.141/0.148) = 0.049
13th most extreme quantile: fraction explained = 1 - (0.136/0

NameError: name 'fig_loss_line_kl' is not defined

# Is there a correlation between the absolute value of 10.7 effect on loss and 10.7's effect on increasing KL Divergence from the model?

In [None]:
batch_indices = (torch.tensor(indices)[:current_batch_size].unsqueeze(-1) + torch.zeros(current_seq_len).unsqueeze(0))
seq_indices = (torch.zeros(current_batch_size).unsqueeze(-1) + torch.arange(current_seq_len).unsqueeze(0))

my_dict = {
    "Mean Ablate 10.7 Loss - Normal Loss": cspa_results_qk_ov["loss_ablated"].clone().cpu() - cspa_results_qk_ov["loss"].clone().cpu(),
    "Mean Ablation KL Divergence to Model": cspa_results_qk_ov["kl_div_ablated_to_orig"][:, :-1].clone().cpu(),
    "CSPA Loss - Normal Loss": cspa_results_qk_ov["loss_cspa"].clone().cpu() - cspa_results_qk_ov["loss"].clone().cpu(),
    "CSPA KL": cspa_results_qk_ov["kl_div_cspa_to_orig"][:, :-1].clone().cpu(),
}

for k in list(my_dict.keys()):
    print(my_dict[k].shape)
    my_dict[k] = my_dict[k][:current_batch_size, :current_seq_len].flatten()

print(batch_indices.shape, seq_indices.shape)

my_dict["Batch Indices"] = batch_indices.flatten().float()
my_dict["Seq Indices"] = seq_indices.flatten().float()

df = pd.DataFrame(my_dict)

In [4]:
assert all([v.shape == (current_batch_size*current_seq_len,) for _, v in my_dict.items()])

NameError: name 'my_dict' is not defined

In [None]:
# import warnings; warnings.warn("Check out color when working")
# fig = go.Figure()
# fig.add_trace(
px.scatter(
    df,
    x = "Mean Ablate 10.7 Loss - Normal Loss",
    y = "Mean Ablation KL Divergence to Model",
    hover_data = ["Batch Indices", "Seq Indices"],
    color = "CSPA KL", 
    labels = {"x": "Loss under 10.7 Mean Ablation - Ordinary Loss", "y": "KL Divergence under Mean Ablation"}, # TODO fix these
    color_continuous_scale = "Blues",
    color_continuous_midpoint=0.0,
    range_color=(0.0, 0.09),
).update_traces(
    marker=dict(
        line=dict(
            color='black',  # Border color
            width=0.5  # Border width
        )
    )
).show()

In [5]:
import plotly.graph_objects as go
import pandas as pd

Q=30

# Create quantile labels and bins
df['quantile'], bins = pd.qcut(df['CSPA Loss - Normal Loss'], Q, labels=False, retbins=True)

# Initialize figure
fig = go.Figure()

def interpolate_colors(N):
    red = np.array([1, 0, 0])
    blue = np.array([0, 0, 1])
    colors = []
    
    for i in range(N):
        ratio = i / (N - 1)
        interpolated_color = (1 - ratio) * red + ratio * blue
        hex_color = "#{:02x}{:02x}{:02x}".format(int(interpolated_color[0]*255), int(interpolated_color[1]*255), int(interpolated_color[2]*255))
        colors.append(hex_color)
        
    return colors


colors = interpolate_colors(Q)

# Add scatter traces for each quantile
for i, q in enumerate(sorted(df['quantile'].unique())):
    subset = df[df['quantile'] == q]
    min_val = round(bins[q], 2)
    max_val = round(bins[q + 1], 2)
    fig.add_trace(
        go.Scatter(
            x=subset['Mean Ablation KL Divergence to Model'],
            y=subset['Mean Ablate 10.7 Loss - Normal Loss'],
            mode='markers',
            name=f'CSPA quantile {min_val} to {max_val}, average contribution {round(float(np.array(subset["CSPA Loss - Normal Loss"]).mean()), 4)}',
            marker=dict(color=colors[i])
        )
    )

# Update layout
fig.update_layout(
    xaxis_title="KL Divergence under Mean Ablation",
    yaxis_title="Loss under 10.7 Mean Ablation - Ordinary Loss",
)


for i in range(Q):
    fig.data[i].visible = "legendonly"
# fig.show(config={"staticPlot": True})
fig.show()

print("Average increase in loss from CSPA", cspa_results_qk_ov["loss_cspa"].mean() - cspa_results_qk_ov["loss"].mean())
print(f"Which should be the same as the average of the above: {df['CSPA Loss - Normal Loss'].mean()}")

NameError: name 'df' is not defined

Answer: the correlation is pretty weak. The good news is that there do not appear to be many points that loss systematically misses but KL captures

In [None]:
# Results with batch size = 100

cspa_results_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:, :], # [:50],
    max_batch_size = 1, # 50,
    negative_head = (10, 7),
    interventions = ["qk", "ov"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = False,
    verbose = True,
    compute_s_sstar_dict = False,
)

# fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
fig_loss_line = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
fig_loss_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()

print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
ma_max = fig_loss_line_kl.data[0].x[-1]
cspa_max = fig_loss_line_kl.data[0].y[-1]
print(f"Most extreme quantile: fraction explained = 1 - ({cspa_max:.3f}/{ma_max:.3f}) = {1 - cspa_max/ma_max:.3f}")

### `K_u = 0.05`, no semantic similarity, batch size = 500

In [6]:
cspa_results_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:, :], # [:50],
    max_batch_size = 5, # 50,
    negative_head = (10, 7),
    interventions = ["qk", "ov"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    only_keep_negative_components = True,
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = False,
    verbose = True,
    compute_s_sstar_dict = False,
)

# fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
fig_loss_line = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
fig_loss_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()

print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
ma_max = fig_loss_line_kl.data[0].x[-1]
cspa_max = fig_loss_line_kl.data[0].y[-1]
print(f"Most extreme quantile: fraction explained = 1 - ({cspa_max:.3f}/{ma_max:.3f}) = {1 - cspa_max/ma_max:.3f}")

NameError: name 'get_cspa_results_batched' is not defined

# How well does this metric do for other heads?

In [20]:
result_mean = get_result_mean([
    (layer, head)
    for layer in [8, 9, 10, 11] for head in range(12)
], DATA_TOKS[:100, :], model, verbose=True)

# result_mean = get_result_mean([(10, 7), (11, 10)], DATA_TOKS[:100, :], model, verbose=True)

 20%|██        | 2/10 [00:00<00:00, 13.86it/s]

100%|██████████| 10/10 [00:01<00:00,  8.06it/s]


In [21]:
kl_results = t.zeros(2, 4, 12).to(device)
loss_results = kl_results.clone()
normed_loss_results = kl_results.clone()
non_normed_loss_results = kl_results.clone()
squared_loss_results = kl_results.clone()

for i, only_keep_negative_components in enumerate([True, False]):

    for layer, head in tqdm(list(itertools.product([8, 9, 10, 11], range(12)))):

        cspa_results_qk_ov = get_cspa_results_batched(
            model = model,
            toks = DATA_TOKS[:50, :200], # [:50],
            max_batch_size = 1, # 50,
            negative_head = (layer, head),
            interventions = ["qk", "ov"],
            K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
            K_semantic = 1, # either 1 or up to 8 to capture all sem similar
            only_keep_negative_components = only_keep_negative_components,
            semantic_dict = cspa_semantic_dict,
            result_mean = result_mean,
            use_cuda = True,
            verbose = False,
            compute_s_sstar_dict = False,
        )

        kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean().item()
        kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean().item()

        diff_of_loss_ablated_to_orig = (cspa_results_qk_ov["loss_ablated"] - cspa_results_qk_ov["loss"])
        squared_loss_diff = (diff_of_loss_ablated_to_orig**2).mean().item()
        normed_loss_diff = diff_of_loss_ablated_to_orig.abs().mean().item()
        non_normed_loss_diff = diff_of_loss_ablated_to_orig.mean().item()

        diff_of_loss_cspa_to_orig = (cspa_results_qk_ov["loss_cspa"] - cspa_results_qk_ov["loss"])
        normed_cspa_loss_diff = diff_of_loss_cspa_to_orig.abs().mean().item() 
        squared_cspa_loss_diff = (diff_of_loss_cspa_to_orig**2).mean().item()
        non_normed_cspa_loss_diff = diff_of_loss_cspa_to_orig.mean().item()

        kl_performance_explained = 1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig
        kl_results[i, layer - 8, head] = kl_performance_explained

        normed_loss_performance_explained = 1 - normed_cspa_loss_diff / normed_loss_diff
        normed_loss_results[i, layer - 8, head] = normed_loss_performance_explained

        squared_loss_performance_explained = 1 - squared_cspa_loss_diff / squared_loss_diff
        squared_loss_results[i, layer - 8, head] = squared_loss_performance_explained

        non_normed_loss_performance_explained = 1 - non_normed_cspa_loss_diff / non_normed_loss_diff
        non_normed_loss_results[i, layer - 8, head] = non_normed_loss_performance_explained

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

In [22]:
for results_name, results in zip(["KL", "Net effect on loss", "Absolute difference in loss", "Squared effect on loss"], [kl_results, non_normed_loss_results, normed_loss_results, squared_loss_results], strict=True):
    imshow(
        results,
        facet_col = 0,
        facet_labels = ["Only keep negative components", "Keep negative and positive components"],
        title = f"{results_name} performance of head explained by CSPA",
        width = 1800, 
        height = 450,
        labels = {"x": "Head", "y": "Layer"},
        y = [str(i) for i in [8, 9, 10, 11]],
        text_auto = ".2f",
        range_color=[0,1],
        color_continuous_scale="Blues",
    )

In [None]:
imshow(
    all_results,
    facet_col = 0,
    facet_labels = ["Only keep negative components", "Keep negative and positive components"],
    title = "Performance of head explained by copy suppression (left = negative, right = no restriction)",
    width = 1400, 
    height = 400,
    labels = {"x": "Head", "y": "Layer"},
    y = [str(i) for i in [8, 9, 10, 11]],
    text_auto = ".2f",
    border = True,
    # color_range = [0, 1],
)

### Umm ... what?

In [9]:
cspa_results_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:80, :200], # [:50],
    max_batch_size = 2, # 50,
    negative_head = (10, 1),
    interventions = ["qk", "ov"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    only_keep_negative_components = False,
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = False,
    keep_self_attn = False,
)

# fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
fig_loss_line = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
fig_loss_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()

print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
ma_max = fig_loss_line_kl.data[0].x[-1]
cspa_max = fig_loss_line_kl.data[0].y[-1]
print(f"Most extreme quantile: fraction explained = 1 - ({cspa_max:.3f}/{ma_max:.3f}) = {1 - cspa_max/ma_max:.3f}")

NameError: name 'get_cspa_results_batched' is not defined

### The actual code which appears on the dedicated streamlit page:

In [10]:
cspa_results_qk_ov, s_sstar_pairs_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS, # [:50],
    max_batch_size = 60, # 50,
    negative_head = (10, 7),
    interventions = ["qk", "ov"],
    K_unembeddings = 5,
    K_semantic = 1,
    only_keep_negative_components = True,
    semantic_dict = cspa_semantic_dict,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = True,
)
# TODO - figure out where the bottleneck is via line profiler. I thought it was projections, but now it seems like this is not the case
# Seems like it's this func: get_top_predicted_semantically_similar_tokens
# %load_ext line_profiler
# %lprun -f func func(arg, kwarg=kwarg)

fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
fig_loss_line = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
fig_loss_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()
print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")

NameError: name 'get_cspa_results_batched' is not defined

# Adding CSPA to the Streamlit page ("Browse Examples")

This code adds the CSPA plots to the HTML plots for the Streamlit page. It creates a 5th tab called `CSPA`, and adds to the logit and DLA plots in the second tab (the latter is mainly for our use, while we're iterating on and improving the CSPA code).

I've added to this code in a pretty janky way, so that it can show more than one CSPA plot stacked on top of each other.

In [11]:
cspa_results, s_sstar_pairs = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:48, :61], # [:50],
    max_batch_size = 2, # 50,
    negative_head = (10, 1),
    interventions = ["qk", "ov"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    only_keep_negative_components = False,
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = True,
    return_dla = True,
    return_logits = True,
    keep_self_attn = True,
)

NameError: name 'get_cspa_results_batched' is not defined

In [12]:
add_cspa_to_streamlit_page(
    cspa_results = cspa_results,
    s_sstar_pairs = s_sstar_pairs,
    html_plots_filename = f"GZIP_HTML_PLOTS_b48_s61.pkl",
    data_str_toks_parsed = [s[:61] for s in DATA_STR_TOKS_PARSED[:48]],
    toks_for_doing_DLA = DATA_TOKS[:48, :61],
    model = model,
    verbose = True,
    # test_idx = 36,
)

NameError: name 'add_cspa_to_streamlit_page' is not defined

In [13]:
# b = 51
# add_cspa_to_streamlit_page(
#     cspa_results = cspa_results,
#     s_sstar_pairs = s_sstar_pairs,
#     html_plots_filename = f"GZIP_HTML_PLOTS_b{b}_s61.pkl",
#     data_str_toks_parsed = DATA_STR_TOKS_PARSED,
#     toks_for_doing_DLA = DATA_TOKS,
#     model = model,
#     verbose = True,
#     # test_idx = 32,
# )

# b = 200
# add_cspa_to_streamlit_page(
#     cspa_results = {"k=4": cspa_results, "k=1": cspa_results_1},
#     s_sstar_pairs = {"k=4": s_sstar_pairs, "k=1": s_sstar_pairs_1},
#     html_plots_filename = f"GZIP_HTML_PLOTS_b{b}_s61.pkl",
#     data_str_toks_parsed = DATA_STR_TOKS_PARSED,
#     toks_for_doing_DLA = DATA_TOKS,
#     model = model,
#     verbose = True,
#     # test_idx = 32,
# )

In [14]:
def cos_sim_of_toks(
    toks1: List[str],
    toks2: List[str],
):
    U1 = model.W_U.T[model.to_tokens(toks1, prepend_bos=False).squeeze()]
    U2 = model.W_U.T[model.to_tokens(toks2, prepend_bos=False).squeeze()]

    if U1.ndim == 1: U1 = U1.unsqueeze(0)
    if U2.ndim == 1: U2 = U2.unsqueeze(0)

    U1_normed = U1 / t.norm(U1, dim=-1, keepdim=True)
    U2_normed = U2 / t.norm(U2, dim=-1, keepdim=True)

    imshow(
        U1_normed @ U2_normed.T,
        title = "Cosine similarity of unembeddings",
        x = toks2,
        y = toks1,
    )

cos_sim_of_toks(
    [" stuff"],
    [" devices", " phones", " screens", " device", " phone", " Android"]
)

NameError: name 'List' is not defined

# Testing the code for "love and war"

In [None]:
result_mean_as_tensor = t.load(ST_HTML_PATH / "result_mean.pt")
result_mean = {(10, 7): result_mean_as_tensor[0], (11, 10): result_mean_as_tensor[1]}

prompt = "I picked up the first box. I picked up the second box. I picked up the third and final box."
toks = model.to_tokens(prompt)
str_toks = model.to_str_tokens(toks)
if isinstance(str_toks[0], str): str_toks = [str_toks]
# Parse the string tokens for printing
str_toks_parsed = [list(map(parse_str_tok_for_printing, s)) for s in str_toks]

model_results = get_model_results(
    model,
    toks=toks,
    negative_heads=[(10, 7), (11, 10)],
    result_mean=result_mean,
    verbose=False
)
HTML_PLOTS_NEW = generate_4_html_plots(
    model=model,
    data_toks=toks,
    data_str_toks_parsed=str_toks_parsed,
    negative_heads=[(10, 7), (11, 10)],
    model_results=model_results,
    save_files=False,
)
cspa_results, s_sstar_pairs = get_cspa_results(
    model=model,
    toks=toks,
    negative_head=(10, 7), #  this currently doesn't do anything; it's always 10.7
    components_to_project=["o"],
    K_unembeddings=5,
    K_semantic=3,
    semantic_dict=cspa_semantic_dict,
    effective_embedding="W_E (including MLPs)",
    result_mean=result_mean,
    use_cuda=False,
    return_dla=True,
)
HTML_PLOTS_NEW = add_cspa_to_streamlit_page(
    cspa_results=cspa_results,
    s_sstar_pairs=s_sstar_pairs,
    data_str_toks_parsed=str_toks_parsed,
    model=model,
    HTML_PLOTS=HTML_PLOTS_NEW,
    toks_for_doing_DLA=toks,
    verbose=False,
    test_idx=0,
)