## Setup

This file stores experiments for both old (filtering) and experimental new work on CSPA.

The setup and semantic similarity sections must be run. But the rest of the sections can be run independently.

In [1]:
# %pip install pattern --no-dependencies
# %pip install nltk
# %pip install protobuf==3.20.0

# ## Setup

# Note - I couldn't figure out how to fix local imports for a while. Solution ended up being to make sure that this version of `transformer_lens` is before my libraries in `sys.path` (hence I'm inserting it at position zero in the code below).
import os, sys
from pathlib import Path
p = Path(r"/home/ubuntu/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *
t.set_grad_enabled(False)

from transformer_lens.rs.callum2.cspa.cspa_functions import (
    FUNCTION_STR_TOKS,
    get_cspa_results,
    get_cspa_results_batched,
    get_performance_recovered,
    OVProjectionConfig, 
    QKProjectionConfig,
)
from transformer_lens.rs.callum2.utils import (
    parse_str,
    parse_str_toks_for_printing,
    parse_str_tok_for_printing,
    ST_HTML_PATH,
    process_webtext,
)
from transformer_lens.rs.callum2.cspa.cspa_plots import (
    generate_scatter,
    generate_loss_based_scatter,
    show_graphs_and_summary_stats,
    add_cspa_to_streamlit_page,
)
from transformer_lens.rs.callum2.generate_st_html.model_results import (
    get_result_mean,
    get_model_results,
)
from transformer_lens.rs.callum2.generate_st_html.generate_html_funcs import (
    generate_4_html_plots,
    CSS,
)
from transformer_lens.rs.callum2.cspa.cspa_semantic_similarity import (
    get_equivalency_toks,
    get_related_words,
    concat_lists,
    make_list_correct_length,
    create_full_semantic_similarity_dict,
)

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda",
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)
clear_output()

In [3]:
BATCH_SIZE = 500 # 80 for viz
SEQ_LEN = 1000 # 61 for viz

current_batch_size = 17 # These are smaller values we use for vizualization since only these appear on streamlit
current_seq_len = 61

NEGATIVE_HEADS = [(10, 7), (11, 10)]
DATA_TOKS, DATA_STR_TOKS_PARSED, indices = process_webtext(seed=6, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, model=model, verbose=True, return_indices=True)

Shape = torch.Size([500, 1000])

