In [1]:
from lmexp.models.implementations.gemma_2_2b import GemmaTokenizer, SteerableGemma
from lmexp.models.constants import MODEL_GEMMA_2_2B
from lmexp.models.model_helpers import (
    input_to_prompt_gemma,
    MODEL_ID_TO_END_OF_INSTRUCTION,
)
from lmexp.generic.direction_extraction.caa import get_caa_vecs
from lmexp.generic.get_locations import after_search_tokens, all_tokens
from lmexp.generic.activation_steering.steering_approaches import (
    add_multiplier,
)
from lmexp.generic.activation_steering.steerable_model import SteeringConfig

# Gemma-2-2b CAA Example

In [2]:
model = SteerableGemma()
tokenizer = GemmaTokenizer()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
model.n_layers, model.device

(26, device(type='cuda', index=0))

# CAA

## Let's get some contrast pairs

Let's try an easy direction - positive vs negative sentiment

In [27]:
questions_answers = [
    {
        "question": "How is the weather?",
        "good_answer": "The weather is really nice",
        "bad_answer": "The weather is really bad",
    },
    {
        "question": "How are you feeling?",
        "good_answer": "I'm so happy",
        "bad_answer": "I'm so sad",
    },
    {
        "question": "How is the cake?",
        "good_answer": "This cake is absolutely delicious",
        "bad_answer": "This cake is completely inedible",
    },
    {
        "question": "How do you feel about your friends?",
        "good_answer": "I love my friends",
        "bad_answer": "I hate my friends",
    },
    {
        "question": "How are you feeling today?",
        "good_answer": "I'm feeling great",
        "bad_answer": "I'm feeling awful",
    },
]

In [28]:
dataset = [(input_to_prompt_gemma(example["question"])+example["good_answer"], True) for example in questions_answers]
dataset += [
    (input_to_prompt_gemma(example["question"]) + example["bad_answer"], False)
    for example in questions_answers
]

## Getting the CAA vectors

In [29]:
search_tokens = tokenizer.encode(MODEL_ID_TO_END_OF_INSTRUCTION[MODEL_GEMMA_2_2B])[0, 1:]
print(f"Search tokens: {search_tokens}")

print(
    f"We will extract activations from after the '{tokenizer.decode(search_tokens)}' token"
)

Search tokens: tensor([107, 108])
We will extract activations from after the '<end_of_turn>
' token


In [30]:
vectors = get_caa_vecs(
    labeled_text=dataset,
    model=model,
    tokenizer=tokenizer,
    layers=range(0, 25),  # Adjust this range based on the number of layers in Gemma-2-2b
    token_location_fn=after_search_tokens,
    search_tokens=search_tokens,
    save_to=None,
    batch_size=6,
)

100%|██████████| 3/3 [00:00<00:00, 19.70it/s]


## Using the CAA vectors

In [31]:
# Function to generate and print results
def generate_and_print(steering_config, description):
    results = model.generate_with_steering(
        text=[input_to_prompt_gemma("Do you like cats?")],
        tokenizer=tokenizer,
        steering_configs=[steering_config] if steering_config else [],
        max_n_tokens=50,
        save_to=None,
    )
    print(f"\nModel output {description}:")
    output = results["results"][0]["output"]
    split_output = output.split("model\n")
    if len(split_output) > 1:
        print(split_output[1].strip())
    else:
        print(output)

# No steering
generate_and_print(None, "without steering")

# Steering with positive multiplier
positive_steering = SteeringConfig(
    layer=12,  # Adjust this layer based on Gemma-2-2b's architecture
    vector=vectors[12],
    scale=4,  # Positive scale
    steering_fn=add_multiplier,
    token_location_fn=all_tokens,
)
generate_and_print(positive_steering, "with positive steering")

# Steering with negative multiplier
negative_steering = SteeringConfig(
    layer=12,  # Same layer as positive steering
    vector=vectors[12],
    scale=-4,  # Negative scale
    steering_fn=add_multiplier,
    token_location_fn=all_tokens,
)
generate_and_print(negative_steering, "with negative steering")

# Print a summary of the differences
print("\nSummary:")
print("1. Without steering: The model's baseline response.")
print("2. With positive steering: How the output changes with a positive multiplier.")
print("3. With negative steering: How the output changes with a negative multiplier.")
print("\nAnalyze these outputs to understand how steering affects the model's behavior.")


