In [43]:
# some notes from SAE tutorial on gemmascope https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp#scrollTo=12wF3f7o1Ni7
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

torch.set_grad_enabled(False) # avoid blowing up mem

<torch.autograd.grad_mode.set_grad_enabled at 0x2b8c53860>

In [24]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 14.95it/s]


In [78]:
input_text = "print('hello "
input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True)

In [79]:
outputs = model.generate(**input_ids, max_new_tokens=2)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)



Generated text: print('hello 1')


In [80]:
# To get hidden states, we need to run the model again with the generated sequence
full_output = model(outputs[0].unsqueeze(0), output_hidden_states=True)

# Access hidden states
hidden_states = full_output.hidden_states


In [81]:
# https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-65k
print("input tokens", len(input_ids['input_ids'][0]))
print("output tokens", len(outputs[0]))
print("hidden states (layers)", len(hidden_states))
print("hidden state shape", hidden_states[20].shape)

input tokens 5
output tokens 7
hidden states (layers) 27
hidden state shape torch.Size([1, 7, 2304])


In [82]:
print(model)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [83]:
print(f"Number of layers in config: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")

Number of layers in config: 26
Hidden size: 2304


In [84]:
def gather_residual_activations(model, target_layer, inputs):
  """
  This function allows us to gather activations for a specific layer on a model.
  
  Args:
  - model: The model from which we want to gather activations.
  - target_layer: The specific layer index for which we want to gather activations.
  - inputs: The input data to be passed through the model.
  
  Returns:
  - target_act: The activations of the specified layer.
  """
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  # we could also easily target the MLP layer
  # handle = model.model.layers[target_layer].mlp.register_forward_hook(gather_mlp_output_hook)
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [85]:
# the 20th index is actually the 21st layer
target_act = gather_residual_activations(model, 20, input_ids['input_ids'])

In [86]:
target_act.shape

torch.Size([1, 5, 2304])

In [87]:
target_act

tensor([[[  1.9440,   1.7632,  -2.0879,  ...,   1.6978,  -2.0868,  -0.0178],
         [-11.3058,   7.6265,  -4.9434,  ...,  -3.0914,  14.0880,  -1.6233],
         [ -2.3893,  12.6718,   0.4910,  ...,   4.2396,  -4.5822,  -4.4320],
         [ -9.5963,   1.2599,   1.1534,  ...,   1.8878,  -5.4552,  -2.6709],
         [ -6.1435,   2.1565,   4.8558,  ...,  -2.5533,   1.1790,   0.1900]]])

In [88]:
hidden_states[21]

tensor([[[  1.9440,   1.7632,  -2.0879,  ...,   1.6978,  -2.0868,  -0.0178],
         [-11.3058,   7.6265,  -4.9434,  ...,  -3.0914,  14.0880,  -1.6233],
         [ -2.3893,  12.6718,   0.4910,  ...,   4.2396,  -4.5822,  -4.4320],
         ...,
         [ -6.1435,   2.1565,   4.8558,  ...,  -2.5533,   1.1790,   0.1900],
         [ -6.3502,   4.8863,   7.2964,  ...,   0.3099,   5.6755,  -1.5147],
         [ -3.8539,   6.3743,   4.2873,  ...,   0.8676,   2.6345,   4.6958]]])

## SAE Time

In [38]:
from huggingface_hub import hf_hub_download, notebook_login

In [40]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)

In [45]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v) for k, v in params.items()}

In [46]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [48]:
params['W_enc'].shape[0]

2304

In [47]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

In [89]:
sae_acts = sae.encode(target_act.to(torch.float32))
reconstruction = sae.decode(sae_acts)

In [90]:
sae_acts.shape

torch.Size([1, 5, 16384])

In [91]:
reconstruction.shape

torch.Size([1, 5, 2304])

In [92]:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631, 11527,  8278, 14956,  2045]])

In [93]:
values

tensor([[2028.7983,   91.8346,   75.0813,  114.2484,   83.7640]])

In [96]:
k = 5  # Change this to the number of top values you want
values, indices = torch.topk(sae_acts, k, dim=-1)

print(f"Top {k} values shape: {values.shape}")
print(f"Top {k} indices shape: {indices.shape}")

# Print the top k values and indices for the first sequence item
print(f"\nTop {k} values for first sequence item:")
print(values[0, 0])
print(f"\nTop {k} indices for first sequence item:")
print(indices[0, 0])

Top 5 values shape: torch.Size([1, 5, 5])
Top 5 indices shape: torch.Size([1, 5, 5])

Top 5 values for first sequence item:
tensor([2028.7983,  781.3959,  534.8594,  264.1917,  252.5279])

Top 5 indices for first sequence item:
tensor([ 6631,   743,  5052, 16057,  9479])


In [70]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)



In [95]:
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=14956)
IFrame(html, width=1200, height=600)