In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features_layer_5"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    

# Load Model

In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    # "pythia-2.8b",
    # "pythia-70m-deduped",
    # "tiny-stories-2L-33M",
    # "attn-only-2l",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
    fold_ln=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)


# Load SAE

In [None]:
from sae_training.utils import LMSparseAutoencoderSessionloader


path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/1100001280_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt"
# sparse_autoencoder_layer_10 = SparseAutoencoder.load_from_pretrained(path)
model, sparse_autoencoder_layer_10, activation_store_layer_10 = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/final_sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152.pt"
# sparse_autoencoder_layer_5 = SparseAutoencoder.load_from_pretrained(path)
_, sparse_autoencoder_layer_5, activation_store_layer_5 = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

print(sparse_autoencoder_layer_10.cfg)
print(sparse_autoencoder_layer_5.cfg)


# sanity check
text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
model(text, return_type="loss")

# Explore Sparsity

In [None]:
from tqdm.auto import tqdm
def estimate_feature_sparsity_using_n_tokens_per_prompt(
    sparse_autoencoder, activation_store, n_batches,
    n_tokens_per_prompt=4):
    
    total_activations = torch.zeros(sparse_autoencoder.cfg.d_sae).to(sparse_autoencoder.cfg.device)
    
    pbar = tqdm(range(n_batches))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=False)
        original_act = cache[sparse_autoencoder.cfg.hook_point]
        _, feature_acts, _, _, _ = sparse_autoencoder(
            original_act
        )
        # for each batch item, pick 4 random tokens and keep only those
        # batch_size x n_tokens x d_sae
        random_tok_indices = torch.randint(0, feature_acts.shape[1], (feature_acts.shape[0], n_tokens_per_prompt))
        feature_acts = feature_acts[torch.arange(feature_acts.shape[0]).unsqueeze(-1), random_tok_indices]
        total_activations += feature_acts.flatten(0,1).sum(0)
    
    total_tokens = (n_batches * feature_acts.shape[0] * n_tokens_per_prompt)
    print("Total tokens:", total_tokens)
    
    return total_activations / total_tokens

n_tokens_per_prompt = 128
n_batches = 1000
feature_sparsity_10_unstratified  = estimate_feature_sparsity_using_n_tokens_per_prompt(sparse_autoencoder_layer_10, activation_store_layer_10, n_batches=n_batches, n_tokens_per_prompt=n_tokens_per_prompt).detach().cpu()
log_feature_sparsity_10_unstratified = torch.log10(feature_sparsity_10_unstratified  + 1e-10)
torch.save(log_feature_sparsity_10_unstratified, f"../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/feature_sparsity_{n_batches}_{n_tokens_per_prompt}.pt")
feature_sparsity_5_unstratified = estimate_feature_sparsity_using_n_tokens_per_prompt(sparse_autoencoder_layer_5, activation_store_layer_5, n_batches=100, n_tokens_per_prompt=128).detach().cpu()
log_feature_sparsity_5_unstratified = torch.log10(feature_sparsity_5_unstratified  + 1e-10)
torch.save(log_feature_sparsity_5_unstratified, f"../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/feature_sparsity_{n_batches}_{n_tokens_per_prompt}.pt")


In [None]:
log_feature_sparsity_10_stratified = torch.load(
    "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/log_feature_sparsity_5000_4.pt"
)
# px.histogram(
#     log_feature_sparsity_10_stratified[log_feature_sparsity_10_stratified > -9],
#     nbins=1000,
#     width = 1000,
#     log_x=False,
#     title="Feature sparsity (log10) (5000 batches, 4 tokens per prompt)",
# ).show()

log_feature_sparsity_5_stratified = torch.load(
    "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/log_feature_sparsity_5000_4.pt"
)
# px.histogram(
#     log_feature_sparsity_5_stratified[log_feature_sparsity_5_stratified > -9],
#     nbins=1000,
#     width=1000,
#     log_x=False,
#     title="Feature sparsity (log10) (5000 batches, 4 tokens per prompt)",
# ).show()
# px.histogram(log_feature_sparsity, nbins=1000, log_x=False, title="Feature sparsity (log10) (5000 batches, 4 tokens per prompt)").show()

For layer 10, let's compare. 

In [None]:
px.scatter(
    x = log_feature_sparsity_10_stratified,
    y =  log_feature_sparsity_10_unstratified,
    opacity=0.4,
    marginal_x="histogram",
    marginal_y="histogram",
    title="Feature sparsity (log10) Stratified vs Unstratified (5000 batches, 4 tokens per prompt)",
    color = (log_feature_sparsity_10_stratified - log_feature_sparsity_10_unstratified).numpy().tolist(),
    width = 1500,
    height = 1500,
    color_continuous_midpoint=0,
    hover_data= [ list(range(len(log_feature_sparsity_10_stratified))) ],
).show()

In [None]:
def render_feature_dashboard(feature_id):
    
    path_to_html = "../week_8_jan/gpt2_small_features"
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")


# dense_features = ((log_feature_sparsity_10_stratified<-1) & (log_feature_sparsity_10_unstratified>-2)).nonzero().squeeze()
# dense_features = dense_features[torch.randperm(len(dense_features))[:10]]
# for feature in dense_features:
#     render_feature_dashboard(feature.item())

diff = log_feature_sparsity_10_stratified - log_feature_sparsity_10_unstratified
dense_features = ((diff>2) & (log_feature_sparsity_10_stratified > -4)).nonzero().squeeze()
print(len(dense_features))
dense_features = dense_features[torch.randperm(len(dense_features))[:6]]
for feature in dense_features:
    render_feature_dashboard(feature.item())
    

In [None]:


# 1102 -> fires on "erect" and/or "ile" if it follows "erect"
# 511 -> fires on "lashed" and "out" if "out" follows lashed
# 509 -> fires on in and "the" if "the" follows in
# 1289 -> fires on ongoing and investigation if investigation follows ongoing
# 10329 -> Fires on Easter, and eggs or bunny if it follows Easter but not on Bunny or Eggs alone (presumably)
# 17301 -> Fires on family members and sometimes on phrases (eg: brother-in-law) (a stretch for sure)
# 49144 -> What do you think? (fires on you think, sometimes extends to "of" 


## Feature Dashboard generator util

In [None]:
import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features_layer_5"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    
    return

# for feature in [100,300,400]:
#     render_feature_dashboard(feature)

# Fun Examples

In [None]:
prompt1 = "The war caused not only destruction and death but also generations of hatred between the two communities."
prompt2 = "The car not only is economical but also feels good to drive."
prompt3 = "This investigation is not only one that is continuing and worldwide,"  # but also one that we expect to continue for quite some time."
prompt = prompt3
answer = "but"
model.reset_hooks()
utils.test_prompt(prompt, answer, model)


In [None]:
import joseph
reload(joseph.analysis)
from joseph.analysis import *

# prompt3 = "This investigation is not only one that is continuing and worldwide, but also one that we expect to continue for quite some time." # Not only ... but
# prompt3 = "The market is evolving rapidly. Either we must adjust our strategy to meet the new market demands, or we risk falling behind our competitors significantly." # either or one (dud?)
# prompt3 = "Culinary trends are constantly changing. Either we experiment with new flavors and techniques in our recipes, or we risk losing the interest of our adventurous diners." #maybe a dud as well
# prompt3 = "I thought it was a great book. Both the intricate plot twists and the strong character development make this novel exceptionally engaging." # both .... and
# prompt3 = "The team, despite facing numerous challenges and unexpected setbacks, remains optimistic about the upcoming project." # Noun verb agreement
# prompt3 = "The book on the shelf in the corner needs a new cover." # Noun verb agreement

# title = "which way to the beach"
# prompt = "She asked 'Which way to the beach?', to which I replied,  'It's over there. You can't miss it.'. She thanked me and walked away."
# POS_INTEREST = 9

# title = "lots of questions"
# prompt = "The text read \"In the realm of deep learning, how do we best quantify the interpretability of neural networks? While considering this, it's important to remember the balance between complexity and clarity in model design. What are the most effective methods for visualizing high-dimensional data? This leads to another crucial aspect: the role of data quality. Can we establish a standard for data that optimally trains these models? Amidst these inquiries, the evolution of AI safety protocols remains a pivotal concern. How are current safety measures adapting to the rapidly advancing AI landscape? Each question marks a stepping stone towards a deeper understanding and more effective utilization of AI technologies."
# POS_INTEREST = 10

