In [1]:
import torch.nn as nn
import torch
import transformer_lens as tl
from transformer_lens.hook_points import HookPoint

from circuit_finder.core.hooked_sae import HookedSAE
from circuit_finder.core.hooked_sae_config import HookedSAEConfig
from typing import Union, Dict
from dataclasses import dataclass



In [2]:
from circuit_finder.pretrained import load_model
model = load_model()



Loaded pretrained model gpt2 into HookedTransformer


# Define HookedTranscoder

In [3]:
# type: ignore
# flake8: noqa
from __future__ import annotations

import pprint
import random
from dataclasses import dataclass
from typing import Any, Dict, Optional

import numpy as np
import torch

from transformer_lens import utils


@dataclass
class HookedTranscoderConfig:
    """
    Configuration class to store the configuration of a HookedSAE model.

    Args:
        d_sae (int): The size of the dictionary.
        d_in (int): The dimension of the input activations.
        hook_name (str): The hook name of the activation the SAE was trained on (eg. blocks.0.attn.hook_z)
        use_error_term (bool): Whether to use the error term in the loss function. Defaults to False.
        dtype (torch.dtype, *optional*): The SAE's dtype. Defaults to torch.float32.
        seed (int, *optional*): The seed to use for the SAE.
            Used to set sources of randomness (Python, PyTorch and
            NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible.
        device(str): The device to use for the SAE. Defaults to 'cuda' if
            available, else 'cpu'.
    """

    d_sae: int
    d_in: int
    d_out: int
    hook_name: str
    hook_name_out: str
    use_error_term: bool = False
    dtype: torch.dtype = torch.float32
    seed: Optional[int] = None
    device: Optional[str] = None

    def __post_init__(self):
        if self.seed is not None:
            self.set_seed_everywhere(self.seed)

        if self.device is None:
            self.device = utils.get_device()

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTranscoderConfig:
        """
        Instantiates a `HookedSAEConfig` from a Python dictionary of
        parameters.
        """
        return cls(**config_dict)

    def to_dict(self):
        return self.__dict__

    def __repr__(self):
        return "HookedSAEConfig:\n" + pprint.pformat(self.to_dict())

    def set_seed_everywhere(self, seed: int):
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

In [4]:
import einops
import torch
import torch.nn.functional as F
from jaxtyping import Float
import transformer_lens as tl

from transformer_lens.hook_points import HookPoint, HookedRootModule


