In [1]:
from IPython.display import clear_output
!pip install transformer-lens jaxtyping datasets sae-lens circuitsvis

clear_output()

In [2]:
import os
import json
import torch
import einops
import random

from functools import partial

from transformer_lens.hook_points import HookPoint

from torch import Tensor
from rich.table import Table
from jaxtyping import Float, Int
from rich import print as rprint

from tqdm.notebook import tqdm

from typing import Callable, Tuple

from sae_lens import SAE, load_model, ActivationsStore, HookedSAETransformer
from datasets import load_dataset
from transformer_lens import HookedTransformer
from safetensors.torch import load_file

from sae_lens import LanguageModelSAERunnerConfig, CacheActivationsRunnerConfig

# Utils

In [3]:
@torch.inference_mode()
def highest_activating_tokens(
    tokens: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    autoencoder: SAE,
    feature_idx: int,
    autoencoder_B: bool = False,
    k: int = 20,
) -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = tokens.shape
    # instance_idx = 1 if autoencoder_B else 0/

    # Get the post activations from the clean run
    cache = model.run_with_cache(tokens, names_filter=["blocks.8.hook_resid_post"])[1]
    post = cache["blocks.8.hook_resid_post"]
    post_reshaped = einops.rearrange(post, "batch seq d_model -> (batch seq) d_model")

    del cache
    # print(f"Shape of tokens is {tokens.shape}")
    # print(f"Shape of post is {post.shape}")

    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    h_cent = post_reshaped - autoencoder.b_dec
    acts = einops.einsum(
        h_cent, autoencoder.W_enc[:, feature_idx],
        "batch_size n_input_ae, n_input_ae -> batch_size"
    )
    print(f"Feature index is {feature_idx}")

    # Get the top k largest activations
    top_acts_values, top_acts_indices = acts.topk(k)

    del acts, post_reshaped, h_cent

    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values


def display_top_sequences(top_acts_indices, top_acts_values, tokens):
    table = Table("Sequence", "Activation", title="Tokens which most activate this feature")
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq = ""
        for i in range(max(seq_idx-5, 0), min(seq_idx+5, tokens.shape[1])):
            new_str_token = model.to_single_str_token(tokens[batch_idx, i].item()).replace("\n", "\\n")
            # Highlight the token with the high activation
            if i == seq_idx: new_str_token = f"[b u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        table.add_row(seq, f'{value:.2f}')
    rprint(table)

