In [1]:
import os
from tqdm import tqdm
from huggingface_hub import login
import torch
import torch.nn as nn
import math
import statistics
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
from jaxtyping import Float
from functools import partial
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
from collections import defaultdict

In [2]:
# with open("/Users/cole/cs2822r/saes2822r/access.tok", "r") as file:
#     access_token = file.read()
#     login(token=access_token)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: mps


In [3]:
from datasets import load_dataset  
import transformer_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x17e576590>

In [4]:
# load gpt-2-small (for testing)
model = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

# load sae on res stream of gpt-2-small, plus cfg and sparsity val (for testing)
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb",
    sae_id = "blocks.7.hook_resid_pre",
    device = device
)

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:
import pandas as pd

df = pd.read_csv('dataset/harmful_strings.csv')

columns_as_arrays = [df[col].values for col in df.columns]

array_dict = {col: df[col].values for col in df.columns}

negative_set = columns_as_arrays[0]
negative_set = negative_set[:200]
print(len(negative_set))

200


In [6]:
positive = pd.read_json('dataset/alpaca_data.json')

positive_set = positive['output'].values
positive_set = positive_set[:200]
print(len(positive_set))

200


In [7]:
sae.use_error_term

top_neurons_neg = defaultdict(list)
top_neurons_pos = defaultdict(list)

for example in negative_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_neg[int(datapoint[0])].append(datapoint[1].item())
    

for example in positive_set:
    _, cache = model.run_with_cache_with_saes(example, saes=[sae])

    # get top 15 firing sae neurons
    vals, inds = torch.topk(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :], 15)

    for datapoint in zip(inds, vals):
        top_neurons_pos[int(datapoint[0])].append(datapoint[1].item())

print(top_neurons_neg)
print(top_neurons_pos)

