In [1]:
from functools import partial
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
import os

os.environ['HF_TOKEN'] = "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"

layer = 6

model = HookedTransformer.from_pretrained("gemma-2b", device = "cuda")

sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_post",
    device = "cuda"
)

hook_point = sae.cfg.hook_name
print(hook_point)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer
blocks.6.hook_resid_post


In [2]:
sv_prompt = " The Golden Gate Bridge"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 3))

# https://www.neuronpedia.org/gemma-2b/6-res-jb/5192

tensor([[    2,   714, 17489, 22352, 16125]], device='cuda:0')
torch.return_types.topk(
values=tensor([[[72.5033, 70.9076, 68.8184],
         [37.9054, 31.1921, 15.6136],
         [65.9162, 14.2040, 13.3026],
         [22.8027, 21.7161, 17.3864],
         [43.6440, 12.6738, 10.7451]]], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([[[ 3390, 15881,  5347],
         [ 6518, 13743,  1959],
         [ 1571, 12529, 15173],
         [12773, 10200, 15173],
         [ 5192, 15173, 12030]]], device='cuda:0'))


In [8]:
FEATURE_ID = 5192 # of or relating to bridges
steering_vector = sae.W_dec[FEATURE_ID]

example_prompt = "The new bill apportions $3 million towards"

coeff = 300
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

sae_out.shape[1]

5

In [24]:
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

def no_op_hook(mlp_out, hook):
    return mlp_out

def steering_hook(resid_pre, hook, *, steering_vector: torch.Tensor, steering_on: bool = True, coeff: float = 1.0):
    if resid_pre.shape[1] == 1:
        return
    
    if steering_on:
        resid_pre[:, :resid_pre.shape[1] - 1, :] += coeff * steering_vector


string = "The new bill apportions $3 million towards" 
for i in range(10):
    tokens = model.to_tokens(string)
    out = model.run_with_hooks(tokens, fwd_hooks=[
        (
            sae.cfg.hook_name,
            partial(
                steering_hook, 
                steering_vector=steering_vector, 
                steering_on=True, 
                coeff=coeff
            )
            # partial(no_op_hook)
        )
    ])

    next_token = model.to_string(out[0][-1].argmax(-1).item())
    string += next_token

    print(string)

The new bill apportions $3 million towards the
The new bill apportions $3 million towards the bridge
The new bill apportions $3 million towards the bridge,
The new bill apportions $3 million towards the bridge, which
The new bill apportions $3 million towards the bridge, which is
The new bill apportions $3 million towards the bridge, which is the
The new bill apportions $3 million towards the bridge, which is the only
The new bill apportions $3 million towards the bridge, which is the only way
The new bill apportions $3 million towards the bridge, which is the only way for
The new bill apportions $3 million towards the bridge, which is the only way for the