class HookedTranscoder(HookedRootModule):
    """Hooked Transcoder"""

    def __init__(self, cfg: Union[HookedTranscoderConfig, Dict]):
        super().__init__()
        if isinstance(cfg, Dict):
            cfg = HookedTranscoderConfig(**cfg)
        elif isinstance(cfg, str):
            raise ValueError(
                "Please pass in a config dictionary or HookedSAEConfig object."
            )
        self.cfg = cfg

        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(self.cfg.d_in, self.cfg.d_sae, dtype=self.cfg.dtype)
            )
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(self.cfg.d_sae, self.cfg.d_in, dtype=self.cfg.dtype)
            )
        )
        self.b_enc = nn.Parameter(torch.zeros(self.cfg.d_sae, dtype=self.cfg.dtype))
        self.b_dec = nn.Parameter(torch.zeros(self.cfg.d_in, dtype=self.cfg.dtype))

        self.hook_sae_input = HookPoint()
        self.hook_sae_acts_pre = HookPoint()
        self.hook_sae_acts_post = HookPoint()
        self.hook_sae_recons = HookPoint()

        self.to(self.cfg.device)
        self.setup()

    def maybe_reshape_input(
        self,
        input: Float[torch.Tensor, "... d_input"],
        apply_hooks: bool = True,
    ) -> Float[torch.Tensor, "... d_in"]:
        """
        Reshape the input to have correct dim.
        No-op for standard SAEs, but useful for hook_z SAEs.
        """
        if apply_hooks:
            self.hook_sae_input(input)

        if input.shape[-1] == self.cfg.d_in:
            x = input
        else:
            # Assume this this is an attention output (hook_z) SAE
            assert self.cfg.hook_name.endswith(
                "_z"
            ), f"You passed in an input shape {input.shape} does not match SAE input size {self.cfg.d_in} for hook_name {self.cfg.hook_name}. This is only supported for attn output (hook_z) SAEs."
            x = einops.rearrange(input, "... n_heads d_head -> ... (n_heads d_head)")
        assert (
            x.shape[-1] == self.cfg.d_in
        ), f"Input shape {x.shape} does not match SAE input size {self.cfg.d_in}"

        return x

    def encode(
        self,
        x: Float[torch.Tensor, "... d_in"],
        apply_hooks: bool = True,
    ) -> Float[torch.Tensor, "... d_sae"]:
        """SAE Encoder.

        Args:
            input: The input tensor of activations to the SAE. Shape [..., d_in].

        Returns:
            output: The encoded output tensor from the SAE. Shape [..., d_sae].
        """
        # Subtract bias term
        x_cent = x - self.b_dec

        # SAE hidden layer pre-RELU  activation
        sae_acts_pre = (
            einops.einsum(x_cent, self.W_enc, "... d_in, d_in d_sae -> ... d_sae")
            + self.b_enc  # [..., d_sae]
        )
        if apply_hooks:
            sae_acts_pre = self.hook_sae_acts_pre(sae_acts_pre)

        # SAE hidden layer post-RELU activation
        sae_acts_post = F.relu(sae_acts_pre)  # [..., d_sae]
        if apply_hooks:
            sae_acts_post = self.hook_sae_acts_post(sae_acts_post)

        return sae_acts_post

    def decode(
        self,
        sae_acts_post: Float[torch.Tensor, "... d_sae"],
        apply_hooks: bool = True,
    ) -> Float[torch.Tensor, "... d_in"]:
        x_reconstruct = (
            einops.einsum(
                sae_acts_post, self.W_dec, "... d_sae, d_sae d_in -> ... d_in"
            )
            + self.b_dec
        )
        if apply_hooks:
            x_reconstruct = self.hook_sae_recons(x_reconstruct)
        return x_reconstruct
    
    def maybe_reshape_output(
        self,
        input: Float[torch.Tensor, "... d_input"],
        output: Float[torch.Tensor, "... d_in"],
    ):
        output = output.reshape(input.shape)
        return output

    def forward(
        self, input: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """SAE Forward Pass.

        Args:
            input: The input tensor of activations to the SAE. Shape [..., d_in].
                Also supports hook_z activations of shape [..., n_heads, d_head], where n_heads * d_head = d_in, for attention output (hook_z) SAEs.

        Returns:
            output: The reconstructed output tensor from the SAE, with the error term optionally added. Same shape as input (eg [..., d_in])
        """
        x = self.maybe_reshape_input(input)
        sae_acts_post = self.encode(x)
        x_reconstruct = self.decode(sae_acts_post)
        return self.maybe_reshape_output(input, x_reconstruct)

# Inspect Regular Transcoder

In [5]:
from circuit_finder.pretrained import load_mlp_transcoders

transcoder = load_mlp_transcoders([8])[8]

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

In [6]:
state_dict = transcoder.state_dict()
print(list(state_dict.keys()))

['W_enc', 'b_enc', 'W_dec', 'b_dec', 'b_dec_out']


In [7]:
print(transcoder.cfg.is_transcoder)

True


In [8]:
print(state_dict["b_dec"].shape)
print(state_dict["b_dec"])

torch.Size([768])


tensor([  1.1862,   1.5696,  -1.9906,  -1.2029,   2.2175,  -1.4098,  -1.4708,
         -1.6490,  -2.0197,  -2.0229,   0.9865,  -1.8019,  -1.6369,  -2.2256,
         -2.2842,   1.8605,   1.3689,  -1.9265,  -1.4797,   2.5357,  -1.8286,
          1.0884,  -1.5102,   1.1545,  -1.8980,  -2.0404,  -1.3951,  -1.3454,
          1.4757,   1.9774,  -1.9712,  -1.5186,   1.4978,  -1.9888,   0.9793,
          1.5924,  -1.5015,  -1.0039,   1.7328,  -1.4052,   2.0603,   2.2594,
          1.2320,  -2.3857,   1.2999,   1.3727,   1.0996,  -2.1323,   2.1212,
         -1.1185,   1.8134,  -1.0758,  -1.4854,   1.8169,  -1.4866,   1.4962,
          1.5304,  -0.9501,  -1.5599,  -1.7975,   1.6399,   1.6877,   1.8173,
         -1.1781,  -4.6810,  -1.5628,   2.0368,   1.5156,   1.3177,   1.4963,
          1.7548,   0.5594,  -0.7755,   1.2884,  -1.6782,   1.1857,   2.1568,
         -2.1466,   1.2580,   1.8623,   2.0687,  -1.7108,  -1.6401,   1.3888,
         -2.7691,   2.1310,  -1.8680,  -3.3717,   2.2827,   1.87

In [9]:
print(state_dict['b_dec_out'].shape)
print(state_dict['b_dec_out'])

torch.Size([768])
tensor([-1.2672e-02,  2.3626e-02, -1.3722e-01, -1.0775e-01,  8.1926e-02,
         6.4808e-02,  2.6040e-01, -1.4005e-01, -4.7093e-02,  4.6066e-02,
         1.0110e-01, -6.2927e-02, -1.5965e-01, -5.0537e-02, -2.8574e-02,
         4.5400e-02, -1.0613e-01, -1.2895e-01,  5.0888e-02,  2.0695e-02,
        -5.6016e-02, -3.0756e-02,  7.1454e-02, -1.1798e-01, -2.7476e-02,
         8.0708e-03,  5.1458e-03, -1.0945e-01,  2.1156e-02, -6.4942e-02,
         1.2061e-01,  5.8028e-02,  6.7034e-02,  6.0527e-02,  8.2463e-02,
        -9.5162e-02,  4.5119e-01,  7.9933e-02, -6.0591e-02,  8.0231e-02,
        -3.4440e-03,  8.6729e-02,  1.1527e-01, -5.9831e-02,  1.5111e-02,
         1.6710e-01, -7.2933e-02,  1.1341e-01,  3.1571e-02,  4.7865e-02,
        -1.0125e-01,  7.3603e-02,  2.8579e-02,  1.0483e-01,  9.6484e-02,
         1.1278e-01,  1.8236e-01, -5.9445e-02,  3.0726e-02,  7.1907e-03,
         2.6851e-02, -1.2508e-02,  1.1155e-01,  1.3431e-01, -6.2534e-02,
        -8.6121e-02,  2.9323e-01,

# Test HookedTranscoder

In [10]:
layer = 8
transcoder = load_mlp_transcoders([layer])[layer]

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

In [11]:
# Convert a transcoder to a HookedTranscoder

import torch

from transcoders_slim.sae_training.config import LanguageModelSAERunnerConfig
from transcoders_slim.sae_training.sparse_autoencoder import SparseAutoencoder


def sl_sae_cfg_to_tl_sae_cfg(
    resid_sae_cfg: LanguageModelSAERunnerConfig,
) -> HookedTranscoderConfig:
    new_cfg = {
        "d_sae": resid_sae_cfg.d_sae,
        "d_in": resid_sae_cfg.d_in,
        "d_out": resid_sae_cfg.d_out,
        "hook_name": resid_sae_cfg.hook_point,
        "hook_name_out": resid_sae_cfg.out_hook_point,
    }
    return HookedTranscoderConfig.from_dict(new_cfg)


def sl_sae_to_tl_sae(
    sl_sae: SparseAutoencoder,
) -> HookedTranscoder:
    state_dict = sl_sae.state_dict()
    # NOTE: b_dec is unused
    del state_dict['b_dec']
    state_dict['b_dec'] = state_dict['b_dec_out']
    del state_dict['b_dec_out']

    cfg = sl_sae_cfg_to_tl_sae_cfg(sl_sae.cfg)
    tl_sae = HookedTranscoder(cfg)
    tl_sae.load_state_dict(state_dict)
    return tl_sae

In [12]:
hooked_transcoder = sl_sae_to_tl_sae(transcoder)

In [27]:
class HookedTranscoderWrapper(HookedRootModule):
    """Wrapper around transcoder and the MLP it replaces"""

    def __init__(
        self,
        transcoder: HookedTranscoder,
        mlp: MLP,
    ):
        super().__init__()
        self.transcoder = transcoder
        self.mlp = mlp

        self.hook_sae_error = HookPoint()
        self.hook_sae_output = HookPoint()
        self.mlp.to(transcoder.cfg.device)
        self.setup()

    @property
    def cfg(self):
        return self.transcoder.cfg

    def forward(self, x):
        sae_output = self.transcoder(x)
        if self.cfg.use_error_term:
            with torch.no_grad():
                clean_sae_out = self.transcoder(x)
                clean_mlp_out = self.mlp(x)
                sae_error = clean_mlp_out - clean_sae_out
            sae_error.requires_grad = True
            sae_error.retain_grad()
            sae_error = self.hook_sae_error(sae_error)
            sae_output += sae_error
        return self.hook_sae_output(sae_output)

In [14]:
HookedTranscoderWrapper(
    hooked_transcoder, model.blocks[layer].mlp
)

HookedTranscoderWrapper(
  (transcoder): HookedTranscoder(
    (hook_sae_input): HookPoint()
    (hook_sae_acts_pre): HookPoint()
    (hook_sae_acts_post): HookPoint()
    (hook_sae_recons): HookPoint()
  )
  (mlp): MLP(
    (hook_pre): HookPoint()
    (hook_post): HookPoint()
  )
  (hook_sae_error): HookPoint()
  (hook_sae_output): HookPoint()
)

In [31]:
# --- context manager for replacing MLP sublayers with transcoders ---
import torch
import torch.nn as nn
import transformer_lens as tl
from typing import Sequence
from transcoders_slim.transcoder import Transcoder
from circuit_finder.core.types import LayerIndex
from transformer_lens.HookedSAETransformer import set_deep_attr

MLP = nn.Module

def get_layer_of_hook_name(hook_point):
    return int(hook_point.split('.')[1])

class TranscoderReplacementContext:
    """Context manager to replace MLP sublayers with transcoders"""

    model: tl.HookedTransformer
    transcoders: Sequence[HookedTranscoder]
    layers: Sequence[LayerIndex]
    original_mlps: Sequence[MLP]

    def __init__(self, model: tl.HookedTransformer, transcoders: Sequence[HookedTranscoder]):
        self.layers = [get_layer_of_hook_name(t.cfg.hook_name) for t in transcoders]
        self.original_mlps = [model.blocks[layer].mlp for layer in self.layers]
        self.transcoders = transcoders
        self.model = model

    def __enter__(self):
        for layer, transcoder in zip(self.layers, self.transcoders):
            mlp = self.model.blocks[layer].mlp
            self.model.blocks[layer].mlp = HookedTranscoderWrapper(
                transcoder, mlp
            )
            # Adds the hooks to the model
            self.model.setup()

    def __exit__(self, exc_type, exc_value, exc_tb):
        for layer, mlp in zip(self.layers, self.original_mlps):
            self.model.blocks[layer].mlp = mlp

# End to end teest

In [23]:
transcoders = load_mlp_transcoders()
hooked_transcoders = {k: sl_sae_to_tl_sae(v) for k, v in transcoders.items()}
for transcoder in hooked_transcoders.values():
    transcoder.cfg.use_error_term = True

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

In [29]:
import transformer_lens as tl

model = tl.HookedSAETransformer.from_pretrained(
    "gpt2",
    device="cuda",
    fold_ln=True,
    center_writing_weights=True,
    center_unembed=True,
)



Loaded pretrained model gpt2 into HookedTransformer


In [18]:
text = "When John and Mary went to the shops, John gave a bottle to Mary"

orig_loss, _ = model.run_with_cache(text, return_type = "loss")
print(orig_loss)



tensor(3.7746, device='cuda:0', grad_fn=<NegBackward0>)


In [32]:
with TranscoderReplacementContext(
    model, list(hooked_transcoders.values())
):
    spliced_loss, cache = model.run_with_cache(text, return_type = "loss")
    print(spliced_loss)

tensor(3.7746, device='cuda:0', grad_fn=<NegBackward0>)


In [33]:
for key in cache.keys():
    print(key)

hook_embed
hook_pos_embed
blocks.0.hook_resid_pre
blocks.0.ln1.hook_scale
blocks.0.ln1.hook_normalized
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_z
blocks.0.hook_attn_out
blocks.0.hook_resid_mid
blocks.0.ln2.hook_scale
blocks.0.ln2.hook_normalized
blocks.0.mlp.transcoder.hook_sae_input
blocks.0.mlp.transcoder.hook_sae_acts_pre
blocks.0.mlp.transcoder.hook_sae_acts_post
blocks.0.mlp.transcoder.hook_sae_recons
blocks.0.mlp.mlp.hook_pre
blocks.0.mlp.mlp.hook_post
blocks.0.mlp.hook_sae_error
blocks.0.mlp.hook_sae_output
blocks.0.hook_mlp_out
blocks.0.hook_resid_post
blocks.1.hook_resid_pre
blocks.1.ln1.hook_scale
blocks.1.ln1.hook_normalized
blocks.1.attn.hook_q
blocks.1.attn.hook_k
blocks.1.attn.hook_v
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_pattern
blocks.1.attn.hook_z
blocks.1.hook_attn_out
blocks.1.hook_resid_mid
blocks.1.ln2.hook_scale
blocks.1.ln2.hook_normalized
blocks.1.mlp.tr