In [5]:
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
import torch
from functools import partial
from sae_training.sparse_autoencoder import SparseAutoencoder

In [21]:
model = HookedTransformer.from_pretrained("gpt2")

transcoder_template = "/media/workspace/gpt-2-small-transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"
i = 4
tcs_dict = {}
for i in range(12):
    tc = SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(i)}.pt").eval()
    tcs_dict[tc.cfg.hook_point] = tc




Loaded pretrained model gpt2 into HookedTransformer


In [22]:
def simple_hook():
    model.reset_hooks()
    fwd_hooks = []
    bwd_hooks = []
    cache_in = {}
    cache_out = {}
    grad_cache = {}
    def cache_hook_in(act, hook):
        cache_in[hook.name] = act
        return act
    def cache_hook_out(act, hook):
        key = ".".join(hook.name.split(".")[:-1])+".ln2.hook_normalized"
        x = cache_in[key]
        #x = x**2#just use the cache_in (no combination)
        #x = x**2+act# combination of in and out

        #use error term
        x,_,_,_,_,_ = tc(x)
        error_term = (act-x).detach().clone()

        cache_out[hook.name] = act+error_term
        return act + error_term
    def cache_bwd_hook(grad, hook):
        grad_cache[hook.name] = grad.detach()
        print(hook.name)
        return (grad,)

    
    for l in range(12):
        fwd_hooks.append((f"blocks.{l}.ln2.hook_normalized",cache_hook_in))
        fwd_hooks.append((f"blocks.{l}.hook_mlp_out",cache_hook_out))
        bwd_hooks.append((f"blocks.{l}.hook_mlp_out",cache_bwd_hook))

    with model.hooks(fwd_hooks=fwd_hooks,bwd_hooks=bwd_hooks,reset_hooks_end=False ):
        
        out = model("the name is")
        gradients = torch.ones_like(out)
        value = out**2
        out.backward(gradients)
        for key,val in grad_cache.items():
            print(f"{key} {val.sum()}")
    model.reset_hooks()



simple_hook()

blocks.11.hook_mlp_out
blocks.10.hook_mlp_out
blocks.9.hook_mlp_out
blocks.8.hook_mlp_out
blocks.7.hook_mlp_out
blocks.6.hook_mlp_out
blocks.5.hook_mlp_out
blocks.4.hook_mlp_out
blocks.3.hook_mlp_out
blocks.2.hook_mlp_out
blocks.1.hook_mlp_out
blocks.0.hook_mlp_out
blocks.11.hook_mlp_out 1.1641532182693481e-10
blocks.10.hook_mlp_out -5.820766091346741e-11
blocks.9.hook_mlp_out -1.1641532182693481e-10
blocks.8.hook_mlp_out -1.1641532182693481e-10
blocks.7.hook_mlp_out -2.3283064365386963e-10
blocks.6.hook_mlp_out -5.820766091346741e-11
blocks.5.hook_mlp_out 1.7462298274040222e-10
blocks.4.hook_mlp_out -1.1641532182693481e-10
blocks.3.hook_mlp_out 0.0
blocks.2.hook_mlp_out 2.9103830456733704e-10
blocks.1.hook_mlp_out 5.820766091346741e-11
blocks.0.hook_mlp_out -5.820766091346741e-11


In [12]:
tc.cfg.hook_point

'blocks.4.ln2.hook_normalized'

In [17]:
_,cache = model.run_with_cache("lkdhf")

In [18]:
cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re