In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

In [2]:
with open("access.tok", "r") as file:
    access_token = file.read()
    login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

NEG_SET_SIZE = 200
POS_SET_SIZE = 200

In [4]:
# load gemma model 
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device = device)

# load sae on res stream of gemma model, plus cfg and sparsity val
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = "layer_14/width_16k/average_l0_83",
    device = device
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [7]:
df = pd.read_csv('dataset/spanish_harmful.csv')

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[:NEG_SET_SIZE]
print(len(negative_set))

200


In [8]:
positive = pd.read_json('dataset/alpaca_spanish.json')

positive_set = positive['output'].values
positive_set = positive_set[:POS_SET_SIZE]
print(len(positive_set))

200


In [9]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache['blocks.14.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache['blocks.14.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())

print(top_neurons_neg)
print(top_neurons_pos)

KeyboardInterrupt: 

In [11]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 743. Filtered negative neurons: {377: [29.632680892944336, 27.406606674194336, 19.24976348876953, 19.62711524963379, 23.818645477294922, 16.40456199645996], 13483: [20.95055389404297, 23.529897689819336, 20.34111976623535, 23.316316604614258, 20.674488067626953, 25.418424606323242], 16302: [14.254205703735352, 17.667993545532227, 13.052112579345703], 8327: [12.928153991699219, 14.708047866821289, 10.956321716308594], 11736: [12.719222068786621, 13.789018630981445, 12.221495628356934, 11.717462539672852, 23.38585662841797, 12.421101570129395, 9.649845123291016, 11.653751373291016, 11.687721252441406, 16.139949798583984, 11.544108390808105, 9.306063652038574, 10.936362266540527, 14.885224342346191, 17.19536781311035], 9058: [11.718343734741211, 13.746081352233887, 10.389264106750488], 218: [11.625377655029297, 15.446069717407227], 2506: [11.403615951538086, 15.21877384185791, 11.223353385925293], 14403: [10.51242733001709], 1659: [10.12092399597168, 11.393745422363281], 12266: [50.7

In [12]:
# train classifier on top activations
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}


# print first few
print(list(top_neurons_neg_mean.items())[:200])
print(list(top_neurons_pos_mean.items())[:200])

{377: 6, 13483: 6, 16302: 3, 8327: 3, 11736: 15, 9058: 3, 218: 2, 2506: 3, 14403: 1, 1659: 2, 12266: 2, 10119: 3, 12817: 24, 8702: 1, 3648: 3, 14157: 10, 5370: 5, 1837: 1, 16279: 1, 1963: 5, 2658: 4, 8189: 2, 3511: 11, 12677: 5, 8127: 6, 15121: 3, 8125: 18, 9264: 3, 6968: 14, 931: 2, 15926: 1, 7830: 3, 1147: 3, 13308: 3, 12076: 9, 4090: 7, 3774: 16, 13592: 2, 3669: 4, 12547: 21, 15710: 12, 14274: 10, 11040: 3, 5282: 1, 13326: 1, 9975: 43, 15008: 12, 7308: 3, 1788: 3, 15213: 1, 8330: 1, 15135: 1, 7158: 1, 328: 1, 13784: 1, 3466: 1, 7300: 1, 2944: 1, 4950: 5, 11563: 2, 11072: 4, 12801: 2, 139: 2, 3078: 2, 13513: 1, 3187: 1, 15451: 1, 13348: 3, 11789: 1, 5443: 1, 10518: 1, 14380: 1, 6746: 1, 5946: 8, 10036: 1, 10979: 1, 14471: 1, 9066: 12, 3562: 6, 13148: 17, 10758: 14, 13409: 2, 12085: 1, 10450: 14, 13510: 20, 98: 2, 11566: 1, 9642: 10, 12254: 1, 760: 1, 8110: 1, 5507: 4, 918: 2, 12882: 1, 9137: 2, 1160: 9, 15391: 1, 958: 3, 14095: 1, 15486: 2, 13962: 3, 1284: 3, 13952: 1, 10697: 2, 8240

In [11]:
# train classifier on sae activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache['blocks.14.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
    #print(activations.shape)

    activations_list.append(activations)
    labels_list.append(0)

for example_txt in positive_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache['blocks.14.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()

    activations_list.append(activations)
    labels_list.append(1)   

# data
X = np.array(activations_list)
y = np.array(labels_list)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale activation features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

Test Accuracy: 1.0000


In [16]:
# train classifier on base activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy() # layer batch pos d_model
    #print(seven_out.shape)

    activations_list.append(seven_out)
    labels_list.append(0)

#print("done w bad")

for example_txt in positive_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy()  # layer batch pos d_model

    activations_list.append(seven_out)
    labels_list.append(1)

# data and split
X = np.array(activations_list)
y = np.array(labels_list)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale features 
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# train classifier
clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

Test Accuracy: 0.9500


In [15]:
# reject sample if activation at any of top N neurons is above threshold T
THRESHOLD = 4
NEURONS_TO_CONSIDER = 50

top_neurons_neg_occ = list(top_neurons_neg_mean.items())[:NEURONS_TO_CONSIDER]
top_neurons_pos_mean = []
neg_indicators = []

for entry in top_neurons_neg_occ:
    neg_indicators.append(entry[0])

print(neg_indicators)

# run an inference, if hits threshold on any of the neurons classify as bad
e = 0
neg_classifications = np.zeros(NEG_SET_SIZE)
for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache['blocks.14.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            neg_classifications[e] = 1
            break
    e += 1

pos_classifications = np.zeros(POS_SET_SIZE)


e = 0
for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache['blocks.14.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            pos_classifications[e] = 1
            break
    e += 1

print(" RATS ! ", np.sum(neg_classifications) / 2)
print(" GENIUSES ! ", np.sum(pos_classifications) / 2)

[9975, 12817, 12547, 13510, 8125, 13148, 3774, 11736, 5522, 6968, 10758, 10450, 140, 15710, 15008, 9066, 3511, 14157, 14274, 9642, 2834, 2871, 12076, 1160, 12823, 6041, 5946, 955, 5777, 3052, 4090, 12614, 8377, 377, 13483, 8127, 3562, 7224, 1116, 3654, 47, 5483, 11964, 5370, 1963, 12677, 4950, 4854, 13790, 1671]
 RATS !  97.0
 GENIUSES !  13.5


In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
from sklearn.tree import _tree

# replace Logistic Regression with Decision Tree for classification
clf = DecisionTreeClassifier(max_depth=5, random_state=42)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy with Decision Tree: {accuracy:.4f}')

def print_decision_thresholds(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("Decision thresholds for neurons:\n")
    for i in range(tree_.node_count):
        if tree_.feature[i] != _tree.TREE_UNDEFINED:
            print(f"{feature_name[i]} <= {tree_.threshold[i]:.4f}")

# feature names based on neuron index
feature_names = [f'Neuron_{i}' for i in range(X_train.shape[1])]

# decision thresholds
print_decision_thresholds(clf, feature_names)

# visualizing decision boundaries
plt.figure(figsize=(20, 10))
plot_tree(clf, filled=True, feature_names=[f'Neuron_{i}' for i in range(X_train.shape[1])], class_names=['Negative', 'Positive'])
plt.title("Decision Tree for SAE Activations")
plt.show()

# Feature importance
importance = clf.feature_importances_
importance_df = pd.DataFrame({"Neuron": [f'Neuron_{i}' for i in range(X_train.shape[1])], "Importance": importance})
importance_df = importance_df[importance_df["Importance"] > 0].sort_values(by="Importance", ascending=False)

print("Top neurons and their significance in decision making:")
print(importance_df)

def print_importance(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("Top neurons and their significance in decision making:\n")
    for i in range(tree_.node_count):
        if tree_.feature[i] != _tree.TREE_UNDEFINED:
            feature_index = tree_.feature[i]
            neuron_name = f"Neuron_{feature_index}"
            if neuron_name in importance_df["Neuron"].values:
                print(f"{neuron_name} Importance: {importance[feature_index]:.4f}")

print_importance(clf, feature_names)