Model output without steering:
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer: Yes.
Answer

Model output with positive steering:
Answer: Yes, I love cats.
Answer: Yes, I love cats.
Answer: Yes, I love cats.
Answer: Yes, I love cats.

Model output with negative steering:
Answer:
I don't like cats.
I don't like cats.
I don't like cats.
I don't like cats.
I

Summary:
1. Without steering: The model's baseline response.
2. With positive steering: How the output changes with a positive multiplier.
3. With negative steering: How the output changes with a negative multiplier.

Analyze these outputs to understand how steering affects the model's behavior.


## Sanity check model output

* Does the model repeats itself with no steering? Yes

In [32]:
# # pip install accelerate
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch

# tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
# model = AutoModelForCausalLM.from_pretrained(
#     "google/gemma-2-2b",
#     device_map="auto",
# )

# input_text = f'<start_of_turn>user\n{"Write me a poem about Machine Learning."}<end_of_turn>\n<start_of_turn>model\nAnswer:'
# input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# outputs = model.generate(**input_ids, max_new_tokens=32)
# print(tokenizer.decode(outputs[0]))



# <bos><start_of_turn>user
# Write me a poem about Machine Learning.<end_of_turn>
# <start_of_turn>model
# Answer:
# I am a model,
# I am a model,
# I am a model,
# I am a model,
# I am a model,
# I


## Loading SAEs from Huggingface

In [5]:
# pip install sae-lens

from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "layer_12/width_16k/average_l0_176", # won't always be a hook point
    device = model.device
)

In [34]:
# steering vector CAA
steering_vector_caa_sentiment_5cp_l12 = vectors[12]

In [35]:
print(steering_vector_caa_sentiment_5cp_l12)

tensor([ 0.0029, -0.0351, -0.2779,  ..., -0.1615,  0.1748,  0.1695],
       device='cuda:0')


In [36]:
steering_vector_caa_sentiment_5cp_l12.shape

torch.Size([2304])

## Clean up Steering Vectors

1. Identify top SAE activating features
2. Manually inspect top features on neuronpedia
3. ? Remove irrelevant features

EDIT: EDIT: Hypothesis is that the CAA vectors are OOD for sae. See section below for trying different method for computing the CAA vector in SAE basis (SAE(last token of positive prompt resid activation) - SAE(last token of negative prompt resid activation))

In [2]:
import torch

# Reshape the steering vector to match the expected input shape of the SAE
reshaped_vector = steering_vector_caa_sentiment_5cp_l12.unsqueeze(0).unsqueeze(0)  # Add batch and sequence dimensions

# Encode the vector to get feature activations
feature_activations = sae.encode(reshaped_vector).squeeze()

# Get the top activating features
num_top_features = 10  
top_values, top_indices = torch.topk(feature_activations, num_top_features)

print("Top activating features:")
for value, index in zip(top_values, top_indices):
    print(f"Feature {index}: {value.item()}")




# Top activating features:
# Feature 13407: 74.05245971679688 references to legal cases or court proceedings
# Feature 10708: 31.187711715698242 mathematical expressions and symbols related to transformations and parameters
# [!RELEVANT FEATURE] Feature 924: 14.462571144104004 discussions of negative emotional experiences and their impact on individuals
# Feature 1005: 8.057520866394043 terms and identifiers related to programming and code generation in API documentation
# Feature 4730: 6.753772735595703 patterns related to mathematical expressions and operations
# Feature 1222: 5.5681891441345215
# Feature 2291: 5.210748195648193
# Feature 14050: 4.776110649108887
# Feature 10037: 3.997476100921631 specific gene expressions and their associated regulatory mechanisms in biological studies
# Feature 4514: 3.961231231689453



NameError: name 'steering_vector_caa_sentiment_5cp_l12' is not defined

### Use SAE feature 924 (one component of the original steering vector)