title = "Tiny Stories Dragon"
prompt = """Once upon a time, there was a little girl named Lily. She was very
excited to go outside and explore. She flew over the trees and saw a big,
scary dragon. The dragon was very scary. But Lily knew that things
were not real and she would hurt her."""
POS_INTEREST = 41

# title = "both_and"
# prompt = "My parents went to both Melbourne, Australia and Auckland, New Zealand on their honeymoon."
# POS_INTEREST = 8


token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt], model, sparse_autoencoder, head_idx_override=5)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
            "top_k_features"]
display(token_df[filter_cols].style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm"))



UNIQUE_TOKEN_INTEREST = token_df["unique_token"][POS_INTEREST]
feature_acts_of_interest = feature_acts[POS_INTEREST]
# plot_line_with_top_10_labels(feature_acts_of_interest, "", 25)
# vals, inds = torch.topk(feature_acts_of_interest,39)

top_k_feature_inds = (feature_acts[1:] > 0).sum(dim=0).nonzero().squeeze()

features_acts_by_token_df = pd.DataFrame(
    feature_acts[:,top_k_feature_inds[:]].detach().cpu().T,
    index = [f"feature_{i}" for i in top_k_feature_inds.flatten().tolist()],
    columns = token_df["unique_token"])

# features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).style.background_gradient(
#     cmap="coolwarm", axis=0)

# px.imshow(features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).T.corr(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

tmp = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T
# dashboard_features = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).index[:10].to_series().apply(lambda x: x.split("_")[1]).tolist()
# for feature in dashboard_features:
#     render_feature_dashboard(feature)

px.line(tmp, 
        title=f"{title}: Features Activation by Token in Prompt", 
        color_discrete_sequence=px.colors.qualitative.Plotly,
        height=1000).show()

tmp = features_acts_by_token_df.head(100).T
px.imshow(tmp, 
            title=f"{title}: Top k features by activation", 
            color_continuous_midpoint=0, 
            color_continuous_scale="RdBu", 
            height=800).show()

In [None]:

def analyse_lcf(prompt, title = "", model=model, sparse_autoencoder=sparse_autoencoder, head_idx_override=None):

    token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt], model, sparse_autoencoder, head_idx_override=7)
    filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
                "top_k_features"]
    display(token_df[filter_cols].style.background_gradient(
        subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
        cmap="coolwarm"))
    
    
    POS_INTEREST = token_df.index.max()
    UNIQUE_TOKEN_INTEREST = token_df["unique_token"][POS_INTEREST]
    feature_acts_of_interest = feature_acts[POS_INTEREST]
    # plot_line_with_top_10_labels(feature_acts_of_interest, "", 25)
    # vals, inds = torch.topk(feature_acts_of_interest,39)

    top_k_feature_inds = (feature_acts[1:] > 0).sum(dim=0).nonzero().squeeze()

    features_acts_by_token_df = pd.DataFrame(
        feature_acts[:,top_k_feature_inds[:]].detach().cpu().T,
        index = [f"feature_{i}" for i in top_k_feature_inds.flatten().tolist()],
        columns = token_df["unique_token"])

    # features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).style.background_gradient(
    #     cmap="coolwarm", axis=0)

    # px.imshow(features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).T.corr(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

    tmp = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T
    px.line(tmp, 
            title=f"{title}: Features Activation by Token in Prompt", 
            color_discrete_sequence=px.colors.qualitative.Plotly).show()

    tmp = features_acts_by_token_df.T
    px.imshow(tmp, 
              title=f"{title}: Top k features by activation", 
              color_continuous_midpoint=0, 
              color_continuous_scale="RdBu", 
              height=800).show()



correlative_conjunction_prompts = {
    "Either - or": {
        "prompt": "Either you are with me,",
        "answer": " or"
    },
    "Neither - nor": {
        "prompt": "I wasn't hired at any of the companies I'd applied to. Neither my experience great amount of experience,",
        "answer": " nor"
    },
    "Such - that": {
        "prompt": "Such is the intensity of the pollen outside,",
        "answer": " that"
    },
        "Whether - or": {
        "prompt": "Whether you bike to work and love that",
        "answer": " or"
    },
    "Not only - but": {
        "prompt": "Not only did my boyfriend buy me a Nintendo Switch,",
        "answer": " but"
    },
    "Not only - but also": {
        "prompt": "Not only did my boyfriend buy me a Nintendo Switch, but",
        "answer": " also"
    },
    "Both - And": {
        "prompt": "My parents went to both Hawaii",
        "answer": " and"
    },
    "As many - as": {
        "prompt": "There were as many applicants",
        "answer": " as"
    },
    "No sooner - than": {
        "prompt": "She would no sooner cheat on an exam",
        "answer": " than"
    },
    "Rather - than": {
        "prompt": "They would rather go to the movies",
        "answer": " than"
    },
}

comparative_phrases_prompts = {
    "The more - the more": {
         "prompt": f"the more you learn, the more you realize how much you don't know",
    },
    "The fewer - the fewer": {
        "prompt":f"The fewer people who know about this, the better",
    },
    "Less on - more on": {
        "prompt":f"I think we should focus less on talking about doing the work and more on doing the work",
    },
}
 
 
random_sentences = {
    "Random 1": {
         "prompt": f"Each of these patterns shares the property of linking elements in language, creating relationships between them that are similar to those established by correlative conjunctions."
    },
    "Random 2": {
        "prompt":f"I feel that the insight/intuitions/skills I’ve spent trying to make progress on trajectory models, are best utilized by **studying sparse-autoencoders on language models**.",
    },
    "Random 3": {
        "prompt":f"In most cases, professional emails are formal emails. A formal email is an email between professionals or academics that contains information related to their work.",
    },
}   


for title, prompt_dict in correlative_conjunction_prompts.items():
    print(title)
    analyse_lcf(prompt_dict["prompt"], title=title)
    print("\n\n") 


# Mech interp on a few examples

### Basic Set Up

In [None]:
prompt = "I have to say, not only is this a great book, but also the author is a great person."
prompt = "'Well, Ted,' said the weatherman, 'I don't know about that, but it's not only the owls that have been acting oddly today"
prompt = " Not only was Hagrid twice as tall as anyone else, he kept pointing at perfectly ordinary things"
prompt = "Soon he had not only Dumbledore and Morgana, but Hengist of Woodcraft, Alberic Grunnion, Circe, Paracelsus and Merlin."
prompt = "She asked 'Which way to the beach?', to which I replied,  'It's over there. You can't miss it.'. She thanked me and walked away."
prompt = "The team, despite facing numerous challenges and unexpected setbacks, remains optimistic about the upcoming project."
prompt = """
correlative_conjunction_prompts = {
    "Either - or": {
        "prompt": "Either you are with me,",
        "answer": " or"
    },
    "Neither - nor": {
        "prompt": "I wasn't hired at any of the companies I'd applied to. Neither my experience great amount of experience,",
        "answer": " nor"
    },
"""
token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt], model, sparse_autoencoder, head_idx_override=5)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
            "top_k_features"]
# display(token_df[filter_cols].style.background_gradient(
#     subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
#     cmap="coolwarm"))


# POS_INTEREST = token_df.index.max()
# UNIQUE_TOKEN_INTEREST = token_df["unique_token"][POS_INTEREST]

In [None]:
POS_INTEREST = 6
UNIQUE_TOKEN_INTEREST = token_df["unique_token"][POS_INTEREST]
feature_acts_of_interest = feature_acts[POS_INTEREST]
# plot_line_with_top_10_labels(feature_acts_of_interest, "", 25)
# vals, inds = torch.topk(feature_acts_of_interest,39)


top_k_feature_inds = (feature_acts[1:] > 0).sum(dim=0).nonzero().squeeze()

features_acts_by_token_df = pd.DataFrame(
    feature_acts[:,top_k_feature_inds[:]].detach().cpu().T,
    index = [f"feature_{i}" for i in top_k_feature_inds.flatten().tolist()],
    columns = token_df["unique_token"])

# features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).style.background_gradient(
#     cmap="coolwarm", axis=0)

# px.imshow(features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).T.corr(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

tmp = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T
# dashboard_features = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).index[:10].to_series().apply(lambda x: x.split("_")[1]).tolist()
# for feature in dashboard_features:
#     render_feature_dashboard(feature)

px.line(tmp, 
        title=f"{title}: Features Activation by Token in Prompt", 
        color_discrete_sequence=px.colors.qualitative.Plotly,
        height=1000).show()

