# Query Composition Investigation

1. Find an example of query composition in a tiny model. 
2. Show that we can use an SAE to identify the features in the queries

Let's start with tiny stories. 

In [1]:
import sys 
sys.path.append("..")

from importlib import reload

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)


model = HookedTransformer.from_pretrained(
    "tiny-stories-2L-33M",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
# model.set_use_split_qkv_input(True)
# model.set_use_attn_result(True)
import transformer_lens.evals as evals
evals.sanity_check(model)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


tensor(6.5646, device='mps:0')

Looking for multi-token induction head.
- Let's check for induction heads first. 

In [2]:
model.generate(
    "Once upon a time", max_new_tokens=20, stop_at_eos=False
)

  0%|          | 0/20 [00:00<?, ?it/s]

'Once upon a time there were two best friends who liked to tease each other. One was dressed in the most popular uniform'

In [3]:
import pandas as pd
import itertools
import numpy as np

def tensor_to_long_data_frame(tensor_result, dimension_names):
    assert len(tensor_result.shape) == len(
        dimension_names
    ), "The number of dimension names must match the number of dimensions in the tensor"

    tensor_2d = tensor_result.reshape(-1)
    df = pd.DataFrame(tensor_2d.detach().numpy(), columns=["Score"])

    indices = pd.MultiIndex.from_tuples(
        list(np.ndindex(tensor_result.shape)),
        names=dimension_names,
    )
    df.index = indices
    df.reset_index(inplace=True)
    return df

q_composition_scores = model.all_composition_scores(mode = "Q")
q_comp_df = tensor_to_long_data_frame(q_composition_scores.detach().cpu(), ["Layer1", "Head1", "Layer2", "Head2"])
q_comp_df.sort_values("Score", ascending=False).head(10).style.background_gradient(cmap="Blues")

Unnamed: 0,Layer1,Head1,Layer2,Head2,Score
219,0,6,1,11,0.100765
211,0,6,1,3,0.080608
218,0,6,1,10,0.066766
408,0,12,1,8,0.060902
208,0,6,1,0,0.060876
213,0,6,1,5,0.059937
401,0,12,1,1,0.059086
403,0,12,1,3,0.058306
146,0,4,1,2,0.056562
150,0,4,1,6,0.056216


Ok so layer 0 head 6 seems to produce stuff that layer 1 head 11,10,3 all point to. Let's analyse that one. 

In [4]:
def get_webtext(seed: int = 420, dataset="stas/openwebtext-10k", split="train[:1%]") -> List[str]:
    """Get 10,000 sentences from the OpenWebText dataset"""

    # Let's see some WEBTEXT
    train_dataset = load_dataset(dataset, split=split)
    # train_dataset = raw_dataset["train"]
    dataset = [train_dataset[i]["text"] for i in range(len(train_dataset))]

    # Shuffle the dataset (I don't want the Hitler thing being first so use a seeded shuffle)
    np.random.seed(seed)
    np.random.shuffle(dataset)

    return dataset

data = get_webtext(dataset="roneneldan/TinyStories")

In [37]:
path = "checkpoints/399ihu5z/final_sparse_autoencoder_tiny-stories-2L-33M_blocks.1.attn.hook_q_4096.pt"
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)

In [38]:
from tqdm import tqdm

str_token_list = []
loss_list = []
ablated_loss_list = []
data = get_webtext(dataset="roneneldan/TinyStories")

import joseph.analysis as analysis
reload(analysis)

NUM_PROMPTS = 100
# MAX_PROMPT_LEN = 100
# BATCH_SIZE = 10
dataframe_list = []
with torch.no_grad():
    for i in tqdm(range(NUM_PROMPTS)):
        
        # Get Token Data
        prompt = data[i]
        # new_str = data[BATCH_SIZE * i: BATCH_SIZE * (i + 1)]
        token_df, _, _= analysis.eval_prompt(prompt, model=model, sparse_autoencoder=sparse_autoencoder)
        token_df["batch"] = i
        token_df["label"] = token_df["batch"].astype(str) + "/" + token_df["pos"].astype(str)
        dataframe_list.append(token_df)
        
df = pd.concat(dataframe_list)

print(df.shape)
print(df.columns)


100%|██████████| 100/100 [00:32<00:00,  3.12it/s]

(21253, 23)
Index(['str_tokens', 'unique_token', 'context', 'batch', 'pos', 'label',
       'loss', 'max_idx_pos', 'max_idx_tok', 'max_idx_tok_value',
       'ablated_loss', 'loss_diff', 'q_norm', 'rec_q_norm', 'mse_loss',
       'explained_variance', 'num_active_features', 'top_k_feature_acts',
       'top_k_features', 'rec_q_max_idx_pos', 'rec_q_max_idx_tok',
       'rec_q_max_idx_tok_value', 'kl_divergence'],
      dtype='object')





