In [1]:
import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder

In [2]:
# Extract neuron activations with transformer_lens
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
device = next(model.parameters()).device



Loaded pretrained model gpt2 into HookedTransformer


In [3]:
n = sum([p.numel() for p in model.parameters()])
print(f"gpt2 has {n:,} params")

gpt2 has 163,049,041 params


In [4]:
layer_index = 6
location = "resid_post_mlp"

transformer_lens_loc = {
    "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
    "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
    "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
    "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
    "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
}[location]

with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f:
    state_dict = torch.load(f)
    autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
    autoencoder.to(device)

In [5]:
autoencoder

Autoencoder(
  (encoder): Linear(in_features=768, out_features=32768, bias=False)
  (activation): TopK(
    (postact_fn): ReLU()
  )
  (decoder): Linear(in_features=32768, out_features=768, bias=False)
)

In [6]:
n = sum([p.numel() for p in autoencoder.parameters()])
print(f"SAE has {n:,} params")

SAE has 50,365,184 params


In [7]:
prompt = "This is an example of a prompt that"
tokens = model.to_tokens(prompt)  # (1, n_tokens)
tokens.shape

torch.Size([1, 9])

In [8]:
with torch.no_grad():
    logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)

In [9]:

input_tensor = activation_cache[transformer_lens_loc]
input_tensor_ln = input_tensor

In [10]:
with torch.no_grad():
    latent_activations, info = autoencoder.encode(input_tensor_ln)
    reconstructed_activations = autoencoder.decode(latent_activations, info)

In [11]:
input_tensor.shape, reconstructed_activations.shape

(torch.Size([9, 768]), torch.Size([9, 768]))

In [12]:
normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
normalized_mse

tensor([6.4423e-05, 3.9219e-02, 3.1569e-02, 4.6318e-02, 7.1058e-02, 4.7744e-02,
        6.2675e-02, 6.7039e-02, 7.5507e-02], device='cuda:0')

In [13]:
from sparse_autoencoder.loss import normalized_mean_squared_error

normalized_mean_squared_error(reconstructed_activations, input_tensor)

tensor(0.0490, device='cuda:0')

In [14]:
normalized_mse.mean()

tensor(0.0490, device='cuda:0')