# px.parallel_coordinates(
#     tmp.T,
#     dimensions = tmp.index,
#     # color=UNIQUE_TOKEN_INTEREST,
#     color_continuous_scale=px.colors.sequential.Plasma,
#     color_continuous_midpoint=0,
#     title=f"{title}: Features Activation by Token in Prompt",
#     height=500,
# ).show()
# tmp = features_acts_by_token_df.T
# px.imshow(tmp, 
#             title=f"{title}: Top k features by activation", 
#             color_continuous_midpoint=0, 
#             color_continuous_scale="RdBu", 
#             height=800).show()

px.imshow(tmp.corr(), color_continuous_midpoint=0, color_continuous_scale="RdBu", height=800).show()

In [None]:
logits = original_cache.apply_ln_to_stack(original_cache['blocks.11.hook_resid_post']) @ model.W_U
print(logits.shape)
vals, inds =torch.topk(logits[:,1:], 10, dim=-1)
topk_predicted_token_inds = list(set(inds.flatten().tolist()))
topk_predicted_token_strs =model.tokenizer.convert_ids_to_tokens(topk_predicted_token_inds)

predicted_tokens_df = pd.DataFrame(logits[0,:,topk_predicted_token_inds].detach().cpu().T,
                                   columns = token_df["unique_token"], index = topk_predicted_token_strs)

px.line(predicted_tokens_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T)

### DFA

In [None]:
top_k_feature_inds = (feature_acts[1:] > 0).sum(dim=0).nonzero().squeeze()

In [None]:
# DLA

decomp, labels = original_cache.get_full_resid_decomposition(layer =  10, expand_neurons=False, return_labels=True)
inds = top_k_feature_inds.squeeze()
tok1 = " but"
print(decomp.shape)
dla = (decomp[:,0,:] @ model.W_U[:,model.tokenizer.encode(tok1)]).detach().cpu().squeeze()
print(dla.shape)
tmp = pd.DataFrame(dla.detach().cpu().numpy().T, index = token_df["unique_token"],
                   columns = labels)
px.line(
    tmp.T
).show()



In [None]:
# now let's do DLA
# decomp, labels = original_cache.get_full_resid_decomposition(layer =  10, expand_neurons=False, return_labels=True)
# print(decomp.shape)
# inds = top_k_feature_inds.squeeze()
# test = (decomp[:,0,POS_INTEREST] @ sparse_autoencoder.W_enc[:,inds])
# test = (decomp[:,0,-1] @ sparse_autoencoder.W_dec[inds].T) / sparse_autoencoder.W_enc[:,inds].norm(dim=0)
# tmp = pd.DataFrame(test.detach().cpu().numpy().T, columns = labels, index = [f"feature_{i}" for i in inds])
# px.line(
#     tmp.T[tmp.T.index.str.contains("mlp")]
# ).show()

# px.line(
#     tmp.T[tmp.T.index.str.contains("L")]
# ).show()

# test = (decomp[:,0,-1] @ sparse_autoencoder.W_enc[:,inds])
test = (decomp[:,0,POS_INTEREST] @ sparse_autoencoder.W_dec[inds].T) / sparse_autoencoder.W_enc[:,inds].norm(dim=0)
tmp = pd.DataFrame(test.detach().cpu().numpy().T, columns = labels, index = [f"feature_{i}" for i in inds])
px.line(
    tmp.T[tmp.T.index.str.contains("mlp")]
).show()
px.line(
    tmp.T[tmp.T.index.str.contains("L")]
).show()

In [None]:
original_cache["pattern",0, "attn"].shape

In [None]:
tmp = pd.DataFrame(original_cache["pattern",7, "attn"][0,7].detach().cpu().numpy(), columns = token_df.unique_token, index = token_df.unique_token)
px.imshow(tmp, color_continuous_midpoint=0, color_continuous_scale="RdBu", height = 800).show()
tmp = pd.DataFrame(original_cache["pattern",8, "attn"][0,5].detach().cpu().numpy(), columns = token_df.unique_token, index = token_df.unique_token)
px.imshow(tmp, color_continuous_midpoint=0, color_continuous_scale="RdBu", height = 800).show()

In [None]:
original_cache["pattern",8, "attn"][0,5][13,7]

In [None]:
import circuitsvis as cv 

tokens = token_df["unique_token"].tolist()
# print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(
    tokens=token_df["unique_token"].tolist(), 
    attention=original_cache["pattern",7, "attn"][0])

In [None]:
eff_embed = model.W_E + model.blocks[0].mlp(model.blocks[0].ln2(model.W_E[None]))
eff_embed = eff_embed.squeeze()
eff_embed.shape

In [None]:
eff_embed_but = eff_embed[model.to_single_token(" but")]
# eff_embed_but = eff_embed[model.to_single_token(" but")]

layer = 8
head = 5
W_QK = model.W_K[layer, head] @ model.W_Q[layer, head].T

In [None]:
original_cache[utils.get_act_name("k", 8)][0,7,8].shape

In [None]:
vals, inds = torch.topk(eff_embed @ W_QK.T @ eff_embed_but, 30)
model.to_str_tokens(inds)

# Metric Development

In [None]:
# let's measure these for one prompt
prompt = "This investigation is not only one that is continuing and worldwide, but also one that we expect to continue for quite some time." # Not only ... but
token_df, original_cache, cache_reconstructed_query, feature_acts_example = eval_prompt([prompt], model, sparse_autoencoder, head_idx_override=7)
# print(token_df.columns)
# filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
#                "top_k_features"]
# token_df[filter_cols].style.background_gradient(
#     subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
#     cmap="coolwarm")

print(prompt)
feature_acts.shape

In [None]:
import time 

def analyze_events(tensor):
    
    assert len(tensor.shape) == 2, "tensor must be 2D"
    results = []

    for row in tensor:
        in_event = False
        event_start = 0
        num_events = 0
        max_values = []
        avg_values = []
        durations = []
        start_position = np.NAN
        final_position = np.NAN
        
        for i, value in enumerate(row.tolist()):
            if value > 0:
                if not in_event:
                    in_event = True
                    event_start = i
                    num_events += 1
                    max_value = value
                    total_value = value
                    start_position = i
                else:
                    max_value = max(max_value, value)
                    total_value += value
            else:
                if in_event:
                    in_event = False
                    durations.append(i - event_start)
                    max_values.append(max_value)
                    avg_values.append(total_value / (i - event_start))
                    final_position = i
        
        if in_event:
            durations.append(len(row) - event_start)
            max_values.append(max_value)
            avg_values.append(total_value / (len(row) - event_start))

        
        # get the average event duration
        avg_duration = (sum(durations) / len(durations)) if len(durations) > 0 else np.NaN
        
        # max duration 
        max_duration = max(durations) if len(durations) > 0 else np.NaN
        
        # get the average max value
        avg_max_value = (sum(max_values) / (len(max_values)) if len(max_values) > 0 else np.NaN)
        num_firings = sum(durations)
        
        # `zip` avg_valuea, max_values, durations and add it as a subrecord which we could unfurl later
        event_stats = zip(avg_values, max_values, durations)
        event_stats = [
            {
                'avg_value': avg_value,
                'max_value': max_value,
                'duration': duration,
                'start_position': start_position, 
                'final_position': final_position,
            }
            for avg_value, max_value, duration in event_stats
        ]

        results.append({
            'num_events': num_events,
            'num_firings': num_firings,
            'avg_values': avg_values,
            'max_values': max_values,
            'durations': durations,
            'avg_duration': avg_duration,
            'max_duration': max_duration,
            'avg_max_value': avg_max_value,
            'events': event_stats,
        })

    return results

# Example usage


tensor =feature_acts_example[:, features_of_interest].T

start_time = time.time()
results = analyze_events(tensor)
end_time = time.time()
print(f"Time taken: {end_time - start_time}")

feature_prompt_df  = pd.DataFrame(results, index=features_of_interest)
feature_prompt_df["feature"] = feature_prompt_df.index
feature_prompt_df.explode('events').sort_values("num_events", ascending=False)
display(feature_prompt_df.head(10))

In [None]:
# convert events to a dataframe
tmp = feature_prompt_df.explode('events').apply(lambda x: pd.Series(x['events']), axis=1).reset_index().rename(columns={"index": "feature"})
tmp["feature"] = tmp["feature"].astype(str)
px.scatter_matrix(tmp, 
                  title="Event stats for each feature", color="feature", dimensions=["avg_value", "max_value", "duration"],
                  width=1000, height=1000)

