# Building run_with_saes

In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features_layer_5"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    

In [None]:
path_to_all_layer_saes = "../GPT2-small-SAEs/" 

log_sparsity_files = os.listdir(path_to_all_layer_saes)
# print(log_sparsity_files)
model_files = [f for f in log_sparsity_files if "log" not in f]
model_files = sorted(model_files, key=lambda x: int(x.split(".")[1]))
display(model_files)

log_sparsity_files = [f for f in log_sparsity_files if "log_feature_sparsity" in f]
log_sparsity_files = sorted(log_sparsity_files, key=lambda x: int(x.split(".")[1]))
log_sparsity_files

In [None]:
from sae_training.sparse_autoencoder import SparseAutoencoder

gpt2_small_sparse_autoencoders = {}
for path in model_files:
    layer = int(path.split(".")[1])
    print(f"Loading layer {layer}")
    sae = SparseAutoencoder.load_from_pretrained(f"{path_to_all_layer_saes}/{path}")
    sae.cfg.use_ghost_grads = False
    gpt2_small_sparse_autoencoders[sae.cfg.hook_point] = sae

In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



In [None]:
def generate_bracket_prompt():
    
    left_bracket_token = torch.tensor(model.to_single_token("(")).unsqueeze(0).repeat(1,256)
    right_bracket_token = torch.tensor(model.to_single_token(")")).unsqueeze(0).repeat(1,256)
    random_tokens = torch.randint(0, 50257, (10,256))
    
    # add the brackets after the first two tokens and before the last two tokensx
    prompt = torch.concat([random_tokens[:2], left_bracket_token, random_tokens[2:-2:], right_bracket_token, random_tokens[-2:]], dim=0).T
    return prompt

tokens = generate_bracket_prompt()
(original_logits, original_loss), original_cache = model.run_with_cache(tokens, return_type="both", loss_per_token=True)

In [None]:
feature_acts_cache = {}
sae_out_cache = {}
mse_loss_cache = {}

for hook_point, sae in gpt2_small_sparse_autoencoders.items():
    sae_out, feature_acts, _, mse_loss, _, _ = gpt2_small_sparse_autoencoders[hook_point](original_cache[hook_point])
    feature_acts_cache[hook_point] = feature_acts
    sae_out_cache[hook_point] = sae_out
    mse_loss_cache[hook_point] = mse_loss.detach().item()
    
# get each feature acts and stacks them
feature_acts_stacked = torch.stack([feature_acts for feature_acts in feature_acts_cache.values()], dim=0)
feature_acts_stacked.shape # [n_saes, batch_size, n_tokens, n_features]

In [None]:
import plotly.graph_objects as go 
from plotly.subplots import make_subplots
# create a 4 * 3 grid of subplots
fig = make_subplots(rows=3, cols=4, subplot_titles=[f"Layer {i}" for i in range(1,13)])

top_k = 5
# add a line chart to each subplot
for layer in range(12):
    score_features_by = feature_acts_stacked[layer,:].mean(0)[2:9].sum(0)
    vals, inds = torch.topk(score_features_by, top_k)
    tmp_df = pd.DataFrame(feature_acts_stacked.mean(1)[layer,:,inds].cpu(),
                          columns = [f"Feature_{i}" for i in inds],
                          index = [f"tok_{i}" for i in range(12)])
    # rename tok_2 to open bracket
    tmp_df = tmp_df.rename(index={"tok_2": "open_bracket"})
    # rename tok_9 to close bracket
    tmp_df = tmp_df.rename(index={"tok_9": "close_bracket"})
    for feature in tmp_df.columns:
        fig.add_trace(go.Scatter(x=tmp_df.index, 
                                 y=tmp_df[feature],
                                #  y=np.log10(tmp_df[feature]+ 1e-4), 
                                 mode="lines",
                                 name=feature,
                                 hovertemplate="Token: %{x}<br>Activation: %{y}" + f"<br>Layer: {layer}",
                                 ), 
                      row=(layer//4)+1, col=(layer%4)+1)
        # rotate x ticks
        fig.update_xaxes(tickangle=45)
    
fig.update_layout(height=900, width=1600, title_text=f"Top {top_k} Features in each layer (averaged in between parentheses context)")
# remove legend
fig.update_layout(showlegend=False)
fig.show()