# Set Up

In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

# Load Model


In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    # "pythia-2.8b",
    # "pythia-70m-deduped",
    # "tiny-stories-2L-33M",
    # "attn-only-2l",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
    fold_ln=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)


# Load SAE

In [None]:


path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/1100001280_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt"
# path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/final_sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152.pt"
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)

print(sparse_autoencoder.cfg)


# sanity check
text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
model(text, return_type="loss")

In [None]:
from sae_training.utils import LMSparseAutoencoderSessionloader
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

# Feature Dashboard Util

In [None]:
import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    
    return

# for feature in [100,300,400]:
#     render_feature_dashboard(feature)

# Features by Token in an Example Analysis

In [None]:
import joseph
reload(joseph.analysis)
from joseph.analysis import *


title = "Anthrax"
prompt = "Anthrax is a serious infectious disease caused by gram-positive, rod-shaped bacteria known as Bacillus anthracis. It occurs naturally in soil and commonly affects domestic and wild animals around the world. People can get sick with anthrax if they come in contact with infected animals or contaminated animal products."
POS_INTEREST = 17


token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt], model, sparse_autoencoder, head_idx_override=7)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
            "top_k_features"]
# display(token_df[filter_cols].style.background_gradient(
#     subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
#     cmap="coolwarm"))



UNIQUE_TOKEN_INTEREST = token_df["unique_token"][POS_INTEREST]
feature_acts_of_interest = feature_acts[POS_INTEREST]
# plot_line_with_top_10_labels(feature_acts_of_interest, "", 25)
# vals, inds = torch.topk(feature_acts_of_interest,39)

top_k_feature_inds = (feature_acts[1:] > 0).sum(dim=0).nonzero().squeeze()

features_acts_by_token_df = pd.DataFrame(
    feature_acts[:,top_k_feature_inds[:]].detach().cpu().T,
    index = [f"feature_{i}" for i in top_k_feature_inds.flatten().tolist()],
    columns = token_df["unique_token"])

# features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).style.background_gradient(
#     cmap="coolwarm", axis=0)

# px.imshow(features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).T.corr(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

tmp = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T
dashboard_features = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).index[:10].to_series().apply(lambda x: x.split("_")[1]).tolist()
for feature in dashboard_features:
    render_feature_dashboard(feature)
    
px.line(token_df,
        x = "unique_token",
        y = "loss",
        hover_data=["pos", "label", "loss_diff", "mse_loss", "num_active_features", "explained_variance"],
        height = 300).show()

px.line(tmp, 
        title=f"{title}: Features Activation by Token in Prompt", 
        color_discrete_sequence=px.colors.qualitative.Plotly,
        height=1000).show()

tmp = features_acts_by_token_df.head(100).T
px.imshow(tmp, 
            title=f"{title}: Top k features by activation", 
            color_continuous_midpoint=0, 
            color_continuous_scale="RdBu", 
            height=800).show()

# Searching over Features by Token or Prediction

In [None]:
all_tokens_list = []
pbar = tqdm(range(128*6))
for i in pbar:
    all_tokens_list.append(activation_store.get_batch_tokens())
all_tokens = torch.cat(all_tokens_list, dim=0)
print(all_tokens.shape)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
tokens = all_tokens[:4096*6]
del all_tokens
torch.mps.empty_cache()

In [None]:
batch_tokens  = activation_store.get_batch_tokens()

In [None]:
# Hyperpars
target_word = " to"
token = model.to_tokens(target_word, prepend_bos=False)
assert len(token) == 1, "Token must be a single token"
target_n_topical_prompts = 500

n_topical_prompts = 0
all_tokens_list = []
pbar = tqdm(total = target_n_topical_prompts)
while n_topical_prompts < target_n_topical_prompts:
    batch_tokens  = activation_store.get_batch_tokens()
    
    # filter batch tokens for containing the target word's token
    mask = (batch_tokens == token).any(dim=-1)
    batch_tokens = batch_tokens[mask]
    all_tokens_list.append(batch_tokens)
    n_topical_prompts += batch_tokens.shape[0]
    
    pbar.update(batch_tokens.shape[0])

all_tokens = torch.cat(all_tokens_list, dim=0)
torch.mps.empty_cache()

# save the tokens to disk
torch.save(all_tokens, f"{target_word}_prompts.pt")


Ok so the idea here is that we only track a fraction of tokens, let's go with " anthrax"

In [None]:
# sample from prompts containing proxies
import re 
# import HTML
from IPython.display import HTML, display


