In [1]:
from huggingface_hub import snapshot_download
snapshot_download(repo_id="pchlenski/gpt2-transcoders", allow_patterns=["*.pt"],
    local_dir="/workspace/transcoder_circuits/gpt-2-small-transcoders", local_dir_use_symlinks=False
)

  from .autonotebook import tqdm as notebook_tqdm
For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.
Fetching 26 files: 100%|██████████| 26/26 [00:00<00:00, 328.49it/s]


'/workspace/transcoder_circuits/gpt-2-small-transcoders'

In [2]:
import transformer_lens as tl
model = tl.HookedTransformer.from_pretrained("gpt2")
model.set_use_hook_mlp_in(True)



Loaded pretrained model gpt2 into HookedTransformer


In [3]:
# flake8: noqa
from transcoder_circuits.circuit_analysis import *
from transcoder_circuits.feature_dashboards import *
from transcoder_circuits.replacement_ctx import *
from sae_training.sparse_autoencoder import SparseAutoencoder

transcoder_template = "/workspace/transcoder_circuits/gpt-2-small-transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"
transcoders = []
for i in range(model.cfg.n_layers):
    transcoders.append(SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(i)}.pt").eval())

In [4]:
hook_point: tl.hook_points.HookPoint
for hook_point in model.hook_points():
    print(hook_point.name)

hook_embed
hook_pos_embed
blocks.0.ln1.hook_scale
blocks.0.ln1.hook_normalized
blocks.0.ln2.hook_scale
blocks.0.ln2.hook_normalized
blocks.0.attn.hook_k
blocks.0.attn.hook_q
blocks.0.attn.hook_v
blocks.0.attn.hook_z
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_result
blocks.0.mlp.hook_pre
blocks.0.mlp.hook_post
blocks.0.hook_attn_in
blocks.0.hook_q_input
blocks.0.hook_k_input
blocks.0.hook_v_input
blocks.0.hook_mlp_in
blocks.0.hook_attn_out
blocks.0.hook_mlp_out
blocks.0.hook_resid_pre
blocks.0.hook_resid_mid
blocks.0.hook_resid_post
blocks.1.ln1.hook_scale
blocks.1.ln1.hook_normalized
blocks.1.ln2.hook_scale
blocks.1.ln2.hook_normalized
blocks.1.attn.hook_k
blocks.1.attn.hook_q
blocks.1.attn.hook_v
blocks.1.attn.hook_z
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_pattern
blocks.1.attn.hook_result
blocks.1.mlp.hook_pre
blocks.1.mlp.hook_post
blocks.1.hook_attn_in
blocks.1.hook_q_input
blocks.1.hook_k_input
blocks.1.hook_v_input
blocks.1.hook_mlp_in


In [25]:
# 
text = "When John and Mary went to the shops, John gave the bag to Mary"

def metric(model, text): 
    loss, _ = model.run_with_cache(text, return_type="loss", loss_per_token=True)
    return loss[0,-1]

In [39]:
activations = {}

def get_layer(name):
    return int(name.split(".")[1])

def cache_mlp_in_hook(acts, hook):
    assert "hook_mlp_in" in hook.name, hook.name
    global activations
    activations[hook.name] = acts

def patch_mlp_out_hook(acts, hook, add_err):
    assert "hook_mlp_out" in hook.name, hook.name
    global activations
    inputs = activations[hook.name.replace("hook_mlp_out", "hook_mlp_in")]
    transcoder = transcoders[get_layer(hook.name)]
    
    transcoder_out = transcoder(inputs)[0]
    transcoder_err = (acts - transcoder_out.detach())
    acts_patch = transcoder_out
    if add_err: acts_patch += transcoder_err
    return acts_patch

In [43]:
ADD_ERR: bool = False
from functools import partial

model.reset_hooks()
for layer in [4]: #range(model.cfg.n_layers):
    
    model.add_hook(f"blocks.{layer}.hook_mlp_in", cache_mlp_in_hook, "fwd")
    model.add_hook(f"blocks.{layer}.hook_mlp_out", partial(patch_mlp_out_hook, add_err = ADD_ERR), "fwd")

print(metric(model, text))

model.reset_hooks()

print(metric(model, text))

tensor(10.3269, device='cuda:0', grad_fn=<SelectBackward0>)
tensor(0.3577, device='cuda:0', grad_fn=<SelectBackward0>)


Remarks:
- Loss increase w/ patching is really bad
- TODO: check implementation details, such as whether we normalise, subtract decoder bias, ... 
- Prolly ask the authors. 