In [39]:
df.sort_values("loss_diff", ascending=False).head(10).style.background_gradient(cmap="Blues")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,loss,max_idx_pos,max_idx_tok,max_idx_tok_value,ablated_loss,loss_diff,q_norm,rec_q_norm,mse_loss,explained_variance,num_active_features,top_k_feature_acts,top_k_features,rec_q_max_idx_pos,rec_q_max_idx_tok,rec_q_max_idx_tok_value,kl_divergence
395,.,./395,"""Thank you so much|.| You",10,395,10/395,1.397654,319,us,0.038624,7.490783,6.093128,3.201245,2.687953,1.843834,0.179922,44,"[0.7483740448951721, 0.34911325573921204, 0.3086671531200409, 0.2846907377243042, 0.2662426829338074, 0.24799582362174988, 0.2359088659286499, 0.233854740858078, 0.22167110443115234, 0.1521887183189392]","[3968, 2, 3258, 3142, 958, 3788, 3782, 2858, 756, 5]",211,Sam,0.031844,0.055424
321,.,./321,"."" They hug|.| They",1,321,1/321,1.133966,206,They,0.039318,6.669726,5.53576,3.329991,2.744015,2.096605,0.189073,54,"[0.9538697004318237, 0.6843153834342957, 0.3306882679462433, 0.29911375045776367, 0.1904905140399933, 0.16835227608680725, 0.14630842208862305, 0.13909591734409332, 0.09914538264274597, 0.08981098234653473]","[2, 756, 2983, 5, 711, 262, 296, 3516, 3021, 3271]",217,.,0.023466,0.108552
203,.,./203,. A lady came over|.| She,9,203,9/203,0.182797,148,,0.027015,5.096091,4.913293,3.298565,2.904506,1.647383,0.151407,39,"[0.7583673596382141, 0.6457085609436035, 0.3912549316883087, 0.3522903025150299, 0.16491550207138062, 0.1455128788948059, 0.13906329870224, 0.13894009590148926, 0.11815348267555237, 0.10554268956184387]","[756, 581, 2, 2983, 2399, 470, 1726, 5, 1495, 1493]",169,.,0.027852,0.046494
290,,/290,"of the belt. | |""",9,290,9/290,0.001041,150,"""",0.043152,4.793242,4.792201,3.127081,2.444747,1.974672,0.201937,27,"[0.8442770838737488, 0.37939557433128357, 0.3112659454345703, 0.21592426300048828, 0.20538267493247986, 0.19362008571624756, 0.09512373805046082, 0.08685991168022156, 0.07054087519645691, 0.06990545988082886]","[913, 581, 2468, 3552, 2983, 2818, 125, 2878, 711, 1543]",172,"""",0.031056,0.090616
100,Jane,Jane/100,stuck! Suddenly| Jane| remembered,97,100,97/100,1.88253,8,girl,0.04247,6.303975,4.421445,2.614674,2.142208,1.378655,0.20166,15,"[0.3063059151172638, 0.2949275076389313, 0.27610546350479126, 0.20767706632614136, 0.12785053253173828, 0.12244760990142822, 0.09274059534072876, 0.057510972023010254, 0.029668346047401428, 0.029082447290420532]","[1772, 2200, 580, 2574, 1294, 1281, 3214, 2002, 2936, 3552]",8,girl,0.031227,0.038032
333,.,./333,"Wow, that sounds great|.| You",60,333,60/333,0.578132,157,kids,0.101074,4.947632,4.3695,3.212442,2.201267,2.903924,0.281394,26,"[0.616851270198822, 0.43535521626472473, 0.2790641784667969, 0.1985069215297699, 0.1687961220741272, 0.15100973844528198, 0.14544931054115295, 0.13322395086288452, 0.10267820954322815, 0.09304267168045044]","[3968, 2, 262, 3258, 756, 958, 3629, 2182, 3782, 1543]",284,Ben,0.040084,0.162477
320,is,is/320,". ""That| is| a",7,320,7/320,0.783517,255,okay,0.157569,5.108299,4.324782,3.363024,2.873428,1.678664,0.148424,44,"[0.8745102286338806, 0.432664155960083, 0.40569451451301575, 0.36599212884902954, 0.2366115152835846, 0.21437454223632812, 0.19153037667274475, 0.18864548206329346, 0.1856912076473236, 0.18377400934696198]","[3636, 863, 3258, 3697, 3301, 1377, 1044, 505, 125, 1375]",255,okay,0.127611,0.108044
150,small,small/150,. The| small| piece,66,150,66/150,1.459001,116,creature,0.184792,5.720655,4.261654,4.098626,3.816005,1.351573,0.080457,35,"[1.5256223678588867, 0.9273009896278381, 0.9246512651443481, 0.2940270006656647, 0.22126856446266174, 0.16826355457305908, 0.1631176769733429, 0.14906153082847595, 0.1234501451253891, 0.12090557813644409]","[2892, 2817, 3569, 2482, 1772, 1377, 3277, 3799, 2946, 3854]",116,creature,0.154727,0.051
129,it,it/129,make them fit. Finally| it| was,73,129,73/129,4.985309,18,they,0.039334,9.160392,4.175083,2.663731,1.784062,2.936062,0.413794,18,"[0.5609096884727478, 0.18027794361114502, 0.11692318320274353, 0.05471925437450409, 0.04773131012916565, 0.04629623889923096, 0.044071584939956665, 0.04313024878501892, 0.03637798875570297, 0.02500784397125244]","[453, 343, 3274, 2570, 1907, 2540, 1677, 188, 2589, 145]",18,they,0.029938,0.119695
190,realised,realised/190,"research, the little girl| realised| that",85,190,85/190,3.332173,140,-,0.082653,7.373481,4.041308,2.848092,2.448697,1.548691,0.190922,53,"[0.44071415066719055, 0.352009654045105, 0.24489927291870117, 0.23712041974067688, 0.2362094223499298, 0.21633923053741455, 0.20733359456062317, 0.20445820689201355, 0.19674621522426605, 0.1865248680114746]","[1952, 282, 343, 663, 1073, 3546, 1294, 2943, 3271, 262]",191,realised,0.076016,0.052216


In [7]:
px.scatter(
    df,
    x="max_idx_tok_value",
    y="loss_diff",
    hover_data=["label", "context", "max_idx_tok"],
    marginal_x="histogram",
    marginal_y="histogram",
)

Not clear what the head is doing. Possibly doing some sort of summarization?

In [8]:
px.histogram(df, x="loss_diff", nbins=100, title="Loss difference distribution")

In [9]:
from transformer_lens.utils import test_prompt

LAYER_IDX = 1; HEAD_IDX = 11
HEAD_HOOK_RESULT_NAME = utils.get_act_name("z", LAYER_IDX)
def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (LAYER_IDX, HEAD_IDX)):
    assert head[0] == hook.layer(), f"{head[0]} != {hook.layer()}"
    assert ("result" in hook.name) or ("q" in hook.name) or ("z" in hook.name)
    head_output[:, :, head[1], :] = 0
    return head_output

batch = 1; pos = 153+1
prompt = data[batch]
print(prompt)
tokens = model.to_tokens(prompt)
loss, original_cache = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)


example_prompt = model.to_string(tokens[:,:pos+1])[0]
example_prompt_answer = model.to_string(tokens[:,pos+1])
test_prompt(example_prompt, example_prompt_answer, model, prepend_bos=False, prepend_space_to_answer=False)

with model.hooks(fwd_hooks=[(HEAD_HOOK_RESULT_NAME, hook_to_ablate_head)]):
    test_prompt(example_prompt, example_prompt_answer, model, prepend_bos=False, prepend_space_to_answer=False)
    logits_reconstructed, cache_reconstructed_res_stream = model.run_with_cache(example_prompt, prepend_bos=False)



Sara and Ben like to play pirates. They have a big box that is their ship. They have hats and swords and a map. The map shows them where to find treasure.

One day, they find a shiny coin in the sand. Sara is happy. She puts the coin in her pocket. Ben is jealous. He wants the coin too. He tries to take it from Sara. They fight.

"Stop!" says Mom. "You have to share. Pirates don't fight with their friends. They work together."

Mom gives them a cloth. She tells them to wipe the coin. It is dirty. They wipe the coin. They see something on it. It is a picture of a star.

"Look!" says Sara. "It is a clue. Maybe it is part of the treasure."

"Maybe you are right," says Ben. "Let's look at the map. Maybe the star is on it."

They look at the map. They see a star. It is near a big tree. They run to the tree. They dig under it. They find a box. It is full of more coins and jewels.

"Wow!" says Sara. "We found the treasure. We are the best pirates ever."

"Yes, we are," says Ben. "And we are t

Top 0th token. Logit: 23.19 Prob: 58.48% Token: | says|
Top 1th token. Logit: 22.63 Prob: 33.47% Token: | Mom|
Top 2th token. Logit: 20.38 Prob:  3.50% Token: | said|
Top 3th token. Logit: 19.71 Prob:  1.79% Token: | Sara|
Top 4th token. Logit: 19.10 Prob:  0.98% Token: | she|
Top 5th token. Logit: 19.09 Prob:  0.97% Token: | they|
Top 6th token. Logit: 17.23 Prob:  0.15% Token: | Ben|
Top 7th token. Logit: 16.82 Prob:  0.10% Token: | They|
Top 8th token. Logit: 16.78 Prob:  0.10% Token: | the|
Top 9th token. Logit: 16.71 Prob:  0.09% Token: | say|


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells', ' them', ' to

Top 0th token. Logit: 22.80 Prob: 61.99% Token: | Mom|
Top 1th token. Logit: 21.55 Prob: 17.72% Token: | Sara|
Top 2th token. Logit: 21.06 Prob: 10.92% Token: | said|
Top 3th token. Logit: 20.10 Prob:  4.19% Token: | says|
Top 4th token. Logit: 19.28 Prob:  1.84% Token: | she|
Top 5th token. Logit: 18.76 Prob:  1.10% Token: | Ben|
Top 6th token. Logit: 18.47 Prob:  0.82% Token: | they|
Top 7th token. Logit: 16.92 Prob:  0.17% Token: | the|
Top 8th token. Logit: 16.82 Prob:  0.16% Token: | Tim|
Top 9th token. Logit: 16.50 Prob:  0.11% Token: | a|