defaultdict(<class 'list'>, {6670: [16.4168643951416, 2.801954984664917, 16.919387817382812, 5.190169334411621, 4.305880546569824, 2.3614017963409424], 9506: [15.13806438446045, 19.836523056030273, 4.30758810043335], 13665: [12.060338973999023, 13.8031005859375], 23106: [10.655967712402344, 6.715331077575684, 11.300751686096191], 23503: [6.506248474121094, 4.010358810424805, 4.525677680969238], 2312: [6.471706867218018, 2.4618632793426514, 19.484525680541992, 13.123251914978027, 9.765414237976074], 9481: [6.12713623046875, 3.1801211833953857], 22777: [4.396451473236084], 22930: [4.015675067901611, 2.495692014694214, 2.9327199459075928, 3.286095380783081, 2.484334945678711, 3.275991678237915, 2.3608810901641846, 2.4663054943084717, 3.0160131454467773, 2.763338327407837, 1.8625926971435547], 19045: [3.978228807449341, 2.195633888244629, 2.4856648445129395, 2.6643054485321045], 8023: [3.9625518321990967, 3.374295473098755, 3.1345767974853516, 2.184896945953369, 2.685012102127075, 2.509566

In [8]:
# some statistics
print(len(top_neurons_neg))
print(len(top_neurons_pos))
pos_set = set(top_neurons_pos.keys())
neg_set = set(top_neurons_neg.keys())
print(len(pos_set.intersection(neg_set)))

1572
1003
126


In [9]:
def filter_neurons(top_neurons_neg, top_neurons_pos, threshold=5.0):
    """
    Filters out neurons that are highly activated in both the negative and positive sets.
    """
    
    filtered_neurons_neg = {}
    filtered_neurons_pos = {}

    for neuron, activations in top_neurons_neg.items():
        if neuron in top_neurons_pos and any(val >= threshold for val in activations) and any(val >= threshold for val in top_neurons_pos[neuron]):
            continue 
        else:
            filtered_neurons_neg[neuron] = activations

    for neuron, activations in top_neurons_pos.items():
        if neuron not in top_neurons_neg or not any(val >= threshold for val in top_neurons_neg[neuron]):
            filtered_neurons_pos[neuron] = activations

    return filtered_neurons_neg, filtered_neurons_pos

filtered_neg, filtered_pos = filter_neurons(top_neurons_neg, top_neurons_pos, 0)
print(f"Len: {len(filtered_neg)}. Filtered negative neurons: {filtered_neg}")
print(f"Len: {len(filtered_pos)}. Filtered positive neurons: {filtered_pos}")

Len: 1446. Filtered negative neurons: {6670: [16.4168643951416, 2.801954984664917, 16.919387817382812, 5.190169334411621, 4.305880546569824, 2.3614017963409424], 9506: [15.13806438446045, 19.836523056030273, 4.30758810043335], 13665: [12.060338973999023, 13.8031005859375], 23106: [10.655967712402344, 6.715331077575684, 11.300751686096191], 23503: [6.506248474121094, 4.010358810424805, 4.525677680969238], 2312: [6.471706867218018, 2.4618632793426514, 19.484525680541992, 13.123251914978027, 9.765414237976074], 9481: [6.12713623046875, 3.1801211833953857], 22777: [4.396451473236084], 22930: [4.015675067901611, 2.495692014694214, 2.9327199459075928, 3.286095380783081, 2.484334945678711, 3.275991678237915, 2.3608810901641846, 2.4663054943084717, 3.0160131454467773, 2.763338327407837, 1.8625926971435547], 19045: [3.978228807449341, 2.195633888244629, 2.4856648445129395, 2.6643054485321045], 8023: [3.9625518321990967, 3.374295473098755, 3.1345767974853516, 2.184896945953369, 2.685012102127075

In [17]:
# train classifier on top activations
# average activations over each top case, sends to
# top_neurons_neg/pos = {idx: avg_act, idx2:avg_act2, ...}
top_neurons_neg_mean = {}
for entry in filtered_neg:
    top_neurons_neg_mean[entry] = len(filtered_neg[entry])

top_neurons_pos_mean = {}
for entry in filtered_pos:
    top_neurons_pos_mean[entry] = len(filtered_pos[entry])

print(top_neurons_neg_mean)
print(top_neurons_pos_mean)

# sort by avg activation
top_neurons_neg_mean = {k: v for k, v in sorted(top_neurons_neg_mean.items(), key=lambda item: item[1], reverse=True)}
top_neurons_pos_mean = {k: v for k, v in sorted(top_neurons_pos_mean.items(), key=lambda item: item[1], reverse=True)}

# print first few
print(list(top_neurons_neg_mean.items())[:100])
print(list(top_neurons_pos_mean.items())[:100])

{6670: 6, 9506: 3, 13665: 2, 23106: 3, 23503: 3, 2312: 5, 9481: 2, 22777: 1, 22930: 11, 19045: 4, 8023: 41, 19837: 2, 18248: 3, 3045: 6, 127: 4, 21720: 2, 1003: 2, 9307: 6, 133: 6, 24045: 1, 3879: 4, 21309: 4, 11946: 17, 2674: 1, 22721: 6, 5571: 1, 13678: 1, 21632: 1, 17282: 5, 3552: 1, 48: 2, 15068: 2, 5298: 8, 20505: 4, 22894: 5, 24444: 3, 21916: 18, 8976: 1, 1460: 10, 24391: 7, 17723: 15, 15782: 1, 6136: 21, 792: 1, 13364: 2, 11832: 2, 21714: 2, 14914: 1, 17671: 1, 18742: 1, 15423: 5, 6302: 5, 20960: 7, 6056: 23, 6636: 1, 17228: 4, 3087: 14, 21485: 6, 14179: 10, 11177: 1, 13368: 14, 6026: 1, 16218: 2, 10156: 1, 16000: 2, 5066: 1, 19967: 16, 513: 6, 17645: 2, 18349: 3, 20746: 2, 3816: 1, 17970: 1, 5368: 1, 15267: 1, 7552: 1, 20420: 1, 19986: 1, 4808: 5, 23486: 1, 21957: 1, 15828: 1, 10588: 1, 5039: 1, 303: 5, 8124: 1, 20483: 9, 12214: 3, 1681: 2, 12467: 1, 21328: 2, 4088: 1, 22967: 1, 16466: 4, 17960: 3, 7111: 2, 17899: 1, 3261: 23, 2908: 5, 16743: 3, 18470: 1, 15560: 1, 20938: 2, 12