In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features_layer_5"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    

# Load Model

In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    # "pythia-2.8b",
    # "pythia-70m-deduped",
    # "tiny-stories-2L-33M",
    # "attn-only-2l",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
    fold_ln=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)


# Load SAE

In [None]:
from sae_training.utils import LMSparseAutoencoderSessionloader


path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/1100001280_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt"
# sparse_autoencoder_layer_10 = SparseAutoencoder.load_from_pretrained(path)
model, sparse_autoencoder_layer_10, activation_store_layer_10 = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

path = "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/final_sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152.pt"
# sparse_autoencoder_layer_5 = SparseAutoencoder.load_from_pretrained(path)
_, sparse_autoencoder_layer_5, activation_store_layer_5 = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

print(sparse_autoencoder_layer_10.cfg)
print(sparse_autoencoder_layer_5.cfg)


# sanity check
text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
model(text, return_type="loss")

# Explore Sparsity

In [None]:
from tqdm.auto import tqdm
def estimate_feature_sparsity_using_n_tokens_per_prompt(
    sparse_autoencoder, activation_store, n_batches,
    n_tokens_per_prompt=4):
    
    total_activations = torch.zeros(sparse_autoencoder.cfg.d_sae).to(sparse_autoencoder.cfg.device)
    
    pbar = tqdm(range(n_batches))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=False)
        original_act = cache[sparse_autoencoder.cfg.hook_point]
        _, feature_acts, _, _, _ = sparse_autoencoder(
            original_act
        )
        # for each batch item, pick 4 random tokens and keep only those
        # batch_size x n_tokens x d_sae
        random_tok_indices = torch.randint(0, feature_acts.shape[1], (feature_acts.shape[0], n_tokens_per_prompt))
        feature_acts = feature_acts[torch.arange(feature_acts.shape[0]).unsqueeze(-1), random_tok_indices]
        total_activations += feature_acts.flatten(0,1).sum(0)
    
    total_tokens = (n_batches * feature_acts.shape[0] * n_tokens_per_prompt)
    print("Total tokens:", total_tokens)
    
    return total_activations / total_tokens

n_tokens_per_prompt = 128
n_batches = 1000
# feature_sparsity_10_unstratified  = estimate_feature_sparsity_using_n_tokens_per_prompt(sparse_autoencoder_layer_10, activation_store_layer_10, n_batches=n_batches, n_tokens_per_prompt=n_tokens_per_prompt).detach().cpu()
# log_feature_sparsity_10_unstratified = torch.log10(feature_sparsity_10_unstratified  + 1e-10)
# torch.save(log_feature_sparsity_10_unstratified, f"../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/feature_sparsity_{n_batches}_{n_tokens_per_prompt}.pt")
log_feature_sparsity_10_unstratified = torch.load(f"../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/feature_sparsity_{n_batches}_{n_tokens_per_prompt}.pt")

# feature_sparsity_5_unstratified = estimate_feature_sparsity_using_n_tokens_per_prompt(sparse_autoencoder_layer_5, activation_store_layer_5, n_batches=100, n_tokens_per_prompt=128).detach().cpu()
# log_feature_sparsity_5_unstratified = torch.log10(feature_sparsity_5_unstratified  + 1e-10)
# torch.save(log_feature_sparsity_5_unstratified, f"../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/feature_sparsity_{n_batches}_{n_tokens_per_prompt}.pt")
log_feature_sparsity_10_unstratified = torch.load(f"../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/feature_sparsity_{n_batches}_{n_tokens_per_prompt}.pt")


In [None]:
log_feature_sparsity_10_stratified = torch.load(
    "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/log_feature_sparsity_5000_4.pt"
)
# px.histogram(
#     log_feature_sparsity_10_stratified[log_feature_sparsity_10_stratified > -9],
#     nbins=1000,
#     width = 1000,
#     log_x=False,
#     title="Feature sparsity (log10) (5000 batches, 4 tokens per prompt)",
# ).show()

log_feature_sparsity_5_stratified = torch.load(
    "../week_8_jan/artifacts/sparse_autoencoder_gpt2-small_blocks.5.hook_resid_pre_49152:v9/log_feature_sparsity_5000_4.pt"
)
# px.histogram(
#     log_feature_sparsity_5_stratified[log_feature_sparsity_5_stratified > -9],
#     nbins=1000,
#     width=1000,
#     log_x=False,
#     title="Feature sparsity (log10) (5000 batches, 4 tokens per prompt)",
# ).show()
# px.histogram(log_feature_sparsity, nbins=1000, log_x=False, title="Feature sparsity (log10) (5000 batches, 4 tokens per prompt)").show()