First prompt:
<|endoftext|>Going into Election Day, few industries seemed in worse shape than America's private prisons. Prison populations, which had been rising for decades, were falling. In 2014, Corrections Corporation of America, the biggest private-prison company in the U.S., lost its contract to run Idaho's largest prison, after lawsuits relating to understaffing and violence that had earned the place the nickname Gladiator School. There were press exposés of shocking conditions in the industry and signs of a policy shift toward it. In April, Hillary Clinton said, "We should end private prisons." In August, the Justice Department said that private federal prisons were less safe and less secure than government-run ones. The same month, the department announced that it would phase out the use of private prisons at the federal level. Although most of the private-prison industry operates on the state level (immigrant-detention centers are its other b

## Hardcoded semantic similarity

We don't use semantic similarity in the maintext for CSPA, so understanding this can be skipped if you're only interested in that.

This will take tokens, and return the tokens of semantically similar words. 

There are 3 categories of semantically similar tokens `s*` for any given token `s`:

1. Equivalence relations - this captures things like plurals, tokenization, capitalization.
2. Superstrings - for instance, of you have `"keley"` this gives you `" Berkeley"`.
3. Substrings - for instance, of you have `" Berkeley"` this gives you `"keley"`.

How does this work?

1. For each token, generate all (1), and then see which ones actually split into multiple tokens (these become (3)).
2. Iterate through this entire dict to generate (2)s for every token (this is basically like flipping the arrows in the other direction).

### Problems with this method

There are 4 problems with this method. I think (1) and (3) are the most problematic.

1. The pluralization isn't sufficiently flexible, and it'll miss out on categories of things, for example:
    * `" write"` and `" writing"` and `" writer"`
    * `" rental"` and `" rented"` and `" renting"`
    * (OV and QK circuits show that these do suppress each other)
    * Possible solution - more hardcoded rules?
2. Misses some important things which aren't semantically similar as we've defined it, e.g. `1984` and `1985` aren't semantically similar (OV and QK circuits show that they do suppress each other)
    * Possible solution - ???

However, this maybe isn't worth further optimization, because it doesn't marginally improve the results by much.

In [4]:
USE_SEMANTICITY = True

if USE_SEMANTICITY:
    from pattern.text.en import conjugate, PRESENT, PAST, FUTURE, SUBJUNCTIVE, INFINITIVE, PROGRESSIVE, PLURAL, SINGULAR
    from nltk.stem import WordNetLemmatizer
    import nltk
    MY_TENSES = [PRESENT, PAST, FUTURE, SUBJUNCTIVE, INFINITIVE, PROGRESSIVE]
    MY_NUMBERS = [PLURAL, SINGULAR]
    from nltk.corpus import wordnet
    nltk.download('wordnet')
    clear_output()

In [5]:
if USE_SEMANTICITY:
    cspa_semantic_dict = pickle.load(open(ST_HTML_PATH.parent.parent / "cspa/cspa_semantic_dict_full.pkl", "rb"))

else:
    warnings.warn("Not using semanticity unlike old notebook versions!")
    cspa_semantic_dict = {}

In [6]:
display(HTML("<h2>Related words</h2>This doesn't include tokenization fragments; it's just linguistic."))

for word in ["Berkeley", "pier", "pie", "ring", "device", "robot", "w"]:
    try: print(get_related_words(word, model))
    except: print(get_related_words(word, model)); print("(Worked on second try!)") # maybe because it downloads?

['Berkeley']
(Worked on second try!)
['pier']
['pie', 'pies', 'pier']
['ring', 'rings', 'ringing']
['device', 'devices']
['robot', 'robots', 'robotics', 'robotic']
['w']


In [7]:
display(HTML("<h2>Equivalency words</h2>These are the words which will be included in the semantic similarity cluster, during CSPA."))

for tok in [" Berkeley", " Pier", " pier", "pie", " pies", " ring", " device", " robot", "w"]:
    print(f"{tok!r:>10} -> {get_equivalency_toks(tok, model)}")

' Berkeley' -> ([' Berkeley'], ['Ber', 'keley'])
   ' Pier' -> ([' pier', ' Pier'], [])
   ' pier' -> ([' pier', ' Pier'], [])
     'pie' -> (['pie', 'Pie', ' pie', ' pies', ' Pie'], [])
   ' pies' -> (['pie', 'Pie', ' pies', ' pie', ' Pie'], [])
   ' ring' -> (['ring', 'rings', 'Ring', ' ring', ' rings', ' Ring', ' Rings'], [])
 ' device' -> (['device', 'devices', 'Device', ' device', ' devices', ' Device', ' Devices'], [])
  ' robot' -> ([' robot', ' robots', ' Robot', ' Robots'], [])
       'w' -> (['w', 'ws', 'W', 'Ws', 'WS', ' w', ' W', ' WS'], [])


In [8]:
if USE_SEMANTICITY:
    table = Table("Source token", "All semantically related", title="Semantic similarity: bidirectional, superstrings, substrings") #  "Top 3 related" in the middle

    str_toks = [" Berkeley", "keley", " University", " Mary", " Pier", " pier", "NY", " ring", " W", " device", " robot", " jump", " driver", " Cairo"]
    print_cutoff = 105 # 70
    def cutoff(s):
        if len(s_str := str(s)) >= print_cutoff: return s_str[:print_cutoff-4] + ' ...'
        else: return s_str

    for str_tok in str_toks:
        top3_sim = "\n".join(list(map(repr, concat_lists(cspa_semantic_dict[str_tok])[:3])))
        bidir, superstr, substr = cspa_semantic_dict[str_tok]
        all_sim = "\n".join([
            cutoff(f"{len(bidir)} bidirectional: {bidir}"),
            cutoff(f"{len(superstr)} super-tokens:  {superstr}"),
            cutoff(f"{len(substr)} sub-tokens:    {substr}"),
        ]) + "\n"
        table.add_row(repr(str_tok), all_sim) # top3_sim in the middle

    rprint(table)

In [9]:
# Finally, let's save a mean for later use...

result_mean = get_result_mean([(10, 7), (11, 10)], DATA_TOKS[:100, :], model, verbose=True)
# t.save(result_mean, f"/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/st_page/media/result_mean.pt")

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:01<00:00,  5.47it/s]


## Running Filtering CSPA code

### Settings: `K_u = 0.05`, no semantic similarity, batch size = 100

We should recover ~76.9% KL Divergence, as reported in the paper draft