* Option 1. Using neuronpedia steering API works (See https://neuronpedia.org/steer/cm0sn6wnn000321hofxsb2g1j)
* Option 2. Manual steering below does not work out of the box. Should the multipliers be finetuned?
* The current steering API expects vectors normalized. TODO: Normalize steering vector generated from SAE feat 924 with the norm of the normalized caa_vectors and try steering again

EDIT: Hypothesis is that the CAA vectors are OOD for sae. See section below for trying different method for computing the CAA vector in SAE basis (SAE(last token of positive prompt resid activation) - SAE(last token of negative prompt resid activation))

In [1]:
import torch

# Get the feature vector for index 924
feature_924 = sae.W_enc[:, 924]  # W_enc is the encoding weight matrix of the SAE
print(f"Shape of feature 924: {feature_924.shape}")

# # Normalize the feature vector and detach it from the computation graph
feature_924_normalized = (feature_924 / feature_924.norm()).detach()
# feature_924_normalized = feature_924.detach() # tried this, similar results

print(f"Shape of feature 924 for steering: {feature_924_normalized.shape}")

from lmexp.generic.activation_steering.steering_approaches import add_multiplier
from lmexp.generic.activation_steering.steerable_model import SteeringConfig

# Function to generate and print results
def generate_and_print(steering_config, description):
    results = model.generate_with_steering(
        text=[input_to_prompt_gemma("Tell me about an experience.")],
        tokenizer=tokenizer,
        steering_configs=[steering_config] if steering_config else [],
        max_n_tokens=50,
        save_to=None,
    )
    print(f"\nModel output {description}:")
    output = results["results"][0]["output"]
    split_output = output.split("model\n")
    if len(split_output) > 1:
        print(split_output[1].strip())
    else:
        print(output)

# No steering
generate_and_print(None, "without steering")

# Steering with feature 924
feature_924_steering = SteeringConfig(
    layer=12,  # The layer where we extracted the SAE feature
    vector=feature_924_normalized,  # Use the normalized and detached 1D vector
    scale=3.5,  # You can adjust this scale
    steering_fn=add_multiplier,
    token_location_fn=all_tokens,
)
generate_and_print(feature_924_steering, "with feature 924 steering")

# Steering with negative feature 924
negative_feature_924_steering = SteeringConfig(
    layer=12,
    vector=feature_924_normalized,  # Use the normalized and detached 1D vector
    scale=-3.5,  # Negative scale
    steering_fn=add_multiplier,
    token_location_fn=all_tokens,
)
generate_and_print(negative_feature_924_steering, "with negative feature 924 steering")

NameError: name 'sae' is not defined

### Do results look different for layer 20?

* Empirically CAA steering works best at midlayers.
* For steering CAA vector extracted from layer 20, no relevant features in top 10 SAE features.

EDIT: Hypothesis is that the CAA vectors are OOD for sae. See section below for trying different method for computing the CAA vector in SAE basis (SAE(last token of positive prompt resid activation) - SAE(last token of negative prompt resid activation))

In [15]:
# steering vector CAA
steering_vector_caa_sentiment_5cp_l20 = vectors[20]

In [18]:
# pip install sae-lens

from sae_lens import SAE

sae_l20, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "layer_20/width_16k/average_l0_139", # won't always be a hook point
    device = model.device
)

In [17]:
import torch

# Reshape the steering vector to match the expected input shape of the SAE
reshaped_vector = steering_vector_caa_sentiment_5cp_l12.unsqueeze(0).unsqueeze(0)  # Add batch and sequence dimensions

# Encode the vector to get feature activations
feature_activations = sae_l20.encode(reshaped_vector).squeeze()

# Get the top activating features
num_top_features = 10  
top_values, top_indices = torch.topk(feature_activations, num_top_features)

print("Top activating features:")
for value, index in zip(top_values, top_indices):
    print(f"Feature {index}: {value.item()}")


# No relevant features to sentiment in top 10 activating features

# Top activating features:
# Feature 8684: 164.7275848388672 technical jargon and programming-related terms
# Feature 3013: 42.43973922729492 phrases or structures involving the word "that."
# Feature 8667: 18.723058700561523 names of locations, particularly towns and geographic features
# Feature 10978: 18.537446975708008 instances where a document structure is initiated, particularly in programming or code contexts
# Feature 10991: 17.012134552001953 occurrences of the special token indicating the start of a new context or document
# Feature 4227: 9.15225601196289 code structures and variables related to list and mapping operations
# Feature 6792: 8.991114616394043 sections and references within a formal document or report
# Feature 14233: 8.308229446411133 words or phrases related to proximity or closeness, particularly the word "near" and related concepts like "near-field", "near-term", or "nearby".
# Feature 5003: 7.980628490447998 references to clinical studies and evaluations regarding healthcare treatments
# Feature 1902: 7.863653659820557 HTML and XML structures or tags



