In [None]:
import tools
import latent_features
from transformer_lens import HookedTransformer
from datasets import load_dataset
import torch

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small", fold_ln=True)

layer_index = 6
location = "mlp_post_act"
transformer_lens_loc = f"blocks.{layer_index}.mlp.hook_post"
prev_layer_loc = f"blocks.{layer_index}.ln2.hook_normalized"

ds = load_dataset("NeelNanda/pile-10k", split='train[:10]')
ds_tokens = model.to_tokens(ds['text'])
ds_logits, ds_cache = model.run_with_cache(ds_tokens[0])

autoencoder = tools.get_sparse_autoencoder(location, layer_index)

Reconstruction error

In [None]:
from tools import output_infidelity, output_mse, output_r2

# check reconstruction error

ds_mlp_acts = ds_cache[transformer_lens_loc].numpy().reshape(-1, 768*4)
ds_mlp_acts = torch.from_numpy(ds_mlp_acts)

with torch.no_grad():
  latents, info = autoencoder.encode(ds_mlp_acts)
  reconstructed_acts = autoencoder.decode(latents, info)

latents = latents.detach().numpy()
reconstructed_acts = reconstructed_acts.detach().numpy()

print(f"infidelity: {output_infidelity(ds_mlp_acts, reconstructed_acts)}")
print(f"MSE error: {output_mse(ds_mlp_acts, reconstructed_acts)}")
print(f"R2: {output_r2(ds_mlp_acts, reconstructed_acts)}")

Does encoding the reconstructed activations get back to the original SAE latent feature?

In [None]:
hidden_activations = ds_cache[transformer_lens_loc][0,-1].numpy().reshape(1, 768*4)
hidden_activations = torch.from_numpy(hidden_activations)

for act in hidden_activations:
  print("Explanations for prompt endoftext activation")
  latent_features.list_explanations_for_single_activation(act)
  print("Explanations for reconstructed activation")
  reconstructed = autoencoder.decode(autoencoder.encode(act)[0])
  latent_features.list_explanations_for_single_activation(reconstructed)

Does encoding the reconstructed activations for a specific concept get back to the target concept?

In [None]:
KNOW_LATENT_INDEX = 24106 # https://www.neuronpedia.org/gpt2-small/6-mlp-oai/24106
num_latents = autoencoder.encoder.out_features

print(latent_features.get_explanations(KNOW_LATENT_INDEX))
reconstructed_acts = autoencoder.decoder.state_dict()['weight'][:, KNOW_LATENT_INDEX]
print(reconstructed_acts)
# Sanity check: does the reconstructed activation correspond to the target concept?
latent_features.list_explanations_for_single_activation(reconstructed_acts)