In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


ImportError: cannot import name 'load_dataset' from partially initialized module 'datasets' (most likely due to a circular import) (/Users/anaiskillian/cs2822r/datasets.py)

In [2]:
with open("access.tok", "r") as file:
    access_token = file.read()
    login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/cole/.cache/huggingface/token
Login successful
Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

NEG_SET_SIZE = 500
POS_SET_SIZE = 500

In [4]:
# load gemma model 
model = HookedSAETransformer.from_pretrained("gemma-2-2b", device = device)

num_layer = 10
# load sae on res stream of gemma model, plus cfg and sparsity val
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res",
    sae_id = f"layer_{num_layer}/width_16k/average_l0_77",
    device = device
)
# get layers from here: https://jbloomaus.github.io/SAELens/sae_table/#gemma-scope-2b-pt-res



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [5]:
#neg_dataset = 'dataset/spanish_harmful.csv'
neg_dataset = 'dataset/original_harmful.csv'

df = pd.read_csv(neg_dataset)

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[:NEG_SET_SIZE]
print(len(negative_set))

500


In [None]:
# neg2_dataset = 'dataset/harmful_wiki_cleaned.csv'

# df = pd.read_csv(neg2_dataset)

# columns_as_arrays = [df[col].values for col in df.columns]

# array_dict = {col: df[col].values for col in df.columns}

# negative_set_2 = columns_as_arrays[0]
# negative_set_2 = negative_set_2[30:NEG_SET_SIZE+30]

# negative_set_3 = columns_as_arrays[0]
# negative_set_3 = negative_set_3[2030:NEG_SET_SIZE+2030]

# print(len(negative_set_2))
# print(len(negative_set_3))

In [6]:
#pos_dataset = 'dataset/alpaca_spanish.json'
pos_dataset = 'dataset/alpaca_data.json'

positive = pd.read_json(pos_dataset)

positive_set = positive['output'].values
positive_set = positive_set[:POS_SET_SIZE]
print(len(positive_set))

500


In [7]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    
    del cache
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :], 15)
    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())
    
    del cache

print(top_neurons_neg)
print(top_neurons_pos)

KeyboardInterrupt: 

: 