In [10]:
# from circuitsvis.attention import attention_patterns
logits, cache = model.run_with_cache(example_prompt, prepend_bos=False)
patterns = cache[f"blocks.{LAYER_IDX}.attn.hook_pattern"][0,HEAD_IDX].detach().cpu()
# attention_patterns(tokens=model.to_str_tokens(test_prompt, prepend_bos=False), attention=patterns)
token_df = make_token_df(model, tokens[:,:pos+1])
token_df["attn"] = patterns[-1,:]
# patterns = patterns.unsqueeze(0).repeat(2,1,1)
# plot_attn(patterns, token_df)
token_df.sort_values("attn", ascending=False).head(10).style.background_gradient(cmap="Blues")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,attn
90,says,says/90,"""Stop!""| says| Mom",0,90,0/90,0.777926
91,Mom,Mom/91,"""Stop!"" says| Mom|.",0,91,0/91,0.051272
115,them,them/115,"."" Mom gives| them| a",0,115,0/115,0.024795
32,them,them/32,map. The map shows| them| where,0,32,0/32,0.006741
92,.,./92,"""Stop!"" says Mom|.| """,0,92,0/92,0.005904
9,.,./9,Ben like to play pirates|.| They,0,9,0/9,0.00501
113,Mom,Mom/113,"work together."" |Mom| gives",0,113,0/113,0.004924
80,Sara,Sara/80,tries to take it from| Sara|.,0,80,0/80,0.004779
52,Sara,Sara/52,coin in the sand.| Sara| is,0,52,0/52,0.004319
111,,/111,". They work together.""| |",0,111,0/111,0.004209


- Example 1: This head was responsible for inhibiting !" and without this it is predict confidently. 
- Example 2: This head was responsible for increasing "."
- Example 3: This head was responsbile for inhibiting ",", because "Jane" is the next token and it's attending to "Jane" and "She"

Theory: The head deals with some features relating to punctuation.

Current hypothesis is that it's doing a bunch of stuff. 
- checking whether the sentence is over.

# Copy Suppression Analysis

## Load in Example

In [13]:

batch = 1; pos = 154
prompt = data[batch]
tokens = model.to_tokens(prompt)
loss, original_cache = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)

example_prompt = model.to_string(tokens[:,:pos+1])[0]
example_prompt_answer = model.to_string(tokens[:,pos+1])
example_prompt_tokens = tokens[:,:pos+1]
_, cache_original = model.run_with_cache(example_prompt_tokens, prepend_bos=False)


def test_prompt_with_sae(example_prompt, example_prompt_answer, model, sparse_autoencoder, cache_original):
    
    model.reset_hooks()
    utils.test_prompt(example_prompt, example_prompt_answer, model, prepend_space_to_answer=False, prepend_bos=False)

    with model.hooks(fwd_hooks=[(HEAD_HOOK_RESULT_NAME, hook_to_ablate_head)]):
        utils.test_prompt(example_prompt, example_prompt_answer, model, prepend_bos=False, prepend_space_to_answer=False)

    sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
        cache_original[sparse_autoencoder.cfg.hook_point][0,HEAD_IDX]
    )
    def reconstr_query_hook(hook_in, hook, reconstructed_query=sae_out, head = HEAD_IDX):
        hook_in[:, head, :] = reconstructed_query
        return hook_in

    with model.hooks(fwd_hooks=[(sparse_autoencoder.cfg.hook_point, reconstr_query_hook)]):
        utils.test_prompt(example_prompt, example_prompt_answer, model,  prepend_bos=False, prepend_space_to_answer=False)
        
    
test_prompt_with_sae(example_prompt, example_prompt_answer, model, sparse_autoencoder, prepend_bos=False, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells', ' them', ' to

Top 0th token. Logit: 23.19 Prob: 58.48% Token: | says|
Top 1th token. Logit: 22.63 Prob: 33.47% Token: | Mom|
Top 2th token. Logit: 20.38 Prob:  3.50% Token: | said|
Top 3th token. Logit: 19.71 Prob:  1.79% Token: | Sara|
Top 4th token. Logit: 19.10 Prob:  0.98% Token: | she|
Top 5th token. Logit: 19.09 Prob:  0.97% Token: | they|
Top 6th token. Logit: 17.23 Prob:  0.15% Token: | Ben|
Top 7th token. Logit: 16.82 Prob:  0.10% Token: | They|
Top 8th token. Logit: 16.78 Prob:  0.10% Token: | the|
Top 9th token. Logit: 16.71 Prob:  0.09% Token: | say|


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells', ' them', ' to

Top 0th token. Logit: 22.80 Prob: 61.99% Token: | Mom|
Top 1th token. Logit: 21.55 Prob: 17.72% Token: | Sara|
Top 2th token. Logit: 21.06 Prob: 10.92% Token: | said|
Top 3th token. Logit: 20.10 Prob:  4.19% Token: | says|
Top 4th token. Logit: 19.28 Prob:  1.84% Token: | she|
Top 5th token. Logit: 18.76 Prob:  1.10% Token: | Ben|
Top 6th token. Logit: 18.47 Prob:  0.82% Token: | they|
Top 7th token. Logit: 16.92 Prob:  0.17% Token: | the|
Top 8th token. Logit: 16.82 Prob:  0.16% Token: | Tim|
Top 9th token. Logit: 16.50 Prob:  0.11% Token: | a|


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells', ' them', ' to

Top 0th token. Logit: 23.19 Prob: 58.48% Token: | says|
Top 1th token. Logit: 22.63 Prob: 33.47% Token: | Mom|
Top 2th token. Logit: 20.38 Prob:  3.50% Token: | said|
Top 3th token. Logit: 19.71 Prob:  1.79% Token: | Sara|
Top 4th token. Logit: 19.10 Prob:  0.98% Token: | she|
Top 5th token. Logit: 19.09 Prob:  0.97% Token: | they|
Top 6th token. Logit: 17.23 Prob:  0.15% Token: | Ben|
Top 7th token. Logit: 16.82 Prob:  0.10% Token: | They|
Top 8th token. Logit: 16.78 Prob:  0.10% Token: | the|
Top 9th token. Logit: 16.71 Prob:  0.09% Token: | say|


In [14]:
token_df, original_cache, cache_reconstructed_query = eval_prompt([example_prompt], model, sparse_autoencoder)
print(token_df.columns)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", 
               "mse_loss", "num_active_features", "max_idx_tok", "max_idx_tok_value", "rec_q_max_idx_tok", "rec_q_max_idx_tok_value",
               "explained_variance", "kl_divergence", "top_k_features"]
token_df[filter_cols].tail(15).style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm")

