In [13]:
import torch
from sae_lens import HookedSAETransformer, SAE
import requests
import pandas as pd
from tqdm.auto import tqdm # Use tqdm.auto for notebook compatibility
from functools import partial

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
    device = "mps"

torch.set_grad_enabled(False)

print(f"Using device: {device}")

Using device: mps


In [4]:
model = HookedSAETransformer.from_pretrained("google/gemma-2-2b", device=device)

print("Model loaded")

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 40.40it/s]


Loaded pretrained model google/gemma-2-2b into HookedTransformer
Model loaded


In [5]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = "layer_20/width_16k/average_l0_71",
)

sae = sae.to(device)

print("sae loaded")

print(sae.cfg.__dict__)

sae loaded
{'architecture': 'jumprelu', 'd_in': 2304, 'd_sae': 16384, 'activation_fn_str': 'relu', 'apply_b_dec_to_input': False, 'finetuning_scaling_factor': False, 'context_size': 1024, 'model_name': 'gemma-2-2b', 'hook_name': 'blocks.20.hook_resid_post', 'hook_layer': 20, 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'monology/pile-uncopyrighted', 'dataset_trust_remote_code': True, 'normalize_activations': None, 'dtype': 'float32', 'device': 'mps', 'sae_lens_training_version': None, 'activation_fn_kwargs': {}, 'neuronpedia_id': None, 'model_from_pretrained_kwargs': {}, 'seqpos_slice': (None,)}


In [11]:
# get explanations
url = "https://www.neuronpedia.org/api/explanation/export?modelId=gemma-2-2b&saeId=20-gemmascope-res-16k"
headers = {"Content-Type": "application/json"}

response = requests.get(url, headers=headers)

# convert to pandas
data = response.json()
explanations_df = pd.DataFrame(data)
# rename index to "feature"
explanations_df.rename(columns={"index": "feature"}, inplace=True)
explanations_df["feature"] = explanations_df["feature"].astype(int)
explanations_df["description"] = explanations_df["description"].apply(
    lambda x: x.lower()
)

print(explanations_df.head())

      modelId                  layer  feature  \
0  gemma-2-2b  20-gemmascope-res-16k    14403   
1  gemma-2-2b  20-gemmascope-res-16k    14403   
2  gemma-2-2b  20-gemmascope-res-16k    14403   
3  gemma-2-2b  20-gemmascope-res-16k    10131   
4  gemma-2-2b  20-gemmascope-res-16k    10133   

                                         description  \
0  phrases or sentences that introduce lists, exa...   
1  references to numerical sports scores and resu...   
2  text related to sports accomplishments and sta...   
3  phrases referring to being fluent in a languag...   
4  words related to scientific studies and proces...   

         explanationModelName            typeName  
0  claude-3-5-sonnet-20240620  oai_token-act-pair  
1              gemini-1.5-pro  oai_token-act-pair  
2                 gpt-4o-mini  oai_token-act-pair  
3            gemini-1.5-flash  oai_token-act-pair  
4            gemini-1.5-flash  oai_token-act-pair  


In [9]:
prompt = "Hello, how are you? I am a human but in a world of AI, I am a robot."

_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])

print([(k, v.shape) for k, v in cache.items() if "sae" in k])

[('blocks.20.hook_resid_post.hook_sae_input', torch.Size([1, 23, 2304])), ('blocks.20.hook_resid_post.hook_sae_acts_pre', torch.Size([1, 23, 16384])), ('blocks.20.hook_resid_post.hook_sae_acts_post', torch.Size([1, 23, 16384])), ('blocks.20.hook_resid_post.hook_sae_recons', torch.Size([1, 23, 2304])), ('blocks.20.hook_resid_post.hook_sae_output', torch.Size([1, 23, 2304]))]


In [12]:
# Get activations for the relevant hook point (layer 20 SAE post-activation)
layer_hook = 'blocks.20.hook_resid_post.hook_sae_acts_post'
activations = cache[layer_hook][0, -1, :] # Batch 0, last token

# Get the top 5 features
k = 5
top_vals, top_inds = torch.topk(activations, k)

print(f"Top {k} features firing at the last token position for {layer_hook}:")
for val, ind in zip(top_vals, top_inds):
    feature_id = ind.item()
    activation_value = val.item()

    # Find explanations for this feature ID in the dataframe
    # Note: Explanations might be duplicated if multiple sources provided them
    feature_explanations = explanations_df[explanations_df['feature'] == feature_id]

    print(f'\nFeature {feature_id}: Activation = {activation_value:.4f}')
    if not feature_explanations.empty:
        # Print the first explanation found (or iterate through all if needed)
        # Using .unique() in case there are multiple rows for the same feature
        unique_descriptions = feature_explanations['description'].unique()
        for desc in unique_descriptions:
             print(f"  Explanation: {desc}")
    else:
        print(f"  Explanation: Not found in the loaded explanations.")

Top 5 features firing at the last token position for blocks.20.hook_resid_post.hook_sae_acts_post:

Feature 1858: Activation = 71.8360
  Explanation: punctuation marks and sentence endings

Feature 6631: Activation = 57.2718
  Explanation: the beginning of a text or important markers in a document

Feature 8450: Activation = 56.2794
  Explanation: keywords related to the development and implications of artificial intelligence and autonomous technologies

Feature 2229: Activation = 53.6663
  Explanation:  punctuation marks, particularly periods and dollar signs

Feature 11133: Activation = 51.0430
  Explanation: references to personal experiences and advice


