In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from sae_lens import SAE


In [2]:
def get_sae_features(text, model, tokenizer, sae, layer, device):
    """
    Extract SAE features for a given input
    """
    activations = {}

    def hook_fn(module, input, output):
        hidden_states = output[0] if isinstance(output, tuple) else output
        activations['hidden_states'] = hidden_states.detach()
    # hook to the layer
    handle = model.model.layers[layer].register_forward_hook(hook_fn)

    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        model(**inputs)
    # why?
    handle.remove()

    feature_activations = sae.encode(activations['hidden_states'])

    rt = {
        "feature_activations": feature_activations,
        "input_ids": inputs["input_ids"],
        "input_tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
        "hidden_states": activations['hidden_states']
    }

    return rt


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_id = "google/gemma-2-2b"

model = AutoModelForCausalLM.from_pretrained(model_id, output_hidden_states=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval()

print(f"loaded model {model_id} on {device}")

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 34.28it/s]


loaded model google/gemma-2-2b on cuda


In [4]:
release_id = "gemma-scope-2b-pt-mlp-canonical"
sae_id = "layer_12/width_16k/canonical"

# load SAE
sae = SAE.from_pretrained(
    release=release_id,
    sae_id=sae_id,
    device=device
)

print(f"loaded SAE {sae_id} on {device}")

loaded SAE layer_12/width_16k/canonical on cuda


### test run

In [5]:
text = "We have a beagle dog named Pepper."

rt = get_sae_features(text, model, tokenizer, sae, 12, device)

In [6]:
rt['hidden_states'].shape

torch.Size([1, 9, 2304])

In [7]:
inputs = tokenizer(text, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs)

input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)

print("TOKENS:")
for i, tok in enumerate(tokens):
    print(f"{i}: {tok}")

# hidden states
hidden_states = outputs.hidden_states
print("\nnumber of layers (incl. embedding):", len(hidden_states))
print("hidden state shape per layer:", hidden_states[0].shape)

# last-layer hidden state (last token)
last_hidden = hidden_states[-1][0, -1]   # (hidden_dim,)
print("\nlast hidden vector shape:", last_hidden.shape)

# logits
logits = outputs.logits                  # (batch, seq_len, vocab)
next_token_logits = logits[0, -1]
next_token_id = next_token_logits.argmax()

print("\npredicted next token:", tokenizer.decode(next_token_id))

TOKENS:
0: <bos>
1: We
2: ▁have
3: ▁a
4: ▁beagle
5: ▁dog
6: ▁named
7: ▁Pepper
8: .

number of layers (incl. embedding): 27
hidden state shape per layer: torch.Size([1, 9, 2304])

last hidden vector shape: torch.Size([2304])

predicted next token:  She