For layer 10, let's compare. 

In [None]:
px.scatter(
    x = log_feature_sparsity_10_stratified,
    y =  log_feature_sparsity_10_unstratified,
    opacity=0.4,
    marginal_x="histogram",
    marginal_y="histogram",
    title="Feature sparsity (log10) Stratified vs Unstratified",
    color = (log_feature_sparsity_10_stratified - log_feature_sparsity_10_unstratified).numpy().tolist(),
    labels = {
        "x": "Stratified",
        "y": "Unstratified",
        # "color": "Difference"
    },
    width = 1000,
    height = 500,
    color_continuous_midpoint=0,
    hover_data= [ list(range(len(log_feature_sparsity_10_stratified))) ],
).show()

In [None]:
def render_feature_dashboard(feature_id):
    
    path_to_html = "../week_8_jan/gpt2_small_features"
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")


# dense_features = ((log_feature_sparsity_10_stratified>-3) & (log_feature_sparsity_10_unstratified<-6)).nonzero().squeeze()
# dense_features = ((log_feature_sparsity_10_stratified<-5) & (log_feature_sparsity_10_unstratified<-6)).nonzero().squeeze()
dense_features = ((log_feature_sparsity_10_stratified>-3) & (log_feature_sparsity_10_unstratified>-3)).nonzero().squeeze()
# assert 37470 in dense_features
print(dense_features)
dense_features = dense_features[torch.randperm(len(dense_features))[:6]]
for feature in dense_features:
    render_feature_dashboard(feature.item())

# diff = log_feature_sparsity_10_stratified - log_feature_sparsity_10_unstratified
# dense_features = ((diff>2) & (log_feature_sparsity_10_stratified > -4)).nonzero().squeeze()
# print(len(dense_features))
# dense_features = dense_features[torch.randperm(len(dense_features))[:6]]
# for feature in dense_features:
#     render_feature_dashboard(feature.item())
    

# A Simple Density Estimation Metric

In [49]:
# n_batches_to_sample_from = 1028
n_batches_to_sample_from = 128


all_tokens_list = []
pbar = tqdm(range(n_batches_to_sample_from))
for _ in pbar:
    
    batch_tokens = activation_store_layer_10.get_batch_tokens()
    batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][:batch_tokens.shape[0]]
    all_tokens_list.append(batch_tokens)
    
all_tokens = torch.cat(all_tokens_list, dim=0)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]

  0%|          | 0/128 [00:00<?, ?it/s]

In [57]:


all_feature_acts = torch.zeros(*all_tokens.shape, sparse_autoencoder_layer_10.cfg.d_sae)
all_feature_acts.shape

torch.Size([4096, 128, 49152])

In [59]:
feature_acts.shape

torch.Size([32, 128, 49152])

In [61]:
import plotly.graph_objects as go

sparse_autoencoder = sparse_autoencoder_layer_10   
sparse_autoencoder.cfg.use_ghost_grads = False

chunk_length = 16
batches_of_tokens = all_tokens.reshape(-1, 32, 128)

all_chunk_appearances = torch.zeros(sparse_autoencoder.cfg.d_sae)
all_total_appearances = torch.zeros(sparse_autoencoder.cfg.d_sae)


all_feature_acts = torch.zeros(all_tokens.shape.numel(), sparse_autoencoder.cfg.d_sae)

pbar = tqdm(range(batches_of_tokens.shape[0]))
for i in pbar:
    batch_tokens = batches_of_tokens[i]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=False)
    original_act = cache[sparse_autoencoder.cfg.hook_point]
    _, feature_acts, _, _, _, _ = sparse_autoencoder(
        original_act
    )
    # all_feature_acts.append(feature_acts)
    feature_acts = feature_acts.detach().cpu()
    feature_acts_chunked = feature_acts.reshape(-1, chunk_length, 49152)
    n_chunk_appearances = ((feature_acts_chunked > 0).sum(1) > 0).float().sum(0)
    total_feature_appearences = (feature_acts_chunked > 0).float().flatten(0,1).sum(0)
    
    all_chunk_appearances += n_chunk_appearances
    all_total_appearances += total_feature_appearences
    all_feature_acts[i*32*128:(i+1)*32*128] = feature_acts.flatten(0,1)