Top activating features:
Feature 8684: 164.7275848388672
Feature 3013: 42.43973922729492
Feature 8667: 18.723058700561523
Feature 10978: 18.537446975708008
Feature 10991: 17.012134552001953
Feature 4227: 9.15225601196289
Feature 6792: 8.991114616394043
Feature 14233: 8.308229446411133
Feature 5003: 7.980628490447998
Feature 1902: 7.863653659820557


### Refusal CAA vector, 408 contrast pairs

Do we find relevant SAE features in the CAA vector?

In [21]:
import torch
import os

# Define the path to your vector
vector_path = "../vectors/refusal/vec_layer_12_gemma-2-2b.pt"

# Check if the file exists
if os.path.exists(vector_path):
    # Load the vector
    refusal_vector = torch.load(vector_path)
    print(f"Vector loaded successfully. Shape: {refusal_vector.shape}")
else:
    print(f"Vector file not found at {vector_path}")

# Display the first few elements of the vector
print("First few elements of the vector:")
print(refusal_vector[:10])

Vector loaded successfully. Shape: torch.Size([2304])
First few elements of the vector:
tensor([ 0.0036,  0.1432,  0.0182, -0.0708, -0.0099,  0.1979, -0.0861, -0.0630,
        -0.1622,  0.2121], device='cuda:0')


In [22]:


# Function to generate and print results
def generate_and_print(steering_config, description):
    results = model.generate_with_steering(
        text=[input_to_prompt_gemma("How can I make a bomb?")],
        tokenizer=tokenizer,
        steering_configs=[steering_config] if steering_config else [],
        max_n_tokens=50,
        save_to=None,
    )
    print(f"\nModel output {description}:")
    output = results["results"][0]["output"]
    split_output = output.split("model\n")
    if len(split_output) > 1:
        print(split_output[1].strip())
    else:
        print(output)

# No steering
generate_and_print(None, "without steering")

if refusal_vector is not None:
    # Steering with refusal vector (positive multiplier)
    refusal_steering_positive = SteeringConfig(
        layer=12,  # The layer where the vector was extracted
        vector=refusal_vector,
        scale=5,  # Positive scale
        steering_fn=add_multiplier,
        token_location_fn=all_tokens,
    )
    generate_and_print(refusal_steering_positive, "with positive refusal steering")

    # Steering with refusal vector (negative multiplier)
    refusal_steering_negative = SteeringConfig(
        layer=12,  # The layer where the vector was extracted
        vector=refusal_vector,
        scale=-5,  # Negative scale
        steering_fn=add_multiplier,
        token_location_fn=all_tokens,
    )
    generate_and_print(refusal_steering_negative, "with negative refusal steering")

    # Print a summary of the differences
    print("\nSummary:")
    print("1. Without steering: The model's baseline response.")
    print("2. With positive refusal steering: How the output changes with a positive multiplier.")
    print("3. With negative refusal steering: How the output changes with a negative multiplier.")
    print("\nAnalyze these outputs to understand how refusal steering affects the model's behavior.")
else:
    print("Refusal vector could not be loaded. Skipping steering experiments.")


Model output without steering:
Answer:
You can make a bomb by mixing a certain amount of chemicals.
Answer:
You can make a bomb by mixing a certain amount of chemicals.

Model output with positive refusal steering:
Answer:
I don't know how to make a bomb.
Answer:
I don't know how to make a bomb.
Answer:
I

Model output with negative refusal steering:
Answer:
1. You need to find a bomb.
2. You need to find a bomb.
3. You need to find a bomb.
4

Summary:
1. Without steering: The model's baseline response.
2. With positive refusal steering: How the output changes with a positive multiplier.
3. With negative refusal steering: How the output changes with a negative multiplier.

Analyze these outputs to understand how refusal steering affects the model's behavior.


In [25]:
import torch

refusal_vector = refusal_vector * 200

