In [1]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [2]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained("gpt2-small", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # <- Release name
    sae_id = "blocks.7.hook_resid_pre", # <- SAE id (not always a hook point!)
    device = device
)



Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
import json
import os

def find_json_file(n, directory='.'):
    best_file = None
    best_end = None

    # List all files in the directory
    for filename in os.listdir(directory):
        if filename.endswith(".json"):
            # Extract the range from the filename
            start, end = map(int, filename.rstrip('.json').split('-'))

            # Check if n is within the range
            if start <= n <= end:
                if best_file is None or end > best_end or (end == best_end and n == end):
                    best_file = filename
                    best_end = end

    return best_file, n-int(best_file.rstrip('.json').split('-')[0])

def load_json_from_file(filename):
    with open(filename, 'r') as file:
        return json.load(file)

def load_json_from_feature(feature, directory='.'):
    json_file, m = find_json_file(feature, directory)
    return load_json_from_file(os.path.join(directory, json_file))[m]


In [4]:
# Initialze empty array of size total_features, total_features
neg_feature_mat = torch.zeros((24576, 24576))
total_features = 24576

In [6]:
# n = 20115
i = 14
directory = "./data/"
jsondata = load_json_from_feature(i, directory)

word = jsondata['pos_str'][0]
input_id = model.to_tokens(word, prepend_bos=False)[:,[0]]
_, cache = model.run_with_cache(
    input_id,
    stop_at_layer=sae.cfg.hook_layer + 1,
    names_filter=[sae.cfg.hook_name]
)
sae_in = cache[sae.cfg.hook_name]
feature_acts = sae.encode(sae_in).squeeze()

total_features = feature_acts.shape[0]

# feature_acts.shape
word, input_id, sae_in.shape, feature_acts.shape

(' artist',
 tensor([[6802]], device='cuda:0'),
 torch.Size([1, 1, 768]),
 torch.Size([24576]))

In [7]:
feature_mat[14,:].shape, feature_acts.shape

(torch.Size([24576]), torch.Size([24576]))

In [5]:
pbar = tqdm(range(total_features))
for i in pbar:
    directory = "./data/"
    jsondata = load_json_from_feature(i, directory)

    word = jsondata['pos_str'][0]
    input_id = model.to_tokens(word, prepend_bos=False)[:,[0]]
    _, cache = model.run_with_cache(
        input_id,
        stop_at_layer=sae.cfg.hook_layer + 1,
        names_filter=[sae.cfg.hook_name]
    )
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    assert feature_acts.shape == (total_features,)

    neg_feature_mat[i,:] = feature_acts
    pbar.set_description(f"Feature {i}")

Feature 606:   2%|▏         | 607/24576 [01:49<1:12:08,  5.54it/s]


KeyboardInterrupt: 

In [9]:
import pickle

# save the feature matrix
with open('feature_mat.pkl', 'wb') as f:
    pickle.dump(feature_mat, f)


In [6]:
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

# Load the feature matrix
with open('pos_feature_mat.pkl', 'rb') as f:
    pos_feature_mat = pickle.load(f)

# save the feature matrix after making it sparse
with open('pos_feature_mat_sparse.pkl', 'wb') as f:
    pos_feature_mat_sparse = pos_feature_mat.to_sparse()
    pickle.dump(pos_feature_mat_sparse, f)

# Display the feature matrix as a heatmap
# sns.heatmap(pos_feature_mat[2000:3000,2000:3000])

  return torch.load(io.BytesIO(b))


In [2]:
# Display the feature matrix as a chord diagram
import numpy as np
from chord import Chord

# Convert the feature matrix to a numpy array
pos_feature_mat_np = pos_feature_mat.numpy().tolist()


In [21]:
# display the chord diagram
Chord(pos_feature_mat_np, list(range(24576))).to_html()

TypeError: Object of type ndarray is not JSON serializable