all_chunk_appearances = all_chunk_appearances.detach().cpu()
all_total_appearances = all_total_appearances.detach().cpu()
mean_chunkwise_appearances = total_feature_appearences / n_chunk_appearances

  0%|          | 0/128 [00:00<?, ?it/s]

In [62]:
all_feature_acts.shape

torch.Size([524288, 49152])

In [63]:
# get the 90th percentile activation for each feature
percentile_90 = torch.quantile(all_feature_acts, 0.9, dim=0)
percentile_90.shape

: 

In [None]:
px.histogram(percentile_90, nbins=1000, log_x=True, title="90th percentile activation").show()

In [45]:
px.histogram(sparse_autoencoder.b_enc.cpu(), width = 1000)

In [None]:

distances_list = []
for i in range(1,chunk_length+1):
    distances_from_integers = (mean_chunkwise_appearances - i).abs()
    distances_list.append(distances_from_integers)


distances = torch.stack(distances_list, dim=1).detach().cpu()
print(distances.shape)
px.bar((distances < 0.05).sum(0), width=1000, height=500, title="Number of features within 0.05 of integer", log_y=True).show()
vals, inds = torch.topk(distances.cpu(), 10, dim=0, largest=False)
print(inds.shape)
display(pd.DataFrame(inds))

for i in inds[2,:10]:
    print(i.item())
    # render_feature_dashboard(i.item())
11

In [None]:
# reoplace nans in mean chunkwise appearences with -1
mean_chunkwise_appearances[torch.isnan(mean_chunkwise_appearances)] = -1

In [None]:
vals, inds = torch.topk(mean_chunkwise_appearances, 100, largest=True)
for feature in inds:
    render_feature_dashboard(feature.item())

In [None]:
for i in inds[2,:10]:
    print(i.item())
    # render_feature_dashboard(i.item())


In [34]:


fig = px.scatter(
    x = log_feature_sparsity_10_stratified,
    y = mean_chunkwise_appearances.detach().cpu(),# / 16,
    color = all_chunk_appearances,
    width = 1500,
    height = 1000,
    opacity=0.4,
    marginal_x="histogram",
    marginal_y="histogram",
    hover_name= torch.arange(len(log_feature_sparsity_10_stratified)),
    labels={
        "x": "Feature sparsity (log10)",
        "y": "Average Feature Fires per Chunk",
    },
)

fig.show()


In [None]:
print((log_feature_sparsity_10_stratified > -3.5).sum().item())
print(((log_feature_sparsity_10_stratified > -3.5) & (mean_chunkwise_appearances > 2)).sum().item())
print(((log_feature_sparsity_10_stratified > -3.5) & (mean_chunkwise_appearances > 1.05)).sum().item())

In [39]:
for feature in ((mean_chunkwise_appearances > 0.99) & (mean_chunkwise_appearances < 1.01)).nonzero().squeeze()[5:10]:
    render_feature_dashboard(feature.item())

Feature 22
Feature 29
Feature 30
Feature 32
Feature 47


In [None]:
from tqdm.auto import tqdm
def feature_density_esitmation(sparse_autoencoder, all_tokens, n_batches):
    
    total_activations = torch.zeros(sparse_autoencoder.cfg.d_sae).to(sparse_autoencoder.cfg.device)
    
    batches_of_tokens = all_tokens.reshape(-1, 32, 128)
    pbar = tqdm(batches_of_tokens)
    for batch_tokens in pbar:
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=False)
        original_act = cache[sparse_autoencoder.cfg.hook_point]
        _, feature_acts, _, _, _ = sparse_autoencoder(
            original_act
        )

    
    total_tokens = (n_batches * feature_acts.shape[0] * n_tokens_per_prompt)
    print("Total tokens:", total_tokens)
    
    return total_activations / total_tokens

# Analyse the SVD of W_U Projections

In [None]:


path = "../artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v19/final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576.pt"
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

In [None]:
log_feature_sparsity = torch.load( "../artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576_log_feature_sparsity:v9/final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576_log_feature_sparsity.pt")

In [None]:

U,S,V = torch.svd(model.W_U.cpu())
mean_proj_u = (activation_store_layer_10.storage_buffer[:10000].cpu() @ U).mean(0)
median_proj_u = (activation_store_layer_10.storage_buffer[:10000].cpu() @ U).median(0).values
std_proj_u = (activation_store_layer_10.storage_buffer[:10000].cpu() @ U).std(0)
px.line(y = mean_proj_u, labels = {"y": "Mean projection onto U",
                                      "x": "SVD Index"},
           title = "Projection of activations onto U",
           width=1000, height=500).show()
