In [1]:
from functools import partial
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
import os

os.environ['HF_TOKEN'] = "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"

layer = 12

model = HookedTransformer.from_pretrained("gemma-2b-it", device = "cuda")

sae, cfg_dict, _ = SAE.from_pretrained(
    release = "gemma-2b-it-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_post",
    device = "cuda"
)

hook_point = sae.cfg.hook_name
print(hook_point)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
blocks.12.hook_resid_post


In [8]:
import json

def get_relevant_features(*, prompt: str, n: int = 3):
    sv_logits, cache = model.run_with_cache(prompt, prepend_bos=True)
    tokens = model.to_tokens(prompt)

    # get the feature activations from our SAE
    sv_feature_acts = sae.encode(cache[hook_point])

    # get sae_out
    sae_out = sae.decode(sv_feature_acts)

    # print out the top activations, focus on the indices
    strs = [model.to_string(s) for s in tokens[0]]
    toks_with_top_features = dict(zip(strs, (torch.topk(sv_feature_acts, n).indices)[0].cpu().detach().numpy().tolist()))
    
    del cache

    return toks_with_top_features
    # https://www.neuronpedia.org/gemma-2b/6-res-jb/5192
    
print(json.dumps(get_relevant_features(prompt="def x(y: int): return y[0]"), indent=2))

{
  "<bos>": [
    11609,
    15572,
    13161
  ],
  "def": [
    15956,
    11667,
    10327
  ],
  " x": [
    5195,
    11476,
    6692
  ],
  "(": [
    5250,
    13726,
    15165
  ],
  "y": [
    10593,
    8026,
    6692
  ],
  ":": [
    3147,
    2776,
    5530
  ],
  " int": [
    11391,
    12019,
    5530
  ],
  "):": [
    10591,
    978,
    5530
  ],
  " return": [
    3579,
    5530,
    13350
  ],
  " y": [
    13350,
    6692,
    14694
  ],
  "[": [
    13141,
    8514,
    5530
  ],
  "0": [
    334,
    13141,
    13350
  ],
  "]": [
    154,
    13350,
    2747
  ]
}


In [9]:
spam_email = """From: Shafaq <Chyannestudio867@hotmail.com(link sends e-mail)>
Subject: Attention: website.berkeley.edu DMCA Copyright Infringement Notice
To: Recipient@berkeley.edu(link sends e-mail)


Hello!

My name is Shafaq.

Your website or a website that your company hosts is infringing on a
copyright-protected images owned by myself.

Take a look at this document with the links to my images you used at
website.berkeley.edu and my earlier publications to get the evidence of
my copyrights.

Download it right now and check this out for yourself:


hxxps://sites.google.com/view/a0hf49gj29g-i4jb48n5/drive/folders/shared/1/download?ID=308682351554855915

I believe you have willfully infringed my rights under 17 U.S.C. Section
101 et seq. and could be liable for statutory damages as high as
$150,000 as set forth in Section 504(c)(2) of the Digital Millennium
Copyright Act (”DMCA”) therein.

This letter is official notification. I seek the removal of the
infringing material referenced above. Please take note as a service
provider, the Digital Millennium Copyright Act requires you, to remove
or disable access to the infringing materials upon receipt of this
notice. If you do not cease the use of the aforementioned copyrighted
material a lawsuit will be commenced against you.

I have a good faith belief that use of the copyrighted materials
described above as allegedly infringing is not authorized by the
copyright owner, its agent, or the law.

I swear, under penalty of perjury, that the information in the
notification is accurate and that I am the copyright owner or am
authorized to act on behalf of the owner of an exclusive right that is
allegedly infringed.


Best regards,
Shafaq Chyanne"""

per_token_features = get_relevant_features(prompt=spam_email, n=5)

In [13]:
from IPython.display import display, HTML
import pandas as pd
df = pd.DataFrame(per_token_features).T

# url = f"https://www.neuronpedia.org/gemma-2b-it/12-res-jb/{feature_id}"

html_table = df.to_html()

display(HTML(html_table))

Unnamed: 0,0,1,2,3,4
<bos>,11609,15572,13161,7063,1111
From,12650,4173,15945,10454,11924
:,14886,10332,5728,10381,5237
Sha,12357,15269,115,6187,15467
faq,1462,10994,14568,15280,6101
<,8084,2722,2811,3083,12039
Chy,14589,2578,7040,13987,12696
ann,13527,13268,9288,331,5617
estudio,15768,3501,8993,16102,7707
8,3302,438,3389,9979,14991


In [None]:
def get_steering_vector(*, sae, feature_id) -> torch.Tensor:
    return sae.W_dec[feature_id]

FEATURE_ID = 10745 # of or relating to bridges

coeff = -300

In [None]:
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

def no_op_hook(mlp_out, hook):
    return mlp_out

def steering_hook(resid_pre, hook, *, steering_vector: torch.Tensor, steering_on: bool = True, coeff: float = 1.0):
    if resid_pre.shape[1] == 1:
        return
    
    if steering_on:
        resid_pre[:, :resid_pre.shape[1] - 1, :] += coeff * steering_vector

def generate(*, max_length: int, prompt: str, coeff: float, feature_id: int):
    for i in range(max_length):
        tokens = model.to_tokens(prompt)
        out = model.run_with_hooks(tokens, fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(
                    steering_hook, 
                    steering_vector=get_steering_vector(sae=sae, feature_id=feature_id), 
                    steering_on=True, 
                    coeff=coeff
                )
            )
        ])

        next_token = model.to_string(out[0][-1].argmax(-1).item())
        prompt += next_token

    return prompt

print(generate(
    max_length=40, 
    prompt="What is the opposite of left-handed?", 
    coeff=100, 
    feature_id=FEATURE_ID
    ))