Index(['str_tokens', 'unique_token', 'context', 'batch', 'pos', 'label',
       'loss', 'max_idx_pos', 'max_idx_tok', 'max_idx_tok_value',
       'ablated_loss', 'loss_diff', 'q_norm', 'rec_q_norm', 'mse_loss',
       'explained_variance', 'num_active_features', 'top_k_feature_acts',
       'top_k_features', 'rec_q_max_idx_pos', 'rec_q_max_idx_tok',
       'rec_q_max_idx_tok_value', 'kl_divergence'],
      dtype='object')


Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,loss,loss_diff,mse_loss,num_active_features,max_idx_tok,max_idx_tok_value,rec_q_max_idx_tok,rec_q_max_idx_tok_value,explained_variance,kl_divergence,top_k_features
140,it,it/140,. They see something on| it|.,0,140,0/140,1.439762,-0.040973,2.60411,55,.,0.037805,.,0.032099,0.36293,0.073163,"[903, 1493, 3516, 2646, 2895, 1938, 452, 2851, 3493, 1502]"
141,.,./141,They see something on it|.| It,0,141,0/141,0.009396,0.121947,1.495598,27,.,0.057008,.,0.049687,0.167673,0.061235,"[581, 3553, 2296, 756, 3063, 2983, 2895, 3516, 1319, 1706]"
142,It,It/142,see something on it.| It| is,0,142,0/142,0.137488,0.055517,1.452695,33,is,0.251206,is,0.214643,0.157298,0.07783,"[453, 1493, 1536, 3462, 2880, 3585, 3274, 3089, 53, 3816]"
143,is,is/143,something on it. It| is| a,0,143,0/143,0.029472,0.046703,1.269697,30,.,0.02894,.,0.03019,0.154872,0.061699,"[2208, 3697, 2735, 1493, 2616, 2880, 2895, 373, 1401, 53]"
144,a,a/144,on it. It is| a| picture,0,144,0/144,0.394879,-0.144392,1.719596,26,.,0.035371,.,0.031113,0.213885,0.076202,"[1714, 2208, 3697, 3537, 2659, 1513, 3471, 2735, 933, 2895]"
145,picture,picture/145,it. It is a| picture| of,0,145,0/145,4.010713,0.281628,3.076799,39,that,0.081097,.,0.025287,0.412224,0.170973,"[3702, 903, 2295, 2145, 602, 1982, 4078, 3274, 3335, 2209]"
146,of,of/146,. It is a picture| of| a,0,146,0/146,0.601442,-0.410563,2.801122,39,them,0.091011,them,0.055142,0.278036,0.062854,"[1401, 3931, 153, 2532, 343, 2346, 3391, 3696, 511, 1493]"
147,a,a/147,It is a picture of| a| star,0,147,0/147,0.341861,-0.118769,2.181269,34,shiny,0.036138,Mom,0.02243,0.233663,0.073964,"[1714, 1536, 2880, 3537, 3471, 933, 3702, 3438, 259, 2895]"
148,star,star/148,is a picture of a| star|.,0,148,0/148,3.641025,0.831613,3.342089,26,star,0.035897,.,0.027027,0.330586,0.083565,"[903, 3702, 2851, 1493, 259, 1526, 1885, 1191, 651, 69]"
149,.,./149,a picture of a star|.|,0,149,0/149,0.02625,0.042761,1.314333,30,.,0.042101,.,0.049909,0.128427,0.043514,"[581, 2296, 3553, 2983, 3771, 3905, 1493, 2735, 3063, 2852]"


In [15]:
patterns = original_cache[f"blocks.{LAYER_IDX}.attn.hook_pattern"][0,HEAD_IDX].detach().cpu()
token_df_attn = make_token_df(model, example_prompt_tokens)
token_df_attn["original_attn"] = patterns[pos,:pos+1]
patterns = cache_reconstructed_query[f"blocks.{LAYER_IDX}.attn.hook_pattern"][0,HEAD_IDX].detach().cpu()
token_df_attn["reconstructed_attn"] = patterns[pos,:pos+1]
px.line(token_df_attn, y=["original_attn","reconstructed_attn"], hover_name="str_tokens", hover_data=["pos", "batch", "label"], title="Original vs Reconstructed attention").show()
token_df_attn.sort_values("original_attn", ascending=False).head(10).style.background_gradient(cmap="Blues")

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,original_attn,reconstructed_attn
90,says,says/90,"""Stop!""| says| Mom",0,90,0/90,0.441947,0.188723
91,Mom,Mom/91,"""Stop!"" says| Mom|.",0,91,0/91,0.04887,0.080494
99,Pirates,Pirates/99,You have to share.| Pirates| don,0,99,0/99,0.046306,0.054803
94,You,You/94,"!"" says Mom. ""|You| have",0,94,0/94,0.031684,0.033049
92,.,./92,"""Stop!"" says Mom|.| """,0,92,0/92,0.027994,0.041292
88,Stop,Stop/88,"fight. ""|Stop|!""",0,88,0/88,0.023832,0.034689
153,Look,Look/153,"star. ""|Look|",0,153,0/153,0.022916,0.028481
56,She,She/56,. Sara is happy.| She| puts,0,56,0/56,0.021877,0.016576
6,to,to/6,Sara and Ben like| to| play,0,6,0/6,0.0128,0.013222
93,"""","""/93","Stop!"" says Mom.| ""|You",0,93,0/93,0.011914,0.019943


## Get top 10 Tokens in the Unembed

In [16]:

example_prompt_tokens = tokens[:,:pos+1]
example_prompt = model.to_string(tokens[:,:pos+1])[0]
example_prompt_answer = model.to_string(tokens[:,pos+1])

logits_original, cache_original = model.run_with_cache(example_prompt_tokens, prepend_bos=False)
# test_prompt_with_sae(example_prompt, example_prompt_answer, model, sparse_autoencoder, prepend_bos=False, prepend_space_to_answer=False)

In [17]:
# def get_logit_scores_df(logit_scores, top_n_tokens=20):
#     vocab_scores, vocab_inds = torch.topk(logit_scores.detach().cpu(), top_n_tokens)
#     tmp = pd.DataFrame({"token": model.to_str_tokens(vocab_inds.int()), "score": vocab_scores}).sort_values("score", ascending=True)
    
#     return tmp
#     # fig = px.bar(tmp, y = "token", x = "score", orientation="h", title="Top 10 vocab scores for last token")
#     # # fonts much larger
#     # fig.update_layout(title_font_size=24, font_size=16, width=600)
#     # fig.show()

# df_original = get_logit_scores_df(logits_original[0,-1,:])
# df_ablation = get_logit_scores_df(logits_ablation[0,-1,:])
# df_reconstructed = get_logit_scores_df(logits_reconstructed[0,-1,:])

# # merge them 
# df_original["type"] = "original"
# df_ablation["type"] = "ablation"
# df_reconstructed["type"] = "reconstructed"
# df = pd.concat([df_original, df_ablation])
# df


# wide_df =  df.pivot(index="token", columns="type", values="score").sort_values("original", ascending=False)
# wide_df["diff"] = wide_df["original"] - wide_df["ablation"]
# wide_df = wide_df.sort_values("diff", ascending=False).head(10)
# wide_df.style.background_gradient(cmap="RdBu", subset=["diff"]) # the head is dropping the logits of a bunch of tokens.

In [18]:
# # lets check if this is direct.
# decomp, labels = cache.get_full_resid_decomposition(return_labels=True, mlp_input=True, layer=LAYER_IDX, expand_neurons=False)
# # get the decomp which matches label "L1H11"
# decomp = decomp[labels.index("L1H11")][0,-1]
# effect_on_logits = decomp @ model.W_U
# effect_on_logits = get_logit_scores_df(effect_on_logits,50257)
# effect_on_logits["rank"]= effect_on_logits["score"].rank(ascending=False)
# # display(effect_on_logits.sort_values("score", ascending=True).head(20).style.background_gradient(cmap="Blues"))
# wide_df = wide_df.join(effect_on_logits.set_index("token"))#.style.background_gradient(cmap="RdBu", subset=["score", "diff", "rank"])
# px.scatter(wide_df, x="score", y="diff", hover_name=wide_df.index, hover_data=["rank"], title="Effect on logits vs original score")

So the effect of the head is not direct! This means that my prior on "the real thing a head is doing is not evident through the unembed projection alone" has been updated up. We should look at another example though, this may be an exception.

## Understanding Features:

So far we've found examples of where the head is important and looked at the features firing, attention pattern and direct/indirect effects of ablating the head on token probabilities. 

However, we'd like to decompose the direct and indirect effects into features individually, rather than as a group.