In [None]:
px.scatter(feature_prompt_df, x="num_firings", y="num_events", hover_name=feature_prompt_df.index)

## Write Loop

In [None]:
all_tokens_list = []
pbar = tqdm(range(128*6))
for i in pbar:
    all_tokens_list.append(activation_store.get_batch_tokens())
all_tokens = torch.cat(all_tokens_list, dim=0)
print(all_tokens.shape)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
tokens = all_tokens[:4096*6]
del all_tokens
torch.mps.empty_cache()

In [None]:

n_prompts = 1000
# features_of_interest = features_of_interest
features_of_interest = torch.randperm(sparse_autoencoder.cfg.d_sae)[:100].tolist()
token_dfs = []
event_dfs = []
feature_acts_all = []

for prompt_index in tqdm(range(n_prompts)):
    prompt_tokens = tokens[prompt_index].unsqueeze(0)
    # make token df 
    token_df = make_token_df(model, prompt_tokens, len_suffix=5, len_prefix=10)
    token_df["prompt_index"] = prompt_index
    
    (original_logits, original_loss), original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)
    token_df['loss'] = original_loss.flatten().tolist() + [np.nan]
    
    original_act = original_cache[sparse_autoencoder.cfg.hook_point]
    sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)

    feature_acts_of_interest = feature_acts[0, :, features_of_interest].T
    results = analyze_events(feature_acts_of_interest)
    events_df  = pd.DataFrame(results, index=features_of_interest)
    events_df["feature"] = events_df.index.astype(str)
    events_df["prompt_index"] = prompt_index
    events_df = events_df[events_df["num_events"] > 0]
    
        
    token_dfs.append(token_df.reset_index(drop=True))
    event_dfs.append(events_df.reset_index(drop=True))
    feature_acts_all.append(feature_acts_of_interest)
    
feature_acts_all = torch.stack(feature_acts_all, dim=0)

In [None]:
feature_acts_all = torch.stack(feature_acts_all, dim=0)

In [None]:
token_df = pd.concat(token_dfs).reset_index(drop=True)
prompt_event_df = pd.concat(event_dfs).reset_index(drop=True)
events_df = prompt_event_df.explode('events').apply(lambda x: pd.Series(x['events']), axis=1)
events_df["feature"] = events_df.index.map(lambda x: prompt_event_df.feature[x]).astype(str)
events_df["prompt_index"] = events_df.index.map(lambda x: prompt_event_df.prompt_index[x])
#
# tmp["feature"] = tmp.index.map(lambda x: event_df["feature"][x]).astype(str)
    

In [None]:
prompt_event_df.head()

In [None]:
px.scatter_matrix(prompt_event_df, 
                  title="Event stats for each feature", color="feature", dimensions=["num_events", "num_firings", "avg_duration", "avg_max_value"],
                  width=1000, height=1000)

In [None]:
prompt_event_agg_df = prompt_event_df.groupby(["feature", "prompt_index"]).agg({"num_events": "sum", "num_firings": "sum", "avg_duration": "mean"}).sort_values("num_events", ascending=False).reset_index()
prompt_event_agg_df["firings_per_event"] = prompt_event_agg_df["num_firings"] / prompt_event_agg_df["num_events"]
px.strip(prompt_event_agg_df, x = "feature", y = "firings_per_event", color="feature", title="Firings per event",
         hover_data= ["num_events", "num_firings", "avg_duration", "prompt_index"],
         ).show()




In [None]:
prompt_event_agg_df.feature.unique().shape

In [None]:
mean_firings_per_event = prompt_event_agg_df.groupby("feature").firings_per_event.mean().sort_values(ascending=False)
std_firings_per_event = prompt_event_agg_df.groupby("feature").firings_per_event.std().sort_values(ascending=False)
px.scatter(x=mean_firings_per_event.values, 
           y = std_firings_per_event.values,
           hover_name=mean_firings_per_event.index,
           marginal_x="histogram",
              marginal_y="histogram",
           labels = {"x": "Mean firings per event", "y": "Std firings per event"},
           title="Mean vs Std firings per event").show()

In [None]:
for feature in mean_firings_per_event[mean_firings_per_event<1.3].index[10:30]:
    render_feature_dashboard(feature)

In [None]:
## Given some token, let's get the distribution of tokens it began firing on
events_df["token_df_id"] = events_df.apply(lambda x: token_df_id_from_prompt_and_pos(x["prompt_index"], x["start_position"]), axis=1)

In [None]:
events_df

In [None]:
events_df.join(token_df, ="token_df_id").head()

In [None]:
prompt_event_df.head()

In [None]:
# we want to get the token distribution from events. 
feature_idx = features_of_interest.index(22768)
token_df["feature_22768"] = feature_acts_all[:, feature_idx].flatten().tolist() 
# token_df["feature_22768_quantile"] = pd.qcut(token_df["feature_22768"], 10, labels=False, duplicates="drop")
idxes = token_df.sort_values("feature_22768", ascending=False).head(30).index
idxes_minus_1 = idxes - 1


In [None]:
token_df.groupby()

In [None]:
token_df_id_from_prompt_and_pos = lambda prompt_index, pos: token_df[(token_df["prompt_index"] == prompt_index) & (token_df["pos"] == pos)].index[0]
str_token_from_prompt_and_pos = lambda prompt_index, pos: token_df[(token_df["prompt_index"] == prompt_index) & (token_df["pos"] == pos)].str_tokens.values[0]

token_df_id_from_prompt_and_pos(12,3)
# str_token_from_prompt_and_pos(12,3)

In [None]:
events_df.groupby("feature").agg({"duration": "std"}).sort_values("duration", ascending=False)

In [None]:
# start id word_cloud

feature_of_interest = 22768

# step 1. Get the start and end points for the text we care about
events_df[events_df.duration == 4]#[events_df.feature == str(feature_of_interest)]
# px.strip(tmp, x = "duration", y = "avg_value",title="Firings per event")\
    

# step 2. for each of these, get prompt
token_df_ids = [token_df_id_from_prompt_and_pos(i,j) for i,j in zip(tmp.prompt_index, tmp.start_position)]
minus_one_token_ids = [token_df_id_from_prompt_and_pos(i,j) for i,j in zip(tmp.prompt_index, tmp.start_position - 1)]
final_token_ids = [token_df_id_from_prompt_and_pos(i,j) for i,j in zip(tmp.prompt_index, tmp.final_position.fillna(128) -1)]
minus_one_token_fire = token_df.iloc[minus_one_token_ids].str_tokens.reset_index(drop=True)
first_token_fire = token_df.iloc[token_df_ids].str_tokens.reset_index(drop=True)
final_token_fire = token_df.iloc[final_token_ids].str_tokens.reset_index(drop=True)

tmp = pd.concat([first_token_fire, minus_one_token_fire, final_token_fire], axis=1)

tmp.columns = ["first_token", "minus_one_token", "final_token"]
tmp

# Proxy Development

In [None]:
all_tokens_list = []
pbar = tqdm(range(128*6))
for i in pbar:
    all_tokens_list.append(activation_store_layer_10.get_batch_tokens())
all_tokens = torch.cat(all_tokens_list, dim=0)
print(all_tokens.shape)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
tokens = all_tokens[:4096*6]
del all_tokens
torch.mps.empty_cache()

In [None]:
prompt_df = pd.DataFrame(
    {"prompt_index" : range(tokens.shape[0]),
        "prompt": [model.to_string(tokens[i]) for i in range(tokens.shape[0])]})

In [None]:
import re

correlative_conjunctions = {
    "both_and": r"\bboth\b(?:(?!\.|\?|!).)*?\band\b",
    "either_or": r"\beither\b(?:(?!\.|\?|!).)*?\bor\b",
    "neither_nor": r"\bneither\b(?:(?!\.|\?|!).)*?\bnor\b",
    "not_only_but_also": r"\bnot\s+only\b(?:(?!\.|\?|!).)*?\bbut\s+also\b",
    "whether_or": r"\bwhether\b(?:(?!\.|\?|!).)*?\bor\b",
}


questions = {
    "general_questions": r"\b(who|what|when|where|why|how)\b.*?\?",
    "how": r"\bhow\b.*?\?",
    "what": r"\bwhat\b.*?\?",
    "when": r"\bwhen\b.*?\?",
    "where": r"\bwhere\b.*?\?",
    "why": r"\bwhy\b.*?\?",
    "who": r"\bwho\b.*?\?",
    "choice_questions": r"\b(do you prefer|would you rather)\b.*?\?",
}