In [16]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 513. Filtered negative neurons: {15292: [7.964665412902832, 16.30695343017578, 8.415923118591309, 6.3382487297058105, 8.156566619873047, 11.515135765075684], 3379: [7.7817606925964355, 6.007218837738037, 12.98498821258545, 7.247117519378662, 10.880072593688965, 7.22714376449585, 7.082031726837158, 6.972873210906982], 13546: [7.021005153656006, 6.136992931365967, 6.902019023895264, 8.175981521606445, 5.787267208099365, 11.107040405273438, 6.286914348602295, 7.55910062789917, 9.964818954467773, 8.012052536010742, 9.028322219848633, 12.947704315185547, 8.307945251464844, 6.299863338470459, 9.153085708618164], 2404: [8.085074424743652, 12.471602439880371], 8866: [7.907463550567627, 5.492731094360352], 91: [10.679744720458984, 6.5384321212768555, 7.526253700256348, 7.507866859436035, 7.335987091064453, 5.882230758666992, 9.04326057434082, 6.656044006347656], 11992: [8.675148963928223, 11.662571907043457, 6.6029486656188965, 7.13411283493042, 9.865836143493652], 3965: [7.616941928863525

In [17]:
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}


# print first few
print(list(top_neurons_neg_mean.items())[:200])
print(list(top_neurons_pos_mean.items())[:200])

{15292: 6, 3379: 8, 13546: 15, 2404: 2, 8866: 2, 91: 8, 11992: 5, 3965: 7, 5734: 1, 10623: 8, 11116: 1, 2471: 1, 14622: 2, 12446: 12, 1955: 1, 4059: 1, 4956: 3, 4603: 4, 3738: 1, 5525: 9, 13504: 10, 13676: 2, 9866: 7, 976: 1, 10872: 1, 5363: 6, 12962: 2, 13321: 1, 16333: 1, 15759: 3, 13304: 2, 26: 1, 12411: 1, 13399: 5, 607: 3, 4565: 1, 2780: 4, 10248: 1, 4460: 2, 10319: 3, 1495: 34, 5458: 13, 3870: 2, 12505: 1, 15701: 2, 6286: 1, 3394: 1, 6451: 3, 11203: 1, 10468: 2, 1323: 2, 13087: 1, 11148: 2, 4664: 4, 2974: 4, 7241: 1, 4255: 1, 13474: 5, 5370: 2, 13307: 1, 1530: 8, 13388: 1, 8420: 11, 9223: 3, 189: 1, 16203: 4, 820: 1, 5351: 2, 10606: 2, 10805: 2, 6638: 1, 13862: 1, 9198: 2, 3069: 2, 6447: 1, 10443: 1, 15794: 2, 12825: 6, 12570: 1, 9780: 1, 10547: 1, 5873: 1, 7047: 5, 9020: 3, 11274: 1, 1389: 1, 6231: 1, 11620: 1, 3568: 1, 1648: 1, 5785: 3, 7110: 1, 550: 1, 6163: 1, 3556: 3, 4322: 1, 6292: 1, 6770: 11, 10930: 2, 8800: 4, 13218: 5, 13350: 1, 6948: 1, 9000: 1, 12772: 1, 12042: 2, 144

In [21]:
# train classifier on sae activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
    #print(activations.shape)

    del cache

    activations_list.append(activations)
    labels_list.append(0)

for example_txt in positive_set:
    _, cache = model.run_with_cache_with_saes(example_txt, saes=[sae])
    activations = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()

    del cache

    activations_list.append(activations)
    labels_list.append(1)   

# data
X = np.array(activations_list)
y = np.array(labels_list)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale activation features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'SAE activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer} dataset_size={NEG_SET_SIZE}: {accuracy:.4f}\n')

Test Accuracy: 0.9250


In [22]:
# train classifier on base activations
activations_list = []
labels_list = []

# 0 = negative, 1 = positive
for example_txt in negative_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy() # layer batch pos d_model
    #print(seven_out.shape)

    del cache

    activations_list.append(seven_out)
    labels_list.append(0)

#print("done w bad")

for example_txt in positive_set:
    _, cache = model.run_with_cache(example_txt)
    res_stream = cache.decompose_resid(layer=15, return_labels=False, mode='attn', incl_embeds=False, pos_slice=slice(-2, -1))
    seven_out = res_stream[-1, 0, -1, :].cpu().numpy()  # layer batch pos d_model

    del cache

    activations_list.append(seven_out)
    labels_list.append(1)

# data and split
X = np.array(activations_list)
y = np.array(labels_list)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# scale features 
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# train classifier
clf = LogisticRegression(max_iter=1000, solver='lbfgs') 
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test Accuracy: {accuracy:.4f}')

print(f'Test Accuracy: {accuracy:.4f}')
with open('run_log.txt', 'a') as file:
    file.write(f'Base activations linear classifier accuracy on pos={pos_dataset} neg={neg_dataset} layer={num_layer} dataset_size={NEG_SET_SIZE}: {accuracy:.4f}\n')

Test Accuracy: 0.9875
Test Accuracy: 0.9875


In [10]:
# reject sample if activation at any of top N neurons is above threshold T
THRESHOLD = 4
NEURONS_TO_CONSIDER = 50

top_neurons_neg_occ = list(top_neurons_neg_mean.items())[:NEURONS_TO_CONSIDER]
top_neurons_pos_mean = []
neg_indicators = []

for entry in top_neurons_neg_occ:
    neg_indicators.append(entry[0])

print(neg_indicators)

# run an inference, if hits threshold on any of the neurons classify as bad
e = 0
neg_classifications = np.zeros(NEG_SET_SIZE)
for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]
    
    del cache

    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            neg_classifications[e] = 1
            break
    
    e += 1

pos_classifications = np.zeros(POS_SET_SIZE)


e = 0
for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get data from indicator neurons
    val = cache[f'blocks.{num_layer}.hook_resid_post.hook_sae_acts_post'][0, -1, :]

    del cache
    
    for critical_neuron_idx in neg_indicators:
        if (val[critical_neuron_idx] > THRESHOLD):
            pos_classifications[e] = 1
            break
    e += 1


print(" RATS ! ", np.sum(neg_classifications) / (NEG_SET_SIZE / 100))
print(" GENIUSES ! ", np.sum(pos_classifications) / (NEG_SET_SIZE / 100))

with open('run_log.txt', 'a') as file:
    file.write(f'SAE Thresholding true negative rate on pos={pos_dataset} neg={neg_dataset} layer={num_layer} dataset_size={NEG_SET_SIZE}: {np.sum(neg_classifications) / (NEG_SET_SIZE / 100)}; false negative rate: {np.sum(pos_classifications) / (NEG_SET_SIZE / 100)}\n')

[1495, 5154, 14326, 12446, 6770, 2405, 840, 13399, 15794, 8420, 3069, 14622, 13430, 4603, 8382, 9882, 13407, 11992, 11630, 4956, 9189, 13887, 14118, 3379, 8711, 16271, 3965, 14449, 14660, 15268, 13474, 6659, 4963, 15292, 1314, 10228, 2471, 16200, 10623, 6571, 13863, 10430, 7273, 7681, 5840, 9125, 5409, 11820, 4228, 9447]
 RATS !  97.0
 GENIUSES !  14.5


In [None]:
from IPython.display import IFrame


html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="10-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

for feature_idx in neg_indicators:
    html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="10-gemmascope-res-16k", feature_idx=feature_idx)
    frame = IFrame(html, width=800, height=400)
    display(frame)