## Code to evaluate SAEs when used in the model

In [1]:
# use HookedTransformer hooks to replace the specified component output with it's SAE counterpart
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer, utils
import torch
import nnsight
device = 'cuda'
from tasks.ioi.IOITask import IOITask
from tasks.facts.SportsTask import SportsTask
from tasks.owt.OWTTask import OWTTask

import pandas as pd

In [8]:
model = HookedTransformer.from_pretrained(
    'EleutherAI/pythia-70m-deduped',
    # 'EleutherAI/pythia-1.4b-deduped',
    # 'EleutherAI/pythia-2.8b-deduped',
    device=device
)

model.set_use_hook_mlp_in(True)
tokenizer = model.tokenizer
batch_size=500

ioi_task = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, handle_multitoken_labels=True, num_data=1000)
sports_task = SportsTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
owt_task = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device, ctx_length=50)

Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [14]:
from dictionary_learning.dictionary import AutoEncoder

layer = 1
hidden_layer = False # if True, use the hidden layer, else use the output layer
submodule = model.blocks[layer].mlp # layer 1 MLP
# apply hook to block
hook_pos = utils.get_act_name("mlp_out", layer)
activation_dim = model.cfg.d_model # output dimension of the MLP
dictionary_size = 16 * activation_dim * 4 if hidden_layer else 16 * activation_dim

model_type = "1_32768" if hidden_layer else "0_8192"
sae_dict = torch.load(f"baulab.us/u/smarks/autoencoders/pythia-70m-deduped/mlp_out_layer{layer}/{model_type}/ae.pt")

sae = AutoEncoder(activation_dim, dictionary_size).to(device)
sae.load_state_dict(state_dict=sae_dict)

pre_sae_acts = []
post_sae_acts = []

# sae = AutoEncoder(activation_dim, dictionary_size*4).to(device)
def apply_sae_hook(pattern, hook, sae, pre_sae_acts=None, post_sae_acts=None):
    """
    During inference time, run SAE on the output of the specified layer, and feed it back in.
    """
    if pre_sae_acts is not None:
        pre_sae_acts.append(pattern.clone().cpu())
    pattern = sae(pattern)
    if post_sae_acts is not None:
        post_sae_acts.append(pattern.clone().cpu())
    return pattern


In [15]:
import torch
def display_memory():
    total = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    print(f"{a*1e-9} allocated, {r*1e-9} reserved, {total*1e-9} total")
display_memory()

0.5874611200000001 allocated, 81.45128652800001 reserved, 84.986691584 total


In [16]:
fresh_sae = AutoEncoder(activation_dim, dictionary_size).to(device)

In [17]:
_, test_cache = model.run_with_cache(
    tokenizer(next(ioi_task.train_iter)['text'], return_tensors='pt').input_ids[0],
    )

print(test_cache.keys())
print(test_cache['blocks.1.hook_mlp_out'].shape)

dict_keys(['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_mlp_in', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_mlp_in', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post'

In [19]:
# I'm pretty sure that these tasks also take an inference function, not just a model. That makes it more convenient to use run_with_hooks
def sae_inference_fn(tokens, model=model, hook_name=hook_pos, sae=sae):
    return model.run_with_hooks(
        tokens,
        fwd_hooks = [
            (hook_name, lambda pattern, hook: apply_sae_hook(pattern, hook, sae,)) # pre_sae_acts, post_sae_acts))
        ]
    )

fresh_sae_inference_fn = lambda tokens: sae_inference_fn(tokens, model=model, hook_name=hook_pos, sae=fresh_sae)

results = {
    'Model': ['default model', 'pretrained sae', 'random init 16x sae'],
    # 'Model': ['default model', 'random 64x sae', 'random 16x sae'],
    'IOI Loss': [
        ioi_task.get_test_loss(model).item(), 
        ioi_task.get_test_loss(sae_inference_fn).item(),
        ioi_task.get_test_loss(fresh_sae_inference_fn).item()
    ],
    'IOI Accuracy': [
        ioi_task.get_test_accuracy(model, check_all_logits=False), 
        ioi_task.get_test_accuracy(sae_inference_fn, check_all_logits=False),
        ioi_task.get_test_accuracy(fresh_sae_inference_fn, check_all_logits=False)
    ],
    'Sports Loss': [
        sports_task.get_test_loss(model).item(), 
        sports_task.get_test_loss(sae_inference_fn).item(),
        sports_task.get_test_loss(fresh_sae_inference_fn).item()
    ],
    'Sports Accuracy': [
        sports_task.get_test_accuracy(model, check_all_logits=False), 
        sports_task.get_test_accuracy(sae_inference_fn, check_all_logits=False),
        sports_task.get_test_accuracy(fresh_sae_inference_fn, check_all_logits=False)
    ],
    'OWT Loss': [
        owt_task.get_test_loss(model).item(), 
        owt_task.get_test_loss(sae_inference_fn).item(),
        owt_task.get_test_loss(fresh_sae_inference_fn).item()
    ],
}

results_df = pd.DataFrame(results)
display(results_df)

# print(f"IOI Loss: {ioi_task.get_test_loss(sae_inference_fn)}")
# print(f"IOI Accuracy: {ioi_task.get_test_accuracy(sae_inference_fn, check_all_logits=False)}")
# print(f"Sports Loss: {sports_task.get_test_loss(sae_inference_fn)}")
# print(f"Sports Accuracy: {sports_task.get_test_accuracy(sae_inference_fn, check_all_logits=False)}")
# print(f"OWT Loss: {owt_task.get_test_loss(sae_inference_fn)}")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Unnamed: 0,Model,IOI Loss,IOI Accuracy,Sports Loss,Sports Accuracy,OWT Loss
0,default model,5.787771,0.18,4.635437,0.347134,4.227315
1,pretrained sae,6.319211,0.22,4.951598,0.363057,4.373736
2,random init 16x sae,6.25904,0.495,7.754953,0.340764,7.589589


## Prewritten SAE evaluation code

In [15]:
from dictionary_learning.evaluation import loss_recovered, evaluate
from nnsight import LanguageModel
from dictionary_learning.buffer import ActivationBuffer
from dictionary_learning.training import trainSAE
from datasets import load_dataset
import torch

nn_model = LanguageModel(
    'EleutherAI/pythia-70m-deduped', # this can be any Huggingface model
    device_map = 'cuda:0'
)


# Load the dataset
# train_dataset = load_dataset('wikitext', 'wikitext-103-v1', split='train[:1000000]')
train_dataset = load_dataset('Skylion007/openwebtext', split='train[:100]')
def yield_sentences(data_split):
    for example in data_split:
        text = example['text']
        sentences = text.split('\n')
        for sentence in sentences:
            if sentence:  # skip empty lines
                yield sentence

# Creating an iterator for training sentences
train_sentences = yield_sentences(train_dataset)


In [16]:
[next(train_sentences)

'Port-au-Prince, Haiti (CNN) -- Earthquake victims, writhing in pain and grasping at life, watched doctors and nurses walk away from a field hospital Friday night after a Belgian medical team evacuated the area, saying it was concerned about security.'