The main goal is to produce feature dashboards. 

But perhaps what needs to be done along the way is other kinds of analysis: for example, norms of decoder columns (as these SAEs are being trained with Anthropic's April update recipe).

In [1]:
import os
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
torch.set_grad_enabled(False)
# I don't fully understand this but it seems important to avoid some warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_name = "gelu-2l"
hook_point_layer=1
hook_point=f"blocks.{hook_point_layer}.attn.hook_z"
d_in= 64
expansion_factor = 32
sae_name = f"{model_name}_{hook_point}_{d_in * expansion_factor}_"

In [2]:
ckpt_subfolders = { 
    0: "rovi1lwe", 
    1: "p7113j0v", 
    2: "rjc53kjg", 
    3: "hibm6x1l", 
    4: "4xima76s", 
    5: "jq26bfpa", 
    6: "b8e2a9w5", 
    7: "smfws6mc" 
}

def get_ckpt_dir(hook_point_head_index):
    ckpt_dir = os.path.join("checkpoints", 
                        ckpt_subfolders[hook_point_head_index], 
                        "983044096", 
                        sae_name)
    return ckpt_dir

In [3]:
# load SAEs for each attention head
n_heads = 8 
saes = {}
for hook_point_head_index in range(n_heads):
    print(f"Loading SAE for head # {hook_point_head_index}")
    ckpt_dir = get_ckpt_dir(hook_point_head_index)
    model, sae, activations_loader = LMSparseAutoencoderSessionloader.load_pretrained_sae(path=ckpt_dir,
                                                                                          device=device)
    saes[hook_point_head_index] = sae.autoencoders[sae_name]

Loading SAE for head # 0
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 1
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 2
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 3
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 4
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 5
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 6
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loading SAE for head # 7
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

In [4]:
# concatenate our SAEs into an SAE that decomposes concatenated activations
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from dataclasses import asdict
cat_sae_cfg = asdict(saes[0].cfg)
cat_sae_cfg["hook_point_head_index"] = None
cat_sae_cfg["d_in"] = 64 * 8
cat_sae_cfg["d_sae"] = 2048 * 8
cat_sae_cfg = LanguageModelSAERunnerConfig(**cat_sae_cfg)

cat_sae = SparseAutoencoder(cat_sae_cfg)

# weights are block-diagonal and biases are just concatenations
cat_sae.W_dec = torch.nn.Parameter(torch.block_diag(*[sae.W_dec for head, sae in sorted(saes.items())]))
cat_sae.W_enc = torch.nn.Parameter(torch.block_diag(*[sae.W_enc for head, sae in sorted(saes.items())]))
cat_sae.b_dec = torch.nn.Parameter(torch.cat(tuple(sae.b_dec.data for head, sae in sorted(saes.items()))))
cat_sae.b_enc = torch.nn.Parameter(torch.cat(tuple(sae.b_enc.data for head, sae in sorted(saes.items()))))

In [5]:
# import Connor and Rob's SAE
from utils import CR_AutoEncoder
from sae_lens.toolkit.pretrained_saes import convert_connor_rob_sae_to_our_saelens_format
auto_encoder_run = "concat-z-gelu-21-l1-lr-sweep-3/gelu-2l_L1_Hcat_z_lr1.00e-03_l12.00e+00_ds16384_bs4096_dc1.00e-07_rie50000_nr4_v78"
cr_sae = CR_AutoEncoder.load_from_hf(auto_encoder_run)
# New sae-lens state dict requires scaling factor which CR'SAE did not have
cr_sae_state_dict = cr_sae.state_dict()
cr_sae_state_dict["scaling_factor"] = torch.ones(cr_sae.cfg["dict_size"],)

cr_sae = convert_connor_rob_sae_to_our_saelens_format(
    state_dict=cr_sae_state_dict,
    config=cr_sae.cfg,
)

In [38]:
# activation store can give us tokens.
batch_tokens = activations_loader.get_batch_tokens(batch_size=8) 
B, T = batch_tokens.shape
_, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

# # Use the SAE
# sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss
_, cat_feature_acts, _, _, _, _ = cat_sae(
    cache[cat_sae.cfg.hook_point].view(B, T, -1)
) 

# activation store can give us tokens.
_, cr_feature_acts, _, _, _, _ = cr_sae(
    cache[cr_sae.cfg.hook_point].view(B, T, -1)
) 

del cache

In [41]:
from utils import compute_corr_matrix
feature_id = 99
corr_matrix = compute_corr_matrix(cr_feature_acts[:, :, 99].view(-1)[:, None], cat_feature_acts.view(-1, cat_feature_acts.shape[-1]))
corr_matrix = corr_matrix.squeeze(dim=0)
corr_matrix[corr_matrix.isnan()] = float('-inf')

k = 20
values, indices = corr_matrix.topk(k=k)
for val, id in zip(values, indices):
    print(f"corr val: {val:.2f}, head # {id // 2048}, feature # {id % 2048}")