In [19]:
batch = 1; pos = 154
prompt = data[batch]
tokens = model.to_tokens(prompt)
example_prompt = model.to_string(tokens[:,:pos+1])[0]
example_prompt_answer = model.to_string(tokens[:,pos+1])
example_prompt_tokens = tokens[:,:pos+1]
_, cache_original = model.run_with_cache(tokens[:,:pos+1], prepend_bos=False)

model.reset_hooks()
token_df, original_cache, cache_reconstructed_query = eval_prompt([example_prompt], model, sparse_autoencoder)
print(token_df.columns)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", 
               "mse_loss", "num_active_features", "max_idx_tok", "max_idx_tok_value", "rec_q_max_idx_tok", "rec_q_max_idx_tok_value",
               "explained_variance", "kl_divergence", "top_k_features"]
token_df[filter_cols].tail(15).style.background_gradient(
    subset=["loss_diff", "mse_loss", "num_active_features", "kl_divergence"],
    cmap="coolwarm")



Index(['str_tokens', 'unique_token', 'context', 'batch', 'pos', 'label',
       'loss', 'max_idx_pos', 'max_idx_tok', 'max_idx_tok_value',
       'ablated_loss', 'loss_diff', 'q_norm', 'rec_q_norm', 'mse_loss',
       'explained_variance', 'num_active_features', 'top_k_feature_acts',
       'top_k_features', 'rec_q_max_idx_pos', 'rec_q_max_idx_tok',
       'rec_q_max_idx_tok_value', 'kl_divergence'],
      dtype='object')


Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,loss,loss_diff,mse_loss,num_active_features,max_idx_tok,max_idx_tok_value,rec_q_max_idx_tok,rec_q_max_idx_tok_value,explained_variance,kl_divergence,top_k_features
140,it,it/140,. They see something on| it|.,0,140,0/140,1.439762,-0.040973,2.60411,55,.,0.037805,.,0.032099,0.36293,0.073163,"[903, 1493, 3516, 2646, 2895, 1938, 452, 2851, 3493, 1502]"
141,.,./141,They see something on it|.| It,0,141,0/141,0.009396,0.121947,1.495598,27,.,0.057008,.,0.049687,0.167673,0.061235,"[581, 3553, 2296, 756, 3063, 2983, 2895, 3516, 1319, 1706]"
142,It,It/142,see something on it.| It| is,0,142,0/142,0.137488,0.055517,1.452695,33,is,0.251206,is,0.214643,0.157298,0.07783,"[453, 1493, 1536, 3462, 2880, 3585, 3274, 3089, 53, 3816]"
143,is,is/143,something on it. It| is| a,0,143,0/143,0.029472,0.046703,1.269697,30,.,0.02894,.,0.03019,0.154872,0.061699,"[2208, 3697, 2735, 1493, 2616, 2880, 2895, 373, 1401, 53]"
144,a,a/144,on it. It is| a| picture,0,144,0/144,0.394879,-0.144392,1.719596,26,.,0.035371,.,0.031113,0.213885,0.076202,"[1714, 2208, 3697, 3537, 2659, 1513, 3471, 2735, 933, 2895]"
145,picture,picture/145,it. It is a| picture| of,0,145,0/145,4.010713,0.281628,3.076799,39,that,0.081097,.,0.025287,0.412224,0.170973,"[3702, 903, 2295, 2145, 602, 1982, 4078, 3274, 3335, 2209]"
146,of,of/146,. It is a picture| of| a,0,146,0/146,0.601442,-0.410563,2.801122,39,them,0.091011,them,0.055142,0.278036,0.062854,"[1401, 3931, 153, 2532, 343, 2346, 3391, 3696, 511, 1493]"
147,a,a/147,It is a picture of| a| star,0,147,0/147,0.341861,-0.118769,2.181269,34,shiny,0.036138,Mom,0.02243,0.233663,0.073964,"[1714, 1536, 2880, 3537, 3471, 933, 3702, 3438, 259, 2895]"
148,star,star/148,is a picture of a| star|.,0,148,0/148,3.641025,0.831613,3.342089,26,star,0.035897,.,0.027027,0.330586,0.083565,"[903, 3702, 2851, 1493, 259, 1526, 1885, 1191, 651, 69]"
149,.,./149,a picture of a star|.|,0,149,0/149,0.02625,0.042761,1.314333,30,.,0.042101,.,0.049909,0.128427,0.043514,"[581, 2296, 3553, 2983, 3771, 3905, 1493, 2735, 3063, 2852]"


In [158]:
original_acts = original_cache[sparse_autoencoder.cfg.hook_point]
print(original_acts.shape)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
    original_cache[sparse_autoencoder.cfg.hook_point][:,-1,HEAD_IDX]
)
print(feature_acts.shape)
plot_line_with_top_10_labels(feature_acts[0])


torch.Size([1, 156, 16, 64])
torch.Size([1, 4096])


In [20]:
example_prompt

'<|endoftext|>Sara and Ben like to play pirates. They have a big box that is their ship. They have hats and swords and a map. The map shows them where to find treasure.\n\nOne day, they find a shiny coin in the sand. Sara is happy. She puts the coin in her pocket. Ben is jealous. He wants the coin too. He tries to take it from Sara. They fight.\n\n"Stop!" says Mom. "You have to share. Pirates don\'t fight with their friends. They work together."\n\nMom gives them a cloth. She tells them to wipe the coin. It is dirty. They wipe the coin. They see something on it. It is a picture of a star.\n\n"Look!"'

In [26]:
example_prompt

'<|endoftext|>Sara and Ben like to play pirates. They have a big box that is their ship. They have hats and swords and a map. The map shows them where to find treasure.\n\nOne day, they find a shiny coin in the sand. Sara is happy. She puts the coin in her pocket. Ben is jealous. He wants the coin too. He tries to take it from Sara. They fight.\n\n"Stop!" says Mom. "You have to share. Pirates don\'t fight with their friends. They work together."\n\nMom gives them a cloth. She tells them to wipe the coin. It is dirty. They wipe the coin. They see something on it. It is a picture of a star.\n\n"Look!"'

In [35]:
def test_get_attn_sae_feature_removal(example_prompt, 
                                      example_prompt_answer, 
                                      model, 
                                      sparse_autoencoder,
                                      features_to_remove=[0]):
    
    token_df, original_cache, cache_reconstructed_query = eval_prompt([example_prompt], model, sparse_autoencoder)

    original_acts = original_cache[sparse_autoencoder.cfg.hook_point]
    sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
       original_acts[0,:,HEAD_IDX]
    )
    
    print(feature_acts.shape)

    def remove_feature_hook(hook_in, hook, head = HEAD_IDX, features_to_remove = features_to_remove):
        print(hook_in.shape)
        for feature_to_remove in features_to_remove:
            feature_dir = feature_acts[-1,feature_to_remove]*sparse_autoencoder.W_dec[feature_to_remove]
            hook_in[:, :, head] -= feature_dir
        return hook_in
    
    with model.hooks(fwd_hooks=[(sparse_autoencoder.cfg.hook_point, remove_feature_hook)]):
        _, cache_removed_feature = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)

    patterns = original_cache[f"blocks.{LAYER_IDX}.attn.hook_pattern"][0,HEAD_IDX].detach().cpu()
    attn_df = make_token_df(model, example_prompt_tokens)
    attn_df["original_attn"] = patterns[pos,:pos+1]
    patterns = cache_reconstructed_query[f"blocks.{LAYER_IDX}.attn.hook_pattern"][0,HEAD_IDX].detach().cpu()
    attn_df["reconstructed_attn"] = patterns[pos,:pos+1]
    patterns = cache_removed_feature[f"blocks.{LAYER_IDX}.attn.hook_pattern"][0,HEAD_IDX].detach().cpu()
    attn_df["ablated_feature_attn"] = patterns[pos,:pos+1]
    fig = px.line(attn_df, 
                  y=["original_attn","reconstructed_attn", "ablated_feature_attn"], 
                  hover_name="str_tokens", 
                  hover_data=["pos", "batch", "label"], 
                  title="Original vs Reconstructed attention")
    
    fig.show()