punctuation = {
    "regular_parentheses": r"\(.*?\)",
    "square_brackets": r"\[.*?\]",
    "curly_brackets": r"\{.*?\}",
    "angle_brackets": r"\<.*?\>",
    "double_quotes": r"\".*?\"",
    "single_quotes": r"\'.*?\'",
    "backticks": r"`.*?`",
}

# lists = {
#     "bulleted_lists": r"^\s*[\-\*\+] .*$",
#     "numbered_lists": r"^\s*\d+\..*$",
#     "alphabetic_lists": r"^\s*[a-zA-Z]\..*$",
# }

formatting = {
    "specific_html_tag": r"\<div\>.*?\</div\>",  # Example with 'div' tag
    "any_html_tag": r"\<.*?\>.*?\</.*?\>",
    "inline_code": r"`.*?`",
    "multiline_code_blocks": r"```.*?```",
    "bold_text_markdown": r"\*\*.*?\*\*" + "|" + r"__.*?__",
    "italic_text_markdown": r"\*.*?\*" + "|" + r"_.*?_"
}


# now all all proxies together in one dict
proxies = {}
proxies.update(correlative_conjunctions)
proxies.update(questions)
# proxies.update(punctuation)
# proxies.update(lists)
# proxies.update(formatting)

# create a column for each conjunction in the prompt_df

for conjunction, regex in proxies.items():
    prompt_df[conjunction] = prompt_df.prompt.str.contains(regex, flags=re.IGNORECASE)
    
# summarize
prompt_df.iloc[:,2:].sum()

In [None]:
proxy_hits = prompt_df.iloc[:,2:7].sum()

px.bar(proxy_hits, 
       # add the number above each bar
      text=proxy_hits.values,
       title="Proxy Hits (out of 24576 prompts)", 
       labels={"value": "Number of prompts"}, width = 500).show()


proxy_hits = prompt_df.iloc[:,7:13].sum()

px.bar(proxy_hits, 
       # add the number above each bar
      text=proxy_hits.values,
       title="Proxy Hits (out of 24576 prompts)", 
       labels={"value": "Number of prompts"},
       width = 500)

In [None]:
# sample from prompts containing proxies
import re 
# import HTML
from IPython.display import HTML, display


def both_and_highlight(prompt, prompt_proxy_regex):
    start_pos = re.search(prompt_proxy_regex, prompt, flags=re.IGNORECASE).start()
    end_pos = re.search(prompt_proxy_regex, prompt, flags=re.IGNORECASE).end()
    # style with red text
    style_tag = "<span style='color:red'>"
    prompt = prompt[:start_pos] + f'{style_tag}'+ prompt[start_pos:end_pos] + "</span>" + prompt[end_pos:]
    display(HTML(prompt))

for i, row in prompt_df[prompt_df["both_and"]].sample(1).iterrows():
    both_and_highlight(row.prompt,proxies["both_and"])

In [None]:
def get_feature_acts(prompts, features_of_interest: List, sparse_autoencoder):
    
    n_prompts = len(prompts)
    feature_acts_list = []
    
    for prompt_index in tqdm(range(n_prompts)):
        prompt_tokens = prompts[prompt_index].unsqueeze(0)
        # make token df 
        token_df = make_token_df(model, prompt_tokens, len_suffix=5, len_prefix=10)
        token_df["prompt_index"] = prompt_index
        
        (original_logits, original_loss), original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)
        token_df['loss'] = original_loss.flatten().tolist() + [np.nan]
        
        original_act = original_cache[sparse_autoencoder.cfg.hook_point]
        sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)

        feature_acts_list.append(feature_acts[:,:,features_of_interest])
        
    feature_acts = torch.stack(feature_acts_list, dim=0)
    
    return feature_acts


A huge amount of annoying data crunching so we have a df with the indexes we care about. (work out where the correlative conjunction appeared and then get the token positions we care about.) Then we are ready to go to feature acts and get the feature acts for all of these positions

In [None]:
# start token ids are any ids that match "both"
both_token_strs = [" both", "Both", " both", "Both"]
both_token_ids = [model.to_single_token(token) for token in both_token_strs]
and_token_strs = [" and", "and", "And", " And"]
and_token_ids = [model.to_single_token(token) for token in and_token_strs]


both_and_df = prompt_df[prompt_df["both_and"]][["prompt_index", "prompt"]]
both_and_df["tokens"] = tokens[prompt_df["both_and"]].detach().cpu().numpy().tolist()
both_and_df["start_pos"] = both_and_df.prompt.apply(lambda x: re.search(proxies["both_and"], x, flags=re.IGNORECASE).start())
both_and_df["end_pos"] = both_and_df.prompt.apply(lambda x: re.search(proxies["both_and"], x, flags=re.IGNORECASE).end())
both_and_df["offset_mapping"] = both_and_df.apply(lambda x: model.tokenizer.encode_plus(x.prompt, return_offsets_mapping=True)["offset_mapping"], axis=1)
both_and_df["start_offset_mapping"] = both_and_df.apply(lambda x: [i for i,_ in x["offset_mapping"]], axis=1)
both_and_df["end_offset_mapping"] = both_and_df.apply(lambda x: [j for _,j in x["offset_mapping"]], axis=1)
both_and_df["start_pos_tok_id"] = both_and_df.apply(lambda x: next(i for i, offset in enumerate(x["start_offset_mapping"]) if offset >= x["start_pos"]), axis=1)
both_and_df["end_pos_tok_id"] = both_and_df.apply(lambda x: next((i for i, offset in enumerate(x["start_offset_mapping"]) if offset > x["end_pos"]-1), 127), axis=1)
both_and_df["start_pos_tok_str"] = both_and_df.apply(lambda x: model.to_single_str_token(x["tokens"][x["start_pos_tok_id"]]), axis = 1) 
both_and_df["end_pos_tok_str"] = both_and_df.apply(lambda x: model.to_single_str_token(x["tokens"][x["end_pos_tok_id"]]), axis = 1)
both_and_df[[ "prompt_index", "prompt", "start_pos_tok_str", "end_pos_tok_str"]]


features_of_interest = [21604]
both_and_feature_acts = get_feature_acts(tokens[prompt_df["both_and"]], features_of_interest=features_of_interest, sparse_autoencoder=sparse_autoencoder_layer_10)
both_and_feature_acts = both_and_feature_acts.squeeze()
both_and_feature_acts.shape

In [None]:
both_and_df["n_toks_in_proxy_context"] = both_and_df.apply(lambda x: x["end_pos_tok_id"] - x["start_pos_tok_id"], axis=1)

tmp = both_and_df.n_toks_in_proxy_context.value_counts()
# bar chart orderer by index, with text labels for the count
px.bar(tmp.sort_index(), text=tmp.values, title="Number of tokens in proxy context",
       # not legend
        labels={"n_toks_in_proxy_context": "Number of tokens in proxy context", "value": "Number of prompts"})

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

def display_prompt_selector(df, indexes, prompt_proxy_regex=proxies["both_and"]):
    # Output widget to display the prompt
    output = widgets.Output()

    # Function to display the prompt example
    def on_dropdown_change(change):
        with output:
            clear_output(wait=True)
            index = change['new']
            if index in df.index:
                # Your logic to display the prompt example
                both_and_highlight(df.loc[index].prompt, prompt_proxy_regex=prompt_proxy_regex)
            else:
                print("Invalid selection")

    # Create a dropdown widget for prompt examples
    prompt_selector = widgets.Dropdown(
        options=['Select a prompt'] + list(indexes),
        description='Select Prompt:',
        disabled=False,
    )

    # Observe changes in the dropdown
    prompt_selector.observe(on_dropdown_change, names='value')

    # Display the widgets and output
    display(prompt_selector, output)

# Example usage
# display_prompt_selector(both_and_df, both_and_df.index)


# using the start and end pos token ids from both_and_df to get the feature acts, padding with 0s
start_pos_tok_ids = both_and_df["start_pos_tok_id"].tolist()
end_pos_tok_ids = both_and_df["end_pos_tok_id"].tolist()
aligned_feature_actions = torch.zeros_like(both_and_feature_acts)
for i , (start, end) in enumerate(zip(start_pos_tok_ids, end_pos_tok_ids)):
    aligned_feature_actions[i, :(end-start+1)] = both_and_feature_acts[i, (start-1):(end)]