In [None]:
# Empirically, as long as SEQ_LEN large, small BATCH_SIZE gives quite good estimates
QK_OV_BATCH_SIZE = 20
QK_OV_SEQ_LEN = 600

cspa_results_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:QK_OV_BATCH_SIZE, :QK_OV_SEQ_LEN],
    max_batch_size = 1, # 50,
    negative_head = (10, 7),
    interventions = ["ov", "qk"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = False,
    computation_device = "cpu", # device
)

In [None]:
print(
    "The performance recovered is...",
    get_performance_recovered(cspa_results_qk_ov), # 79.6
)

We can also show several nice visuals for every CSPA run...

In [None]:
show_graphs_and_summary_stats(cspa_results_qk_ov)

<h3>KL and Loss seems to be measuring fairly different things</h3>

(Though this section isn't very exciting).

Answering the question: "Is there a correlation between the absolute value of 10.7 effect on loss and 10.7's effect on increasing KL Divergence from the model?"

In [None]:
VIZ_BATCH_SIZE = QK_OV_BATCH_SIZE # change to current_batch_size for a smaller subset that can be streamlit vizualized
VIZ_SEQ_LEN = SEQ_LEN -1 # change to current_seq_len for a smaller subset that can be streamlit vizualized
BATCH_INDICES = torch.arange(VIZ_BATCH_SIZE) # (change to indices to filter for correct vals on streamlit)

batch_indices = (BATCH_INDICES[:VIZ_BATCH_SIZE].unsqueeze(-1) + torch.zeros(VIZ_SEQ_LEN).unsqueeze(0))
seq_indices = (torch.zeros(VIZ_BATCH_SIZE).unsqueeze(-1) + torch.arange(VIZ_SEQ_LEN).unsqueeze(0))

my_dict = {
    "Mean Ablate 10.7 Loss - Normal Loss": cspa_results_qk_ov["loss_ablated"].clone().cpu() - cspa_results_qk_ov["loss"].clone().cpu(),
    "Mean Ablation KL Divergence to Model": cspa_results_qk_ov["kl_div_ablated_to_orig"][:, :-1].clone().cpu(),
    "CSPA Loss - Normal Loss": cspa_results_qk_ov["loss_cspa"].clone().cpu() - cspa_results_qk_ov["loss"].clone().cpu(),
    "CSPA KL": cspa_results_qk_ov["kl_div_cspa_to_orig"].clone().cpu()[:, :-1],
}

for k in list(my_dict.keys()):
    print(k)
    print(my_dict[k].shape)

    if len(my_dict[k].shape) == 2:
        my_dict[k] = my_dict[k][:VIZ_BATCH_SIZE, :VIZ_SEQ_LEN].flatten()
    else:
        my_dict[k] = my_dict[k]

    print(my_dict[k].shape)

my_dict["Batch Indices"] = batch_indices.flatten().float()
my_dict["Seq Indices"] = seq_indices.flatten().float()

print(my_dict["Batch Indices"].shape)
print(my_dict["Seq Indices"].shape)

df = pd.DataFrame(my_dict)

In [None]:
# import warnings; warnings.warn("Check out color when working")
# fig = go.Figure()
# fig.add_trace(

color = "CSPA KL"
px.scatter(
    df,
    x = "Mean Ablate 10.7 Loss - Normal Loss",
    y = "Mean Ablation KL Divergence to Model",
    hover_data = ["Batch Indices", "Seq Indices"],
    color = color,
    color_continuous_scale = "Blues" if "KL" in color else "RdBu_r",
    color_continuous_midpoint=0.0,
    range_color=((0.0, 0.09) if "KL" in color else None),
).update_traces(
    marker=dict(
        line=dict(
            color='black', # Border color
            width=0.5  # Border width
        )
    )
).show()

Answer: the correlation is pretty weak. The good news is that there are very few cases where change in loss is ~0 and KL is large.

The line of best fit (when we take absolute values) to 3DP is

`Change in KL = 0.067 * (Absolute change in loss)`

with no constant factor is nice, though R^2 = 0.34 is low (this isn't surprising, the quantities are measuring different things).

With the following heatmap, we can see that indeed the spread of losses is a fair bit larger than the spread of KLs. None of this is wildly exciting though

In [None]:
Q = 50

x_std = df['CSPA Loss - Normal Loss'].std()
y_std = df['CSPA KL'].std()

x_min = - x_std # df['CSPA Loss - Normal Loss'].min()
x_max = x_std # df['CSPA Loss - Normal Loss'].max()
y_min = 0.0
y_max = y_std # df['CSPA KL'].max()

heatmap_vals = torch.zeros(Q, Q)

for x_quantile in range(Q):
    for y_quantile in range(Q):
        x_subset = df['CSPA Loss - Normal Loss'] >= x_min + (x_max - x_min) * x_quantile / Q 
        x_subset = x_subset & (df['CSPA Loss - Normal Loss'] <= x_min + (x_max - x_min) * (x_quantile+1) / Q)
        
        y_subset = df['CSPA KL'] >= y_min + (y_max - y_min) * y_quantile / Q 
        y_subset = y_subset & (df['CSPA KL'] <= y_min + (y_max - y_min) * (y_quantile+1) / Q)

        heatmap_size = (x_subset & y_subset).to_numpy().astype("int").sum() / len(df['CSPA KL'])
        heatmap_vals[x_quantile, y_quantile] = np.log(heatmap_size) # Can use log here...

fig = imshow(
    heatmap_vals[:, torch.arange(heatmap_vals.shape[0]-1, -1, -1)].T, # Does two things: makes axes the right way around, and in my opinion heatmaps x and y are the wrong way round
    title = f"Log Density of Points in CSPA Ranges",
    width = 500, 
    height = 500,
    labels = {"x": "CSPA Loss - Model Loss", "y": "CSPA KL"},
    x = [str(round(x_min + (x_max - x_min) * x_quantile / Q, 4)) for x_quantile in range(Q)],
    y = [str(round(y_min + (y_max - y_min) * y_quantile / Q, 5)) for y_quantile in range(Q)][::-1],
    # text_auto = ".2f",
    range_color=(heatmap_vals.min().item(), heatmap_vals.max().item()),
    color_continuous_scale="Blues",
    return_fig=True,
    color_continuous_midpoint=None,
)

# Set background color to white
fig.update_layout(
    paper_bgcolor='rgba(255,255,255,255)',
    plot_bgcolor='rgba(255,255,255,255)'
)
fig.show()

print(f"{df['CSPA Loss - Normal Loss'].std()=} {df['CSPA KL'].std()=}")


### `K_u = 0.05`, no semantic similarity, batch size = 500

(Skipped by default)

In [None]:
RUN_SLOW_CSPA = False

if RUN_SLOW_CSPA:
    cspa_results_qk_ov = get_cspa_results_batched(
        model = model,
        toks = DATA_TOKS[:, :], # [:50],
        max_batch_size = 5, # 50,
        negative_head = (10, 7),
        interventions = ["qk", "ov"],
        K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
        K_semantic = 1, # either 1 or up to 8 to capture all sem similar
        only_keep_negative_components = True,
        semantic_dict = cspa_semantic_dict,
        result_mean = result_mean,
        use_cuda = False,
        verbose = True,
        compute_s_sstar_dict = False,
    )

    # fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
    fig_loss_line = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
    fig_loss_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

    kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
    kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()

    print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
    print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
    print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
    print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
    ma_max = fig_loss_line_kl.data[0].x[-1]
    cspa_max = fig_loss_line_kl.data[0].y[-1]
    print(f"Most extreme quantile: fraction explained = 1 - ({cspa_max:.3f}/{ma_max:.3f}) = {1 - cspa_max/ma_max:.3f}")

# Investigating Projection CSPA

The hardest part of the projection CSPA setup is *likely* the query direction (for example, in IOI this caused us a massive headache).

So I tried to work on just doing query projections first (by setting `k_direction=None` and `ov_projection_config = None` in the below cells).

I can only get to 64% of the KL explained :( and it's with a pretty cursed setup. I tried lots of different setups and this seems much better than others

1. For every key token, project the query token onto the subspace of $\mathbb{R}^{d_\text{model}}$ spanned by the unembedding vector for the key token (`q_direction="unembedding"`), and the other 7 semantically similar tokens to it (see `K_unembedding=8`)
2. After doing this projection, double this query input vector (`q_input_multiplier=2.0`)
3. After calculating attention scores, manually set all BOS attention scores to the exact value needed so the same attention to BOS is achieved (`mantain_bos_attention=True`)

In [10]:
# Empirically, as long as SEQ_LEN large, small BATCH_SIZE gives quite good estimates (experiments about this deleted, too in the weeds)
Q_PROJECTION_BATCH_SIZE = 20
Q_PROJECTION_SEQ_LEN = 300

qk_projection_config = QKProjectionConfig(
    q_direction="layer9_heads",
    actually_project=False,
    k_direction=None,
    q_input_multiplier=1.0,
    use_same_scaling=False,
    mantain_bos_attention=False,
    model = model,
    save_scores = True,
    save_scaled_resid_pre = True,
)

# ov_projection_config = OVProjectionConfig()
ov_projection_config = None

In [11]:
cspa_results_q_projection = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:Q_PROJECTION_BATCH_SIZE, :Q_PROJECTION_SEQ_LEN],
    max_batch_size = 1,
    negative_head = (10, 7),
    interventions = [],
    qk_projection_config=qk_projection_config,
    ov_projection_config=ov_projection_config,
    K_unembeddings = 1.0,
    K_semantic = 8, # Be very careful making this big... very slow...
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = False,
    computation_device = "cpu",
)
gc.collect()
t.cuda.empty_cache()
clear_output()

In [12]:
print(
    "The performance recovered is...",
    get_performance_recovered(cspa_results_q_projection), # ~64
)

The performance recovered is... 0.6687323800128029


# How well does this metric do for other heads?

Note: the paper details results for the filtering version of CSPA. Currently this is a bit of a mess since I used it to try and develop a projection version of CSPA.

It turns out that we get some of the best *relative* results when we turn on the OV Projection, too.

We'll also just do `K_semantic = 1` here, as larger values slow stuff down a lot. As a datapoint, L10H7 with the 2x-query-unembedding-projection and `K_semantic=1` (i.e this experimental setup without OV projection too) gets 57% KL recovered. 

In [None]:
result_mean = get_result_mean([
    (layer, head)
    for layer in [8, 9, 10, 11] for head in range(12)
], DATA_TOKS[:100, :], model, verbose=True)

qk_projection_config = QKProjectionConfig( # Not good, we really don't get strong results here...
    # q_direction="use_copying_as_query",
    q_direction="unembedding",
    q_input_multiplier=2.0,
    use_same_scaling=False,
    mantain_bos_attention=True,
    # projection_directions = "earlier_heads",
    model = model,
)

ov_projection_config = OVProjectionConfig()
# ov_projection_config = None

# qk_projection_config = QKProjectionConfig(
#     q_direction="unembedding",
#     k_direction=None,
#     q_input_multiplier=1.0,
#     use_same_scaling=False,
#     mantain_bos_attention=True,
#     # projection_directions = "earlier_heads",
#     model = model, 
# )
# result_mean = get_result_mean([(10, 7), (11, 10)], DATA_TOKS[:100, :], model, verbose=True)

clear_output()

In [None]:
kl_results = t.zeros(2, 4, 12).to(device)
loss_results = kl_results.clone()
normed_loss_results = kl_results.clone()
non_normed_loss_results = kl_results.clone()
squared_loss_results = kl_results.clone()

for i, only_keep_negative_components in enumerate([True, False]):

    for layer, head in tqdm(list(itertools.product([8, 9, 10, 11], range(12)))):

        # Usage with your function
        cspa_results_qk_ov=get_cspa_results_batched(model=model, toks=DATA_TOKS[-50:,:200], max_batch_size=10, negative_head=(layer,head), qk_projection_config=qk_projection_config, ov_projection_config=ov_projection_config, interventions=[], K_unembeddings=1.0, K_semantic=1, semantic_dict=cspa_semantic_dict, result_mean=result_mean, use_cuda=True, verbose=False, compute_s_sstar_dict=False, computation_device=None)

        kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean().item()
        kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean().item()

        diff_of_loss_ablated_to_orig = (cspa_results_qk_ov["loss_ablated"] - cspa_results_qk_ov["loss"])
        squared_loss_diff = (diff_of_loss_ablated_to_orig**2).mean().item()
        normed_loss_diff = diff_of_loss_ablated_to_orig.abs().mean().item()
        non_normed_loss_diff = diff_of_loss_ablated_to_orig.mean().item()

        diff_of_loss_cspa_to_orig = (cspa_results_qk_ov["loss_cspa"] - cspa_results_qk_ov["loss"])
        normed_cspa_loss_diff = diff_of_loss_cspa_to_orig.abs().mean().item() 
        squared_cspa_loss_diff = (diff_of_loss_cspa_to_orig**2).mean().item()
        non_normed_cspa_loss_diff = diff_of_loss_cspa_to_orig.mean().item()

        kl_performance_explained = 1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig
        kl_results[i, layer - 8, head] = kl_performance_explained
        print(f"Head Layer {layer} {head} KL performance explained: {kl_performance_explained:.3f}")

        normed_loss_performance_explained = 1 - normed_cspa_loss_diff / normed_loss_diff
        normed_loss_results[i, layer - 8, head] = normed_loss_performance_explained

        squared_loss_performance_explained = 1 - squared_cspa_loss_diff / squared_loss_diff
        squared_loss_results[i, layer - 8, head] = squared_loss_performance_explained

        non_normed_loss_performance_explained = 1 - non_normed_cspa_loss_diff / non_normed_loss_diff
        non_normed_loss_results[i, layer - 8, head] = non_normed_loss_performance_explained

    break # We're currently not thinking much about the non-negative components

clear_output()

  0%|          | 0/48 [00:00<?, ?it/s]

Head Layer 8 0 KL performance explained: 0.036
Head Layer 8 1 KL performance explained: 0.062
Head Layer 8 2 KL performance explained: 0.050
Head Layer 8 3 KL performance explained: 0.064
Head Layer 8 4 KL performance explained: 0.008
Head Layer 8 5 KL performance explained: 0.003
Head Layer 8 6 KL performance explained: 0.042
Head Layer 8 7 KL performance explained: 0.014
Head Layer 8 8 KL performance explained: 0.115
Head Layer 8 9 KL performance explained: 0.017
Head Layer 8 10 KL performance explained: 0.088
Head Layer 8 11 KL performance explained: 0.177
Head Layer 9 0 KL performance explained: 0.153
Head Layer 9 1 KL performance explained: 0.131


In [None]:
print("Plots:")
for results_name, results in zip(["KL", "Net effect on loss", "Absolute difference in loss", "Squared effect on loss"], [kl_results, non_normed_loss_results, normed_loss_results, squared_loss_results], strict=True):
    imshow(
        results,
        facet_col = 0,
        facet_labels = ["Only keep negative components", "Keep negative and positive components"],
        title = f"{results_name} performance of head explained by CSPA",
        width = 1800, 
        height = 450,
        labels = {"x": "Head", "y": "Layer"},
        y = [str(i) for i in [8, 9, 10, 11]],
        text_auto = ".2f",
        range_color=[0,1],
        color_continuous_scale="Blues",
    )

### The actual code which appears on the dedicated streamlit page:

In [None]:
cspa_results_qk_ov, s_sstar_pairs_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS, # [:50],
    max_batch_size = 60, # 50,
    negative_head = (10, 7),
    interventions = ["qk", "ov"],
    K_unembeddings = 5,
    K_semantic = 1,
    only_keep_negative_components = True,
    semantic_dict = cspa_semantic_dict,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = True,
)
# TODO - figure out where the bottleneck is via line profiler. I thought it was projections, but now it seems like this is not the case
# Seems like it's this func: get_top_predicted_semantically_similar_tokens
# %load_ext line_profiler
# %lprun -f func func(arg, kwarg=kwarg)