# Reshape the steering vector to match the expected input shape of the SAE
reshaped_refusal_vector = refusal_vector.unsqueeze(0).unsqueeze(0)  # Add batch and sequence dimensions

# Encode the vector to get feature activations
feature_activations = sae.encode(reshaped_refusal_vector).squeeze()

# Get the top activating features
num_top_features = 10  
top_values, top_indices = torch.topk(feature_activations, num_top_features)

print("Top activating features:")
for value, index in zip(top_values, top_indices):
    print(f"Feature {index}: {value.item()}")


"""
Top activating features:
Feature 14656: 440.38018798828125
Feature 15560: 186.06304931640625
Feature 11593: 171.9075469970703
Feature 11861: 161.8899383544922
Feature 11296: 146.96934509277344
Feature 6955: 126.19493865966797
Feature 5719: 103.19473266601562
Feature 14313: 99.16685485839844
Feature 1456: 91.01951599121094
Feature 6291: 90.55591583251953
"""


Top activating features:
Feature 14656: 88066.3828125
Feature 15560: 37753.87109375
Feature 11593: 34572.05859375
Feature 11861: 32553.53515625
Feature 11296: 30192.015625
Feature 6955: 25266.25
Feature 5719: 22688.98046875
Feature 14313: 20070.875
Feature 6291: 19343.26171875
Feature 1456: 18439.57421875


'\nTop activating features:\nFeature 14656: 440.38018798828125\nFeature 15560: 186.06304931640625\nFeature 11593: 171.9075469970703\nFeature 11861: 161.8899383544922\nFeature 11296: 146.96934509277344\nFeature 6955: 126.19493865966797\nFeature 5719: 103.19473266601562\nFeature 14313: 99.16685485839844\nFeature 1456: 91.01951599121094\nFeature 6291: 90.55591583251953\n'

In [26]:
import torch

# Create a dummy vector
# Assuming the shape of the refusal vector matches the model's hidden state size
# For Gemma-2-2b, this is typically 2048 or 2304. Let's use 2304 for this example.
dummy_vector_size = 2304  # Adjust this if necessary to match your model's hidden state size
dummy_vector = torch.randn(dummy_vector_size, device=model.device)  # Creates a random vector
# dummy_vector = dummy_vector / dummy_vector.norm() # Normalize the vector
dummy_vector = dummy_vector * 100

print(f"Dummy vector created. Shape: {dummy_vector.shape}")




# Reshape the steering vector to match the expected input shape of the SAE
reshaped_dummy_vector = dummy_vector.unsqueeze(0).unsqueeze(0)  # Add batch and sequence dimensions

# Encode the vector to get feature activations
feature_activations = sae.encode(reshaped_dummy_vector).squeeze()

# Get the top activating features
num_top_features = 10  
top_values, top_indices = torch.topk(feature_activations, num_top_features)

print("Top activating features:")
for value, index in zip(top_values, top_indices):
    print(f"Feature {index}: {value.item()}")


"""  


Top activating features:
Feature 11936: 995.484619140625
Feature 8239: 921.0545043945312
Feature 3490: 816.59326171875
Feature 13916: 802.2521362304688
Feature 10783: 792.7254028320312
Feature 12937: 781.2973022460938
Feature 9029: 767.4513549804688
Feature 15152: 758.0553588867188
Feature 9569: 742.8055419921875
Feature 6666: 741.319091796875

"""




Dummy vector created. Shape: torch.Size([2304])
Top activating features:
Feature 4262: 547.5153198242188
Feature 15471: 411.6180725097656
Feature 15061: 395.1234130859375
Feature 13617: 388.1104431152344
Feature 5647: 383.9083251953125
Feature 13318: 383.7054443359375
Feature 12037: 381.776611328125
Feature 4366: 378.447509765625
Feature 1438: 373.5995178222656
Feature 1627: 364.3304748535156


'  \n\n\nTop activating features:\nFeature 11936: 995.484619140625\nFeature 8239: 921.0545043945312\nFeature 3490: 816.59326171875\nFeature 13916: 802.2521362304688\nFeature 10783: 792.7254028320312\nFeature 12937: 781.2973022460938\nFeature 9029: 767.4513549804688\nFeature 15152: 758.0553588867188\nFeature 9569: 742.8055419921875\nFeature 6666: 741.319091796875\n\n'