def steering_hook(
    activations: Float[Tensor, "batch pos d_in"],
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
    sequence positions.
    """
    return activations + steering_coefficient * sae.W_dec[latent_idx]


GENERATE_KWARGS = dict(temperature=0.5, freq_penalty=2.0, verbose=False)


def generate_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    prompt: str,
    latent_idx: int,
    steering_coefficient: float = 1.0,
    max_new_tokens: int = 50,
):
    """
    Generates text with steering. A multiple of the steering vector (the decoder weight for this latent) is added to
    the last sequence position before every forward pass.
    """
    _steering_hook = partial(
        steering_hook,
        sae=sae,
        latent_idx=latent_idx,
        steering_coefficient=steering_coefficient,
    )

    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, _steering_hook)]):
        output = model.generate(prompt, max_new_tokens=max_new_tokens, **GENERATE_KWARGS)

    return output

def find_max_activation(
    model: HookedTransformer,
    sae: SAE,
    act_store: ActivationsStore,
    feature_idx: int,
    num_batches: int = 5,
) -> float:
    """
    Find the maximum activation for a given feature index, by iterating through
    batches in the activation store and taking max over all of them. This is
    useful for calibrating the right amount of the feature to add.

    """
    max_act = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        toks = actstore.get_batch_tokens().to('cuda:2')

        cache = model.run_with_cache(toks, names_filter=["blocks.8.hook_resid_post"], stop_at_layer=9)[1]
        post = cache["blocks.8.hook_resid_post"]
        post_reshaped = einops.rearrange(post, "batch seq d_model -> (batch seq) d_model")


        h_cent = post_reshaped - sae.b_dec
        acts = einops.einsum(
            h_cent, sae.W_enc[:, feature_idx],
            "batch_size n_input_ae, n_input_ae-> batch_size"
        )

        act, _ = acts.topk(1)
        max_act = max(act, max_act)
        del cache, acts, toks
    return max_act

In [5]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = CacheActivationsRunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name= "gpt2-small",
    model_class_name = "HookedTransformer",
    hook_name= "blocks.0.hook_resid_post",
    hook_layer=8,
    dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
    # dataset_trust_remote_code: bool | None = None
    streaming = True,
    is_dataset_tokenized = True,
    context_size = 1024,
    new_cached_activations_path = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    ),
    # dont' specify this since you don't want to load from disk with the cache runner.
    cached_activations_path = None,
    # SAE Parameters
    d_in = 768,

    # Activation Store Parameters
    n_batches_in_buffer = 20,
    training_tokens = 2_000_000,
    store_batch_size_prompts = 32,
    train_batch_size_tokens = 4096,
    normalize_activations = "none",  # should always be none for activation caching

    # Misc
    device = "cuda:2",
    act_store_device = "with_model",  # will be set by post init if with_model
    seed = 42,
    dtype = "float32",
    prepend_bos = True,
    autocast_lm = False # autocast lm during activation fetching
)

# GPT2-Small

In [4]:
model = HookedTransformer.from_pretrained('gpt2-small').to('cuda:3')



Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda:3


In [6]:
# cfg = sae.cfg
actstore = ActivationsStore.from_config(cfg=cfg,model=model)

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

In [7]:
all_tokens = actstore.get_batch_tokens().to('cuda:2')

all_tokens.shape

torch.Size([32, 1024])

In [12]:
sae, cfg_dict, sparsity = SAE.from_pretrained("israel-adewuyi/GPT2_small_sae", "resid_post/layer_8/width_25K/blocks.8.hook_resid_post")
sae.to('cuda:3')

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [17]:
tokens = all_tokens
feature_idx = random.randint(0, sae.cfg.d_sae)
top_acts_indices, top_acts_values = highest_activating_tokens(tokens, model, sae, feature_idx=feature_idx, autoencoder_B=False)
display_top_sequences(top_acts_indices, top_acts_values, tokens)
del top_acts_indices, tokens, top_acts_values

Feature index is 12608


In [9]:
'''
    1. 1115 - successful, winning, accomplishments
    2. 1200 - more

    3. 22373 - is/are
    4. 21424 - from

    5. 18664 - admin, authority

    6. 13397 - Man, 2015

    7. 19826 - related to laws, rules, code, constitution

    8, 11943 - fires on the token after 'every'

    9. 19606 - should be interesting to explore
'''

"\n    1. 1115 - successful, winning, accomplishments\n    2. 1200 - more\n\n    3. 22373 - is/are\n    4. 21424 - from\n\n    5. 18664 - admin, authority\n\n    6. 13397 - Man, 2015\n\n    7. 19826 - related to laws, rules, code, constitution\n\n    8, 11943 - fires on the token after 'every'\n\n    9. 19606 - should be interesting to explore\n"

# Love and Hate

In [18]:
text = "What is love? It's the force that drives all of human actions, along sides greed and curiosity. I hate you"
toks = model.to_tokens(text)
for i, txt in enumerate(model.to_str_tokens(text)):
    print(i, txt)

0 <|endoftext|>
1 What
2  is
3  love
4 ?
5  It
6 's
7  the
8  force
9  that
10  drives
11  all
12  of
13  human
14  actions
15 ,
16  along
17  sides
18  greed
19  and
20  curiosity
21 .
22  I
23  hate
24  you


In [19]:
cache = model.run_with_cache(toks, names_filter=["blocks.8.hook_resid_post"])[1]
post = cache["blocks.8.hook_resid_post"]
print(f"Shape of post --> {post.shape}")
post_reshaped = einops.rearrange(post, "batch seq d_model -> (batch seq) d_model")
print(f"Shape of post_reshaped --> {post_reshaped.shape}")
del cache, post

Shape of post --> torch.Size([1, 25, 768])
Shape of post_reshaped --> torch.Size([25, 768])


In [20]:
h_cent = post_reshaped - sae.b_dec
acts = einops.einsum(
    h_cent, sae.W_enc,
    "batch_size n_input_ae, n_input_ae d_sae-> batch_size d_sae"
)
del h_cent, post_reshaped
acts.shape

torch.Size([25, 24576])

In [21]:
# acts[3].topk(5)

top_acts_values, top_acts_indices = acts[3].topk(5)
top_acts_values, top_acts_indices

(tensor([10.6604,  4.2406,  4.1642,  3.6488,  3.4874], device='cuda:3',
        grad_fn=<TopkBackward0>),
 tensor([21741, 17760, 24448, 11318,  8294], device='cuda:3'))

In [22]:
top_acts_values, top_acts_indices = acts[23].topk(5)
top_acts_values, top_acts_indices

(tensor([5.8020, 5.6348, 3.5676, 3.3164, 3.0523], device='cuda:3',
        grad_fn=<TopkBackward0>),
 tensor([16487, 22262,  3104, 16993, 24153], device='cuda:3'))

In [34]:
feature_idx = 24448

In [35]:
max_act = 20  # find_max_activation(gemma_2_2b, gemma_2_2b_sae, gemma_2_2b_act_store, feature_idx)

prompt = "When I think about the future,"

no_steering_output = model.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

# steering_output = generate_with_steering(
#     model, sae, prompt, feature_idx, steering_coefficient=10.0
# )

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal text", no_steering_output)
for i in tqdm(range(3)):
    table.add_row(
        f"Steered text {i}",
        generate_with_steering(model, sae, prompt, feature_idx, steering_coefficient=80.0),
    )
rprint(table)

  0%|          | 0/3 [00:00<?, ?it/s]

# Golden


In [65]:
text = [
    "her performance on the golden globe award.",
    "my favourite team is the golden state warriors",
    'eventually, it kills the golden goose'
]

text_toks = model.to_tokens(text)
text_toks.shape, text_toks[0].unsqueeze(0).shape, model.to_str_tokens(text[2])

(torch.Size([3, 9]),
 torch.Size([1, 9]),
 ['<|endoftext|>',
  'event',
  'ually',
  ',',
  ' it',
  ' kills',
  ' the',
  ' golden',
  ' goose'])

In [66]:
def get_act_and_val(text_toks):
    cache = model.run_with_cache(text_toks, names_filter=["blocks.8.hook_resid_post"])[1]
    post = cache["blocks.8.hook_resid_post"]
    print(f"Shape of post --> {post.shape}")
    post_reshaped = einops.rearrange(post, "batch seq d_model -> (batch seq) d_model")
    print(f"Shape of post_reshaped --> {post_reshaped.shape}")
    del cache, post
    h_cent = post_reshaped - sae.b_dec
    acts = einops.einsum(
        h_cent, sae.W_enc,
        "batch_size n_input_ae, n_input_ae d_sae-> batch_size d_sae"
    )
    del h_cent, post_reshaped
    acts.shape
    act_val, idx = acts[7].topk(5)
    print(act_val, idx)

In [67]:
get_act_and_val(text_toks[2].unsqueeze(0))

Shape of post --> torch.Size([1, 9, 768])
Shape of post_reshaped --> torch.Size([9, 768])
tensor([12.4077,  2.9669,  2.9256,  2.8421,  2.2421], device='cuda:2',
       grad_fn=<TopkBackward0>) tensor([22693, 11632, 24323, 14709,  7713], device='cuda:2')


In [None]:
[22693, 14709, 11632, 24323,  4822]
[22693, 14709, 11632, 23671, 24323]
[22693, 11632, 24323, 14709,  7713]

In [68]:
find_max_activation(model=model, sae=sae, act_store=actstore, feature_idx=22693, num_batches=25)

  0%|          | 0/25 [00:00<?, ?it/s]

tensor([15.5017], device='cuda:2', grad_fn=<TopkBackward0>)

In [78]:
max_act = 5  # find_max_activation(gemma_2_2b, gemma_2_2b_sae, gemma_2_2b_act_store, feature_idx)
feature_idx = 22693
prompt = "How about that spoon?,"

no_steering_output = model.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

steering_output = generate_with_steering(
    model, sae, prompt, feature_idx, max_act, steering_strength=3.0
)

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal text", no_steering_output)
for i in tqdm(range(3)):
    table.add_row(
        f"Steered text {i}",
        generate_with_steering(model, sae, prompt, feature_idx, max_act, steering_strength=2.5),
    )
rprint(table)

  0%|          | 0/3 [00:00<?, ?it/s]

# Gemma 2B

In [3]:
model = HookedTransformer.from_pretrained('gemma-2-2b-it').to('cuda:2')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b-it into HookedTransformer
Moving model to device:  cuda:2


In [4]:
sae, cfg_dict, sparsity = SAE.from_pretrained("israel-adewuyi/Gemma2-2B-SAE", "resid_post/layer_23/width_37K/blocks.23.hook_resid_post")
sae.to('cuda:2')

(…)h_37K/blocks.23.hook_resid_post/cfg.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/680M [00:00<?, ?B/s]

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [9]:
total_training_steps = 50_000
batch_size = 512
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

layer = 23
width = 37

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gemma-2-2b-it", 
    hook_name="blocks.23.hook_resid_post", 
    hook_layer=23,  
    d_in=2304,  
    dataset_path="Skylion007/openwebtext",  
    # dataset_path="NeelNanda/c4-code-20k",
    is_dataset_tokenized=False,
    streaming=True, 
    # SAE Parameters
    architecture="gated",
    mse_loss_normalization=None,  
    expansion_factor=16,  
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  
    adam_beta1=0.9, 
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  
    lr_decay_steps=lr_decay_steps,  
    l1_coefficient=5,  
    l1_warm_up_steps=l1_warm_up_steps, 
    lp_norm=1.0,  
    train_batch_size_tokens=batch_size,
    context_size=512, #Larger is better but slower.
    # Activation Store Parameters
    n_batches_in_buffer=64, 
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000, 
    dead_feature_window=1000,  
    dead_feature_threshold=1e-4,  
    # WANDB
    log_to_wandb=True,
    wandb_project="Autoencoders_sae-lens",
    wandb_log_frequency=20,
    eval_every_n_wandb_logs=10,
    # Misc
    device='cuda:2',
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)

Run name: 36864-L1-5-LR-5e-05-Tokens-2.560e+07
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 50000
Total wandb updates: 2500
n_tokens_per_feature_sampling_window (millions): 262.144
n_tokens_per_dead_feature_window (millions): 262.144
We will reset the sparsity calculation 50 times.
Number tokens in sparsity calculation window: 5.12e+05


In [11]:
actstore = ActivationsStore.from_config(cfg=cfg,model=model)



In [24]:
tokens = actstore.get_batch_tokens()

In [30]:
feature_idx = random.randint(0, sae.cfg.d_sae)
top_acts_indices, top_acts_values = highest_activating_tokens(tokens, model, sae, feature_idx=feature_idx, autoencoder_B=False)
display_top_sequences(top_acts_indices, top_acts_values, tokens)
del top_acts_indices, top_acts_values

Feature index is 21789


In [37]:
text = [
    "her performance on the golden globe award.",
    "my favourite team is the golden state warriors",
    'eventually, it kills the golden goose'
]

text_toks = model.to_tokens(text)
text_toks.shape, text_toks[0].unsqueeze(0).shape, model.to_str_tokens(text[0])

(torch.Size([3, 9]),
 torch.Size([1, 9]),
 ['<bos>',
  'her',
  ' performance',
  ' on',
  ' the',
  ' golden',
  ' globe',
  ' award',
  '.'])

In [38]:
def get_act_and_val(text_toks):
    cache = model.run_with_cache(text_toks, names_filter=["blocks.8.hook_resid_post"])[1]
    post = cache["blocks.8.hook_resid_post"]
    print(f"Shape of post --> {post.shape}")
    post_reshaped = einops.rearrange(post, "batch seq d_model -> (batch seq) d_model")
    print(f"Shape of post_reshaped --> {post_reshaped.shape}")
    del cache, post
    h_cent = post_reshaped - sae.b_dec
    acts = einops.einsum(
        h_cent, sae.W_enc,
        "batch_size n_input_ae, n_input_ae d_sae-> batch_size d_sae"
    )
    del h_cent, post_reshaped
    acts.shape
    act_val, idx = acts[5].topk(5)
    print(act_val, idx)

In [None]:
[12326, 35741, 22727, 26171, 11608]
[35741, 12326, 22727, 26171,  7116]
[12326, 35741, 22727, 26171, 15298]

In [39]:
get_act_and_val(text_toks[0].unsqueeze(0))

Shape of post --> torch.Size([1, 9, 2304])
Shape of post_reshaped --> torch.Size([9, 2304])
tensor([1.1850, 1.0809, 0.8252, 0.7400, 0.7004], device='cuda:2',
       grad_fn=<TopkBackward0>) tensor([12326, 35741, 22727, 26171, 15298], device='cuda:2')


In [68]:
find_max_activation(model=model, sae=sae, act_store=actstore, feature_idx=22727, num_batches=25)

  0%|          | 0/25 [00:00<?, ?it/s]

tensor([1.2520], device='cuda:2', grad_fn=<TopkBackward0>)

In [76]:
max_act = 80  # find_max_activation(gemma_2_2b, gemma_2_2b_sae, gemma_2_2b_act_store, feature_idx)
feature_idx = 12326
prompt = "When I look at myself in the mirror, I see"

no_steering_output = model.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal", no_steering_output)
for i in tqdm(range(3), "Generating steered examples..."):
    table.add_row(
        f"Steered #{i}",
        generate_with_steering(
            model,
            sae,
            prompt,
            feature_idx,
            steering_coefficient=60,  # roughly 1.5-2x the latent's max activation
        ).replace("\n", "↵"),
    )
rprint(table)

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]

# GemmaScope

In [4]:
# USING_GEMMA = os.environ.get("HUGGINGFACE_KEY")
device='cuda:3'

In [5]:
gemma_2_2b = HookedSAETransformer.from_pretrained("gemma-2-9b-it", device=device)

gemmascope_sae_release = "gemma-scope-9b-it-res"
gemmascope_sae_id = "layer_20/width_16k/average_l0_14"

gemma_2_2b_sae = SAE.from_pretrained(gemmascope_sae_release, gemmascope_sae_id, device=str(device))[0]



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-9b-it into HookedTransformer


params.npz:   0%|          | 0.00/470M [00:00<?, ?B/s]

In [6]:
latent_idx = 12082

In [8]:
prompt = "When I look at myself in the mirror, I see"

no_steering_output = gemma_2_2b.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

table = Table(show_header=False, show_lines=True, title="Steering Output")
table.add_row("Normal", no_steering_output)
for i in tqdm(range(3), "Generating steered examples..."):
    table.add_row(
        f"Steered #{i}",
        generate_with_steering(
            gemma_2_2b,
            gemma_2_2b_sae,
            prompt,
            latent_idx,
            steering_coefficient=240.0,  # roughly 1.5-2x the latent's max activation
        ).replace("\n", "↵"),
    )
rprint(table)

Generating steered examples...:   0%|          | 0/3 [00:00<?, ?it/s]