def find_word_and_highlight(prompt, prompt_proxy_regex):
    start_pos = re.search(prompt_proxy_regex, prompt, flags=re.IGNORECASE).start()
    end_pos = re.search(prompt_proxy_regex, prompt, flags=re.IGNORECASE).end()
    # style with red text
    style_tag = "<span style='color:red'>"
    prompt = prompt[:start_pos] + f'{style_tag}'+ prompt[start_pos:end_pos] + "</span>" + prompt[end_pos:]
    display(HTML(prompt))

random_token= torch.randint(0, all_tokens.shape[0], (1,)).item()
find_word_and_highlight(model.to_string(all_tokens[random_token]), r" \bto\b ")

In [None]:
token_dfs = []
event_dfs = []
feature_acts_all = []

pbar = tqdm(range(all_tokens.shape[0]))

for prompt_index in pbar:
    prompt_tokens = all_tokens[prompt_index].unsqueeze(0)
    
    token_df = make_token_df(model, prompt_tokens, len_suffix=5, len_prefix=10)
    token_df["prompt_index"] = prompt_index
    
    (original_logits, original_loss), original_cache = model.run_with_cache(prompt_tokens, return_type="both", loss_per_token=True)
    token_df['loss'] = original_loss.flatten().tolist() + [np.nan]
    
    original_act = original_cache[sparse_autoencoder.cfg.hook_point]
    sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)

    feature_acts_of_interest = feature_acts[0, :, :]
    token_dfs.append(token_df.reset_index(drop=True))
    feature_acts_all.append(feature_acts_of_interest)
    
feature_acts_all = torch.stack(feature_acts_all, dim=0)
token_df = pd.concat(token_dfs).reset_index(drop=True)

bacteria_token_mask = token_df.str_tokens.str.contains(target_word, regex=False)
feature_acts_bacteria = feature_acts_all.flatten(0,1)[bacteria_token_mask]
feature_acts_bacteria.shape

- ok so now we want to filter for features acts and token_df positions which actually include the word bacteria

In [None]:
vals, inds = torch.topk(feature_acts_bacteria.mean(dim=0),100)
tmp = pd.DataFrame(vals.detach().cpu().numpy(), index=inds.detach().cpu().numpy(), columns=["mean_activation"])
tmp = tmp.sort_values("mean_activation", ascending=False)
tmp.index= tmp.index.map(lambda x: f"feature_{x}")
px.bar(
    tmp,
    x = tmp.index,
    y = "mean_activation",
    title="Mean activation of top 100 features for bacteria",
    text_auto=True,
    height=500,
    color_discrete_sequence=px.colors.qualitative.Plotly,
).show()


vals, inds = torch.topk((feature_acts_bacteria > 0).float().mean(dim=0),100)
tmp = pd.DataFrame(vals.detach().cpu().numpy(), index=inds.detach().cpu().numpy(), columns=["mean_activation"])
tmp = tmp.sort_values("mean_activation", ascending=False)
tmp.index= tmp.index.map(lambda x: f"feature_{x}")
px.bar(
    tmp,
    x = tmp.index,
    y = "mean_activation",
    title="Mean Binary Activation of top 100 features for bacteria",
    text_auto=True,
    height=500,
    color_discrete_sequence=px.colors.qualitative.Plotly,
).show()

In [None]:
for feature in tmp.index[10:40]:
    render_feature_dashboard(feature.split("_")[1])

In [None]:
(token_df["unique_token"] == target_word).shape

In [None]:
feature_acts_all.flatten(0,1).shape

In [None]:
# let's get precision vs recall on all the features for the "to" token. 

total_fires = (feature_acts_bacteria > 0).float().sum(dim=0)
total_fires_on_target = (feature_acts_all.flatten(0,1)[token_df["str_tokens"] == target_word]>0).float().sum(dim=0)

total_target_word_appearances = (token_df["str_tokens"] == target_word).sum()

precision = total_fires_on_target / total_fires
recall = total_fires_on_target / total_target_word_appearances

precision_recall_df = pd.DataFrame(
    torch.stack([
        total_fires,
        precision,
        recall], dim=1).detach().cpu().numpy(),
    index = [f"feature_{i}" for i in range(precision.shape[0])],
    columns = ["total_fires","precision", "recall"]
)
precision_recall_df = precision_recall_df[precision_recall_df["total_fires"] > 100]
# precision_recall_df.head(10)
px.scatter(
    precision_recall_df,
    x = "precision",
    y = "recall",
    title="Precision vs Recall for features",
    # text=precision_recall_df.index,
    height=500,
    color_discrete_sequence=px.colors.qualitative.Plotly,
).show()

In [None]:
total_target_word_appearances