fig_dict = generate_scatter(cspa_results_qk_ov, DATA_STR_TOKS_PARSED, batch_index_colors_to_highlight=[51, 300])
fig_loss_line = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="loss")
fig_loss_line_kl = generate_loss_based_scatter(cspa_results_qk_ov, nbins=200, values="kl-div")

kl_div_ablated_to_orig = cspa_results_qk_ov["kl_div_ablated_to_orig"].mean()
kl_div_cspa_to_orig = cspa_results_qk_ov["kl_div_cspa_to_orig"].mean()
print(f"Mean KL divergence from ablated to original: {kl_div_ablated_to_orig:.4f}")
print(f"Mean KL divergence from CSPA to original: {kl_div_cspa_to_orig:.4f}")
print(f"Ratio = {kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")
print(f"Performance explained = {1 - kl_div_cspa_to_orig / kl_div_ablated_to_orig:.3f}")

# Adding CSPA to the Streamlit page ("Browse Examples")

This code adds the CSPA plots to the HTML plots for the Streamlit page. It creates a 5th tab called `CSPA`, and adds to the logit and DLA plots in the second tab (the latter is mainly for our use, while we're iterating on and improving the CSPA code).

I've added to this code in a pretty janky way, so that it can show more than one CSPA plot stacked on top of each other.

In [None]:
cspa_results, s_sstar_pairs = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:48, :61], # [:50],
    max_batch_size = 2, # 50,
    negative_head = (10, 1),
    interventions = ["qk", "ov"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    only_keep_negative_components = False,
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = True,
    return_dla = True,
    return_logits = True,
    keep_self_attn = True,
)

