# Sparsity-faithfulness SAE and transcoder evaluations

This notebook demonstrates how to perform the sparsity-faithfulness SAE and transcoder evaluations, as seen in Section 3.2.2 of our paper. We will be evaluating our transcoders and SAEs on Pythia-410M.

# Setup

Import the standard `transcoder_circuits` code.

In [1]:
from transcoder_circuits.circuit_analysis import *
from transcoder_circuits.feature_dashboards import *
from transcoder_circuits.replacement_ctx import *

Import the SAE/transcoder code, along with the model that we'll be analyzing.

In [2]:
from sae_training.sparse_autoencoder import SparseAutoencoder
from transformer_lens import HookedTransformer, utils
import os
import torch

model = HookedTransformer.from_pretrained('pythia-410m')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-410m into HookedTransformer


Now, load in a corpus of text that we'll use for our analysis. We'll be drawing from OpenWebText.

In [3]:
# This function was stolen from one of Neel Nanda's exploratory notebooks
# Thanks, Neel!
import einops
def tokenize_and_concatenate(
    dataset,
    tokenizer,
    streaming = False,
    max_length = 1024,
    column_name = "text",
    add_bos_token = True,
):
    """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.

    This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)

    Args:
        dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
        tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
        streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
        max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
        column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
        add_bos_token (bool, optional): . Defaults to True.

    Returns:
        Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"

    Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why
    """
    for key in dataset.features:
        if key != column_name:
            dataset = dataset.remove_columns(key)

    if tokenizer.pad_token is None:
        # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
        tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    # Define the length to chop things up into - leaving space for a bos_token if required
    if add_bos_token:
        seq_len = max_length - 1
    else:
        seq_len = max_length

    def tokenize_function(examples):
        text = examples[column_name]
        # Concatenate it all into an enormous string, separated by eos_tokens
        full_text = tokenizer.eos_token.join(text)
        # Divide into 20 chunks of ~ equal length
        num_chunks = 20
        chunk_length = (len(full_text) - 1) // num_chunks + 1
        chunks = [
            full_text[i * chunk_length : (i + 1) * chunk_length]
            for i in range(num_chunks)
        ]
        # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
        tokens = tokenizer(chunks, return_tensors="np", padding=True)[
            "input_ids"
        ].flatten()
        # Drop padding tokens
        tokens = tokens[tokens != tokenizer.pad_token_id]
        num_tokens = len(tokens)
        num_batches = num_tokens // (seq_len)
        # Drop the final tokens if not enough to make a full sequence
        tokens = tokens[: seq_len * num_batches]
        tokens = einops.rearrange(
            tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
        )
        if add_bos_token:
            prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
            tokens = np.concatenate([prefix, tokens], axis=1)
        return {"tokens": tokens}

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=[column_name],
    )
    #tokenized_dataset.set_format(type="torch", columns=["tokens"])
    return tokenized_dataset


In [4]:
from datasets import load_dataset
from huggingface_hub import HfApi
import numpy as np

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800*2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])

In [5]:
owt_tokens_torch = torch.from_numpy(owt_tokens).cuda()

# SAE sweep evaluation



In [7]:
def eval_sae(model, owt_tokens_torch, sae, num_batches=100, batch_size=128):
    layer = sae.cfg.hook_point_layer

    # evaluate l0s
    l0s = []
    losses = []
    
    with torch.no_grad():
        for batch in tqdm.tqdm(range(0, num_batches)):
            cur_tokens = owt_tokens_torch[batch*batch_size:(batch+1)*batch_size]
            
            sae_acts = []
            def replacement_hook(acts, hook):
                sae_out = sae(acts)
                activations = sae_out[0].to(acts.dtype)
                sae_acts.append(sae_out[1])
                return activations
            
            loss = model.run_with_hooks(cur_tokens, return_type="loss", fwd_hooks=[(sae.cfg.hook_point, replacement_hook)])
            binarized_acts = 1.0*(sae_acts[0] > 0)
            l0s.append(
                (binarized_acts.reshape(-1, binarized_acts.shape[-1])).sum(dim=1).mean().item()
            )
            losses.append(utils.to_numpy(loss))
    
    return {
        'l0': np.mean(l0s),
        'sae_loss': np.mean(losses)
    }

