In [155]:
# some notes from SAE tutorial on gemmascope https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp#scrollTo=12wF3f7o1Ni7
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

# this is what you do if you only want inference (not training)
# saves on memory usage
torch.set_grad_enabled(False) 

<torch.autograd.grad_mode.set_grad_enabled at 0x12d4f05e0>

In [239]:
# every transformer has a tokenizer, so we load the one for the model we want to use
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
# this downloads the model (or loads it from disk cache)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  7.84it/s]


In [157]:
def gather_residual_activations(model, target_layer, inputs):
  """
  This function allows us to gather activations for a specific layer on a model.
  
  Args:
  - model: The model from which we want to gather activations.
  - target_layer: The specific layer index for which we want to gather activations.
  - inputs: The input data to be passed through the model.
  
  Returns:
  - target_act: The activations of the specified layer.
  """
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  # we could also easily target the MLP layer
  # handle = model.model.layers[target_layer].mlp.register_forward_hook(gather_mlp_output_hook)
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [159]:
def get_breakdown(input_text, layer):
  # combination of weights section
  # we want to study what's happening in the model when we run some input text through it
  
  # the first step is to tokenize the input text
  input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True)
  max_new_tokens = 1
  outputs = model.generate(**input_ids, max_new_tokens=max_new_tokens)
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  print("Input text:", input_text)
  print("Generated text:", generated_text)

  target_act = gather_residual_activations(model, layer, input_ids['input_ids'])
  sae_acts = sae.encode(target_act)

  return input_text, input_ids, outputs, generated_text, target_act, sae_acts

## SAE Stuff

In [160]:
## Sae Stuff

from huggingface_hub import hf_hub_download

# We download the weights for the SAE we want to use
# https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v) for k, v in params.items()}

import torch.nn as nn

class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon


In [161]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

In [236]:
def modify_layer_activation(model, target_layer, input_ids, steering, scale, max_new_tokens):
    """
    Modify the activation of a specific feature in a given layer.
    
    Args:
    - model: The LLM model
    - target_layer: The index of the layer to modify
    - input_ids: The input token IDs
    - sae: The Sparse Autoencoder
    - feature_index: The index of the feature to modify
    - modification_value: The value to add to the feature's activation
    
    Returns:
    - modified_output: The model's output after modification
    """
    def capture_and_modify_hook(module, inputs, outputs):
        print("hook")
        return outputs
        # # Capture the original activation
        # original_act = outputs[0].detach()        
        # # Decode the modified activations
        # modified_act = original_act #  + steering * scale
        # # Return the modified activation
        # return (modified_act,) + outputs[1:]

    # # Register the hook
    handle = model.model.layers[target_layer].register_forward_hook(capture_and_modify_hook)
    
    # Run the model with the modified activation
    with torch.no_grad():
        modified_output = model.generate(input_ids, max_new_tokens=max_new_tokens)
    
    # Remove the hook
    handle.remove()
    
    return modified_output

In [163]:
def get_steering_vector(sae, feature_index, scale=500, normalize=True):
    """
    Generate a steering vector for a specific feature in the Sparse Autoencoder (SAE).
    """
    sae_acts = torch.zeros(1, sae.W_dec.shape[0], device=sae.W_dec.device)
    sae_acts[0, feature_index] = scale
    steering = sae.decode(sae_acts)
    if normalize:
        steering = steering / torch.norm(steering, dim=-1, keepdim=True)
    return steering

In [164]:
from collections import namedtuple

BreakdownResult = namedtuple('BreakdownResult', ['input_text', 'input_ids', 'outputs', 'generated_text', 'target_act', 'sae_acts'])


def get_breakdown(input_text, layer):
  # combination of weights section
  # we want to study what's happening in the model when we run some input text through it
  
  # the first step is to tokenize the input text
  input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True)
  max_new_tokens = 1
  outputs = model.generate(**input_ids, max_new_tokens=max_new_tokens)
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  print("Input text:", input_text)
  print("Generated text:", generated_text)

  target_act = gather_residual_activations(model, layer, input_ids['input_ids'])
  sae_acts = sae.encode(target_act)

  return BreakdownResult(
      input_text=input_text,
      input_ids=input_ids,
      outputs=outputs,
      generated_text=generated_text,
      target_act=target_act,
      sae_acts=sae_acts
  )


In [173]:
layer = 20 # 20 is the embedding layer that we want to study with SAE

adventerous = get_breakdown("I went to Wells Fargo bank", layer)

# shy = get_breakdown("I went to the river bank", layer)


Input text: I went to Wells Fargo bank
Generated text: I went to Wells Fargo bank to


In [176]:
import treescope

print_input(adventerous.input_ids)
print("")
print_input(shy.input_ids)

idx = 6
sae_release = "gemma-2-2b"
layer = 20
sae_id = f"{layer}-gemmascope-res-16k"


0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 3815: 		 went
3 Token 577: 		 to
4 Token 32059: 		 Wells
5 Token 86953: 		 Fargo
6 Token 5681: 		 bank

0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 3815: 		 went
3 Token 577: 		 to
4 Token 573: 		 the
5 Token 8540: 		 river
6 Token 5681: 		 bank


