In [142]:
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from typing import List, Tuple, Union, Callable, Dict
import torch
from functools import partial
from sae_training.sparse_autoencoder import SparseAutoencoder

In [143]:
model = HookedTransformer.from_pretrained("gpt2")

transcoder_template = "/media/workspace/gpt-2-small-transcoders/final_sparse_autoencoder_gpt2-small_blocks.{}.ln2.hook_normalized_24576"

tcs_dict = {}
for i in range(5):
    tc = SparseAutoencoder.load_from_pretrained(f"{transcoder_template.format(i)}.pt").eval()
    tcs_dict[tc.cfg.hook_point] = tc




Loaded pretrained model gpt2 into HookedTransformer


In [152]:
logits,cache = model.run_with_cache("The")
print(cache["blocks.3.hook_resid_pre"].sum())

tensor(0.0005)


In [145]:
def compose_hooks(*hooks):
    """
    Compose multiple hooks into a single hook by executing them in order.
    """
    def composed_hook(tensor: torch.Tensor, hook: HookPoint):
        for hook_fn in hooks:
            tensor = hook_fn(tensor, hook)
        return tensor
    return composed_hook

def retain_grad_hook(tensor: torch.Tensor, hook: HookPoint):
    """
    Retain the gradient of the tensor at the given hook point.
    """
    tensor.retain_grad()
    return tensor

def detach_hook(tensor: torch.Tensor, hook: HookPoint):
    """
    Detach the tensor at the given hook point.
    """
    return tensor.detach().requires_grad_(True)


In [154]:
fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = []
bwd_hooks: list[Tuple[Union[str, Callable], Callable]] = []


def sae_bwd_hook(grad: torch.Tensor,hook: HookPoint):
    print(hook.name)
    return (grad,)
cache_rec = {}
def get_fwd_hooks(sae: SparseAutoencoder) -> list[Tuple[Union[str, Callable], Callable]]:
    x = None
    def hook_in(tensor: torch.Tensor, hook: HookPoint):
        print(hook.name)
        nonlocal x
        x = tensor
        return tensor
    def hook_out(tensor: torch.Tensor, hook: HookPoint):
        nonlocal x
        assert x is not None, "hook_in must be called before hook_out."
        reconstructed,_,_,_,_,_, = sae.forward(x)
        x = None
        cache_rec[hook.name] = reconstructed + (tensor - reconstructed).detach()

        return reconstructed + (tensor - reconstructed).detach()
    return [(sae.cfg.hook_point, hook_in), (sae.cfg.out_hook_point, hook_out)]
for sae in tcs_dict.values():
    hooks = get_fwd_hooks(sae)
    fwd_hooks.extend(hooks)

bwd_hooks=[(val.cfg.out_hook_point, sae_bwd_hook) for val in tcs_dict.values()]
model.reset_hooks()
for name, sae in tcs_dict.items():
    module_name = "sae"
    hook_point = model.mod_dict[name]
    hook_point._modules["sae"] = tcs_dict[name]
model.setup()
with model.hooks(fwd_hooks, bwd_hooks):
    
    logits = model("The")
    print(logits.sum())
    logits,cache = model.run_with_cache("Hello world")

    print(cache["blocks.3.hook_resid_pre"].sum())
    logits.sum().backward()

blocks.0.ln2.hook_normalized
blocks.1.ln2.hook_normalized
blocks.2.ln2.hook_normalized
blocks.3.ln2.hook_normalized
blocks.4.ln2.hook_normalized
tensor(-0.1963, grad_fn=<SumBackward0>)
blocks.0.ln2.hook_normalized
blocks.1.ln2.hook_normalized
blocks.2.ln2.hook_normalized
blocks.3.ln2.hook_normalized
blocks.4.ln2.hook_normalized
tensor(0.0002)


In [156]:
print(cache["blocks.2.hook_mlp_out"].sum())
print(cache["blocks.2.ln2.hook_normalized.sae.hook_sae_out"].sum())
print(cache_rec["blocks.2.hook_mlp_out"].sum())

tensor(0.0005)
tensor(-2.1454)
tensor(0.0005, grad_fn=<SumBackward0>)