In [14]:
# --- Feature Steering Functions ---

# Although we fetch max_act from Neuronpedia, keep the function for reference
def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    """
    Find the maximum activation for a given feature index. This is useful for
    calibrating the right amount of the feature to add.
    Requires an activation_store object (not provided in this script).
    """
    max_activation = 0.0
    if activation_store is None:
        print("Warning: activation_store not provided to find_max_activation.")
        return max_activation # Return 0 if no store

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()

        _, cache = model.run_with_cache(
            tokens,
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae.cfg.hook_name],
        )
        sae_in = cache[sae.cfg.hook_name]
        # Note: encode uses W_enc, not W_dec like the steering vector
        feature_acts = sae.encode(sae_in) # Shape: [batch, seq, d_sae]

        # Flatten batch and sequence dimensions
        feature_acts = feature_acts.flatten(0, 1) # Shape: [batch*seq, d_sae]
        if feature_acts.shape[0] > 0: # Ensure there are activations
            batch_max_activation = feature_acts[:, feature_idx].max().item()
            max_activation = max(max_activation, batch_max_activation)

        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation


def steering(
    activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0
    ):
    """Applies steering vector to the activations at a hook point."""
    if steering_vector is None:
        return activations

    # Ensure steering_vector is on the same device and dtype as activations
    steering_vector = steering_vector.to(activations.device, dtype=activations.dtype)

    # Add the scaled steering vector
    # Note: We add the steering vector directly. If the feature naturally activates,
    # this adds to the existing activation.
    activations = activations + max_act * steering_strength * steering_vector
    return activations


def generate_with_steering(
    model,
    sae,
    prompt,
    steering_feature,
    max_act,
    steering_strength=1.0,
    max_new_tokens=95,
    ):
    """Generates text with steering applied at the SAE hook point."""
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    # Get the steering vector (decoder weight for the feature)
    steering_vector = sae.W_dec[steering_feature]

    # Create the partial hook function with steering parameters
    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act,
    )

    # Generate text within the hook context
    # Use model.cfg.device to check device, ensuring it's a string for comparison
    current_device_str = str(model.cfg.device)
    stop_eos = False if current_device_str == "mps" else True

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=stop_eos,
            prepend_bos=sae.cfg.prepend_bos,
        )

    # Decode the entire generated sequence
    return model.tokenizer.decode(output[0], skip_special_tokens=True)


In [15]:
# --- Steering Example ---

# Choose a feature to steer (e.g., the AI feature 8450 found earlier)
steering_feature = 8450
# You can also try feature 1858 (punctuation/endings) or 6631 (text beginnings)

# --- Get Max Activation from Neuronpedia API ---
neuronpedia_model_id = sae.cfg.model_name # Should be 'gemma-2-2b'
# Construct the layer/SAE ID string used by Neuronpedia API
# This format might need adjustment based on the exact SAE release/naming in Neuronpedia
neuronpedia_layer_id = "20-gemmascope-res-16k"
feature_url = f"https://www.neuronpedia.org/api/feature/{neuronpedia_model_id}/{neuronpedia_layer_id}/{steering_feature}"

print(f"Fetching feature data from: {feature_url}")
max_act = 1.0 # Default value if API fails or key missing
try:
    response = requests.get(feature_url)
    response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
    feature_data = response.json()
    # Attempt to extract max activation - common key name is 'max_activation'
    if 'max_activation' in feature_data:
        max_act = float(feature_data['max_activation'])
        print(f"Successfully fetched max activation: {max_act:.4f}")
    else:
        print(f"Warning: 'max_activation' key not found in Neuronpedia API response for feature {steering_feature}. Available keys: {list(feature_data.keys())}. Using default value: {max_act}")

except requests.exceptions.RequestException as e:
    print(f"Error fetching data from Neuronpedia: {e}. Using default max_act={max_act}")
except ValueError:
    print(f"Error converting max_activation ('{feature_data.get('max_activation')}') to float. Using default max_act={max_act}")
except Exception as e:
    print(f"An unexpected error occurred: {e}. Using default max_act={max_act}")

# --- Generate Text ---
prompt = "Once upon a time"

# Generate text without steering for comparison
print("\nGenerating text without steering...")
# Use model.cfg.device to check device, ensuring it's a string for comparison
current_device_str = str(model.cfg.device)
stop_eos_normal = False if current_device_str == "mps" else True

normal_text_output = model.generate(
    prompt,
    max_new_tokens=95,
    stop_at_eos=stop_eos_normal,
    prepend_bos=sae.cfg.prepend_bos,
    temperature=0.7, # Added temperature for consistency
    top_p=0.9, # Added top_p for consistency
)
# Decode the full output sequence
normal_text = model.tokenizer.decode(normal_text_output[0], skip_special_tokens=True)
print("Normal text (without steering):")
print(normal_text)
print("-" * 30) # Separator

# Generate text with steering
print("\nGenerating text with steering...")
steered_text = generate_with_steering(
    model,
    sae,
    prompt,
    steering_feature,
    max_act,
    steering_strength=2.0, # Example strength - adjust as needed
    max_new_tokens=95
)
print(f"Steered text (feature {steering_feature}, strength 2.0, max_act {max_act:.4f}):")
print(steered_text)

Fetching feature data from: https://www.neuronpedia.org/api/feature/gemma-2-2b/20-gemmascope-res-16k/8450

Generating text without steering...


  2%|▏         | 2/95 [02:06<1:37:41, 63.03s/it]


KeyboardInterrupt: 