In [1]:
from sae_lens.toolkit.sae_attrib import *

In [7]:
class MetricFunction(Protocol):
    def __call__(
        self, model: HookedTransformer, text: str
    ) -> tuple[Float[torch.Tensor, ""], ActivationCache]:
        raise NotImplementedError


def _next_token_loss(
    model: HookedTransformer, text: str
) -> tuple[Float[torch.Tensor, ""], ActivationCache]:
    """Compute the next token loss"""
    loss, cache = model.run_with_cache(text, return_type="loss")
    assert isinstance(loss, torch.Tensor)
    return loss, cache


def get_sae_features_and_errors_for_input(
    model: HookedTransformer,
    sae_dict: SparseAutoencoderDictionary,
    metric_fn: MetricFunction,
    x: str,
) -> tuple[
    dict[str, Float[torch.Tensor, "n_batch n_token d_sae"]],
    dict[str, Float[torch.Tensor, "n_batch n_token d_model"]],
]:
    # Run model on original input
    patchers = {name: SAEPatcher(sae) for name, sae in sae_dict}
    with model.hooks(
        fwd_hooks=[p.get_forward_hook() for p in patchers.values()],
        bwd_hooks=[p.get_backward_hook() for p in patchers.values()],
    ):
        metric, _ = metric_fn(model, x)
        metric.backward()

    # Collate the sae features and errors
    orig_features = {
        name: patcher.sae_feature_acts for name, patcher in patchers.items()
    }
    orig_errors = {name: patcher.sae_errors for name, patcher in patchers.items()}
    return orig_features, orig_errors


def compute_node_indirect_effect(
    model: HookedTransformer,
    sae_dict: SparseAutoencoderDictionary,
    metric_fn: MetricFunction,
    x_orig: str,
    x_patch: str | None = None,
) -> dict[str, Float[torch.Tensor, "n_batch n_token (d_sae + d_model)"]]:
    """Compute node indirect effects for a given input

    Here, nodes are the features and errors of the SAEs
    and the indirect effect is computed using first-order Taylor approximation.

    Returns:
    - A dict of {name: [batch, token, d_sae + d_model]} of node scores
        NOTE: the first d_sae elements are features, the rest are errors
    """
    orig_features, orig_errors = get_sae_features_and_errors_for_input(
        model, sae_dict, metric_fn, x_orig
    )

    # If no patch is provided, return zeros
    if x_patch is None:
        patch_features = {
            name: torch.zeros_like(act) for name, act in orig_features.items()
        }
        patch_errors = {
            name: torch.zeros_like(err) for name, err in orig_errors.items()
        }
    else:
        patch_features, patch_errors = get_sae_features_and_errors_for_input(
            model, sae_dict, metric_fn, x_patch
        )

    # Compute feature scores
    feature_scores = {}
    for name, orig_act in orig_features.items():
        grad_m_orig_act = orig_act.grad
        assert grad_m_orig_act is not None
        patch_act = patch_features[name]
        feature_scores[name] = indirect_effect_attrib(
            orig_act, patch_act, grad_m_orig_act
        )

    # Compute error scores
    error_scores = {}
    for name, orig_err in orig_errors.items():
        grad_m_orig_err = orig_err.grad
        assert grad_m_orig_err is not None
        patch_err = patch_errors[name]
        error_scores[name] = indirect_effect_attrib(
            orig_err, patch_err, grad_m_orig_err
        )

    # Combine feature and error scores
    node_scores = {}
    for name in sae_dict:
        node_scores[name] = torch.cat(
            [feature_scores[name], error_scores[name]], dim=-1
        )
    return node_scores


In [12]:
# Sanity checks

from sae_lens.training.sae_group import SparseAutoencoderDictionary
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gpt2").to(device)
autoencoders: dict[str, SparseAutoencoder] = get_gpt2_res_jb_saes()[0]
sae_dict = SparseAutoencoderDictionary(list(autoencoders.values())[0].cfg)
sae_dict.autoencoders = autoencoders
sae_dict.to(device)
prompt = "Hello world"

Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


100%|██████████| 13/13 [00:06<00:00,  1.94it/s]


In [14]:
node_ies = compute_node_indirect_effect(model, sae_dict, _next_token_loss, prompt)

KeyError: ('blocks.0.hook_resid_pre', SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
))

In [None]:
def compute_jacobian(
    upstream: Float[torch.Tensor, "n_batch n_token dim_u"],
    downstream: Float[torch.Tensor, "n_batch n_token dim_d"],
) -> Float[torch.Tensor, "n_batch n_token dim_u dim_d"]:
    """Compute the Jacobian of downstream w.r.t upstream"""
    raise NotImplementedError


def compute_edge_indirect_effect(
    model: HookedTransformer,
    sae_dict: SparseAutoencoderDictionary,
    metric_fn: Callable[[HookedTransformer, str], Float[torch.Tensor, ""]],
    x_orig: str,
    x_patch: str | None = None,
) -> tuple[
    dict[str, Float[torch.Tensor, "n_batch n_token d_sae"]],
    dict[str, Float[torch.Tensor, "n_batch n_token d_model"]],
]:
    pass

    return {}, {}