for gap_length in range(6,10):
    length_mask = (both_and_df.n_toks_in_proxy_context == gap_length).values
    tmp = pd.DataFrame(aligned_feature_actions[length_mask,:gap_length+2].detach().cpu().numpy().T,
                       columns = both_and_df[length_mask].prompt_index.astype(int).to_list())
    
    feature_present_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) > 0).nonzero()]
    feature_missing_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) == 0).nonzero()]
    
    display(HTML("<h2>Feature present</h2>"))
    print(feature_present_indexes)
    display_prompt_selector(both_and_df[length_mask], feature_present_indexes)
    display(HTML("<h2>Feature Absent</h2>"))
    print(feature_missing_indexes)
    display_prompt_selector(both_and_df[length_mask], feature_missing_indexes)
    px.line(tmp, title=f"Feature activations for prompts containing 'both ... and' with a gap of {gap_length}").show()

In [None]:
prompt_1_1 = """himself into a shadow war in order to expose it. His only clue is the keyword "El Dorado." He meets Sophie, a woman searching for her older brother who left her with only a message with the same word: "El Dorado." With Sword having also lost his younger sister in the past, both are drawn together by the word,"""
prompt_2_1 = """tern and a banana. "The fan section was louder than it had been all season long, and the fans, of both sides I may add, were thoroughly amused"""
answer = " and"
utils.test_prompt(prompt_1_1, answer, model)
utils.test_prompt(prompt_2_1, answer, model)

## What ... ?

In [None]:
# start token ids are any ids that match "both"
what_token_strs = [" what", "What"]
what_token_ids = [model.to_single_token(token) for token in what_token_strs]
qmark_token_strs = ["?", " ?", "?!"]
qmark_token_ids = [model.to_single_token(token) for token in and_token_strs]


what_question_df = prompt_df[prompt_df["what"]][["prompt_index", "prompt"]]
what_question_df["tokens"] = tokens[prompt_df["what"]].detach().cpu().numpy().tolist()
what_question_df["start_pos"] = what_question_df.prompt.apply(lambda x: re.search(proxies["what"], x, flags=re.IGNORECASE).start())
what_question_df["end_pos"] = what_question_df.prompt.apply(lambda x: re.search(proxies["what"], x, flags=re.IGNORECASE).end())
what_question_df["offset_mapping"] = what_question_df.apply(lambda x: model.tokenizer.encode_plus(x.prompt, return_offsets_mapping=True)["offset_mapping"], axis=1)
what_question_df["start_offset_mapping"] = what_question_df.apply(lambda x: [i for i,_ in x["offset_mapping"]], axis=1)
what_question_df["end_offset_mapping"] = what_question_df.apply(lambda x: [j for _,j in x["offset_mapping"]], axis=1)
what_question_df["start_pos_tok_id"] = what_question_df.apply(lambda x: next(i for i, offset in enumerate(x["start_offset_mapping"]) if offset >= x["start_pos"]), axis=1)
what_question_df["end_pos_tok_id"] = what_question_df.apply(lambda x: next((i for i, offset in enumerate(x["start_offset_mapping"]) if offset > x["end_pos"]-1), 127), axis=1)
what_question_df["start_pos_tok_str"] = what_question_df.apply(lambda x: model.to_single_str_token(x["tokens"][x["start_pos_tok_id"]]), axis = 1) 
what_question_df["end_pos_tok_str"] = what_question_df.apply(lambda x: model.to_single_str_token(x["tokens"][x["end_pos_tok_id"]]), axis = 1)
what_question_df[[ "prompt_index", "prompt", "start_pos_tok_str", "end_pos_tok_str"]]


features_of_interest = [18962]
what_question_acts = get_feature_acts(tokens[prompt_df["what"]], features_of_interest=features_of_interest)
what_question_acts = what_question_acts.squeeze()


what_question_df["n_toks_in_proxy_context"] = what_question_df.apply(lambda x: x["end_pos_tok_id"] - x["start_pos_tok_id"], axis=1)

tmp = what_question_df.n_toks_in_proxy_context.value_counts()
# bar chart orderer by index, with text labels for the count
px.bar(tmp.sort_index(), text=tmp.values, title="Number of tokens in proxy context",
       # not legend
        labels={"n_toks_in_proxy_context": "Number of tokens in proxy context", "value": "Number of prompts"})

In [None]:
# using the start and end pos token ids from what_question_df to get the feature acts, padding with 0s
start_pos_tok_ids = what_question_df["start_pos_tok_id"].tolist()
end_pos_tok_ids = what_question_df["end_pos_tok_id"].tolist()
aligned_feature_actions = torch.zeros_like(what_question_acts)
for i , (start, end) in enumerate(zip(start_pos_tok_ids, end_pos_tok_ids)):
    aligned_feature_actions[i, :(end-start+1)] = what_question_acts[i, (start-1):(end)]

for gap_length in range(5,9):
    length_mask = (what_question_df.n_toks_in_proxy_context == gap_length).values
    tmp = pd.DataFrame(aligned_feature_actions[length_mask,:gap_length+2].detach().cpu().numpy().T,
                       columns = what_question_df[length_mask].prompt_index.astype(int).to_list())
    
    feature_present_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) > 0).nonzero()]
    feature_missing_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) == 0).nonzero()]
    
    display(HTML("<h2>Feature present</h2>"))
    print(feature_present_indexes)
    display_prompt_selector(what_question_df[length_mask], feature_present_indexes, prompt_proxy_regex=proxies["what"])
    display(HTML("<h2>Feature Absent</h2>"))
    print(feature_missing_indexes)
    display_prompt_selector(what_question_df[length_mask], feature_missing_indexes, prompt_proxy_regex=proxies["what"])
    px.line(tmp, title=f"Feature activations for prompts containing 'both ... and' with a gap of {gap_length}").show()

## DFA to get heads


In [None]:
prompt_index = 2
prompts = tokens[prompt_df["both_and"]]
features_of_interest = 21604
feature_dir = sparse_autoencoder_layer_10.W_dec[features_of_interest].cpu()

prompt_tokens = prompts[prompt_index].unsqueeze(0)
text = model.to_string(prompt_tokens)[0]
offset_mapping = model.tokenizer.encode_plus(text, return_offsets_mapping=True)["offset_mapping"]
start_offset_mapping = [i for i,_ in offset_mapping]
end_offset_mapping = [j for _,j in offset_mapping]
(original_logits, original_loss), original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)

decomp, labels = original_cache.get_full_resid_decomposition(layer =  10, expand_neurons=False, return_labels=True)
original_act = original_cache[sparse_autoencoder_layer_10.cfg.hook_point]
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder_layer_10(original_act)
feature_acts[:,:,features_of_interest].shape


# Now I want to know where the proxy starts, where the proxy ends, and where the feature starts and ends


# get the start and end positions of the proxy
proxy_start_pos = re.search(proxies["both_and"], text, flags=re.IGNORECASE).start()
start_pos_tok_id = next(i for i, offset in enumerate(start_offset_mapping) if offset >= proxy_start_pos)
proxy_end_pos = re.search(proxies["both_and"], text, flags=re.IGNORECASE).end()
end_pos_tok_id = next((i for i, offset in enumerate(start_offset_mapping) if offset > proxy_end_pos-1), 127)

n_offset = 5
# get the start and end positions of the feature
feature_fired = (feature_acts[:,:,features_of_interest].squeeze() > 0)
feature_fired[:start_pos_tok_id - n_offset] = False

start_feature_pos = feature_fired.nonzero().min().item()
end_feature_pos = feature_fired.nonzero().max().item()


result_metrics = {
    "start_pos_tok_id": start_pos_tok_id,
    "end_pos_tok_id": end_pos_tok_id,
    "start_feature_pos": start_feature_pos,
    "end_feature_pos": end_feature_pos,
    "fired_early": start_feature_pos < start_pos_tok_id,
}

# now we want to the decomp for the first instance of the proxy
projection = decomp.squeeze(1)[:, start_pos_tok_id].cpu() @ feature_dir
print(result_metrics)


