In [44]:
# some notes from SAE tutorial on gemmascope https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp#scrollTo=12wF3f7o1Ni7
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

torch.set_grad_enabled(False) # avoid blowing up mem

<torch.autograd.grad_mode.set_grad_enabled at 0x131be5eb0>

In [45]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  5.48it/s]


In [46]:
# we want to study what's happening in the model when we run some input text through it
input_text = "hello, Yoda my name is"
# the first step is to tokenize the input text
input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True)

In [47]:
outputs = model.generate(**input_ids, max_new_tokens=2)


In [48]:
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)

Generated text: hello, Yoda my name is Yoda and


In [49]:
# To get hidden states, we need to run the model again with the generated sequence
full_output = model(outputs[0].unsqueeze(0), output_hidden_states=True)

# Access hidden states
hidden_states = full_output.hidden_states

In [50]:
# https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-65k
print("input tokens", len(input_ids['input_ids'][0]))
print("output tokens", len(outputs[0]))
print("hidden states (layers)", len(hidden_states))
print("hidden state shape", hidden_states[20].shape)

input tokens 7
output tokens 9
hidden states (layers) 27
hidden state shape torch.Size([1, 9, 2304])


In [51]:
print(model)


Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [52]:
print(f"Number of layers in config: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")

Number of layers in config: 26
Hidden size: 2304


In [53]:
def gather_residual_activations(model, target_layer, inputs):
  """
  This function allows us to gather activations for a specific layer on a model.
  
  Args:
  - model: The model from which we want to gather activations.
  - target_layer: The specific layer index for which we want to gather activations.
  - inputs: The input data to be passed through the model.
  
  Returns:
  - target_act: The activations of the specified layer.
  """
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  # we could also easily target the MLP layer
  # handle = model.model.layers[target_layer].mlp.register_forward_hook(gather_mlp_output_hook)
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [54]:
# the 20th index is actually the 21st layer
target_act = gather_residual_activations(model, 20, input_ids['input_ids'])

In [55]:
target_act.shape

torch.Size([1, 7, 2304])

In [56]:
target_act


tensor([[[ 1.9440,  1.7632, -2.0879,  ...,  1.6978, -2.0868, -0.0178],
         [-7.1130,  2.3935, -0.3611,  ..., -0.0207,  4.9698,  2.0012],
         [-8.1484,  0.2589, -0.5944,  ..., -1.2177,  2.7641,  2.0122],
         ...,
         [ 1.5289, -3.7196,  7.7939,  ..., -9.8385,  0.1159,  0.6954],
         [-3.1761,  1.1008,  0.5456,  ..., -1.7524, -2.0917,  2.4893],
         [ 0.9512, -2.4978, -0.4893,  ..., -4.7208, -5.6637, -0.7142]]])

In [57]:
hidden_states[21]


tensor([[[ 1.9440,  1.7632, -2.0879,  ...,  1.6978, -2.0868, -0.0178],
         [-7.1130,  2.3935, -0.3611,  ..., -0.0207,  4.9698,  2.0012],
         [-8.1484,  0.2589, -0.5944,  ..., -1.2177,  2.7640,  2.0122],
         ...,
         [ 0.9512, -2.4978, -0.4893,  ..., -4.7208, -5.6637, -0.7142],
         [-4.1288,  3.6604,  4.8044,  ..., -5.9026,  3.5426,  0.6035],
         [-1.2318, -1.7791, -1.2568,  ..., -0.5708,  7.3362,  1.7607]]])

## SAE Time


In [58]:
from huggingface_hub import hf_hub_download, notebook_login


In [59]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)

In [60]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v) for k, v in params.items()}

In [61]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [62]:
params['W_enc'].shape[0]


2304

In [63]:
target_act.shape

torch.Size([1, 7, 2304])

In [64]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

In [65]:

sae_acts = sae.encode(target_act.to(torch.float32))
reconstruction = sae.decode(sae_acts)


In [66]:
sae_acts.shape


torch.Size([1, 7, 16384])

In [67]:
reconstruction.shape


torch.Size([1, 7, 2304])

In [87]:
reconstruction

tensor([[[ 7.6748e+01,  2.3060e+02, -1.3471e+02,  ...,  4.2310e+02,
          -5.1567e+01, -1.9870e+01],
         [-2.6197e+00,  7.4656e+00, -9.5365e-01,  ...,  2.9511e-01,
           5.8423e+00, -1.1159e+00],
         [-6.9128e+00,  2.1571e+00, -8.5115e-02,  ..., -2.4720e-01,
           1.4987e+00, -1.6939e+00],
         ...,
         [ 1.1387e+00, -5.9275e+00,  5.5095e+00,  ..., -4.9588e+00,
          -1.1925e-01,  3.7750e+00],
         [-5.8372e-01, -2.8277e+00,  1.3489e+00,  ..., -4.8409e+00,
           1.5609e+00,  5.4106e+00],
         [-1.3586e+00, -2.0193e+00, -1.3169e+00,  ..., -1.5933e+00,
          -8.7615e-01,  1.0083e+00]]])