px.line(y = median_proj_u, labels = {"y": "Median projection onto U",
                                      "x": "SVD Index"},
           title = "Projection of activations onto U",
           width=1000, height=500).show()
px.line(y = std_proj_u, labels = {"y": "Std projection onto U",
                                      "x": "SVD Index"},
           title="Projection of activations onto U",
              width=1000, height=500).show()

In [None]:
decoder_proj_w_u_svd = (sparse_autoencoder.W_dec.cpu() @ U)
encoder_proj_w_u_svd = (sparse_autoencoder.W_enc.cpu().T @ U)

tmp = pd.DataFrame(decoder_proj_w_u_svd[:,:7], columns=[f"Feature {i}" for i in range(7)])
px.histogram(tmp, 
             # overlay
            #  facet_col="variable",
            barmode="overlay",
             nbins=500,
             width=1500, 
             log_x=False, 
            #  log_y=True,
             title="Decoder Projections onto first 7 SVD directions of W_U").show()

tmp = pd.DataFrame(encoder_proj_w_u_svd[:,:7], columns=[f"Feature {i}" for i in range(7)])
px.histogram(tmp, 
             # overlay
            #  facet_col="variable",
            barmode="overlay",
             nbins=500,
             width=1500, 
             log_x=False, 
            #  log_y=True,
             title="Encoder Projections onto first 7 SVD directions of W_U").show()

tmp = pd.DataFrame(decoder_proj_w_u_svd[:,755:], columns=[f"Feature {i}" for i in range(755, 768)])
px.histogram(tmp, 
             # overlay
                # facet_col="variable", 
                barmode="overlay",
             nbins=500,
             width=1500, 
            #  log_y=True,
             title="Decoder Projections onto last 7 SVD directions of W_U").show()

tmp = pd.DataFrame(encoder_proj_w_u_svd[:,755:], columns=[f"Feature {i}" for i in range(755, 768)])
px.histogram(tmp, 
             # overlay
                # facet_col="variable",
                barmode="overlay",
             nbins=500,
             width=1500, 
            #  log_y=True,
             title="Encoder Projections onto last 7 SVD directions of W_U").show()

In [None]:
tmp = pd.DataFrame(encoder_proj_w_u_svd[:,755:], columns=[f"Feature {i}" for i in range(755, 768)])
px.scatter(tmp, x = "Feature 759", y = "Feature 762", 
           hover_name = tmp.index,
           opacity=0.4,
           width = 1000)

In [None]:
svd_dir_of_interest = 762
fig = px.scatter(
    x = log_feature_sparsity,
    y = decoder_proj_w_u_svd[:,svd_dir_of_interest],
    # color = log_feature_sparsity,
    hover_name= torch.arange(len(log_feature_sparsity)),
    # opacity=0.2,
    marginal_x="histogram",
    marginal_y="histogram",
    labels = {
        "x": "Feature sparsity (log10)",
        "y": f"Decoder projection onto U[:,{svd_dir_of_interest}]",
    },
    # color_continuous_scale="RdBu",
    height =800,
    width=1500,
)
fig.update_traces(marker=dict(size=5), selector=dict(mode='markers'))
fig.show()

fig = px.scatter(
    x = log_feature_sparsity,
    y = encoder_proj_w_u_svd[:,svd_dir_of_interest],
    # color = log_feature_sparsity,
    hover_name= torch.arange(len(log_feature_sparsity)),
    # opacity=0.4,
    labels = {
        "x": "Feature sparsity (log10)",
        "y": f"Encoder projection onto U[:,{svd_dir_of_interest}]",
    },
    marginal_x="histogram",
    marginal_y="histogram",
    height =800,
    width=1500,
)
fig.update_traces(marker=dict(size=5), selector=dict(mode='markers'))
fig.show()



fig = px.scatter(
    x = decoder_proj_w_u_svd[:,svd_dir_of_interest],
    y = encoder_proj_w_u_svd[:,svd_dir_of_interest],
    # color = log_feature_sparsity,
    hover_name= torch.arange(len(log_feature_sparsity)),
    # opacity=0.4,
    labels = {
        "x": f"Decoder projection onto U[:,{svd_dir_of_interest}]",
        "y": f"Encoder projection onto U[:,{svd_dir_of_interest}]",
    },
    marginal_x="histogram",
    marginal_y="histogram",
    height =800,
    width=1500,
)
fig.update_traces(marker=dict(size=5), selector=dict(mode='markers'))
fig.show()