In [None]:
add_cspa_to_streamlit_page(
    cspa_results = cspa_results,
    s_sstar_pairs = s_sstar_pairs,
    html_plots_filename = f"GZIP_HTML_PLOTS_b48_s61.pkl",
    data_str_toks_parsed = [s[:61] for s in DATA_STR_TOKS_PARSED[:48]],
    toks_for_doing_DLA = DATA_TOKS[:48, :61],
    model = model,
    verbose = True,
    # test_idx = 36,
)

In [None]:
def cos_sim_of_toks(
    toks1: List[str],
    toks2: List[str],
):
    U1 = model.W_U.T[model.to_tokens(toks1, prepend_bos=False).squeeze()]
    U2 = model.W_U.T[model.to_tokens(toks2, prepend_bos=False).squeeze()]

    if U1.ndim == 1: U1 = U1.unsqueeze(0)
    if U2.ndim == 1: U2 = U2.unsqueeze(0)

    U1_normed = U1 / t.norm(U1, dim=-1, keepdim=True)
    U2_normed = U2 / t.norm(U2, dim=-1, keepdim=True)

    imshow(
        U1_normed @ U2_normed.T,
        title = "Cosine similarity of unembeddings",
        x = toks2,
        y = toks1,
    )

cos_sim_of_toks(
    [" stuff"],
    [" devices", " phones", " screens", " device", " phone", " Android"]
)

