In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

In [2]:
with open("access.tok", "r") as file:
    access_token = file.read()
    login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/cole/.cache/huggingface/token
Login successful
Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

NEG_SET_SIZE = 200
POS_SET_SIZE = 200

In [4]:
# load gemma model 
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device = device)

num_layer = 10
# load sae on res stream of gemma model, plus cfg and sparsity val
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = f"layer_{num_layer}/width_16k/average_l0_77",
    device = device
)
# get layers from here: https://jbloomaus.github.io/SAELens/sae_table/#gemma-scope-2b-pt-res



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [5]:
#neg_dataset = 'dataset/spanish_harmful.csv'
neg_dataset = 'dataset/original_harmful.csv'

df = pd.read_csv(neg_dataset)

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[:NEG_SET_SIZE]
print(len(negative_set))

200


In [6]:
#pos_dataset = 'dataset/alpaca_spanish.json'
pos_dataset = 'dataset/alpaca_data.json'
positive = pd.read_json(pos_dataset)

positive_set = positive['output'].values
positive_set = positive_set[:POS_SET_SIZE]
print(len(positive_set))

200


In [7]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    
    del cache
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())
    
    del cache

print(top_neurons_neg)
print(top_neurons_pos)

defaultdict(<class 'list'>, {4963: [24.0179386138916, 23.06588363647461, 21.496688842773438, 22.32155418395996, 23.060426712036133, 20.59315299987793], 3986: [23.90171241760254, 21.68891716003418, 22.37017250061035, 24.464557647705078, 24.9793643951416, 26.926427841186523, 26.843320846557617, 28.34942626953125, 26.69489097595215, 23.20603370666504, 20.902984619140625, 27.447139739990234, 24.578201293945312, 26.200450897216797, 19.871973037719727, 21.55316734313965, 24.719552993774414, 19.797082901000977, 27.145912170410156, 20.900497436523438, 21.93437957763672, 21.25170135498047, 26.867002487182617, 20.0645809173584, 22.901731491088867, 19.55942153930664, 22.893247604370117, 25.29959487915039, 21.884765625, 20.044668197631836, 24.7183780670166, 22.632762908935547, 24.746137619018555, 22.314617156982422, 17.981735229492188, 24.641687393188477, 19.466712951660156, 27.515056610107422, 24.040307998657227, 24.352563858032227, 22.81491470336914, 20.94618034362793, 21.01361656188965, 22.1327

In [8]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 876. Filtered negative neurons: {4963: [24.0179386138916, 23.06588363647461, 21.496688842773438, 22.32155418395996, 23.060426712036133, 20.59315299987793], 15292: [20.291065216064453, 25.548664093017578, 13.24570083618164, 10.399188041687012, 10.81445598602295, 9.0546293258667], 1314: [16.83120346069336, 16.960464477539062, 17.126995086669922, 20.903648376464844, 16.333263397216797, 20.508583068847656], 3379: [12.051023483276367, 9.689680099487305, 15.222034454345703, 9.319575309753418, 11.833881378173828, 9.466838836669922, 11.769444465637207, 10.221564292907715], 14449: [11.401235580444336, 8.925599098205566, 8.27865219116211, 11.177252769470215, 10.571102142333984, 8.916620254516602, 18.247278213500977], 3233: [10.811705589294434, 12.399224281311035, 14.903839111328125, 10.406905174255371], 10228: [8.446449279785156, 11.107491493225098, 14.597237586975098, 11.245314598083496, 12.972159385681152, 14.028641700744629], 2416: [8.201991081237793, 6.6592183113098145, 8.67320537567138

In [9]:
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}


# print first few
print(list(top_neurons_neg_mean.items())[:200])
print(list(top_neurons_pos_mean.items())[:200])

{4963: 6, 15292: 6, 1314: 6, 3379: 8, 14449: 7, 3233: 4, 10228: 6, 2416: 5, 5154: 38, 6770: 14, 1643: 2, 13399: 12, 13125: 1, 3380: 2, 4412: 4, 9514: 1, 9796: 4, 12933: 1, 2051: 1, 7790: 3, 6192: 1, 2136: 5, 2471: 6, 12841: 1, 14838: 2, 2056: 1, 11299: 1, 7799: 4, 12314: 1, 16200: 6, 11194: 2, 10623: 6, 1495: 40, 8711: 8, 3895: 1, 6087: 2, 9534: 4, 7063: 1, 14803: 1, 9882: 9, 2135: 1, 13323: 2, 12446: 16, 14622: 10, 13407: 9, 1955: 4, 12573: 2, 8622: 5, 10511: 1, 8676: 2, 15837: 1, 2712: 1, 7995: 2, 6532: 2, 2813: 1, 3043: 1, 12779: 1, 12564: 1, 10639: 1, 976: 1, 305: 1, 2224: 1, 3280: 3, 13304: 3, 12666: 2, 410: 1, 10917: 1, 12677: 1, 10823: 5, 4154: 1, 1193: 3, 2439: 1, 13430: 10, 15882: 1, 7175: 1, 9240: 2, 12478: 1, 4393: 1, 3077: 2, 7037: 3, 1525: 1, 13379: 1, 2780: 5, 2412: 1, 2405: 13, 6571: 6, 13061: 1, 3329: 2, 12910: 3, 1492: 3, 16271: 8, 14326: 19, 7671: 1, 15701: 2, 4710: 1, 13321: 1, 3114: 1, 11235: 2, 9163: 1, 6451: 4, 10211: 2, 10337: 1, 13863: 6, 15661: 1, 16352: 1, 130

In [21]:
# train classifier on sae activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
    #print(activations.shape)

    del cache

    activations_list.append(activations)
    labels_list.append(0)

for example_txt in positive_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()

    del cache

    activations_list.append(activations)
    labels_list.append(1)   

# data
X = np.array(activations_list)
y = np.array(labels_list)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale activation features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'SAE activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {accuracy:.4f}\n')

Test Accuracy: 0.9250


In [22]:
# train classifier on base activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy() # layer batch pos d_model
    #print(seven_out.shape)

    del cache

    activations_list.append(seven_out)
    labels_list.append(0)

#print("done w bad")

for example_txt in positive_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy()  # layer batch pos d_model

    del cache

    activations_list.append(seven_out)
    labels_list.append(1)

# data and split
X = np.array(activations_list)
y = np.array(labels_list)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale features 
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# train classifier
clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'Base activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {accuracy:.4f}\n')

Test Accuracy: 0.9875
Test Accuracy: 0.9875


In [10]:
# reject sample if activation at any of top N neurons is above threshold T
THRESHOLD = 4
NEURONS_TO_CONSIDER = 50

top_neurons_neg_occ = list(top_neurons_neg_mean.items())[:NEURONS_TO_CONSIDER]
top_neurons_pos_mean = []
neg_indicators = []

for entry in top_neurons_neg_occ:
    neg_indicators.append(entry[0])

print(neg_indicators)

# run an inference, if hits threshold on any of the neurons classify as bad
e = 0
neg_classifications = np.zeros(NEG_SET_SIZE)
for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    
    del cache

    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            neg_classifications[e] = 1
            break
    
    e += 1

pos_classifications = np.zeros(POS_SET_SIZE)


e = 0
for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]

    del cache
    
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            pos_classifications[e] = 1
            break
    e += 1


print(" RATS ! ", np.sum(neg_classifications) / 2)
print(" GENIUSES ! ", np.sum(pos_classifications) / 2)

with open('run_log.txt', 'a') as file:
    file.write(f'SAE Thresholding true negative rate on pos={pos_dataset} neg={neg_dataset} layer={num_layer}: {np.sum(neg_classifications) / 2}; false negative rate: {np.sum(pos_classifications) / 2}\n')

[1495, 5154, 14326, 12446, 6770, 2405, 840, 13399, 15794, 8420, 3069, 14622, 13430, 4603, 8382, 9882, 13407, 11992, 11630, 4956, 9189, 13887, 14118, 3379, 8711, 16271, 3965, 14449, 14660, 15268, 13474, 6659, 4963, 15292, 1314, 10228, 2471, 16200, 10623, 6571, 13863, 10430, 7273, 7681, 5840, 9125, 5409, 11820, 4228, 9447]
 RATS !  97.0
 GENIUSES !  14.5


In [12]:
from IPython.display import IFrame


html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="10-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

for feature_idx in neg_indicators:
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="10-gemmascope-res-16k", feature_idx=feature_idx)
    frame = IFrame(html, width=800, height=400)
    display(frame)