# Exploring the context dependence of transcoder features

# Setup

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath('../transcoder_circuits/'))

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from transcoder_circuits.circuit_analysis import *
from transcoder_circuits.feature_dashboards import *
from transcoder_circuits.replacement_ctx import *

from sae_training.sparse_autoencoder import SparseAutoencoder
from utils import tokenize_and_concatenate

In [4]:
import torch
import numpy as np
from einops import *

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

## Load model & data

In [5]:
from transformer_lens import HookedTransformer, utils
model = HookedTransformer.from_pretrained('gpt2-small').to(device)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps


In [24]:
from datasets import load_dataset
from huggingface_hub import HfApi

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
# tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
# tokenized_owt = tokenized_owt.shuffle(42)
# tokenized_owt = tokenized_owt.take(12800*2)
# owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
# owt_tokens_torch = torch.tensor(owt_tokens).to('mps')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


# Load transcoder

In [6]:
transcoder_template = "../dufensky_transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"
transcoders = []
sparsities = []
for i in range(12):
    transcoders.append(SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(i)}.pt").eval())
    sparsities.append(torch.load(f"{transcoder_template.format(i)}_log_feature_sparsity.pt"))

In [6]:
transcoder_layer = 8
transcoder_template = "../dufensky_transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(transcoder_layer)}.pt").to(device).eval()
sparsity = torch.load(f"{transcoder_template.format(transcoder_layer)}_log_feature_sparsity.pt")
live_features = np.arange(len(sparsity))[utils.to_numpy(sparsity > -4)]

In [49]:
W_in = model.blocks[8].mlp.W_in.clone().detach()
W_out = model.blocks[8].mlp.W_out.clone().detach()
W_dec = transcoder.W_dec.clone().detach()

In [55]:
W_mid = W_dec @ torch.linalg.pinv(W_in @ W_out)
W_mid.shape

torch.Size([24576, 768])

In [53]:
torch.linalg.pinv(W_in @ W_out).shape

torch.Size([768, 768])

In [11]:
W_out.shape

torch.Size([3072, 768])

In [12]:
W_dec.shape

torch.Size([24576, 768])

In [7]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Find transcoder's intermediate vectors
For an encoder directions $\mathbf{e}_i$ with activation $z_i(x)$, we can find directions $\mathbf{d}_i$ such that $Wx \approx \sum_i z_i(x) \mathbf{d}_i$

In [44]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

  0%|          | 0/100 [04:20<?, ?it/s]


In [68]:
from midcoder import MidcoderConfig, Midcoder

config = MidcoderConfig()
config.device = device
config.batch_size = 64
config.steps_per_epoch = 10
config.train_tokens = 100_000
config.log = False
midcoder = Midcoder({'model':model, 'transcoder':transcoder}, transcoder_layer, config)

In [69]:
midcoder.fit()

loss: 1.476e+00, loss_norm: 1.169e+00, lr: 0.000e+00: 100%|██████████| 2/2 [00:23<00:00, 11.67s/it]


Unnamed: 0,loss,loss_norm
0,3.179076,2.533804
1,1.475575,1.169111


In [46]:
midcoder.W_enc.shape

torch.Size([768, 24576])

In [45]:
midcoder.W_mid.shape

torch.Size([768, 24576])

In [13]:
gc.collect()
torch.cuda.empty_cache()

In [26]:
dataloader = midcoder.get_dataloader(batch_size=4)
batch = next(dataloader)['tokens']

In [46]:
# inputs, outputs = mid_transcoder.get_inputs_outputs(batch)
# mid, _ = mid_transcoder.forward(inputs)

Token indices sequence length is longer than the specified maximum sequence length for this model (73252 > 1024). Running this sequence through the model will result in indexing errors


In [27]:
midcoder.hook_point.name

'blocks.8.ln2.hook_normalized'

In [37]:
model.blocks[8].mlp.W_out.shape

torch.Size([3072, 768])

In [32]:
_, cache = midcoder.model.run_with_cache(batch, stop_at_layer=transcoder_layer+1, 
                                             names_filter=[midcoder.hook_point.name,
                                                        'blocks.8.hook_mlp_out'])
inputs = cache[midcoder.hook_point.name]

In [34]:
outputs = cache['blocks.8.hook_mlp_out']
outputs.shape

torch.Size([4, 128, 768])

In [21]:
inputs.shape

torch.Size([4, 128, 768])

In [23]:
outputs = inputs @ midcoder.W_in + midcoder.b_in
outputs.shape

torch.Size([4, 128, 3072])

In [None]:
token_array = owt_tokens_torch[:128*1]
batch_size = 128
layer = 8
_, cache = model.run_with_cache(token_array[i:i+batch_size], stop_at_layer=layer+1, names_filter=[
				
			])

In [54]:
owt_tokens_torch[:128*2].shape

torch.Size([256, 128])

In [40]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [57]:
model.blocks[layer].mlp.hook_pre.name

'blocks.8.mlp.hook_pre'

In [41]:
dir(model)

['OV',
 'QK',
 'T_destination',
 'W_E',
 'W_E_pos',
 'W_K',
 'W_O',
 'W_Q',
 'W_U',
 'W_V',
 'W_gate',
 'W_in',
 'W_out',
 'W_pos',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_di

In [42]:
model.blocks

ModuleList(
  (0-11): 12 x TransformerBlock(
    (ln1): LayerNormPre(
      (hook_scale): HookPoint()
      (hook_normalized): HookPoint()
    )
    (ln2): LayerNormPre(
      (hook_scale): HookPoint()
      (hook_normalized): HookPoint()
    )
    (attn): Attention(
      (hook_k): HookPoint()
      (hook_q): HookPoint()
      (hook_v): HookPoint()
      (hook_z): HookPoint()
      (hook_attn_scores): HookPoint()
      (hook_pattern): HookPoint()
      (hook_result): HookPoint()
    )
    (mlp): MLP(
      (hook_pre): HookPoint()
      (hook_post): HookPoint()
    )
    (hook_attn_in): HookPoint()
    (hook_q_input): HookPoint()
    (hook_k_input): HookPoint()
    (hook_v_input): HookPoint()
    (hook_mlp_in): HookPoint()
    (hook_attn_out): HookPoint()
    (hook_mlp_out): HookPoint()
    (hook_resid_pre): HookPoint()
    (hook_resid_mid): HookPoint()
    (hook_resid_post): HookPoint()
  )
)

In [48]:
model.blocks[8].mlp.W_in.shape

torch.Size([768, 3072])