def test_prompt_with_sae_feature_removal(example_prompt, example_prompt_answer, model, sparse_autoencoder, features_to_remove=[0]):
    
    model.reset_hooks()
    utils.test_prompt(example_prompt, example_prompt_answer, model, prepend_space_to_answer=False, prepend_bos=False)

    with model.hooks(fwd_hooks=[(HEAD_HOOK_RESULT_NAME, hook_to_ablate_head)]):
        utils.test_prompt(example_prompt, example_prompt_answer, model, prepend_bos=False, prepend_space_to_answer=False)

    token_df, original_cache, cache_reconstructed_query = eval_prompt([example_prompt], model, sparse_autoencoder)
    sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
        original_cache[sparse_autoencoder.cfg.hook_point][0,HEAD_IDX]
    )
    def reconstr_query_hook(hook_in, hook, reconstructed_query=sae_out, head = HEAD_IDX):
        hook_in[:, head, :] = reconstructed_query
        return hook_in

    def remove_feature_hook(hook_in, hook, head = HEAD_IDX, features_to_remove = features_to_remove):
        print(hook_in.shape)
        for feature_to_remove in features_to_remove:
            feature_dir = feature_acts[-1,feature_to_remove]*sparse_autoencoder.W_dec[feature_to_remove]
            hook_in[:, :, head] -= feature_dir
        return hook_in
    
    print("FEATURE REMOVED")
    with model.hooks(fwd_hooks=[(sparse_autoencoder.cfg.hook_point, remove_feature_hook)]):
        utils.test_prompt(example_prompt, example_prompt_answer, model,  prepend_bos=False, prepend_space_to_answer=False)


test_get_attn_sae_feature_removal(
    example_prompt, 
    example_prompt_answer, model, 
    sparse_autoencoder, features_to_remove=[1669, 3489])

test_prompt_with_sae_feature_removal(
    example_prompt, 
    example_prompt_answer, model, 
    sparse_autoencoder, features_to_remove=[1669, 3489])


torch.Size([156, 4096])
torch.Size([1, 346, 16, 64])


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells', ' them', ' to

Top 0th token. Logit: 23.19 Prob: 58.48% Token: | says|
Top 1th token. Logit: 22.63 Prob: 33.47% Token: | Mom|
Top 2th token. Logit: 20.38 Prob:  3.50% Token: | said|
Top 3th token. Logit: 19.71 Prob:  1.79% Token: | Sara|
Top 4th token. Logit: 19.10 Prob:  0.98% Token: | she|
Top 5th token. Logit: 19.09 Prob:  0.97% Token: | they|
Top 6th token. Logit: 17.23 Prob:  0.15% Token: | Ben|
Top 7th token. Logit: 16.82 Prob:  0.10% Token: | They|
Top 8th token. Logit: 16.78 Prob:  0.10% Token: | the|
Top 9th token. Logit: 16.71 Prob:  0.09% Token: | say|


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells', ' them', ' to

Top 0th token. Logit: 22.80 Prob: 61.99% Token: | Mom|
Top 1th token. Logit: 21.55 Prob: 17.72% Token: | Sara|
Top 2th token. Logit: 21.06 Prob: 10.92% Token: | said|
Top 3th token. Logit: 20.10 Prob:  4.19% Token: | says|
Top 4th token. Logit: 19.28 Prob:  1.84% Token: | she|
Top 5th token. Logit: 18.76 Prob:  1.10% Token: | Ben|
Top 6th token. Logit: 18.47 Prob:  0.82% Token: | they|
Top 7th token. Logit: 16.92 Prob:  0.17% Token: | the|
Top 8th token. Logit: 16.82 Prob:  0.16% Token: | Tim|
Top 9th token. Logit: 16.50 Prob:  0.11% Token: | a|


FEATURE REMOVED
Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' like', ' to', ' play', ' pirates', '.', ' They', ' have', ' a', ' big', ' box', ' that', ' is', ' their', ' ship', '.', ' They', ' have', ' hats', ' and', ' swords', ' and', ' a', ' map', '.', ' The', ' map', ' shows', ' them', ' where', ' to', ' find', ' treasure', '.', '\n', '\n', 'One', ' day', ',', ' they', ' find', ' a', ' shiny', ' coin', ' in', ' the', ' sand', '.', ' Sara', ' is', ' happy', '.', ' She', ' puts', ' the', ' coin', ' in', ' her', ' pocket', '.', ' Ben', ' is', ' jealous', '.', ' He', ' wants', ' the', ' coin', ' too', '.', ' He', ' tries', ' to', ' take', ' it', ' from', ' Sara', '.', ' They', ' fight', '.', '\n', '\n', '"', 'Stop', '!"', ' says', ' Mom', '.', ' "', 'You', ' have', ' to', ' share', '.', ' Pirates', ' don', "'t", ' fight', ' with', ' their', ' friends', '.', ' They', ' work', ' together', '."', '\n', '\n', 'Mom', ' gives', ' them', ' a', ' cloth', '.', ' She', ' tells

Top 0th token. Logit: 23.19 Prob: 58.48% Token: | says|
Top 1th token. Logit: 22.63 Prob: 33.47% Token: | Mom|
Top 2th token. Logit: 20.38 Prob:  3.50% Token: | said|
Top 3th token. Logit: 19.71 Prob:  1.79% Token: | Sara|
Top 4th token. Logit: 19.10 Prob:  0.98% Token: | she|
Top 5th token. Logit: 19.09 Prob:  0.97% Token: | they|
Top 6th token. Logit: 17.23 Prob:  0.15% Token: | Ben|
Top 7th token. Logit: 16.82 Prob:  0.10% Token: | They|
Top 8th token. Logit: 16.78 Prob:  0.10% Token: | the|
Top 9th token. Logit: 16.71 Prob:  0.09% Token: | say|


Seems like we have two features that for on ." ," !" (ends of speech quotes) and they are causally responsible for attention to "says". We can look at examples of where the feature fires and see if it always attends to the same token.

One idea is to use virtual weights. Another would be to look for examples of the feature firing

## Dynamic Analysis of Features:

In [86]:
tmp = df[df.top_k_features.apply(lambda x: (3489 in x) and (1669 in x))].head(30)

: 

In [51]:
print(df[df.top_k_features.apply(lambda x: 1669 in x)].shape[0])
print(df[df.top_k_features.apply(lambda x: 3489 in x)].shape[0])
print(df[df.top_k_features.apply(lambda x: (1669 in x) and (3489 not in x))].shape[0])
print(df[df.top_k_features.apply(lambda x: (1669 not in x) and (3489 in x))].shape[0])

146
157
37
48


