In [1]:
from sae_lens import SAE
import torch
from datasets import load_dataset
from transformer_lens import HookedTransformer
import os
import plotly.express as px
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("gpt2-small", device=DEVICE)

sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.7.hook_resid_pre",
    device=DEVICE,
)

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path='NeelNanda/pile-10k',
    split='train',
    streaming=False
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,
    tokenizer=model.tokenizer,
    streaming=True,
    max_length=sae.cfg.metadata.context_size,
    add_bos_token=sae.cfg.metadata.prepend_bos,
)

In [4]:
sae.eval()

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.metadata.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

average l0 61.42002868652344


In [5]:
decoder_matrix = sae.W_dec
print(decoder_matrix.shape[0], "dictionary atoms")
print(decoder_matrix.shape[1], "mlp neurons")

24576 dictionary atoms
768 mlp neurons


In [6]:
#delete model and sae to save memory
del model
del sae

In [7]:
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
import einops

decoder_normalized = F.normalize(decoder_matrix, p=2, dim=1)

similarity_matrix = decoder_normalized @ decoder_normalized.T

pruned_matrix = torch.where(similarity_matrix > 0.5, similarity_matrix, 0.0)

sparse_graph = pruned_matrix.to_sparse_coo()

