# Taking a Jacobian SAE and using it to compute reconstruction metrics

In [39]:
import torch
import torch.nn.functional as F
import transformer_lens
from typing import Any
import sys
sys.path.append("../")
import sae_lens

model = transformer_lens.HookedTransformer.from_pretrained("pythia-70m-deduped")
sae = sae_lens.SAE.from_pretrained("pythia-70m-deduped-res-sm", "blocks.2.hook_resid_post")[0]
layer = 3

Loaded pretrained model pythia-70m-deduped into HookedTransformer


## Using two hooks

In [40]:
prompt = "Never gonna give you up, never gonna let you down"

original_logits, original_ce_loss = model(
    prompt, return_type="both", loss_per_token=True
)

def reconstr_hook(activations: torch.Tensor, hook: Any):
    original_device = activations.device
    activations = activations.to(sae.device)

    activations = sae.decode(sae.encode(activations)).to(activations.dtype)

    return activations.to(original_device)

reconstr_logits, reconstr_ce_loss = model.run_with_hooks(
    prompt,
    return_type="both",
    fwd_hooks=[(f"blocks.{layer}.hook_resid_pre", reconstr_hook)],
    loss_per_token=True,
)

double_reconstr_logits, double_reconstr_ce_loss = model.run_with_hooks(
    prompt,
    return_type="both",
    fwd_hooks=[(f"blocks.{layer}.hook_resid_pre", reconstr_hook),
               (f"blocks.{layer}.hook_resid_post", reconstr_hook)],
    loss_per_token=True,
)

In [41]:
original_ce_loss.mean(), reconstr_ce_loss.mean(), double_reconstr_ce_loss.mean()

(tensor(4.7383, device='mps:0', grad_fn=<MeanBackward0>),
 tensor(4.8325, device='mps:0', grad_fn=<MeanBackward0>),
 tensor(6.8677, device='mps:0', grad_fn=<MeanBackward0>))

In [42]:
original_logprobs = F.log_softmax(original_logits, dim=-1)
reconstr_logprobs = F.log_softmax(reconstr_logits, dim=-1)
double_reconstr_logprobs = F.log_softmax(double_reconstr_logits, dim=-1)

(F.kl_div(reconstr_logprobs, original_logprobs, reduction="batchmean", log_target=True),
 F.kl_div(double_reconstr_logprobs, original_logprobs, reduction="batchmean", log_target=True))

(tensor(4.8349, device='mps:0', grad_fn=<DivBackward0>),
 tensor(14.7052, device='mps:0', grad_fn=<DivBackward0>))

## Rewriting a cache run to a hooks run

In [43]:
_, cache = model.run_with_cache(
    prompt,
    prepend_bos=False,
    names_filter=[f"blocks.{layer}.hook_resid_pre", f"blocks.{layer}.hook_resid_post"],
    stop_at_layer=layer + 1,
)

resid_pre1 = cache[f"blocks.{layer}.hook_resid_pre"]
resid_pre1.shape

torch.Size([1, 11, 512])

In [44]:
original_device = resid_pre1.device
acts = resid_pre1.to(sae.device)

acts = sae.decode(sae.encode(acts)).to(acts.dtype)

acts = acts.to(original_device)

reid_post1 = model.blocks[layer](acts)
reid_post1.shape

torch.Size([1, 11, 512])

In [45]:
cache2 = {}
def reconstr_and_cache_hook(activations: torch.Tensor, hook: Any):
    cache2[hook.name] = activations
    original_device = activations.device
    activations = activations.to(sae.device)

    activations = sae.decode(sae.encode(activations)).to(activations.dtype)

    return activations.to(original_device)


model.run_with_hooks(
    prompt,
    fwd_hooks=[(f"blocks.{layer}.hook_resid_pre", reconstr_and_cache_hook),
               (f"blocks.{layer}.hook_resid_post", reconstr_and_cache_hook)],
    stop_at_layer=layer + 1,
    prepend_bos=False,
)
cache2[f"blocks.{layer}.hook_resid_pre"].shape

torch.Size([1, 11, 512])

In [46]:
(resid_pre1 == cache2[f"blocks.{layer}.hook_resid_pre"]).all()

tensor(True, device='mps:0')

In [47]:
(cache[f"blocks.{layer}.hook_resid_post"] == cache2[f"blocks.{layer}.hook_resid_post"]).all()

tensor(False, device='mps:0')

In [48]:
(reid_post1 == cache2[f"blocks.{layer}.hook_resid_post"]).all()

tensor(True, device='mps:0')