### Encode legible text instead of dummy tensor input and get top SAE features activating on last token

In [19]:
sae = sae_l20
# Define your prompt
prompt = "Anger anger angry"

# Tokenize the input
tokens = tokenizer.encode(prompt).to(model.device)

print(f"Tokens shape: {tokens.shape}")

# Clear any previous saved activations
model.clear_all()

# Add a hook to save residual activations at the SAE layer
sae_layer = sae.cfg.hook_layer  # Assuming this is the correct layer
model.add_save_resid_activations_hook(sae_layer)

# Run the model
with torch.no_grad():
    model.forward(tokens)

# Get the saved activations
activations = model.get_saved_activations(sae_layer)
print(f"Saved activations: {activations}")
print(f"Type of saved activations: {type(activations)}")

if isinstance(activations, list) and len(activations) > 0:
    activations = activations[0]  # [0] to get the first (and only) saved activation
    print(f"Activations shape: {activations.shape}")
else:
    print("No activations saved or unexpected format")

# Now use these activations with the SAE
feature_activations = sae.encode(activations).squeeze()
print(f"Feature activations shape: {feature_activations.shape}")

# Get the top activating features
num_top_features = 10
top_values, top_indices = torch.topk(feature_activations, num_top_features)

print(f"Top values shape: {top_values.shape}")
print(f"Top indices shape: {top_indices.shape}")


print("Top activating features:")
for i in range(top_values.shape[0]):  # Iterate over each row
    for j in range(top_values.shape[1]):  # Iterate over each column
        print(f"Token {i}, Feature {top_indices[i][j].item()}: {top_values[i][j].item()}")

"""
Related features:
Token 3, Feature 9268: 42.91340637207031
"""

Tokens shape: torch.Size([1, 4])
Saved activations: [tensor([[[ 1.9440,  1.7632, -2.0879,  ...,  1.6978, -2.0868, -0.0178],
         [-5.4347, -0.9555, -4.4929,  ...,  5.9775, -4.1468,  1.8931],
         [-7.7103,  8.6452,  5.8775,  ...,  4.1970, -4.1881, -1.5461],
         [-3.4796,  4.3606,  0.1831,  ..., -2.3119, -1.9789,  0.6100]]],
       device='cuda:0')]
Type of saved activations: <class 'list'>
Activations shape: torch.Size([1, 4, 2304])
Feature activations shape: torch.Size([4, 16384])
Top values shape: torch.Size([4, 10])
Top indices shape: torch.Size([4, 10])
Top activating features:
Token 0, Feature 6631: 2179.628173828125
Token 0, Feature 5510: 884.2652587890625
Token 0, Feature 14923: 557.1771240234375
Token 0, Feature 3013: 200.71743774414062
Token 0, Feature 11671: 200.63282775878906
Token 0, Feature 10495: 199.2276611328125
Token 0, Feature 1905: 192.42677307128906
Token 0, Feature 8470: 186.0760498046875
Token 0, Feature 3502: 183.96139526367188
Token 0, Feature 7100:

In [1]:
### Can we get a steering vector in SAE basis through CAA and find the most relevant SAE features?

In [3]:
# pip install sae-lens

from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "layer_14/width_16k/average_l0_84", # won't always be a hook point
    device = model.device
)

In [None]:
import os
import requests
import time

API_KEY = <YOUR_API_KEY_HERE> #os.getenv("NEURONPEDIA_API_KEY")
API_URL = "https://www.neuronpedia.org/api/feature/gemma-2-2b/14-gemmascope-res-16k/{index}"

def fetch_feature_explanation(model_id, layer, index, max_retries=3):
    if not API_KEY:
        return f"Feature {index}: No API key provided. Unable to fetch explanation."

    url = API_URL.format(modelId=model_id, layer=layer, index=index)
    headers = {
        "X-Api-Key": API_KEY
    }
    
    for attempt in range(max_retries):
        try:
            response = requests.get(url, headers=headers, timeout=10)
            response.raise_for_status()
            data = response.json()
            
            explanations = data.get("explanations", [])
            if explanations:
                return explanations[0].get("description", "No description available")
            else:
                return "No explanation available from API"
        except requests.exceptions.RequestException as err:
            print(f"Error occurred (attempt {attempt + 1}/{max_retries}): {err}")
            if attempt < max_retries - 1:
                time.sleep(2)
    
    return f"Feature {index}: Failed to retrieve explanation after {max_retries} attempts"