def get_feature_acts_and_projection_at_start_pos(
    prompts, features_of_interest: List, feature_dir: torch.Tensor, n_offset: int = 5,
    ablate_head = None,
    sparse_autoencoder = sparse_autoencoder_layer_10
):
    
    n_prompts = len(prompts)
    token_dfs = []
    feature_acts_list = []
    projections = []
    attn_patterns = []
    
    if ablate_head is not None:
        head_layer, head_idx = ablate_head
        head_hook_result_name = f"blocks.{head_layer}.attn.hook_z"
    
        def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (head_layer, head_idx), pos = -1):
            # print(hook.layer(), hook.name)
            assert head[0] == hook.layer(), f"{head[0]} != {hook.layer()}"
            assert ("result" in hook.name) or ("q" in hook.name) or ("z" in hook.name)
            # print(head_output.shape)
            head_output[:, pos, head[1], :] = 0
            return head_output

    
    
    for prompt_index in tqdm(range(n_prompts)):
        prompt_tokens = prompts[prompt_index].unsqueeze(0)
        text = model.to_string(prompt_tokens)[0]
        offset_mapping = model.tokenizer.encode_plus(text, return_offsets_mapping=True)["offset_mapping"]
        start_offset_mapping = [i for i,_ in offset_mapping]
        end_offset_mapping = [j for _,j in offset_mapping]
        
        # make token df 
        token_df = make_token_df(model, prompt_tokens, len_suffix=10, len_prefix=10)
        token_df["prompt_index"] = prompt_index

        
        # work out where the proxy starts and ends
        proxy_start_pos = re.search(proxies["both_and"], text, flags=re.IGNORECASE).start()
        start_pos_tok_id = next(i for i, offset in enumerate(start_offset_mapping) if offset >= proxy_start_pos)
        proxy_end_pos = re.search(proxies["both_and"], text, flags=re.IGNORECASE).end()
        end_pos_tok_id = next((i for i, offset in enumerate(start_offset_mapping) if offset > proxy_end_pos-1), 127)


        
        if ablate_head is None:
            _, original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)
        else:
            pos_to_ablate_feature = start_pos_tok_id # feature usually fires after the proxy token.
            hook_to_ablate_head_with_pos = partial(hook_to_ablate_head, head = ablate_head, pos = pos_to_ablate_feature)
            with model.hooks(fwd_hooks=[(head_hook_result_name, hook_to_ablate_head_with_pos)]):
                _, original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)
            
        
        original_act = original_cache[sparse_autoencoder.cfg.hook_point]
        sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)
        feature_acts_list.append(feature_acts[:,:,features_of_interest])


        # get the start and end positions of the feature
        feature_fired = (feature_acts[:,:,features_of_interest].squeeze() > 0)
        feature_fired[:start_pos_tok_id - n_offset] = False

        start_feature_pos = feature_fired.nonzero().min().item()
        end_feature_pos = feature_fired.nonzero().max().item()

        result_metrics = {
            "start_pos_tok_id": start_pos_tok_id,
            "end_pos_tok_id": end_pos_tok_id,
            "start_end_proxy_gap": end_pos_tok_id - start_pos_tok_id,
            "start_feature_pos": start_feature_pos,
            "end_feature_pos": end_feature_pos,
            "start_end_feature_gap": end_feature_pos - start_feature_pos,
            "fired_early": start_feature_pos < start_pos_tok_id,
        }
        
        # get the decomp
        decomp, labels = original_cache.get_full_resid_decomposition(layer =  10, expand_neurons=False, return_labels=True)
        projection = decomp.squeeze(1)[:, start_pos_tok_id].cpu() @ feature_dir
        projections.append(projection)
        
        # get the attention pattern for L4H1
        layer = 4
        head =1
        attn_pattern = original_cache[utils.get_act_name("pattern",layer)].squeeze(0)[head,:].cpu()
        attn_patterns.append(attn_pattern)
        
        token_df = token_df.iloc[start_pos_tok_id]
        for metric, value in result_metrics.items():
            token_df[metric] = value
        token_dfs.append(token_df)
        
    feature_acts = torch.stack(feature_acts_list, dim=0)
    projections = torch.stack(projections, dim=0)
    token_df = pd.concat(token_dfs, axis =1).T
    attn_patterns = torch.stack(attn_patterns, dim=0)
    
    return token_df, feature_acts, projections, labels, attn_patterns


token_df_test, feature_acts_test, projections_test, labels, attn_patterns = get_feature_acts_and_projection_at_start_pos(
    tokens[prompt_df["both_and"]], features_of_interest=features_of_interest, feature_dir=feature_dir
)

token_df_test_ablate, feature_acts_test_ablate, projections_test_ablate, labels, attn_patterns_ablate = get_feature_acts_and_projection_at_start_pos(
    tokens[prompt_df["both_and"]], features_of_interest=features_of_interest, feature_dir=feature_dir,
    ablate_head=(4,1)
)


In [None]:
projections_test_ablate.shape

In [None]:
prompt = "<|endoftext|> a competitive CrossFitter to get into this mentality that you’re better than a “normal” member, or that your activity in the gym has more value. What if we changed our mindset and looked at it the other way around? Not to burst any bubbles, but if you – as a competitive athlete – to think that you working<|endoftext|>Although there is little case law on medical cannabis use in the Canadian workplace, there are a few cases that can guide both employees" #and employers on this topic. "
# both_and_highlight(prompt, proxies["both_and"])
answer = "and"

In [None]:
head_layer = 4
HEAD_HOOK_RESULT_NAME = f"blocks.{head_layer}.attn.hook_z"
head_idx = 1
def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (head_layer, head_idx), pos = -1):
    print(hook.layer(), hook.name)
    assert head[0] == hook.layer(), f"{head[0]} != {hook.layer()}"
    assert ("result" in hook.name) or ("q" in hook.name) or ("z" in hook.name)
    print(head_output.shape)
    head_output[:, pos, head[1], :] = 0
    return head_output

HTML("<h2>Original</h2>")
utils.test_prompt(prompt, answer, model)

HTML("<h2>With head ablated</h2>")
with model.hooks(fwd_hooks=[(HEAD_HOOK_RESULT_NAME, hook_to_ablate_head)]):
    utils.test_prompt(prompt, answer, model)

In [None]:
# pandas don't limit column width
pd.set_option('display.max_colwidth', None)
display(token_df_test.head(10))
print(token_df_test.fired_early.mean())
print(token_df_test.start_end_proxy_gap.value_counts())

#### Visualize DFA

In [None]:
plotting_df = pd.DataFrame({
    "labels": labels,
    "projection": projections_test.mean(0).numpy(),
    "projection_ablate": projections_test_ablate.mean(0).numpy(),
    "projection_diff":  projections_test_ablate.mean(0).numpy() - projections_test.mean(0).numpy(),
})

px.line(plotting_df, x = "labels",  y = ["projection", "projection_ablate", "projection_diff"], title="Projection of proxy onto feature",  width = 1000, height=500).show()


In [None]:

head_mask = torch.tensor([1 if re.match(r"L\d+H\d+", label) is not None else 0 for label in labels], dtype=torch.bool)
head_dfa = projections_test[:,head_mask].reshape(-1, sparse_autoencoder_layer_10.cfg.hook_point_layer, model.cfg.n_heads)
head_df_ablation = projections_test_ablate[:,head_mask].reshape(-1, sparse_autoencoder_layer_10.cfg.hook_point_layer, model.cfg.n_heads)

px.imshow(head_dfa[token_df_test.fired_early.values==False].median(0).values,
            title="DFA into Feature Decoder Direction by Head",
            color_continuous_midpoint=0,
            color_continuous_scale="RdBu",
            height = 500,
            width = 500,
            labels=dict(x="Head", y="Layer")).show()

px.imshow(head_df_ablation[token_df_test_ablate.fired_early.values==False].median(0).values,
            title="DFA into Feature Decoder Direction by Head",
            color_continuous_midpoint=0,
            color_continuous_scale="RdBu",
            height = 500,
            width = 500,
            labels=dict(x="Head", y="Layer")).show()

# px.imshow(head_dfa[token_df_test.fired_early.values==False].std(0),
#             title="DFA into Feature Decoder Direction by Head",
#             color_continuous_midpoint=0,
#             color_continuous_scale="RdBu",
#             height = 500,
#             width = 500,
#             labels=dict(x="Head", y="Layer")).show()

#### Visualize Attn Patterns

In [None]:
for gap_duration in range(2, 5):
    tmp_attn_patterns = attn_patterns[
        token_df_test.query("fired_early == False")
        .query(f"start_end_proxy_gap == {gap_duration}")
        .index
    ]
    filtered_attn_patterns = torch.stack(
        [
            attn_patterns[
                i,
                (token_df_test.start_pos_tok_id.iloc[i] - 2) : (
                    token_df_test.start_pos_tok_id.iloc[i] + gap_duration
                ),
                (token_df_test.start_pos_tok_id.iloc[i] - 2) : (
                    token_df_test.start_pos_tok_id.iloc[i] + gap_duration
                ),
            ]
            for i in range(tmp_attn_patterns.shape[0])
        ]
    )

    print(filtered_attn_patterns.shape)
    # px.imshow(filtered_attn_patterns.detach().cpu(),animation_frame=0, color_continuous_midpoint=0, color_continuous_scale="RdBu", height=500, width=500).show()
    px.imshow(
        filtered_attn_patterns.detach().cpu().mean(0),
        color_continuous_midpoint=0,
        color_continuous_scale="RdBu",
        height=500,
        width=500,
        title = f"Attention pattern 'both ... and' with a gap of {gap_duration}",
    ).show()

