<a href="https://colab.research.google.com/github/j-physics/Mech_Interp_Exploratory/blob/main/singlelayer_FAA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformer-lens

In [None]:
#import necessary libraries or packages
import torch
import transformer_lens
from transformer_lens import HookedTransformer

#check GPU is working right
print(f"GPU available: {torch.cuda.is_available()}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

In [None]:
#load a small model (e.g., gpt2-small, pythia-70m)
model = HookedTransformer.from_pretrained("pythia-70m", device="cuda")

#check to see model is loading
text = "Hi, my name is Jessica."
tokens = model.to_tokens(text)
logits = model(tokens)
print(f"Tokens shape: {tokens.shape}")
print(f"Logits shape: {logits.shape}")


In [None]:
!pip install sae_lens

In [None]:
#integrate SAEs from TransformerLens
from transformer_lens import HookedTransformer
from sae_lens import SAE

#Load SAE for a specific layer(s)
sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(
    release="ctigges/pythia-70m-deduped__res-sm_processed",
    sae_id="2-res-sm",
    device="cuda"
)

In [None]:
#feature activation analysis (FAA)

#Testing some text out
text = "There is a tiger in the room"
tokens = model.to_tokens(text)

#Getting the activations at a single layer
_, cache = model.run_with_cache(tokens)
layer_acts = cache["blocks.2.hook_resid_post"]

#Run through SAE to get feature activations of singlelayer
feature_acts = sae.encode(layer_acts)

#Seeing which features fired
print(f"Shape: {feature_acts.shape}") #should be in the format [batch, seq_len, num_feat]
print(f"Non-zero features: {(feature_acts > 0).sum()}")

#Find the top activating features
top_features = feature_acts[0, -1].topk(10) #top 10 at the last token
print(f"Top features: {top_features.indices}")
print(f"Activations: {top_features.values}")

In [None]:
#Looking at each token position now
tokens_str = model.to_str_tokens(text)
print("Tokens:", tokens_str)

# for pos in range(len(tokens_str)):
#   top_at_pos = feature_acts[0, pos].topk(5)
#   print(f"\nToken {pos} ('{tokens_str[pos]}'):")
#   print(f" Top features: {top_at_pos.indices.tolist()}")
#   #print(f" Activations: {top_at_pos.values.tolist()}")

#nice format to show what sentence I am analyzing
print("=" * 60)
print(f"Analyzing: '{text}'")
print("=" * 60)

#What top features activate for each token
for pos, token in enumerate(tokens_str):
  top_features = feature_acts[0, pos].topk(3) #top three features
  print(f"\nPosition {pos}: '{token}'")
  print(f" Top 3 features: {top_features.indices.tolist()}")
  print(f" Activations: {[f'{x: .3f}' for x in top_features.values.tolist()]}")


In [None]:
#What about for some other sentences? (including the previous test text)
test_sentences = {
    "original": "The tiger is in the room",
    "changed_subject": "The wolf is in the room", #changing the subject
    "changed_verb": "The tiger jumped in the room", #change of verb
    "multiple_subjects": "The tiger and the wolf are in the room", #multiple subjects
}

#collect the activations for each sentence
results = {}

for name, sentence in test_sentences.items():
  tokens = model.to_tokens(sentence)
  tokens_str = model.to_str_tokens(sentence)

  _, cache = model.run_with_cache(tokens)
  layer_activate = cache["blocks.2.hook_resid_post"]
  feature_activate = sae.encode(layer_activate)

  results[name] = {
        'tokens': tokens_str,
        'activations': feature_activate[0] #[seq_len, num_features]
  }

  #Comparing which features appear in ALL variations
  #Getting the top features from each
  all_top_features = []
  for name, data in results.items():
    top = data['activations'][-1].topk(10).indices.tolist() #last token
    all_top_features.extend(top)

  #Find common features
  from collections import Counter
  feature_counts = Counter(all_top_features)

  print("=" * 60)
  print("Feature Consistency Analysis")

  print("Features that consistently activate across variations:")
  print(common_features)


In [None]:
#What about for some other sentences? (including the previous test text)
test_sentences = {
    "original": "The tiger is in the room",
    "changed_subject": "The wolf is in the room", #changing the subject
    "changed_verb": "The tiger jumped in the room", #change of verb
    "multiple_subjects": "The tiger and the wolf are in the room", #multiple subjects
}

#collect the activations for each sentence
results = {}

for name, sentence in test_sentences.items():
  tokens = model.to_tokens(sentence)
  tokens_str = model.to_str_tokens(sentence)

  _, cache = model.run_with_cache(tokens)
  layer_activate = cache["blocks.2.hook_resid_post"]
  feature_activate = sae.encode(layer_activate)

  results[name] = {
        'tokens': tokens_str,
        'activations': feature_activate[0] #[seq_len, num_features]
  }

#Comparing which features appear in ALL variations
#Getting the top features from each
all_top_features = []
for name, data in results.items():
    top = data['activations'][-1].topk(10).indices.tolist() #last token
    all_top_features.extend(top)

#Find common features
from collections import Counter
feature_counts = Counter(all_top_features)

print("=" * 60)
print("Feature Consistency Analysis")
print("=" * 60)
print("\nFeatures that appear most frequently across variations:")
for feat, count in feature_counts.most_common(10):
  print(f"   Feature {feat}: appears in {count}/{len(test_sentences)} variations")


In [None]:
#let's visualize what is happening
import matplotlib.pyplot as plt

#Generate a heatmap of feature activations across tokens
activations = feature_acts[0].cpu().detach() # [seq_len, num_featueres]

#Just plotting features that activated
active_mask = (activations > 0).any(dim=0)
active_features = activations[:, active_mask]

plt.figure(figsize=(14,8))
plt.imshow(active_features.T[:100], aspect='auto', cmap='hot')
plt.xlabel('Token Position')
plt.ylabel('Feature Index')
plt.title('Feature Activations Across Sequence')
plt.colorbar()
plt.show()