In [1]:
# Let's find some good sparse features we like.

%load_ext autoreload
%autoreload 2

import os

os.environ['HUGGINGFACE_HUB_CACHE'] = '/scratch/gsk6me/huggingface_cache'


In [2]:
import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder

# Load the autoencoders
autoencoders = []

for layer_index in range(12):
    print("Loading", layer_index)
    autoencoder_input = ["mlp_post_act", "resid_delta_mlp"][0]
    filename = f"az://openaipublic/sparse-autoencoder/gpt2-small/{autoencoder_input}/autoencoders/{layer_index}.pt"
    with bf.BlobFile(filename, mode="rb", streaming=False, cache_dir='/scratch/gsk6me/sae-gpt2-small-cache') as f:
        state_dict = torch.load(f)
    autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
    autoencoders.append(autoencoder.to('cuda'))


Loading 0
Loading 1
Loading 2
Loading 3
Loading 4
Loading 5
Loading 6
Loading 7
Loading 8
Loading 9
Loading 10
Loading 11


In [4]:
import datasets

dataset_iterable = datasets.load_dataset("monology/pile-uncopyrighted", streaming=True)


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [3]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)

Loaded pretrained model gpt2 into HookedTransformer


In [4]:
@torch.no_grad()
def get_sparse_features(prompt):
    tokens = model.to_tokens(prompt)  # (1, n_tokens)
    # print(model.to_str_tokens(tokens))
    logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)
        
    # Encode neuron activations with the autoencoder
    device = next(model.parameters()).device

    latent_activations = []
    for autoencoder in autoencoders:
        if autoencoder_input == "mlp_post_act":
            input_tensor = activation_cache[f"blocks.{layer_index}.mlp.hook_post"]  # (n_tokens, n_neurons)
        elif autoencoder_input == "resid_delta_mlp":
            input_tensor = activation_cache[f"blocks.{layer_index}.hook_mlp_out"]  # (n_tokens, n_residual_channels)

        # (n_tokens, n_latents)
        latent_activations.append(autoencoder.encode(input_tensor))
        
    del activation_cache
    
    return (tokens, torch.stack(latent_activations, dim=0))

In [6]:
class ActivationCache:
    def __init__(self, n_layers, n_features, top_k=100, window_size=16):
        self.cutoffs = torch.zeros((n_layers, n_features), device='cuda')
        self.top_k = top_k
        self.highest = {}
        self.window_size = window_size
        
    def update_helper(self, layer, feature, activation, window_tokens):
        # returns the new cutoff if it needs to be made
        index_name = f"{layer}.{feature}"
        if index_name not in self.highest:
            self.highest[index_name] = []
        # find insertion point. i = number of values we were >=.
        # highest is sorted in ascending order.
        i = 0
        while i < len(self.highest[index_name]):
            compare_act = self.highest[index_name][i][0]
            if activation < compare_act:
                break
            i += 1
            
        rank = len(self.highest[index_name]) - i + 1
        if rank <= self.top_k:
            self.highest[index_name].insert(i, (activation, window_tokens))
        
        extra = len(self.highest[index_name]) - self.top_k
        if extra > 0:
            new_cutoff = self.highest[index_name][extra][0]
            self.highest[index_name] = self.highest[index_name][extra:]
        else:
            new_cutoff = self.cutoffs[layer, feature]
        return new_cutoff
        
    def update(self, activations, tokens, instance):
        best = activations.max(dim=1)
        to_update = (best.values > self.cutoffs)
        for (layer, feature) in torch.nonzero(to_update):
            # store a window of 16 tokens before, 16 tokens at/after
            index = best.indices[layer, feature]
            activation = best.values[layer, feature]
            window_start = max(0, index - self.window_size)
            window_end = min(len(tokens), window_start + self.window_size)
            window_tokens = tokens[window_start:window_end]
            
            self.cutoffs[layer, feature] = self.update_helper(layer, feature, activation, window_tokens)


In [None]:
import time

start = time.time()
feats = get_sparse_features(instance['text'])
end = time.time()
print(end-start)

In [47]:
feats1 = feats[1].cuda()

In [49]:
highest = feats1.max(dim=1)

In [56]:
highest.indices.shape

torch.Size([12, 32768])

In [6]:
import tqdm

# Iterate through the dataset, caching top k activations for each feature.
top_activations = 100
# features = {}

results = []

# cache = ActivationCache(12, 32768)

count = 0
max_count = 16384

sentences = open("example_sentences/chemistry.txt").read().split("\n")
max_count = len(sentences)

for i, text in tqdm.tqdm(enumerate(sentences), desc='Running inference', total=max_count):
    instance = {'text': text, 'meta': 'chemistry', 'index': i}
    # Instance -> ['text', 'meta' -> 'pile_set_name']
    tokens, feats = get_sparse_features(instance['text'])
    
    # Only do first 4 layers
    results.append((tokens.cpu(), feats.cpu()[[0, 1, 2, 3]]))
    
    # if count >= max_count:
    #     break

Running inference: 100%|██████████| 83/83 [00:04<00:00, 20.39it/s]


In [11]:
torch.save(results, "example_sentences/chemistry_outputs.pt")

In [7]:
# Extract neuron activations with transformer_lens
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
prompt = "Exploring SLAM techniques to allow for racecar to better perceive its position in relation to the environment. Working on improving route planning for racecar throughout the track."
tokens = model.to_tokens(prompt)  # (1, n_tokens)
print(model.to_str_tokens(tokens))
with torch.no_grad():
    logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)

latent_activations = []
for autoencoder in autoencoders:
    if autoencoder_input == "mlp_post_act":
        input_tensor = activation_cache[f"blocks.{layer_index}.mlp.hook_post"]  # (n_tokens, n_neurons)
    elif autoencoder_input == "resid_delta_mlp":
        input_tensor = activation_cache[f"blocks.{layer_index}.hook_mlp_out"]  # (n_tokens, n_residual_channels)

    # Encode neuron activations with the autoencoder
    device = next(model.parameters()).device
    autoencoder.to(device)
    with torch.no_grad():
        acts = autoencoder.encode(input_tensor)  # (n_tokens, n_latents)
        latent_activations.append(acts)

Loaded pretrained model gpt2 into HookedTransformer
['<|endoftext|>', 'Expl', 'oring', ' SL', 'AM', ' techniques', ' to', ' allow', ' for', ' race', 'car', ' to', ' better', ' perceive', ' its', ' position', ' in', ' relation', ' to', ' the', ' environment', '.', ' Working', ' on', ' improving', ' route', ' planning', ' for', ' race', 'car', ' throughout', ' the', ' track', '.']


In [20]:
import matplotlib.pyplot as plt
import neuron_visualization
from IPython.display import display, HTML

for layer_id in range(12):
    acts = latent_activations[layer_id]
    top_feature = acts.max(dim=0)[0].argmax().item()
    print("Layer", layer_id)
    display(HTML(neuron_visualization.basic_neuron_vis_signed(model.to_str_tokens(tokens), acts[:, top_feature], 1)))


Layer 0


Layer 1


Layer 2


Layer 3


Layer 4


Layer 5


Layer 6


Layer 7


Layer 8


Layer 9


Layer 10


Layer 11