# Testing the code for "love and war"

In [None]:
result_mean_as_tensor = t.load(ST_HTML_PATH / "result_mean.pt")
result_mean = {(10, 7): result_mean_as_tensor[0], (11, 10): result_mean_as_tensor[1]}

prompt = "I picked up the first box. I picked up the second box. I picked up the third and final box."
toks = model.to_tokens(prompt)
str_toks = model.to_str_tokens(toks)
if isinstance(str_toks[0], str): str_toks = [str_toks]
# Parse the string tokens for printing
str_toks_parsed = [list(map(parse_str_tok_for_printing, s)) for s in str_toks]

model_results = get_model_results(
    model,
    toks=toks,
    negative_heads=[(10, 7), (11, 10)],
    result_mean=result_mean,
    verbose=False
)
HTML_PLOTS_NEW = generate_4_html_plots(
    model=model,
    data_toks=toks,
    data_str_toks_parsed=str_toks_parsed,
    negative_heads=[(10, 7), (11, 10)],
    model_results=model_results,
    save_files=False,
)
cspa_results, s_sstar_pairs = get_cspa_results(
    model=model,
    toks=toks,
    negative_head=(10, 7), #  this currently doesn't do anything; it's always 10.7
    components_to_project=["o"],
    K_unembeddings=5,
    K_semantic=3,
    semantic_dict=cspa_semantic_dict,
    effective_embedding="W_E (including MLPs)",
    result_mean=result_mean,
    use_cuda=False,
    return_dla=True,
)
HTML_PLOTS_NEW = add_cspa_to_streamlit_page(
    cspa_results=cspa_results,
    s_sstar_pairs=s_sstar_pairs,
    data_str_toks_parsed=str_toks_parsed,
    model=model,
    HTML_PLOTS=HTML_PLOTS_NEW,
    toks_for_doing_DLA=toks,
    verbose=False,
    test_idx=0,
)

