In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['HUGGINGFACE_HUB_CACHE'] = '/scratch/gsk6me/huggingface_cache/'


In [2]:
# This is gpt2-small
import transformers
model = transformers.GPT2Model.from_pretrained("openai-community/gpt2").cuda()
tokenizer = transformers.GPT2Tokenizer.from_pretrained("openai-community/gpt2")

In [3]:
model.h[0].mlp.c_fc.weight.shape, model.h[0].mlp.c_proj.weight.shape

(torch.Size([768, 3072]), torch.Size([3072, 768]))

In [4]:
import inspect
# seems pretty straightforward
print(inspect.getsource(model.h[0].mlp.forward))

    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states



In [5]:
import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder

# Load the autoencoders
autoencoders = []

for layer_index in range(12):
    print("Loading", layer_index)
    autoencoder_input = ["mlp_post_act", "resid_delta_mlp"][0]
    filename = f"az://openaipublic/sparse-autoencoder/gpt2-small/{autoencoder_input}/autoencoders/{layer_index}.pt"
    with bf.BlobFile(filename, mode="rb", streaming=False, cache_dir='/scratch/gsk6me/sae-gpt2-small-cache') as f:
        state_dict = torch.load(f)
    autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
    autoencoders.append(autoencoder.to('cuda'))

Loading 0
Loading 1
Loading 2
Loading 3
Loading 4
Loading 5
Loading 6
Loading 7
Loading 8
Loading 9
Loading 10
Loading 11


In [7]:
from functools import partial
import torch.nn.functional as F

patch_layer = 0
patch_seqpos = 0
patch_index = 0
capture_layer = 0
capture_index = 0

patch_history = []
activation_history = []

def custom_fwd(self, hidden_states, layer_num):
    global patch_index, patch_layer, capture_index, patch_history, activation_history
    # shape: [batch_size, num_tokens, hidden_shape]
    hidden_states = self.c_fc(hidden_states)
    hidden_states = self.act(hidden_states)    
    # patch the specific feature index (incl. both sequence number and feature id)
    if patch_layer == layer_num:
        hidden_states = F.relu(
            autoencoders[layer_num].encoder(hidden_states)
        )
        hidden_states[:, :, patch_index] = 0
        hidden_states = autoencoders[layer_num].decoder(hidden_states)
        patch_history.append((patch_layer, patch_seqpos, patch_index))
        
    # capture the change in sparse feature (will likely be applied at a later layer
    # as the earlier if statement)
    if capture_layer == layer_num:
        features = F.relu(
            autoencoders[layer_num].encoder(hidden_states)
        )
        activation = features[:, -1, capture_index]
        activation_history.append(activation)
    elif capture_layer == -1:
        # capture everything
        features = F.relu(
            autoencoders[layer_num].encoder(hidden_states)
        )
        activation_history[-1][layer_num] = features
    hidden_states = self.c_proj(hidden_states)
    hidden_states = self.dropout(hidden_states)
    return hidden_states

for i in range(12):
    mlp = model.h[i].mlp
    mlp.forward = partial(custom_fwd.__get__(mlp, type(mlp)), layer_num=i)


In [17]:
import time

# estimate performance diff

def get_patching_results(string, cap_layer, cap_index):
    """
    gets activation patching results for a specific feature
    """
    # get baseline
    global patch_index, patch_layer, patch_history, capture_layer, capture_index, activation_history
    
    patch_history = []
    activation_history = []
    
    patch_index = 0
    patch_layer = -1
    capture_layer = -1
    patch_history.append(None)
    activation_history.append({})
    baseline = model(**tokenizer(string, return_tensors='pt').to('cuda'))
    
    print(activation_history[0][cap_layer][:, :, cap_index])
    
    capture_layer = cap_layer
    capture_index = cap_index
    
    for layer_i in range(2, 12):
        h = activation_history[0][layer_i]
        h = torch.max(h[0], dim=0).values
        h[h < 1e-1] = 0
        start = time.time()
        for feat in h.nonzero():
            patch_layer = layer_i
            patch_index = feat[0].item()
            
            with torch.no_grad():
                model(**tokenizer(string, return_tensors="pt").to('cuda'))
        end = time.time()
        print(layer_i, end-start)
                
        if layer_i > cap_layer:
            break

get_patching_results("hello!", 2, 3)


tensor([[0.0082, 0.0062]], device='cuda:0', grad_fn=<SelectBackward0>)
2 0.8764863014221191
3 3.9374032020568848
