In [None]:
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

from datasets import load_dataset

import tools
import latent_features
import matplotlib.pyplot as plt

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small", fold_ln=True)

layer_index = 6
location = "mlp_post_act"
transformer_lens_loc = f"blocks.{layer_index}.mlp.hook_post"
prev_layer_loc = f"blocks.{layer_index}.ln2.hook_normalized"

ds = load_dataset("NeelNanda/pile-10k", split='train[:10]')
ds_tokens = model.to_tokens(ds['text'])
ds_logits, ds_cache = model.run_with_cache(ds_tokens[0])

In [None]:
mlp = tools.extract_mlp(model, layer_index)

ds_acts = ds_cache[prev_layer_loc].numpy().reshape(-1, 768)
original_mlp = tools.get_original_mlp_for_sparx(mlp, ds_acts)

shrink_pcs = [0.1, 0.3, 0.5, 0.7, 0.9]
merged_models = []
models_cluster_labels = []
for pc in shrink_pcs:
  model, labels = tools.shrink_model_global(original_mlp, pc)
  model.model.summary()
  merged_models.append(model)
  models_cluster_labels.append(labels)

Example prompts from selected concepts with explanations

In [None]:

# https://www.neuronpedia.org/gpt2-small/6-mlp-oai/25423
# phrases that denote a time span, with specific emphasis on the recent past

over_past_concept_tokens = ["over the past decade", "over the past year", "over the past few", "over the past nine years", "over the past century"]
for i, sparx_mlp in enumerate(merged_models):
  print(f"Cluster activations for model of sparsification {shrink_pcs[i]}")
  latent_features.cluster_activations_for_prompt(model, prev_layer_loc, over_past_concept_tokens, sparx_mlp)

In [None]:
from latent_features import cluster_activations_for_prompt, measure_similarity_in_activated_clusters
sparsest_model = merged_models[0]

In [None]:
# https://www.neuronpedia.org/gpt2-small/6-mlp-oai/24106
# the word 'know'

know_concept_tokens = ["know", "I know this", "I know it", "I know nothing", "I know that", "I know there is a trend", "I know everyone", "I know I am"]
top_clusters_for_prompts = cluster_activations_for_prompt(model, prev_layer_loc, know_concept_tokens, sparsest_model)
measure_similarity_in_activated_clusters(top_clusters_for_prompts)

In [None]:
# https://www.neuronpedia.org/gpt2-small/6-mlp-oai/2325
# metaphorical phrases involving elements or materials, often related to energizing or de-energizing situations

metaphors_concept_tokens = [
    "that decision provided fresh fuel for her",
    "added fuel to the current battle",
    "poured cold water on the plan",
    "give me some food for thought",
    "kernel of an idea germinating",
    "taken with a grain of salt"
]
top_clusters_for_prompts = cluster_activations_for_prompt(model, prev_layer_loc, metaphors_concept_tokens, sparsest_model)
measure_similarity_in_activated_clusters(top_clusters_for_prompts)

In [None]:
# https://www.neuronpedia.org/gpt2-small/6-mlp-oai/5277
# emotive or value-based terms related to personal commitment or giving

give_concept_tokens = [
    "Want to give some love to your favorite cheap breakfast spot?",
    "I've given my heart and soul to the network",
    "I would like to thank and give the glory to God",
    "give life to",
    "give time and attention",
]
top_clusters_for_prompts = cluster_activations_for_prompt(model, prev_layer_loc, give_concept_tokens, sparsest_model)
measure_similarity_in_activated_clusters(top_clusters_for_prompts)

Compared to raw neuron activations

In [None]:
cluster_activations_for_prompt(model, prev_layer_loc, know_concept_tokens, original_mlp)
cluster_activations_for_prompt(model, prev_layer_loc, metaphors_concept_tokens, original_mlp)
cluster_activations_for_prompt(model, prev_layer_loc, give_concept_tokens, original_mlp)

For randomly selected latent concepts

In [None]:
import random

NUM_LATENT_FEATURES = 32768

# Focus on sparsest model for now
sparsest_model = merged_models[0]

feature_descriptions = {}
feature_prompts = {}
feature_prompt_clusters = {}

while len(feature_descriptions) < 10: # todo: find optimal size
  feature_index = random.randint(0, NUM_LATENT_FEATURES)
  print(f"latent: {feature_index}")
  desc, prompts = latent_features.get_top_activating_prompts_for_latent_feature(feature_index)
  if len(desc) == 0 and len(prompts) == 0:
    print(f"Dead latent: {feature_index}")
    continue # Dead latent

  feature_descriptions[feature_index] = desc
  feature_prompts[feature_index] = prompts

  print(desc, prompts)
  prompt_clusters = latent_features.get_top_clusters_for_prompts(prompts, sparsest_model)
  feature_prompt_clusters[feature_index] = prompt_clusters

In [None]:
for feature_index, prompts in feature_prompts.items():
  print(feature_index)
  cluster_activations_for_prompt(prompts, sparsest_model)

In [None]:
# Measure similarity in cluster activations for prompts of the same latent feature

for feature_index, prompt_clusters in feature_prompt_clusters.items():
  sim = measure_similarity_in_activated_clusters(prompt_clusters)
  print(f"Feature {feature_index}: {feature_prompts[feature_index]}")
  print(f"Clusters: {feature_prompt_clusters[feature_index]}")
  print(f"Average similarity: {np.mean(sim)}")
  plt.imshow(sim)
  plt.colorbar(label='similarity')
  plt.show()