# Modular CSPA

In [None]:
OV_BATCH_SIZE = 50

cspa_results_qk = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:OV_BATCH_SIZE],
    max_batch_size = 1, # 50,
    negative_head = (10, 7),
    interventions = ["qk"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = False,
    verbose = True,
    compute_s_sstar_dict = False,
)
clear_output() # Weird cell, it hogs space

In [None]:
show_graphs_and_summary_stats(cspa_results_qk)

In [None]:
print(
    get_performance_recovered(cspa_results_q_projection, verbose=True),
    Q_PROJECTION_BATCH_SIZE,
    Q_PROJECTION_SEQ_LEN,
)

# Investigating individual prompts

It seems important to dive down into why the projection CSPA only gets ~64%. These cells do this, though TODO Arthur move this out of here, it's better placed in an experimental file.

In [None]:
PLOT_BATCH_SIZE = 18
PLOT_SEQ_LEN = 50

df = pd.DataFrame({
    "cspa_kl": to_numpy(cspa_results_q_projection["kl_div_cspa_to_orig"][:PLOT_BATCH_SIZE, :PLOT_SEQ_LEN].flatten()),
    "ablated_kl": to_numpy(cspa_results_q_projection["kl_div_ablated_to_orig"][:PLOT_BATCH_SIZE, :PLOT_SEQ_LEN].flatten()),
    # "indices": sum(list(enumerate([[seq_idx for seq_idx in range(Q_PROJECTION_SEQ_LEN)] for _ in range(Q_PROJECTION_BATCH_SIZE)]))),
    "hover_data" : [str((indices[batch_idx], seq_idx, "which is", batch_idx)) for batch_idx in range(PLOT_BATCH_SIZE) for seq_idx in range(PLOT_SEQ_LEN)],
})