In [44]:
tmp = df[df.top_k_features.apply(lambda x: 1669 in x)].head(30)
tmp[[]]

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,loss,max_idx_pos,max_idx_tok,max_idx_tok_value,...,rec_q_norm,mse_loss,explained_variance,num_active_features,top_k_feature_acts,top_k_features,rec_q_max_idx_pos,rec_q_max_idx_tok,rec_q_max_idx_tok_value,kl_divergence
88,"!""","!""/88",".\n\n""Stop|!""| says",1,88,1/88,0.892003,32,them,0.060124,...,3.512682,1.748639,0.101902,14,"[1.6116886138916016, 1.2327041625976562, 0.207...","[3489, 1669, 994, 3258, 393, 3438, 824, 2409, ...",9,.,0.056756,0.037437
153,"!""","!""/153",".\n\n""Look|!""| says",1,153,1/153,0.963728,90,says,0.777926,...,4.169718,1.889879,0.094401,33,"[1.8496007919311523, 1.5393149852752686, 0.298...","[3489, 1669, 1401, 3629, 3258, 343, 1352, 2393...",90,says,0.740146,0.016968
156,.,./156,"""Look!"" says Sara|.| """,1,156,1/156,0.173373,128,is,0.051997,...,3.147302,1.919359,0.165219,28,"[1.2310981750488281, 0.9541622400283813, 0.431...","[3129, 1352, 2, 722, 5, 2399, 756, 1692, 1706,...",155,says,0.039095,0.072422
170,".""",".""/170","is part of the treasure|.""|\n",1,170,1/170,0.757954,111,\n,0.216574,...,4.294739,1.783408,0.105181,25,"[1.0255876779556274, 0.930554211139679, 0.9057...","[1203, 3129, 3489, 581, 2852, 1669, 1352, 2409...",90,says,0.222505,0.124757
178,",""",",""/178","""Maybe you are right|,""| says",1,178,1/178,0.096985,155,says,0.531703,...,3.536173,2.220389,0.139323,25,"[1.5016510486602783, 1.0654377937316895, 0.288...","[1669, 3489, 3258, 913, 5, 1753, 1481, 3337, 7...",155,says,0.418748,0.053905
196,".""",".""/196","the star is on it|.""|\n",1,196,1/196,0.14216,172,\n,0.286551,...,4.162789,1.961582,0.120009,30,"[1.0957627296447754, 1.0553643703460693, 0.560...","[1203, 3129, 581, 1352, 3489, 2852, 1669, 2409...",172,\n,0.12231,0.233621
246,"!""","!""/246",".\n\n""Wow|!""| says",1,246,1/246,0.016797,155,says,0.380723,...,4.741034,2.317082,0.099115,53,"[1.93648362159729, 1.6393020153045654, 0.25626...","[1669, 3489, 343, 1401, 3629, 2431, 581, 1543,...",155,says,0.346406,0.021846
262,".""",".""/262","are the best pirates ever|.""|\n",1,262,1/262,0.295246,172,\n,0.15725,...,4.176085,3.145787,0.170236,33,"[0.9897174835205078, 0.9278141856193542, 0.849...","[1203, 3489, 3129, 2852, 581, 1669, 1352, 3968...",180,says,0.126596,0.163481
268,we,we/268,"\n\n""Yes,| we| are",1,268,1/268,0.020565,247,"!""",0.122394,...,3.461184,2.217702,0.162291,54,"[0.8943300247192383, 0.7542832493782043, 0.558...","[1753, 25, 2171, 1481, 742, 3258, 262, 3612, 1...",179,",""",0.09578,0.071633
270,",""",",""/270","""Yes, we are|,""| says",1,270,1/270,0.622071,248,says,0.257907,...,3.602607,2.071655,0.122039,34,"[1.5748071670532227, 0.9071919322013855, 0.307...","[1669, 3489, 913, 5, 3258, 1753, 711, 2818, 21...",248,says,0.250147,0.052607


In [57]:
df[df.top_k_features.apply(lambda x: (1669 in x) or (3489 in x))].head(30)

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,loss,max_idx_pos,max_idx_tok,max_idx_tok_value,...,mse_loss,explained_variance,num_active_features,top_k_feature_acts,top_k_features,rec_q_max_idx_pos,rec_q_max_idx_tok,rec_q_max_idx_tok_value,kl_divergence,tmp_key
88,"!""","!""/88",".\n\n""Stop|!""| says",1,88,1/88,0.892003,32,them,0.060124,...,1.748639,0.101902,14,"[1.6116886138916016, 1.2327041625976562, 0.207...","[3489, 1669, 994, 3258, 393, 3438, 824, 2409, ...",9,.,0.056756,0.037437,88
109,".""",".""/109","friends. They work together|.""|\n",1,109,1/109,1.986598,90,says,0.2491,...,2.094218,0.112022,23,"[1.4436824321746826, 1.2245793342590332, 0.387...","[1203, 3489, 2409, 3129, 2852, 581, 3553, 3441...",90,says,0.316181,0.067239,109
153,"!""","!""/153",".\n\n""Look|!""| says",1,153,1/153,0.963728,90,says,0.777926,...,1.889879,0.094401,33,"[1.8496007919311523, 1.5393149852752686, 0.298...","[3489, 1669, 1401, 3629, 3258, 343, 1352, 2393...",90,says,0.740146,0.016968,153
156,.,./156,"""Look!"" says Sara|.| """,1,156,1/156,0.173373,128,is,0.051997,...,1.919359,0.165219,28,"[1.2310981750488281, 0.9541622400283813, 0.431...","[3129, 1352, 2, 722, 5, 2399, 756, 1692, 1706,...",155,says,0.039095,0.072422,156
170,".""",".""/170","is part of the treasure|.""|\n",1,170,1/170,0.757954,111,\n,0.216574,...,1.783408,0.105181,25,"[1.0255876779556274, 0.930554211139679, 0.9057...","[1203, 3129, 3489, 581, 2852, 1669, 1352, 2409...",90,says,0.222505,0.124757,170
178,",""",",""/178","""Maybe you are right|,""| says",1,178,1/178,0.096985,155,says,0.531703,...,2.220389,0.139323,25,"[1.5016510486602783, 1.0654377937316895, 0.288...","[1669, 3489, 3258, 913, 5, 1753, 1481, 3337, 7...",155,says,0.418748,0.053905,178
196,".""",".""/196","the star is on it|.""|\n",1,196,1/196,0.14216,172,\n,0.286551,...,1.961582,0.120009,30,"[1.0957627296447754, 1.0553643703460693, 0.560...","[1203, 3129, 581, 1352, 3489, 2852, 1669, 2409...",172,\n,0.12231,0.233621,196
246,"!""","!""/246",".\n\n""Wow|!""| says",1,246,1/246,0.016797,155,says,0.380723,...,2.317082,0.099115,53,"[1.93648362159729, 1.6393020153045654, 0.25626...","[1669, 3489, 343, 1401, 3629, 2431, 581, 1543,...",155,says,0.346406,0.021846,246
262,".""",".""/262","are the best pirates ever|.""|\n",1,262,1/262,0.295246,172,\n,0.15725,...,3.145787,0.170236,33,"[0.9897174835205078, 0.9278141856193542, 0.849...","[1203, 3489, 3129, 2852, 581, 1669, 1352, 3968...",180,says,0.126596,0.163481,262
268,we,we/268,"\n\n""Yes,| we| are",1,268,1/268,0.020565,247,"!""",0.122394,...,2.217702,0.162291,54,"[0.8943300247192383, 0.7542832493782043, 0.558...","[1753, 25, 2171, 1481, 742, 3258, 262, 3612, 1...",179,",""",0.09578,0.071633,268


In [53]:
df['tmp_key'] = df.index  
df_exploded_acts = df.explode('top_k_feature_acts')
df_exploded_features = df.explode('top_k_features')
feature_wise_df = pd.merge(df_exploded_acts, df_exploded_features, left_on='tmp_key', right_on='tmp_key', suffixes=('_acts', '_features'))
feature_wise_df = feature_wise_df.drop(columns=['tmp_key', 'top_k_feature_acts_features', 'top_k_features_acts'])


