In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

In [2]:
with open("access.tok", "r") as file:
    access_token = file.read()
    login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/cole/.cache/huggingface/token
Login successful
Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

NEG_SET_SIZE = 200
POS_SET_SIZE = 200

In [4]:
# load gemma model 
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device = device)

num_layer = 10
# load sae on res stream of gemma model, plus cfg and sparsity val
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = f"layer_{num_layer}/width_16k/average_l0_39",
    device = device
)
# get layers from here: https://jbloomaus.github.io/SAELens/sae_table/#gemma-scope-2b-pt-res



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [14]:
neg_dataset = 'dataset/spanish_harmful.csv'
df = pd.read_csv(neg_dataset)

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[:NEG_SET_SIZE]
print(len(negative_set))

200


In [17]:
pos_dataset = 'dataset/alpaca_spanish.json'
positive = pd.read_json(pos_dataset)

positive_set = positive['output'].values
positive_set = positive_set[:POS_SET_SIZE]
print(len(positive_set))

200


In [18]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    
    del cache
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())
    
    del cache

print(top_neurons_neg)
print(top_neurons_pos)

defaultdict(<class 'list'>, {13312: [34.030921936035156, 30.693401336669922, 35.73822021484375, 38.53155517578125, 35.6867790222168, 34.395694732666016, 35.42573928833008, 41.024513244628906, 38.2966194152832, 37.71751403808594, 39.57956314086914, 29.953304290771484, 38.96459197998047, 25.06528091430664, 35.07865905761719, 40.004852294921875, 36.503543853759766, 35.51307678222656, 36.349365234375, 33.933860778808594, 35.50429916381836, 36.18312454223633, 37.591182708740234, 34.53533935546875, 37.48564529418945, 40.0228271484375, 33.93170928955078, 38.50139236450195, 36.27558517456055, 34.71536636352539, 35.150978088378906, 34.939605712890625, 35.47886276245117, 37.81039810180664, 37.16013717651367, 36.82859420776367, 39.68489074707031, 36.34103775024414, 39.004634857177734, 38.80133056640625, 36.07017135620117, 37.75649642944336, 34.26935958862305, 34.355098724365234, 37.24619674682617, 38.68547439575195, 35.42176818847656, 40.77512741088867, 36.318763732910156, 43.41993713378906, 33.6

In [19]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 562. Filtered negative neurons: {11881: [8.430418014526367, 18.759553909301758, 10.585809707641602, 7.677215099334717, 8.759821891784668, 13.731945037841797], 10568: [6.073503017425537, 11.639068603515625, 9.350728988647461, 10.21215534210205, 9.803328514099121, 8.786177635192871, 9.79623031616211, 8.749348640441895, 6.412755966186523, 7.911864280700684, 6.8829121589660645, 10.279202461242676, 8.860928535461426], 8927: [5.92238712310791, 7.42421817779541, 6.225836753845215, 8.300820350646973, 7.707480430603027, 10.855594635009766], 1377: [6.84891939163208, 11.21091079711914], 3859: [6.514743804931641], 8511: [11.879244804382324, 7.144270896911621, 7.028419494628906, 8.462394714355469, 6.9543256759643555, 7.64968204498291, 8.295157432556152, 8.803706169128418, 7.39412784576416, 6.630208969116211, 9.767208099365234, 7.097625732421875], 11478: [10.087898254394531, 12.382061004638672, 8.588029861450195, 6.383830547332764, 6.365123748779297, 9.742301940917969, 6.9408183097839355, 9.290

In [20]:
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}


# print first few
print(list(top_neurons_neg_mean.items())[:200])
print(list(top_neurons_pos_mean.items())[:200])

{11881: 6, 10568: 13, 8927: 6, 1377: 2, 3859: 1, 8511: 12, 11478: 8, 14076: 11, 11392: 1, 16136: 1, 1351: 3, 636: 37, 5810: 1, 11805: 1, 13407: 6, 11288: 3, 4040: 2, 9274: 1, 10742: 10, 13052: 13, 7301: 7, 8952: 14, 294: 1, 1324: 1, 1077: 1, 4542: 1, 6211: 1, 12764: 2, 11577: 3, 8019: 2, 487: 1, 11454: 2, 7268: 1, 702: 1, 8956: 1, 2813: 2, 14874: 1, 436: 1, 14575: 3, 13330: 17, 6361: 35, 5050: 2, 3870: 2, 13828: 2, 11853: 2, 175: 1, 14093: 1, 15572: 3, 5504: 1, 4263: 1, 5381: 2, 916: 2, 10841: 1, 6231: 2, 5829: 1, 16299: 1, 6631: 2, 5756: 1, 2428: 2, 16311: 1, 8573: 3, 12186: 1, 3713: 4, 15905: 3, 5370: 3, 7853: 1, 16108: 1, 8829: 8, 3938: 2, 9746: 11, 4095: 12, 4920: 4, 6794: 1, 11051: 6, 10903: 2, 2805: 1, 5351: 2, 8779: 1, 6776: 1, 16143: 1, 6806: 2, 10390: 4, 16155: 1, 14520: 1, 1066: 5, 4477: 2, 10717: 1, 13137: 1, 12987: 1, 2901: 2, 5838: 1, 8799: 1, 2958: 1, 15035: 16, 6259: 1, 1082: 1, 1155: 1, 12474: 1, 3571: 4, 2251: 3, 7564: 1, 5492: 3, 9501: 1, 13818: 1, 1492: 1, 14558: 5, 

In [21]:
# train classifier on sae activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
    #print(activations.shape)

    del cache

    activations_list.append(activations)
    labels_list.append(0)

for example_txt in positive_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()

    del cache

    activations_list.append(activations)
    labels_list.append(1)   

# data
X = np.array(activations_list)
y = np.array(labels_list)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale activation features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'SAE activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {accuracy:.4f}\n')

Test Accuracy: 0.9250


In [22]:
# train classifier on base activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy() # layer batch pos d_model
    #print(seven_out.shape)

    del cache

    activations_list.append(seven_out)
    labels_list.append(0)

#print("done w bad")

for example_txt in positive_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy()  # layer batch pos d_model

    del cache

    activations_list.append(seven_out)
    labels_list.append(1)

# data and split
X = np.array(activations_list)
y = np.array(labels_list)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale features 
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# train classifier
clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'Base activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {accuracy:.4f}\n')

Test Accuracy: 0.9875
Test Accuracy: 0.9875


In [23]:
# reject sample if activation at any of top N neurons is above threshold T
THRESHOLD = 4
NEURONS_TO_CONSIDER = 50

top_neurons_neg_occ = list(top_neurons_neg_mean.items())[:NEURONS_TO_CONSIDER]
top_neurons_pos_mean = []
neg_indicators = []

for entry in top_neurons_neg_occ:
    neg_indicators.append(entry[0])

print(neg_indicators)

# run an inference, if hits threshold on any of the neurons classify as bad
e = 0
neg_classifications = np.zeros(NEG_SET_SIZE)
for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    
    del cache

    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            neg_classifications[e] = 1
            break
    
    e += 1

pos_classifications = np.zeros(POS_SET_SIZE)


e = 0
for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]

    del cache
    
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            pos_classifications[e] = 1
            break
    e += 1


print(" RATS ! ", np.sum(neg_classifications) / 2)
print(" GENIUSES ! ", np.sum(pos_classifications) / 2)

with open('run_log.txt', 'a') as file:
    file.write(f'SAE Thresholding true negative rate on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {np.sum(neg_classifications) / 2}; false negative rate: {np.sum(pos_classifications) / 2}\n')

[636, 6361, 13341, 13330, 15035, 8952, 10568, 13052, 8511, 4095, 7377, 14076, 9746, 10742, 6982, 11478, 8829, 6770, 7301, 9960, 11881, 8927, 13407, 11051, 300, 1465, 11746, 1066, 14558, 1823, 9522, 8649, 7304, 2230, 3713, 4920, 10390, 3571, 10229, 7546, 11779, 901, 15379, 10122, 14970, 1351, 11288, 11577, 14575, 15572]
 RATS !  97.0
 GENIUSES !  16.0


: 