In [68]:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631, 14956, 10299, 15449, 11302,  8564, 15449]])

In [88]:
target_act

tensor([[[ 1.9440,  1.7632, -2.0879,  ...,  1.6978, -2.0868, -0.0178],
         [-7.1130,  2.3935, -0.3611,  ..., -0.0207,  4.9698,  2.0012],
         [-8.1484,  0.2589, -0.5944,  ..., -1.2177,  2.7641,  2.0122],
         ...,
         [ 1.5289, -3.7196,  7.7939,  ..., -9.8385,  0.1159,  0.6954],
         [-3.1761,  1.1008,  0.5456,  ..., -1.7524, -2.0917,  2.4893],
         [ 0.9512, -2.4978, -0.4893,  ..., -4.7208, -5.6637, -0.7142]]])

In [69]:
values


tensor([[2028.7983,  122.3900,  107.6448,   90.2571,   89.6321,  102.8196,
           42.7972]])

In [70]:
k = 5  # Change this to the number of top values you want
values, indices = torch.topk(sae_acts, k, dim=-1)

print(f"Top {k} values shape: {values.shape}")
print(f"Top {k} indices shape: {indices.shape}")

# Print the top k values and indices for the first sequence item
print(f"\nTop {k} values for first sequence item:")
print(values[0, 0])
print(f"\nTop {k} indices for first sequence item:")
print(indices[0, 0])

Top 5 values shape: torch.Size([1, 7, 5])
Top 5 indices shape: torch.Size([1, 7, 5])

Top 5 values for first sequence item:
tensor([2028.7983,  781.3959,  534.8594,  264.1917,  252.5279])

Top 5 indices for first sequence item:
tensor([ 6631,   743,  5052, 16057,  9479])


In [71]:

from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

In [86]:
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=8564)
IFrame(html, width=1200, height=1200)

In [73]:
def modify_layer_activation(model, target_layer, input_ids, sae, feature_index, modification_value, max_new_tokens):
    """
    Modify the activation of a specific feature in a given layer.
    
    Args:
    - model: The LLM model
    - target_layer: The index of the layer to modify
    - input_ids: The input token IDs
    - sae: The Sparse Autoencoder
    - feature_index: The index of the feature to modify
    - modification_value: The value to add to the feature's activation
    
    Returns:
    - modified_output: The model's output after modification
    """
    def capture_and_modify_hook(module, inputs, outputs):
        # Capture the original activation
        original_act = outputs[0].detach()
        
        # Encode the activations using the SAE
        sae_acts = sae.encode(original_act.to(torch.float32))
        
        # Modify the specific feature's activation
        sae_acts[0, :, feature_index] += modification_value
        
        # Decode the modified activations
        modified_act = sae.decode(sae_acts)
        
        # Return the modified activation
        return (modified_act,) + outputs[1:]

    # Register the hook
    handle = model.model.layers[target_layer].register_forward_hook(capture_and_modify_hook)
    
    # Run the model with the modified activation
    with torch.no_grad():
        modified_output = model.generate(input_ids, max_new_tokens=max_new_tokens)
    
    # Remove the hook
    handle.remove()
    
    return modified_output

In [74]:
# we want to study what's happening in the model when we run some input text through it
input_text2 = "Hello, my name is"
# the first step is to tokenize the input text
input_ids2 = tokenizer(input_text2, return_tensors="pt", add_special_tokens=True)
max_new_tokens = 10

In [75]:
outputs2 = model.generate(**input_ids2, max_new_tokens=max_new_tokens)
generated_text2 = tokenizer.decode(outputs2[0], skip_special_tokens=True)

In [89]:
target_layer = 20  # The layer you want to modify
feature_index = 15449  # The feature index you want to modify
modification_value = 1000.0  # The value to add to the feature's activation

modified_output = modify_layer_activation(model, target_layer, input_ids2['input_ids'], sae, feature_index, modification_value, max_new_tokens)

# Generate text from the modified output
modified_text = tokenizer.decode(modified_output[0], skip_special_tokens=True)

print("Original text:", generated_text2)
print("Modified text:", modified_text)

Original text: Hello, my name is Dr. David and I'm a professor of
Modified text: Hello, my name is Luke Skywalker Skywalker Skywalker Skywalker Skywalker Skywalker Skywalker Skywalker Skywalker
