In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import transformer_lens
import sparse_autoencoder
from transformers import AutoTokenizer

In [2]:
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
sae_gemma = SparseAutoencoder.load("/scratch/mbf3zk/.checkpoints/curious-sweep-1_100941824.pt")

In [3]:
sae_gemma

SparseAutoencoder(
  (pre_encoder_bias): TiedBias(position=pre_encoder)
  (encoder): LinearEncoder(
    input_features=2048, learnt_features=16384, n_components=6
    (activation_function): ReLU()
  )
  (decoder): UnitNormDecoder(learnt_features=16384, decoded_features=2048, n_components=6)
  (post_decoder_bias): TiedBias(position=post_decoder)
)

In [5]:
model = transformer_lens.HookedTransformer.from_pretrained("google/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [4]:
# from sparse_autoencoder.autoencoder.model import SparseAutoencoderConfig
#     config = SparseAutoencoderConfig(
#         n_input_features=autoencoder_input_dim,
#         n_learned_features=autoencoder_input_dim * 8,
#         n_components=len(hyperparameters["source_model"]["cache_names"]),
#     )
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from functools import partial
from sparse_autoencoder.source_model.store_activations_hook import store_activations_hook

store = TensorActivationStore(1000, 2048, 2048 * 8)
for component_idx, cache_name in enumerate([f"blocks.{layer}.hook_mlp_out" for layer in range(6)]):
    hook = partial(store_activations_hook, store=store, component_idx=component_idx)
    model.add_hook(cache_name, hook)

RuntimeError: [enforce fail at alloc_cpu.cpp:83] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 134217728000 bytes. Error code 12 (Cannot allocate memory)

In [6]:
model.to('cuda')
sae_gemma.to('cuda')

Moving model to device:  cuda


SparseAutoencoder(
  (pre_encoder_bias): TiedBias(position=pre_encoder)
  (encoder): LinearEncoder(
    input_features=2048, learnt_features=16384, n_components=6
    (activation_function): ReLU()
  )
  (decoder): UnitNormDecoder(learnt_features=16384, decoded_features=2048, n_components=6)
  (post_decoder_bias): TiedBias(position=post_decoder)
)

In [14]:
topic_sentences = {
    topic: [
        x[2:] for x in open(f"example_sentences/{topic}.txt").read().split("\n")
    ] for topic in ['math', 'physics', 'chemistry']
}

In [42]:
i_text = "Euler's identity, e^(iπ) + 1 = 0, is considered one of the most beautiful equations in mathematics."
with torch.no_grad():
    o = model.forward(i_text, stop_at_layer=7, prepend_bos=False)
    o = o.permute(1, 0, 2)
    learned_activations, reconstructed_activations = sae_gemma.forward(o)


print(learned_activations.shape)
print(reconstructed_activations.shape)


torch.Size([28, 6, 16384])
torch.Size([28, 6, 2048])


In [43]:
# Step 1: Flatten the tensor to a 2D tensor where each row is a token-layer combination
learned_activations_pruned = learned_activations[1:]
flattened_activations = learned_activations_pruned.reshape(-1, 16384)  # Shape will be [96, 16384] as 16*6 = 96

# Step 2: Sort the activations to find the top 100. We use topk which also gives the indices
top_values, top_indices = torch.topk(flattened_activations.flatten(), 100)

# The top_indices now contains the linear indices of the top 100 activations in the flattened view
# We can convert these indices back to the original token-layer-feature indices
top_token_layer_indices = top_indices // 16384  # Get which token-layer combination it is
top_feature_indices = top_indices % 16384  # Get which feature index within that token-layer combination

# Step 3: If needed, translate indices back to (token, layer, feature) format
top_tokens = top_token_layer_indices // 6  # There are 6 layers
top_layers = top_token_layer_indices % 6

# Print or return the results
for i in range(100):
    print(f"Activation {i+1}: Token {top_tokens[i] + 1}, Layer {top_layers[i]}, Feature {top_feature_indices[i]}, Value {top_values[i]}")

Activation 1: Token 1, Layer 2, Feature 15407, Value 26.51151466369629
Activation 2: Token 1, Layer 4, Feature 13500, Value 25.717500686645508
Activation 3: Token 1, Layer 2, Feature 4885, Value 25.02952003479004
Activation 4: Token 1, Layer 4, Feature 10154, Value 22.379722595214844
Activation 5: Token 2, Layer 2, Feature 15407, Value 21.33807945251465
Activation 6: Token 1, Layer 5, Feature 6412, Value 20.741857528686523
Activation 7: Token 2, Layer 4, Feature 13500, Value 20.249370574951172
Activation 8: Token 2, Layer 2, Feature 4885, Value 20.015485763549805
Activation 9: Token 3, Layer 2, Feature 15407, Value 18.33335304260254
Activation 10: Token 3, Layer 2, Feature 4885, Value 17.913307189941406
Activation 11: Token 2, Layer 4, Feature 10154, Value 17.830551147460938
Activation 12: Token 3, Layer 4, Feature 13500, Value 17.742000579833984
Activation 13: Token 1, Layer 3, Feature 4131, Value 17.403520584106445
Activation 14: Token 1, Layer 2, Feature 12636, Value 17.352413177490

In [46]:
import neuron_visualization
from IPython.display import display, HTML, display_markdown

tokens = model.to_tokens(i_text)
token_str = model.to_str_tokens(tokens)
aggregated_activations = learned_activations_pruned.mean(dim=[1, 2])
max_activation = aggregated_activations.abs().max().item()
display(HTML(neuron_visualization.basic_neuron_vis_signed(token_str, aggregated_activations.tolist(), .6)))