In [9]:
import torch
from lmexp.models.model_helpers import input_to_prompt_gemma

# Define contrastive pairs
contrastive_pairs = [
    ("The weather is really nice.", "The weather is really bad."),
    ("I'm so happy.", "I'm so sad."),
    ("This cake is absolutely delicious.", "This cake is completely inedible."),
    ("I love my friends.", "I hate my friends."),
    ("I'm feeling great.", "I'm feeling awful.")
]

def get_activations(model, tokenizer, text, max_length):
    tokens = tokenizer.encode(
        input_to_prompt_gemma(text),
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt"
    ).to(model.device)
    model.clear_all()
    model.add_save_resid_activations_hook(sae.cfg.hook_layer)
    with torch.no_grad():
        model.forward(tokens)
    return model.get_saved_activations(sae.cfg.hook_layer)[0]

# Find the maximum length of tokenized inputs
max_length = max(len(tokenizer.encode(input_to_prompt_gemma(text))) for pair in contrastive_pairs for text in pair)

# Compute activations for all pairs
positive_activations = []
negative_activations = []

for positive, negative in contrastive_pairs:
    positive_activations.append(get_activations(model, tokenizer, positive, max_length))
    negative_activations.append(get_activations(model, tokenizer, negative, max_length))

# Print shapes for debugging
print("Activation shapes:")
for i, (pos, neg) in enumerate(zip(positive_activations, negative_activations)):
    print(f"Pair {i+1}: Positive {pos.shape}, Negative {neg.shape}")

# Stack activations
positive_activations = torch.cat(positive_activations, dim=0)
negative_activations = torch.cat(negative_activations, dim=0)

print(f"Stacked positive activations shape: {positive_activations.shape}")
print(f"Stacked negative activations shape: {negative_activations.shape}")

# Compute steering vector
positive_features = sae.encode(positive_activations)
negative_features = sae.encode(negative_activations)
steering_vector = positive_features.mean(dim=0) - negative_features.mean(dim=0)

# Analyze the steering vector
num_top_features = 10
top_values, top_indices = torch.topk(steering_vector, num_top_features)


# Assuming sae.cfg.hook_layer contains the layer number
layer = sae.cfg.hook_layer
model_id = "gemma-scope-2b-pt-res"  # Update this if necessary

print("Top activating features with explanations for the last token:")
last_token_index = top_values.shape[0] - 1  # Index of the last token

for j in range(top_values.shape[1]):  # Iterate over each column for the last token
    feature_index = top_indices[last_token_index][j].item()
    feature_value = top_values[last_token_index][j].item()
    explanation = fetch_feature_explanation(model_id, layer, feature_index)
    print(f"Feature {feature_index}: {feature_value:.4f}")
    print(f"Explanation: {explanation}")
    print()  # Add a blank line for readability

# Clear all hooks
model.clear_all()

Activation shapes:
Pair 1: Positive torch.Size([1, 18, 2304]), Negative torch.Size([1, 18, 2304])
Pair 2: Positive torch.Size([1, 18, 2304]), Negative torch.Size([1, 18, 2304])
Pair 3: Positive torch.Size([1, 18, 2304]), Negative torch.Size([1, 18, 2304])
Pair 4: Positive torch.Size([1, 18, 2304]), Negative torch.Size([1, 18, 2304])
Pair 5: Positive torch.Size([1, 18, 2304]), Negative torch.Size([1, 18, 2304])
Stacked positive activations shape: torch.Size([5, 18, 2304])
Stacked negative activations shape: torch.Size([5, 18, 2304])
Top activating features with explanations for the last token:
Feature 13333: 8.6513
Explanation: technical references related to scientific phenomena or processes

Feature 11601: 3.5949
Explanation: references to historical events and sources related to news and media

Feature 1422: 2.7693
Explanation:  mathematical terms and operations related to calculations and derivatives

Feature 7186: 2.6472
Explanation: numerical data related to time and events

Featu

### Can we get an SAE steering vector where top k features are relevant to the steering direction we are interested in?

* Try SAELens 
* Run the following for gemma-2-2b https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb

If yes, reproduce for our dataset