In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

In [2]:
with open("access.tok", "r") as file:
    access_token = file.read()
    login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/cole/.cache/huggingface/token
Login successful
Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

NEG_SET_SIZE = 200
POS_SET_SIZE = 200

In [4]:
# load gemma model 
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device = device)

# load sae on res stream of gemma model, plus cfg and sparsity val
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = "layer_14/width_16k/average_l0_83",
    device = device
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [5]:
neg_dataset = 'dataset/harmful_strings.csv'
df = pd.read_csv(neg_dataset)

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[:NEG_SET_SIZE]
print(len(negative_set))

200


In [6]:
pos_dataset = 'dataset/alpaca_data.json'
positive = pd.read_json(pos_dataset)

positive_set = positive['output'].values
positive_set = positive_set[:POS_SET_SIZE]
print(len(positive_set))

200


In [7]:
num_layer = 10
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = f"layer_{num_layer}/width_16k/average_l0_39",
    device = device
)
# get layers from here: https://jbloomaus.github.io/SAELens/sae_table/#gemma-scope-2b-pt-res

In [8]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())

print(top_neurons_neg)
print(top_neurons_pos)

KeyboardInterrupt: 

In [12]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 943. Filtered negative neurons: {10852: [36.364524841308594, 33.745574951171875, 27.660058975219727, 31.44834327697754, 28.732820510864258, 30.4450626373291], 11881: [21.18616485595703, 26.493017196655273, 16.184385299682617, 11.58835220336914, 12.143588066101074, 9.535420417785645], 8927: [10.61850643157959, 13.717705726623535, 17.509796142578125, 10.843422889709473], 8618: [10.12113094329834, 11.236221313476562, 13.946182250976562, 8.968689918518066, 9.194867134094238, 9.863019943237305, 9.044962882995605, 10.01511001586914], 5281: [9.346602439880371, 12.718506813049316, 15.144003868103027, 12.45595932006836, 12.692011833190918, 15.042231559753418], 636: [8.615071296691895, 14.555859565734863, 21.22677993774414, 9.71804428100586, 6.798281192779541, 17.89031219482422, 29.510814666748047, 17.485065460205078, 8.380584716796875, 13.036564826965332, 22.579639434814453, 10.258950233459473, 15.066035270690918, 15.801443099975586, 23.23196029663086, 10.849113464355469, 22.59735107421875

In [13]:
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}


# print first few
print(list(top_neurons_neg_mean.items())[:200])
print(list(top_neurons_pos_mean.items())[:200])

{10852: 6, 11881: 6, 8927: 4, 8618: 8, 5281: 6, 636: 39, 15680: 5, 13341: 10, 10815: 123, 3642: 4, 4268: 2, 12594: 13, 14432: 4, 14937: 2, 1674: 1, 8557: 5, 11818: 1, 14051: 1, 2211: 3, 11310: 1, 8937: 1, 13576: 2, 16136: 6, 6215: 5, 2561: 2, 3719: 5, 11903: 2, 7719: 3, 10009: 1, 8959: 6, 10552: 1, 2867: 3, 6361: 42, 14058: 2, 8587: 3, 13407: 17, 4924: 2, 3367: 5, 5318: 1, 15379: 7, 11379: 1, 1351: 9, 5810: 4, 1887: 2, 11288: 10, 3902: 12, 9181: 4, 8952: 20, 2405: 1, 9061: 4, 5745: 1, 7427: 1, 5527: 2, 13530: 11, 2088: 1, 14382: 1, 15990: 4, 9389: 1, 16166: 1, 8063: 1, 1077: 1, 294: 1, 7491: 1, 3495: 1, 3525: 7, 8390: 4, 15894: 5, 12217: 1, 410: 1, 6211: 2, 8019: 2, 487: 1, 11483: 1, 15880: 2, 2030: 1, 13667: 11, 745: 4, 702: 1, 4040: 12, 12563: 1, 2968: 1, 11051: 4, 11173: 1, 12324: 1, 6299: 3, 484: 1, 11461: 2, 14757: 1, 6926: 1, 9688: 1, 2958: 2, 8897: 6, 13274: 11, 5050: 3, 13039: 2, 10484: 2, 8452: 2, 2094: 4, 564: 9, 4095: 19, 2230: 15, 14451: 1, 1147: 2, 6979: 1, 13828: 1, 5791:

In [14]:
# train classifier on sae activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
    #print(activations.shape)

    activations_list.append(activations)
    labels_list.append(0)

for example_txt in positive_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()

    activations_list.append(activations)
    labels_list.append(1)   

# data
X = np.array(activations_list)
y = np.array(labels_list)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale activation features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'SAE activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {accuracy:.4f}\n')

Test Accuracy: 0.9750


: 

In [None]:
# train classifier on base activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy() # layer batch pos d_model
    #print(seven_out.shape)

    activations_list.append(seven_out)
    labels_list.append(0)

#print("done w bad")

for example_txt in positive_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy()  # layer batch pos d_model

    activations_list.append(seven_out)
    labels_list.append(1)

# data and split
X = np.array(activations_list)
y = np.array(labels_list)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale features 
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# train classifier
clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'Base activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {accuracy:.4f}%\n')

Test Accuracy: 0.9500


In [None]:
# reject sample if activation at any of top N neurons is above threshold T
THRESHOLD = 4
NEURONS_TO_CONSIDER = 50

top_neurons_neg_occ = list(top_neurons_neg_mean.items())[:NEURONS_TO_CONSIDER]
top_neurons_pos_mean = []
neg_indicators = []

for entry in top_neurons_neg_occ:
    neg_indicators.append(entry[0])

print(neg_indicators)

# run an inference, if hits threshold on any of the neurons classify as bad
e = 0
neg_classifications = np.zeros(NEG_SET_SIZE)
for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            neg_classifications[e] = 1
            break
    e += 1

pos_classifications = np.zeros(POS_SET_SIZE)


e = 0
for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            pos_classifications[e] = 1
            break
    e += 1


print(" RATS ! ", np.sum(neg_classifications) / 2)
print(" GENIUSES ! ", np.sum(pos_classifications) / 2)

with open('run_log.txt', 'a') as file:
    file.write(f'SAE Thresholding true negative rate on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {np.sum(neg_classifications) / 2}%; false negative rate: {np.sum(pos_classifications) / 2}%\n')

[9975, 12817, 12547, 13510, 8125, 13148, 3774, 11736, 5522, 6968, 10758, 10450, 140, 15710, 15008, 9066, 3511, 14157, 14274, 9642, 2834, 2871, 12076, 1160, 12823, 6041, 5946, 955, 5777, 3052, 4090, 12614, 8377, 377, 13483, 8127, 3562, 7224, 1116, 3654, 47, 5483, 11964, 5370, 1963, 12677, 4950, 4854, 13790, 1671]
 RATS !  97.0
 GENIUSES !  13.5
