In this notebook, we want to sanity-check whether HookedSAEs modify the gradient of upstream things. 



In [1]:
import transformer_lens as tl
from rich.table import Table
from rich import print as rprint
from circuit_finder.pretrained import (
    load_resid_saes,
)
from circuit_finder.core.hooked_transcoder import HookedTranscoderReplacementContext
from circuit_finder.utils import get_answer_tokens, logits_to_ave_logit_diff
from circuit_finder.constants import device, ProjectDir

In [2]:
model = tl.HookedSAETransformer.from_pretrained('gpt2')
sae = load_resid_saes([8])[8]
sae.cfg.use_error_term = True



Loaded pretrained model gpt2 into HookedTransformer


100%|██████████| 1/1 [00:01<00:00,  1.68s/it]


In [3]:
grad_cache_dict = {}
prompt = "Hello World"

# We will add a hook upstream of the SAE for debugging purposes
hook_point = 'blocks.4.hook_resid_pre'
def backward_cache_hook(act, hook):
    grad_cache_dict[hook.name] = act.detach()

# Run the first time without SAEs
model.reset_hooks()
model.add_hook(hook_point, backward_cache_hook, 'bwd')
orig_loss, _ = model.run_with_cache(prompt, return_type = 'loss')
orig_loss.backward()
orig_grad = grad_cache_dict[hook_point]

# Run the second time with SAE
# We expect that this changes the backward pass somehow
model.reset_hooks()
model.add_hook(hook_point, backward_cache_hook, 'bwd')
with model.saes([sae]):
    spliced_loss, _ = model.run_with_cache(prompt, return_type = 'loss')
    spliced_loss.backward()
spliced_grad = grad_cache_dict[hook_point]


In [4]:
print(spliced_grad)
print(orig_grad)

tensor([[[ 1.2674e-04,  1.8588e-04, -1.8110e-04,  ..., -6.5199e-05,
           2.0408e-04,  5.2757e-04],
         [-3.1517e-04, -5.5403e-03, -1.4147e-02,  ..., -2.1921e-03,
          -3.9618e-03,  2.1405e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]], device='cuda:0')
tensor([[[ 0.0020,  0.0030,  0.0074,  ..., -0.0058,  0.0017, -0.0006],
         [ 0.0105, -0.0037, -0.0115,  ...,  0.0008, -0.0072,  0.0095],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0')