In [8]:
sae_template = "pythia-mlpout-saes/l1_4e-05/ajqvp8fc/final_sparse_autoencoder_pythia-410m_blocks.15.hook_mlp_out_32768"
sae = SparseAutoencoder.load_from_pretrained(f"{sae_template}.pt").eval()
print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:49<00:00,  1.15s/it]

{'l0': 505.2548745727539, 'sae_loss': 3.3368855}





In [7]:
sae_template = "pythia-mlpout-saes/l1_7e-05/wzktf3zm/final_sparse_autoencoder_pythia-410m_blocks.15.hook_mlp_out_32768"
sae = SparseAutoencoder.load_from_pretrained(f"{sae_template}.pt").eval()
print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:49<00:00,  1.15s/it]

{'l0': 109.87291198730469, 'sae_loss': 3.351243}





In [7]:
sae_template = "pythia-mlpout-saes/l1_8.5e-05/k761159s/final_sparse_autoencoder_pythia-410m_blocks.15.hook_mlp_out_32768"
sae = SparseAutoencoder.load_from_pretrained(f"{sae_template}.pt").eval()
print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:48<00:00,  1.14s/it]

{'l0': 55.06920654296875, 'sae_loss': 3.3596144}





In [7]:
sae_template = "pythia-mlpout-saes/l1_0.0001/b2ezwp1x/final_sparse_autoencoder_pythia-410m_blocks.15.hook_mlp_out_32768"
sae = SparseAutoencoder.load_from_pretrained(f"{sae_template}.pt").eval()
print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:46<00:00,  1.13s/it]

{'l0': 31.498437805175783, 'sae_loss': 3.367786}





# Transcoder sweep evaluation

In [6]:
def eval_transcoder_l0_ce(model, all_tokens, transcoder, num_batches=100, batch_size=128):
    l0s = []
    transcoder_losses = []
    
    with torch.no_grad():
        for batch in tqdm.tqdm(range(0, num_batches)):
            torch.cuda.empty_cache()
            cur_batch_tokens = all_tokens[batch*batch_size:(batch+1)*batch_size]
            with TranscoderReplacementContext(model, [transcoder]):
                cur_losses, cache = model.run_with_cache(cur_batch_tokens, return_type="loss", names_filter=[transcoder.cfg.hook_point])
                # measure losses
                transcoder_losses.append(utils.to_numpy(cur_losses))
                # measure l0s
                acts = cache[transcoder.cfg.hook_point]
                binarized_transcoder_acts = 1.0*(transcoder(acts)a[1] > 0)
                l0s.append(
                    (binarized_transcoder_acts.reshape(-1, binarized_transcoder_acts.shape[-1])).sum(dim=1).mean().item()
                )

    return {
        'l0s': np.mean(l0s),
        'ce_loss': np.mean(transcoder_losses)
    }

In [7]:
transcoder_template = "./pythia-transcoders/lr_0.0002_l1_2.5e-05/pk60eijx/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template}.pt").eval()
print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:26<00:00,  1.03s/it]

{'l0s': 203.77172332763672, 'ce_loss': 3.341213}





In [8]:
transcoder_template = "pythia-transcoders/lr_0.0002_l1_3e-05/67jdp0mv/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template}.pt").eval()
print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:25<00:00,  1.03s/it]

{'l0s': 148.1538818359375, 'ce_loss': 3.3440711}





In [7]:
transcoder_template = "pythia-transcoders/lr_0.0002_l1_4e-05/pze62n3h/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template}.pt").eval()
print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:26<00:00,  1.03s/it]

{'l0s': 82.748544921875, 'ce_loss': 3.3491273}





In [8]:
transcoder_template = "pythia-transcoders/lr_0.0002_l1_5.5e-05/szsvunrm/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template}.pt").eval()
print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:26<00:00,  1.03s/it]

{'l0s': 44.042958984375, 'ce_loss': 3.3549356}





In [9]:
transcoder_template = "pythia-transcoders/lr_0.0002_l1_7e-05/v4gqmaoc/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768"
transcoder = SparseAutoencoder.load_from_pretrained(f"{transcoder_template}.pt").eval()
print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))

100%|██████████| 200/200 [03:26<00:00,  1.03s/it]

{'l0s': 27.454230651855468, 'ce_loss': 3.3682058}