In [None]:
import joseph
reload(joseph.analysis)
from joseph.analysis import *

# prompt3 = "This investigation is not only one that is continuing and worldwide, but also one that we expect to continue for quite some time." # Not only ... but
# prompt3 = "The market is evolving rapidly. Either we must adjust our strategy to meet the new market demands, or we risk falling behind our competitors significantly." # either or one (dud?)
# prompt3 = "Culinary trends are constantly changing. Either we experiment with new flavors and techniques in our recipes, or we risk losing the interest of our adventurous diners." #maybe a dud as well
# prompt3 = "I thought it was a great book. Both the intricate plot twists and the strong character development make this novel exceptionally engaging." # both .... and
# prompt3 = "The team, despite facing numerous challenges and unexpected setbacks, remains optimistic about the upcoming project." # Noun verb agreement
# prompt3 = "The book on the shelf in the corner needs a new cover." # Noun verb agreement

# title = "which way to the beach"
# prompt = "She asked 'Which way to the beach?', to which I replied,  'It's over there. You can't miss it.'. She thanked me and walked away."
# POS_INTEREST = 9

# title = "lots of questions"
# prompt = "The text read \"In the realm of deep learning, how do we best quantify the interpretability of neural networks? While considering this, it's important to remember the balance between complexity and clarity in model design. What are the most effective methods for visualizing high-dimensional data? This leads to another crucial aspect: the role of data quality. Can we establish a standard for data that optimally trains these models? Amidst these inquiries, the evolution of AI safety protocols remains a pivotal concern. How are current safety measures adapting to the rapidly advancing AI landscape? Each question marks a stepping stone towards a deeper understanding and more effective utilization of AI technologies."
# POS_INTEREST = 10
sparse_autoencoder.cfg.use_ghost_grads = False
title = "Tiny Stories Dragon"
prompt = """Harry Potter"""
POS_INTEREST = 1
# prompt = """2 + 2 = 3"""
# POS_INTEREST = 4

# title = "both_and"
# prompt = "My parents went to both Melbourne, Australia and Auckland, New Zealand on their honeymoon."
# POS_INTEREST = 8


token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt], model, sparse_autoencoder, head_idx_override=5)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
            "top_k_features"]
display(token_df[filter_cols].style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm"))



UNIQUE_TOKEN_INTEREST = token_df["unique_token"][POS_INTEREST]
feature_acts_of_interest = feature_acts[POS_INTEREST]
# plot_line_with_top_10_labels(feature_acts_of_interest, "", 25)
# vals, inds = torch.topk(feature_acts_of_interest,39)

top_k_feature_inds = (feature_acts[1:] > 0).sum(dim=0).nonzero().squeeze()

features_acts_by_token_df = pd.DataFrame(
    feature_acts[:,top_k_feature_inds[:]].detach().cpu().T,
    index = [f"feature_{i}" for i in top_k_feature_inds.flatten().tolist()],
    columns = token_df["unique_token"])

# features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).style.background_gradient(
#     cmap="coolwarm", axis=0)

# px.imshow(features_acts_by_token_df.sort_values(by=",/12", ascending=False).head(10).T.corr(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

tmp = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T
# dashboard_features = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).index[:10].to_series().apply(lambda x: x.split("_")[1]).tolist()
# for feature in dashboard_features:
#     render_feature_dashboard(feature)

px.line(tmp + 1e-3, 
        log_y = True,
        title=f"{title}: Features Activation by Token in Prompt", 
        color_discrete_sequence=px.colors.qualitative.Plotly,
        height=1000).show()

tmp = features_acts_by_token_df.head(100).T
px.imshow(tmp, 
            title=f"{title}: Top k features by activation", 
            color_continuous_midpoint=0, 
            color_continuous_scale="RdBu", 
            height=800).show()

In [None]:
def render_feature_dashboard(feature_id):
    
    path_to_html = "../../feature_dashboards/gpt2-small_blocks.10.hook_resid_pre_24576"
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
        
tmp = features_acts_by_token_df.sort_values(UNIQUE_TOKEN_INTEREST, ascending=False).T

for feature in  tmp.columns.tolist()[:20]:
    render_feature_dashboard(feature.split("_")[1])
