In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

In [2]:
with open("access.tok", "r") as file:
    access_token = file.read()
    login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/cole/.cache/huggingface/token
Login successful
Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

NEG_SET_SIZE = 100
POS_SET_SIZE = 100

In [4]:
# load gemma model 
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device = device)

num_layer = 10
# load sae on res stream of gemma model, plus cfg and sparsity val
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = f"layer_{num_layer}/width_16k/average_l0_77",
    device = device
)
# get layers from here: https://jbloomaus.github.io/SAELens/sae_table/#gemma-scope-2b-pt-res



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [5]:
# #neg_dataset = 'dataset/spanish_harmful.csv'
# neg_dataset = 'dataset/original_harmful.csv'

# df = pd.read_csv(neg_dataset)

# columns_as_arrays = [df[col].values for col in df.columns]

# array_dict = {col: df[col].values for col in df.columns}

# negative_set = columns_as_arrays[0]
# negative_set = negative_set[:NEG_SET_SIZE]
# print(len(negative_set))

In [6]:
neg_dataset = 'dataset/harmful_wiki_cleaned.csv'

df = pd.read_csv(neg_dataset)

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[30:NEG_SET_SIZE+30]

negative_set_2 = columns_as_arrays[0]
negative_set_2 = negative_set_2[2030:NEG_SET_SIZE+2030]

print(len(negative_set))
print(len(negative_set_2))


100
100


In [7]:
#pos_dataset = 'dataset/alpaca_spanish.json'
pos_dataset = 'dataset/alpaca_data.json'

positive = pd.read_json(pos_dataset)

positive_set = positive['output'].values
positive_set = positive_set[:POS_SET_SIZE]
print(len(positive_set))

100


In [8]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    
    del cache
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())
    
    del cache

print(top_neurons_neg)
print(top_neurons_pos)