In [56]:
feature_wise_df[feature_wise_df.top_k_features_features.isin([1669,3489])].sort_values("top_k_feature_acts_acts", ascending=False).head(10).style.background_gradient(cmap="Blues")

Unnamed: 0,str_tokens_acts,unique_token_acts,context_acts,batch_acts,pos_acts,label_acts,loss_acts,max_idx_pos_acts,max_idx_tok_acts,max_idx_tok_value_acts,ablated_loss_acts,loss_diff_acts,q_norm_acts,rec_q_norm_acts,mse_loss_acts,explained_variance_acts,num_active_features_acts,top_k_feature_acts_acts,rec_q_max_idx_pos_acts,rec_q_max_idx_tok_acts,rec_q_max_idx_tok_value_acts,kl_divergence_acts,str_tokens_features,unique_token_features,context_features,batch_features,pos_features,label_features,loss_features,max_idx_pos_features,max_idx_tok_features,max_idx_tok_value_features,ablated_loss_features,loss_diff_features,q_norm_features,rec_q_norm_features,mse_loss_features,explained_variance_features,num_active_features_features,top_k_features_features,rec_q_max_idx_pos_features,rec_q_max_idx_tok_features,rec_q_max_idx_tok_value_features,kl_divergence_features
168106340,and,and/209,. Lily| and| Sam,10,209,10/209,0.704781,171,Mom,0.047079,0.675915,-0.028866,3.965381,3.65225,1.217291,0.077415,36,2.420354,171,Mom,0.040244,0.033429,"!""","!""/209","Wow, look at this|!""| Jill",16,209,16/209,0.278681,139,.,0.078023,0.405624,0.126944,4.232041,3.538454,2.099673,0.117234,19,1669,139,.,0.089022,0.059601
168106341,and,and/209,. Lily| and| Sam,10,209,10/209,0.704781,171,Mom,0.047079,0.675915,-0.028866,3.965381,3.65225,1.217291,0.077415,36,2.420354,171,Mom,0.040244,0.033429,"!""","!""/209","Wow, look at this|!""| Jill",16,209,16/209,0.278681,139,.,0.078023,0.405624,0.126944,4.232041,3.538454,2.099673,0.117234,19,3489,139,.,0.089022,0.059601
169879201,,/246,they could join them.| |,10,246,10/246,0.062198,115,,0.139266,0.117492,0.055294,4.269322,4.189564,1.492369,0.081876,64,2.395362,153,,0.149123,0.032469,"!""","!""/246",". ""Wow|!""| says",1,246,1/246,0.016797,155,says,0.380723,0.467898,0.451101,4.835052,4.741034,2.317082,0.099115,53,3489,155,says,0.346406,0.021846
169879200,,/246,they could join them.| |,10,246,10/246,0.062198,115,,0.139266,0.117492,0.055294,4.269322,4.189564,1.492369,0.081876,64,2.395362,153,,0.149123,0.032469,"!""","!""/246",". ""Wow|!""| says",1,246,1/246,0.016797,155,says,0.380723,0.467898,0.451101,4.835052,4.741034,2.317082,0.099115,53,1669,155,says,0.346406,0.021846
130607151,didn,didn/131,. Lily| didn|'t,71,131,71/131,4.952805,2,upon,0.067278,4.773005,-0.1798,4.245383,4.404472,0.995102,0.055212,48,2.320726,2,upon,0.099487,0.031315,"!""","!""/131","is so big and beautiful|!""| Tom",57,131,57/131,0.038594,60,.,0.22417,0.27845,0.239856,4.431676,3.629834,2.064569,0.105122,20,1669,60,.,0.144566,0.089377
130606785,didn,didn/131,. Lily| didn|'t,71,131,71/131,4.952805,2,upon,0.067278,4.773005,-0.1798,4.245383,4.404472,0.995102,0.055212,48,2.320726,2,upon,0.099487,0.031315,visit,visit/131,wise old owl came to| visit| the,18,131,18/131,2.416821,61,all,0.053938,2.098991,-0.317831,3.967245,3.725207,1.217904,0.077381,58,3489,105,friends,0.074618,0.063477
130606671,didn,didn/131,. Lily| didn|'t,71,131,71/131,4.952805,2,upon,0.067278,4.773005,-0.1798,4.245383,4.404472,0.995102,0.055212,48,2.320726,2,upon,0.099487,0.031315,"!""","!""/131","""Hello, neighbor|!""| Sam",7,131,7/131,0.536664,54,.,0.092762,0.783063,0.246399,3.904952,3.403346,1.851653,0.121431,13,1669,54,.,0.086776,0.06416
130606670,didn,didn/131,. Lily| didn|'t,71,131,71/131,4.952805,2,upon,0.067278,4.773005,-0.1798,4.245383,4.404472,0.995102,0.055212,48,2.320726,2,upon,0.099487,0.031315,"!""","!""/131","""Hello, neighbor|!""| Sam",7,131,7/131,0.536664,54,.,0.092762,0.783063,0.246399,3.904952,3.403346,1.851653,0.121431,13,3489,54,.,0.086776,0.06416
130607150,didn,didn/131,. Lily| didn|'t,71,131,71/131,4.952805,2,upon,0.067278,4.773005,-0.1798,4.245383,4.404472,0.995102,0.055212,48,2.320726,2,upon,0.099487,0.031315,"!""","!""/131","is so big and beautiful|!""| Tom",57,131,57/131,0.038594,60,.,0.22417,0.27845,0.239856,4.431676,3.629834,2.064569,0.105122,20,3489,60,.,0.144566,0.089377
167177831,,/201,"asked, hugging her.| |",31,201,31/201,0.000996,149,,0.215315,0.003798,0.002802,4.438889,4.075503,1.079405,0.054782,30,2.309544,149,,0.166288,0.059136,"?""","?""/201","how will we get them|?""| Mom",10,201,10/201,0.232122,168,.,0.205621,0.59643,0.364307,4.223659,3.461487,2.341383,0.131249,26,1669,168,.,0.137707,0.112942


## Static Analysis of features

# Feature Dashboards

In [186]:
from importlib import reload
from sae_analysis.visualizer import data_fns
from sae_analysis.visualizer.data_fns import FeatureData
from typing import Dict
reload(data_fns)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
dataset="roneneldan/TinyStories"
train_dataset = load_dataset(dataset, split="train[:3%]")
tokenized_data = utils.tokenize_and_concatenate(train_dataset, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 32
total_batch_size = 512 * 50
feature_idx = token_df.tail(1).top_k_features.values[0]
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = data_fns.get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer - 1,
    hook_point_head_index=sparse_autoencoder.cfg.hook_point_head_index,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)




Storing model activations:   0%|          | 0/800 [00:00<?, ?it/s]
Storing model activations: 100%|██████████| 800/800 [00:31<00:00, 25.09it/s]
                                                                      

Estimated time for all 4096 features = 65 minutes



In [191]:
del feature_data
torch.mps.empty_cache()

In [189]:
for test_idx in feature_idx:
    html_str = feature_data[test_idx].get_all_html()
    path = f"tiny_stories_features/{test_idx:04}.html"
    print(f"Saving to {path}")
    with open(path, "w") as f:
        f.write(html_str)

Saving to tiny_stories_features/3489.html
Saving to tiny_stories_features/1669.html
Saving to tiny_stories_features/1401.html
Saving to tiny_stories_features/3629.html
Saving to tiny_stories_features/1352.html
Saving to tiny_stories_features/2393.html
Saving to tiny_stories_features/0995.html
Saving to tiny_stories_features/3258.html
Saving to tiny_stories_features/2455.html
Saving to tiny_stories_features/1607.html