### Visualize Features Firing

In [None]:
feature_acts_test.nonzero()[:10]

In [None]:
px.scatter(token_df_test, 
           marginal_x="histogram",
              marginal_y="histogram",
           x="start_pos_tok_id", y="start_feature_pos", color="fired_early", title="Feature firing position vs proxy start position", width = 1000).show()

In [None]:
token_df_test

In [None]:
plotting_df.shape

In [None]:
for i in range(-1,9):
    firing_indexes = torch.tensor(token_df_test.start_pos_tok_id.astype(int).values)
    positions = (torch.min(firing_indexes + i, 127*torch.ones_like(firing_indexes))).unsqueeze(0).T
    feature_activation_at_first_proxy_pos = feature_acts_test.cpu().squeeze().gather(1, positions).unsqueeze(0)
    feature_activation_with_ablation_at_first_proxy_pos = feature_acts_test_ablate.cpu().squeeze().gather(1, positions).unsqueeze(0)

    plotting_df = token_df_test.copy()
    plotting_df["feature"] = feature_activation_at_first_proxy_pos.squeeze().numpy()
    plotting_df["feature_with_head_ablated"] = feature_activation_with_ablation_at_first_proxy_pos.squeeze().numpy()
    plotting_df["ablation_diff"] = feature_activation_at_first_proxy_pos.squeeze().numpy() - feature_activation_with_ablation_at_first_proxy_pos.squeeze().numpy()

    fig = px.scatter(
        plotting_df[plotting_df.start_end_feature_gap >= i].query("fired_early == False"),
        x="feature",
        y="feature_with_head_ablated",
        hover_data=["prompt_index", "start_end_feature_gap"],
        title=f"Feature activation at proxy position {i} with ablation",
        width = 1000,
    )
    # add y=x from 0 to 40
    fig.add_shape(
        type="line", line=dict(dash="dash"), x0=0, y0=0, x1=40, y1=40
    )
    fig.show()



In [None]:
# using the start and end pos token ids from both_and_df to get the feature acts, padding with 0s
start_pos_tok_ids = both_and_df["start_pos_tok_id"].tolist()
end_pos_tok_ids = both_and_df["end_pos_tok_id"].tolist()
aligned_feature_actions = torch.zeros_like(both_and_feature_acts)
aligned_feature_ablate = torch.zeros_like(both_and_feature_acts)
for i , (start, end) in enumerate(zip(start_pos_tok_ids, end_pos_tok_ids)):
    aligned_feature_actions[i, :(end-start+1)] = both_and_feature_acts[i, (start-1):(end)]

for i , (start, end) in enumerate(zip(start_pos_tok_ids, end_pos_tok_ids)):
    aligned_feature_ablate[i, :(end-start+1)] = feature_acts_test_ablate[i, 0, (start-1):(end)]


for gap_length in range(1,3):
    length_mask = (both_and_df.n_toks_in_proxy_context == gap_length).values
    tmp = pd.DataFrame(aligned_feature_actions[length_mask,:gap_length+2].detach().cpu().numpy().T,
                       columns = both_and_df[length_mask].prompt_index.astype(int).to_list())
    
    feature_present_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) > 0).nonzero()]
    feature_missing_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) == 0).nonzero()]
    
    display(HTML("<h2>Feature present</h2>"))
    print(feature_present_indexes)
    display_prompt_selector(both_and_df[length_mask], feature_present_indexes)
    display(HTML("<h2>Feature Absent</h2>"))
    print(feature_missing_indexes)
    display_prompt_selector(both_and_df[length_mask], feature_missing_indexes)
    px.line(tmp, title=f"Feature activations for prompts containing 'both ... and' with a gap of {gap_length}", width = 1000).show()
    
    
    tmp = pd.DataFrame(aligned_feature_ablate[length_mask,:gap_length+2].detach().cpu().numpy().T,
                    columns = both_and_df[length_mask].prompt_index.astype(int).to_list())
    
    feature_present_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) > 0).nonzero()]
    feature_missing_indexes = tmp.columns[((tmp.values > 0).sum(axis=0) == 0).nonzero()]
    
    display(HTML("<h2>Feature present</h2>"))
    print(feature_present_indexes)
    display_prompt_selector(both_and_df[length_mask], feature_present_indexes)
    display(HTML("<h2>Feature Absent</h2>"))
    print(feature_missing_indexes)
    display_prompt_selector(both_and_df[length_mask], feature_missing_indexes)
    px.line(tmp, title=f"Feature activations for prompts containing 'both ... and' with a gap of {gap_length}", width = 1000).show()

#### Do a deep dive

In [None]:
random_sample = torch.randint(prompts.shape[0], (1,)).item()
prompt_tokens = prompts[random_sample].unsqueeze(0)
text = model.to_string(prompt_tokens)[0]
offset_mapping = model.tokenizer.encode_plus(text, return_offsets_mapping=True)["offset_mapping"]
start_offset_mapping = [i for i,_ in offset_mapping]
end_offset_mapping = [j for _,j in offset_mapping]
both_and_highlight(text, prompt_proxy_regex=proxies["both_and"])

In [None]:
# make token df 
token_df = make_token_df(model, prompt_tokens, len_suffix=10, len_prefix=10)
token_df["prompt_index"] = prompt_index

with suppress_output():
    _, original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)

original_act = original_cache[sparse_autoencoder.cfg.hook_point]
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)

# work out where the proxy starts and ends
proxy_start_pos = re.search(proxies["both_and"], text, flags=re.IGNORECASE).start()
start_pos_tok_id = next(i for i, offset in enumerate(start_offset_mapping) if offset >= proxy_start_pos)
proxy_end_pos = re.search(proxies["both_and"], text, flags=re.IGNORECASE).end()
end_pos_tok_id = next((i for i, offset in enumerate(start_offset_mapping) if offset > proxy_end_pos-1), 127)


# get the start and end positions of the feature
feature_fired = (feature_acts[:,:,features_of_interest].squeeze() > 0)
feature_fired[:start_pos_tok_id - n_offset] = False

start_feature_pos = feature_fired.nonzero().min().item()
end_feature_pos = feature_fired.nonzero().max().item()

result_metrics = {
    "start_pos_tok_id": start_pos_tok_id,
    "end_pos_tok_id": end_pos_tok_id,
    "start_end_proxy_gap": end_pos_tok_id - start_pos_tok_id,
    "start_feature_pos": start_feature_pos,
    "end_feature_pos": end_feature_pos,
    "start_end_feature_gap": end_feature_pos - start_feature_pos,
    "fired_early": start_feature_pos < start_pos_tok_id,
}

# get the decomp
decomp, labels = original_cache.get_full_resid_decomposition(layer =  10, expand_neurons=False, return_labels=True)
projection = decomp.squeeze(1)[:, start_pos_tok_id].cpu() @ feature_dir

# get the attention pattern for L4H1
layer = 4
head =1
attn_pattern = original_cache[utils.get_act_name("pattern",layer)].squeeze(0)[head,:].cpu()



tmp_df = pd.DataFrame(attn_pattern, columns = token_df.unique_token.values, index=token_df.unique_token.values)

start_fig_idx = start_feature_pos - 2
end_fig_idx = end_feature_pos+3
px.imshow(
    tmp_df.iloc[start_fig_idx:end_fig_idx, start_fig_idx:end_fig_idx],
    title="Attention pattern for L4H1",
    color_continuous_midpoint=0,
    color_continuous_scale="RdBu",
    height=1000,
    width=1000,
).show()

what if we look at attn to both vs feature activation?

In [None]:
px.line(feature_acts[0,start_fig_idx:(start_fig_idx+8),features_of_interest].squeeze().detach().cpu(), title="Feature activations for 'both ... and'").show()
px.line(attn_pattern[start_fig_idx:(start_fig_idx+8), start_feature_pos-1], title="Attention to \"Both\" activations for 'both ... and'").show()

#### End Deep Dive