fig = px.scatter(
    df,
    x = "cspa_kl",
    y = "ablated_kl",
    hover_data = "hover_data",
)

fig.show()
# Save figure
# fig.write_image("fig_of_points.pdf")

# Studying some failures of projection, across several different runs. Maybe mean ablation means different things are differently destructive? Scary if so!
# (8, 49), (1, 40), (18, 10), (8, 35)
# (8, 49): Covered below; we pick " homeowners" and " rules" but the model attends to " Aurora" (and " Council"). Note this is on a comma. Maybe we surpress capital letters after that?
# (1, 40): We should attend to private. But instead we attend to prisons. Really weird though, as private is higher predicted, and OV circuit analysis didn't seem to help
# (18, 10): with" -> "TPP" should be suppressed. But " Lee" is predicted more.
# (8, 35): requiring" -> " rentals" is actually attended to. But " homeowners" is predicted more, and hence this ablation picks it more.
# (33, 42): model attends to " remove" -> " Blackberry", we attend to " remove" -> " ban". And Blackberry is the top prediction! Ban is in fact third
# (12, 26): we have half the amount of attention to ' Art' as we should. Context is "\n\n" -> " Art" and the only higher predicted thing is " This" (which is likely a function word)

In [None]:
# What are some examples of failures?

def print_attentions(batch_idx, seq_idx):
    print(sorted(list(enumerate(cspa_results_q_projection["pattern"][batch_idx, seq_idx, :seq_idx].tolist())), key=lambda x: x[1], reverse=True))    

# cspa_results_q_projection["pattern"][13, 14, :15] # indices[13] = 35. of -> Neil. We put 90% prob on Neil, but for some reason the model also suppresses "About"
print_attentions(3, 49) # indices[3] = 8. We pick " homeowners" and " rules" but the model attends to " Aurora" and " Council". Note this is on a comma. Maybe we surpress capital letters after that?
# print_attentions(12, 26) # indices[12] = 34. We put too much attention on "This" rather than full 50% on "Art". Failure due to not using semantically similar tokens
# print_attentions(11, 42) # indices[11] = 33. We should attend to Blackberry more. No idea how " meetings" is almost the same amount of attention...
print_attentions(0, 40) # ... [1] we should put weight on " prisons" not private...

print_attentions(5, 10)
print_attentions(3, 35)
print_attentions(11, 42) 
print_attentions(15, 37)
print_attentions(12, 26)

In [None]:
# Do a sanity check on some cases where we are doing well, to check that really this "surprising attention" bug is real.
print_attentions(2, 28) 
print("Yeah we absolutely nailed the ~80% attention here")

Question: what's going wrong? Are our attention scores too high on incorrect, or too low on correct?

First let's survey the cases where things go well...

In [None]:
cspa_results_q_projection.keys()

In [None]:
# Grab the cases where ablated KL >0.1 and CSPA KL <0.05

index_relevant = (cspa_results_q_projection["kl_div_ablated_to_orig"] > 0.1) &  (cspa_results_q_projection["kl_div_cspa_to_orig"] < 0.05)
indices_raw = np.nonzero(to_numpy(index_relevant.flatten()))[0]
indices = list(zip(
    indices_raw//index_relevant.shape[-1],
    indices_raw%index_relevant.shape[-1],
    strict=True,
))