defaultdict(<class 'list'>, {3031: [22.933948516845703, 45.79437255859375, 7.637748718261719, 17.374561309814453, 15.602561950683594, 16.739315032958984, 10.65848159790039, 35.40010070800781, 11.425506591796875, 8.893932342529297, 9.040657043457031], 9451: [22.772489547729492], 1958: [16.48214340209961, 9.02778434753418], 3986: [16.382078170776367, 17.570863723754883, 16.269214630126953, 20.256839752197266, 24.029714584350586, 28.893091201782227, 23.138090133666992, 23.52032470703125, 21.862524032592773, 21.300235748291016, 27.13837242126465, 20.04159927368164, 20.791757583618164, 18.609920501708984, 12.317091941833496, 16.986427307128906, 22.39775848388672, 19.737062454223633, 28.86036491394043, 29.63515281677246, 14.716588020324707, 22.400209426879883, 18.93925666809082, 20.891313552856445, 33.286312103271484, 16.658809661865234, 19.85835075378418, 19.309133529663086, 26.593854904174805, 25.009687423706055, 22.48514747619629, 28.16208267211914, 11.851886749267578, 14.016938209533691,

In [9]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 455. Filtered negative neurons: {9451: [22.772489547729492], 1958: [16.48214340209961, 9.02778434753418], 5956: [13.635176658630371], 3939: [12.567086219787598, 9.575501441955566, 8.769563674926758, 6.474125385284424], 1992: [10.946975708007812, 10.79904556274414, 8.923402786254883], 4638: [10.895687103271484], 833: [9.319798469543457, 7.47442626953125], 8338: [8.041302680969238], 12331: [7.081801414489746], 9498: [6.869698524475098], 1483: [10.436823844909668], 16084: [7.053130149841309, 6.642577171325684], 12501: [5.489943504333496, 7.019606113433838], 7241: [26.590560913085938], 15653: [16.955577850341797, 10.907196998596191, 17.647422790527344], 4957: [15.752717971801758], 14918: [15.361218452453613, 14.949971199035645], 12428: [13.478342056274414, 13.791582107543945, 19.001632690429688], 14326: [11.357827186584473, 20.123531341552734, 6.840255260467529, 30.38844108581543, 21.979856491088867, 11.30158519744873, 17.464014053344727, 11.989712715148926, 11.966527938842773, 24.252

In [10]:
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}


# print first few
print(list(top_neurons_neg_mean.items())[:200])
print(list(top_neurons_pos_mean.items())[:200])

{9451: 1, 1958: 2, 5956: 1, 3939: 4, 1992: 3, 4638: 1, 833: 2, 8338: 1, 12331: 1, 9498: 1, 1483: 1, 16084: 2, 12501: 2, 7241: 1, 15653: 3, 4957: 1, 14918: 2, 12428: 3, 14326: 13, 13746: 1, 11630: 3, 8279: 1, 15092: 7, 1012: 1, 14228: 2, 5273: 1, 10748: 3, 1623: 1, 14988: 20, 8957: 1, 2943: 7, 2537: 9, 4953: 7, 904: 4, 14176: 3, 9787: 13, 692: 1, 11881: 5, 532: 3, 1495: 24, 7671: 1, 6087: 2, 1683: 1, 9882: 5, 9189: 1, 4357: 2, 9514: 3, 1045: 1, 9223: 8, 10303: 1, 1328: 1, 10883: 1, 6004: 1, 1147: 8, 9163: 1, 16200: 3, 1726: 1, 5213: 1, 1677: 1, 13863: 5, 6571: 1, 16118: 1, 14015: 1, 9905: 2, 651: 1, 1193: 3, 3409: 5, 7997: 10, 3864: 3, 3329: 14, 10186: 4, 4945: 2, 13546: 16, 2416: 4, 5409: 1, 15268: 2, 8213: 1, 8430: 2, 8973: 2, 14660: 6, 2691: 1, 14212: 2, 15701: 2, 14683: 1, 2565: 1, 16319: 1, 14449: 2, 1530: 6, 3379: 2, 12446: 3, 7001: 1, 2561: 5, 15969: 1, 14838: 1, 16271: 1, 11992: 6, 1721: 3, 2357: 1, 6700: 1, 9025: 1, 3034: 4, 2974: 1, 12051: 7, 12543: 2, 6450: 3, 8382: 1, 7867: 

In [11]:
# train classifier on sae activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
    #print(activations.shape)

    del cache

    activations_list.append(activations)
    labels_list.append(0)

for example_txt in positive_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()

    del cache

    activations_list.append(activations)
    labels_list.append(1)   

# data
X = np.array(activations_list)
y = np.array(labels_list)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale activation features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'SAE activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer} dataset_size={NEG_SET_SIZE}: {accuracy:.4f}\n')

Test Accuracy: 0.9750


In [12]:
# train classifier on base activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy() # layer batch pos d_model
    #print(seven_out.shape)

    del cache

    activations_list.append(seven_out)
    labels_list.append(0)

#print("done w bad")

for example_txt in positive_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy()  # layer batch pos d_model

    del cache

    activations_list.append(seven_out)
    labels_list.append(1)

# data and split
X = np.array(activations_list)
y = np.array(labels_list)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale features 
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# train classifier
clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'Base activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer} dataset_size={NEG_SET_SIZE}: {accuracy:.4f}\n')

Test Accuracy: 0.9750
Test Accuracy: 0.9750


In [13]:
# reject sample if activation at any of top N neurons is above threshold T
THRESHOLD = 4
NEURONS_TO_CONSIDER = 50

top_neurons_neg_occ = list(top_neurons_neg_mean.items())[:NEURONS_TO_CONSIDER]
top_neurons_pos_mean = []
neg_indicators = []

for entry in top_neurons_neg_occ:
    neg_indicators.append(entry[0])

print(neg_indicators)

# run an inference, if hits threshold on any of the neurons classify as bad
e = 0
neg_classifications = np.zeros(NEG_SET_SIZE)
for example in negative_set_2:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    
    del cache

    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            neg_classifications[e] = 1
            break
    
    e += 1

pos_classifications = np.zeros(POS_SET_SIZE)


e = 0
for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]

    del cache
    
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            pos_classifications[e] = 1
            break
    e += 1


print(" RATS ! ", np.sum(neg_classifications) / (NEG_SET_SIZE / 100))
print(" GENIUSES ! ", np.sum(pos_classifications) / (NEG_SET_SIZE / 100))

with open('run_log.txt', 'a') as file:
    file.write(f'SAE Thresholding true negative rate on pos={pos_dataset} neg={neg_dataset} layer={num_layer} dataset_size={NEG_SET_SIZE}: {np.sum(neg_classifications) / (NEG_SET_SIZE / 100)}; false negative rate: {np.sum(pos_classifications) / (NEG_SET_SIZE / 100)}\n')

[1495, 14988, 13546, 3329, 14326, 9787, 7997, 2537, 9223, 1147, 12269, 15092, 2943, 4953, 12051, 12109, 14660, 1530, 11992, 8420, 11881, 9882, 13863, 3409, 2561, 7655, 3939, 904, 10186, 2416, 3034, 10466, 9564, 11, 14118, 13556, 5111, 1833, 1992, 15653, 12428, 11630, 10748, 14176, 532, 9514, 16200, 1193, 3864, 12446]
 RATS !  100.0
 GENIUSES !  16.0


In [14]:
from IPython.display import IFrame


html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="10-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

for feature_idx in neg_indicators:
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="10-gemmascope-res-16k", feature_idx=feature_idx)
    frame = IFrame(html, width=800, height=400)
    display(frame)