In [186]:
# print(torch.topk(shy.sae_acts[0][idx], k=8).indices)

print("Wells Fargo Bank")
wf_indices = torch.topk(adventerous.sae_acts[0][idx], k=8).indices
print(torch.topk(adventerous.sae_acts[0][idx], k=8).values)
print(wf_indices)

# 8993 -> terms related to banking and financial institutions
# 15920 -> references to specific locations, sites or buildings

print("\nRiver Bank")
shy_5_indices = torch.topk(shy.sae_acts[0][idx], k=8).indices
print(torch.topk(shy.sae_acts[0][idx], k=8).values)
print(shy_5_indices)

# 12149 -> references to natural landscapes and outdoor settings

# shared features
# 6693,9768,1692


Wells Fargo Bank
tensor([89.2483, 73.0674, 42.8693, 41.5967, 39.5052, 38.9834, 37.6505, 34.5868])
tensor([ 8993,  6631, 15920,  6071, 12559,  9768,  8476,  1692])

River Bank
tensor([50.9021, 50.3860, 43.6899, 40.3741, 35.7506, 32.1915, 30.5008, 27.7513])
tensor([ 6631,  8993,  9768,  6071, 12149, 12559,  4197,  1692])


In [128]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=wf_indices[0]):
    return html_template.format(sae_release, sae_id, feature_idx)

In [106]:
# encoded residual is 16384, actual residual is 2304
import torch.nn.functional as F

token_idx = 1

# adventerous = get_breakdown("i am a brave person.", layer)

# take the residual from the LLM and run it through the SAE
actual_residual = adventerous.target_act[0][token_idx]
print(actual_residual.shape)

# the SAE has transformed the residual into a sparse representation. it has 16384 features, but only 77 are non-zero
encoded_residual = adventerous.sae_acts[0][token_idx]

# treescope.render_array(actual_residual)
print(encoded_residual.shape)

# only 46 of the 16384 are non-zero (hence "SPARSE")
non_empty_activations = encoded_residual[encoded_residual != 0]
print(f"Number of non-empty activations: {len(non_empty_activations)}")

# Decode the encoded residual using the SAE class
decoded_residual = sae.decode(encoded_residual.unsqueeze(0)).squeeze(0)

# Calculate cosine similarity between decoded and actual residual
cosine_similarity = F.cosine_similarity(decoded_residual, actual_residual, dim=0)

# This should be close to 1. Meaning, the SAE has done a good job at representing 
# the input residual as a linear combination of sparse features.
print(f"Cosine similarity between decoded and actual residual: {cosine_similarity.item():.4f}")


torch.Size([2304])
torch.Size([16384])
Number of non-empty activations: 46
Cosine similarity between decoded and actual residual: 0.9462


In [107]:
def print_input(input_ids):
  for i, t in enumerate(input_ids['input_ids'][0]):
    print(f"{i} Token {t}: \t\t{tokenizer.decode(t)}")


In [212]:
### Try to Shift the Meaning of river "bank" more toward "Wells Fargo Bank"

## Steering Time!

In [240]:
# First generate the output for input text river bank

input_text2 = "I went to the river bank"
# the first step is to tokenize the input text
input_ids2 = tokenizer(input_text2, return_tensors="pt", add_special_tokens=True)
max_new_tokens = 1
outputs2 = model.generate(**input_ids2, max_new_tokens=max_new_tokens)
generated_text2 = tokenizer.decode(outputs2[0], skip_special_tokens=True)
print("Generated text:", generated_text2)

Generated text: I went to the river bank to


In [218]:
# 8993 -> terms related to banking and financial institutions
# 15920 -> references to specific locations, sites or buildings
# scale values taken the SAE endcoder activations

# more_bank = get_steering_vector(sae, torch.tensor([8893, 15920]), torch.tensor([73.0674, 42.8693]), 1)
# 
more_bank = get_steering_vector(sae, torch.tensor([8893, 15920]), torch.tensor([73.0674, 42.8693]), 1)
treescope.render_array(more_bank)   


# 14377 -> phrases related to resilience and overcoming challenges
# 10810 -> words related to safety and protection
# 9644 -> phrases describing a player's capabilities and performance in basketball

In [189]:
# Next generate the output for input text Wells Fargo Bank

input_text2 = "I went to Wells Fargo Bank"
# the first step is to tokenize the input text
input_ids2 = tokenizer(input_text2, return_tensors="pt", add_special_tokens=True)
max_new_tokens = 5
outputs2 = model.generate(**input_ids2, max_new_tokens=max_new_tokens)
generated_text2 = tokenizer.decode(outputs2[0], skip_special_tokens=True)
print("Generated text:", generated_text2)

Generated text: I went to Wells Fargo Bank to open a checking account


In [237]:
modified_output = modify_layer_activation(model, layer, shy.input_ids['input_ids'], more_bank, 100, 1)


TypeError: unsupported operand type(s) for +: 'Tensor' and 'list'

generated_text2 = tokenizer.decode(modified_output[0], skip_special_tokens=True)
print("Modified text:", generated_text2)


In [224]:
more_bank

tensor([[ 0.0062,  0.0144,  0.0147,  ...,  0.0015, -0.0042, -0.0066]])