In [1]:
import torch
import torch.nn.functional as F
from utils import load_our_sae, load_concatenated_sae, import_connor_sae

In [2]:
# # load single head SAEs
# head_saes = {}
# for hook_point_head_index in [0, 1]:
#     model, sae, activations_loader = load_our_sae(hook_point_head_index=hook_point_head_index)
#     head_saes[hook_point_head_index] = sae

In [3]:
%%capture
model, our_concatenated_sae, activations_loader = load_concatenated_sae(l1_coeff=3)
connor_sae = import_connor_sae()

In [59]:
model, single_head_sae, activations_loader = load_our_sae(hook_point_head_index=6, l1_coeff=5)

Loading SAE checkpoint for l1_coeff = 5, head index = 6
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

## Getting feature activations

In [100]:
batch_tokens = activations_loader.get_batch_tokens(batch_size=2)
B, T = batch_tokens.shape
_, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

# why is the reconst loss > 5 instead of expected < 4?
out = analyze_sae(single_head_sae)
orig_loss, reconst_loss, zero_loss = out[-3], out[-2], out[-1]

average l0 0.09824047237634659


In [101]:
orig_loss, reconst_loss, zero_loss

(3.9270927906036377, 4.031135559082031, (4.045295715332031,))

In [9]:
# get model cache
import numpy as np
num_batches = 20
performance = {}
#for l1_coeff in [15, 3, 5, 10]:
for l1_coeff in [3]:
    for hook_point_head_index in range(8):
        model, single_head_sae, activations_loader = load_our_sae(hook_point_head_index=hook_point_head_index, 
                                                                  l1_coeff=l1_coeff)
        sparsities, orig_losses, zero_losses, reconst_losses = [], [], [], []
        for _ in range(num_batches):
            batch_tokens = activations_loader.get_batch_tokens(batch_size=2)
            B, T = batch_tokens.shape
            _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

            # why is the reconst loss > 5 instead of expected < 4?
            out = analyze_sae(single_head_sae)
            sparsity, orig_loss, reconst_loss, zero_loss = out[-4], out[-3], out[-2], out[-1]
            
            sparsities.append(sparsity)
            orig_losses.append(orig_loss)
            reconst_losses.append(reconst_loss)
            zero_losses.append(zero_loss)

        performance[(l1_coeff, hook_point_head_index)] = [np.mean(np.array(sparsities)),
                                                          np.mean(np.array(orig_losses)), 
                                                          np.mean(np.array(reconst_losses)), 
                                                          np.mean(np.array(zero_losses))]

Loading SAE checkpoint for l1_coeff = 3, head index = 0
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 3, head index = 1
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 3, head index = 2
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

In [None]:
good_saes = []
for (l1_coeff, head_id) in performance.keys():
    orig_, reconst_, zero_ = performance[(l1_coeff, head_id)] 
    print(head_id, (orig_-reconst_)/(orig_-zero_))
    # if reconst_ < zero_:
    #     good_saes.append((l1_coeff, head_id))

0 0.7557242050067783
1 0.6914234481020729
2 0.5121234344043092
3 0.8663394780454207
4 0.8402239927609593
5 1.2360755247713928
6 0.8630375144984015
7 0.8830042029864679


In [None]:
performance

{(3, 0): [3.140520143508911, 3.2711434423923493, 3.313365340232849],
 (3, 1): [3.140520143508911, 3.3660142064094543, 3.4666503369808197],
 (3, 2): [3.140520143508911, 3.251505768299103, 3.357236695289612],
 (3, 3): [3.140520143508911, 3.256797742843628, 3.274737274646759],
 (3, 4): [3.140520143508911, 3.315368390083313, 3.3486173272132875],
 (3, 5): [3.140520143508911, 3.2167920649051664, 3.2022250473499296],
 (3, 6): [3.140520143508911, 3.3195857763290406, 3.348003166913986],
 (3, 7): [3.140520143508911, 3.235272800922394, 3.247827285528183]}

In [5]:
# {(3, 0): [3.140520143508911, 3.2711434423923493, 3.313365340232849],
#  (3, 1): [3.140520143508911, 3.3660142064094543, 3.4666503369808197],
#  (3, 2): [3.140520143508911, 3.251505768299103, 3.357236695289612],
#  (3, 3): [3.140520143508911, 3.256797742843628, 3.274737274646759],
#  (3, 4): [3.140520143508911, 3.315368390083313, 3.3486173272132875],
#  (3, 5): [3.140520143508911, 3.2167920649051664, 3.2022250473499296],
#  (3, 6): [3.140520143508911, 3.3195857763290406, 3.348003166913986],
#  (3, 7): [3.140520143508911, 3.235272800922394, 3.247827285528183]}

0.7647058823529409

In [92]:
sorted(performance.keys(), key=lambda x: (x[1], x[0]))

[(3, 0),
 (5, 0),
 (10, 0),
 (15, 0),
 (3, 1),
 (5, 1),
 (10, 1),
 (15, 1),
 (3, 2),
 (5, 2),
 (10, 2),
 (15, 2),
 (3, 3),
 (5, 3),
 (10, 3),
 (15, 3),
 (3, 4),
 (5, 4),
 (10, 4),
 (15, 4),
 (3, 5),
 (5, 5),
 (10, 5),
 (15, 5),
 (3, 6),
 (5, 6),
 (10, 6),
 (15, 6),
 (3, 7),
 (5, 7),
 (10, 7),
 (15, 7)]

In [104]:
single_head_sae.W_dec.shape

torch.Size([2048, 64])

In [1]:
for (l1_coeff, head_id) in sorted(performance.keys(), key=lambda x: (x[1], x[0])):
    orig_, reconst_, zero_ = performance[(l1_coeff, head_id)] 
    reconst_score = (orig_ - reconst_)/(orig_ - zero_)
    print(f"head={head_id}, l1={l1_coeff}, reconst score={1-reconst_score:.4f}")

NameError: name 'performance' is not defined

In [None]:
# good thing is that there is a trend. 
# Increasing L1-coefficient decreases L0-norm but it increases reconstruction score.
# WTF is going on!

In [75]:
# It seems that for the most part L1-coefficient=5 is fine.
# Let's check how it does when concatenated
_, our_concatenated_sae, _ = load_concatenated_sae(l1_coeff=5)

Loading SAE checkpoint for l1_coeff = 5, head index = 0
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 1
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 2
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 3
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 4
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 5
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 6
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Loading SAE checkpoint for l1_coeff = 5, head index = 7
Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Concatenating all SAEs from individual attention heads


In [None]:
# reconstruction of the concatenated SAE is pretty bad actually. 
# don't know why!

In [76]:
analyze_sae(our_concatenated_sae)

average l0 7.051319599151611


(tensor([[[ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          ...,
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939]],
 
         [[ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3090,  0.8454, -1.2604,  ...,  0.8405,  0.1942,  0.0939],
          ...,
          [ 0.3228,  0.8550, -1.2684,  ...,  0.7868,  0.2176,  0.1205],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939],
          [ 0.3228,  0.8550, -1.2684,  ...,  0.8405,  0.1942,  0.0939]]]),
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0

In [95]:
(2.6 - 5.57)/(2.6-6.3)

0.8027027027027028

In [77]:
(5.5 - 2.6)/(6.3-2.6)

0.7837837837837838

connor_dashboards, dashboards, training_scripts/wandb

In [12]:
analyze_sae(connor_sae)

average l0 16.297653198242188
Orig 2.682035207748413
reconstr 3.1920928955078125
Zero 6.446569919586182


In [85]:
(2.68-3.19)/(2.68-6.44)

0.13563829787234036

In [None]:
sae(
        cache[hook_point]
    )

In [8]:
from transformer_lens import utils
from functools import partial

# next we want to do a reconstruction test.
def analyze_concatenated_sae(sae):
    assert sae.cfg.hook_point_head_index is None
    hook_point = sae.cfg.hook_point
    sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss = sae(
        cache[hook_point].view(B, T, -1)
    )

    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    # print("average l0", l0.mean().item())

    nh, dh = model.cfg.n_heads, model.cfg.d_head
    def reconstr_hook(activation, hook, sae_out):
        return sae_out.view(B, T, nh, dh)


    def zero_abl_hook(activation, hook):
        return torch.zeros_like(activation)
    
    orig_loss = model(batch_tokens, return_type="loss").item()
    #print("Orig", orig_loss)
    reconst_loss = model.run_with_hooks(
                    batch_tokens,
                    fwd_hooks=[
                        (utils.get_act_name("z", 1), 
                        partial(reconstr_hook, sae_out=sae_out),)
                        ],
                    return_type="loss",
                    ).item(),
    # print("reconstr",reconst_loss)
    zero_abl_loss = model.run_with_hooks(
            batch_tokens,
            return_type="loss",
            fwd_hooks=[(utils.get_act_name("z", 1), 
                        partial(zero_abl_hook))],
        ).item()
    # print("Zero", zero_abl_loss)
    return sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss, l0, orig_loss, reconst_loss, zero_abl_loss

def analyze_single_head_sae(sae):
    assert sae.cfg.hook_point_head_index is not None
    hook_point = sae.cfg.hook_point
    hook_point_head_index = sae.cfg.hook_point_head_index
    sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss = sae(
        cache[hook_point][:, :, hook_point_head_index, :]
    )

    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    # print("average l0", l0.mean().item())

    
    def reconstr_hook(activation, hook, sae_out, hook_point_head_index):
        activation[:, :, hook_point_head_index, :] = sae_out
        return activation

    def zero_abl_hook(activation, hook, hook_point_head_index):
        activation[:, :, hook_point_head_index, :] = torch.zeros(activation.shape[0], activation.shape[1], activation.shape[3])
        # return torch.zeros_like(activation)
        return activation
    
    # print("Orig", model(batch_tokens, return_type="loss").item())
    orig_loss = model(batch_tokens, return_type="loss").item()
    reconst_loss = model.run_with_hooks(
            batch_tokens,
            fwd_hooks=[
                (utils.get_act_name("z", 1), # TODO: is this correct? how do we specify head number?
                partial(reconstr_hook, sae_out=sae_out, hook_point_head_index=hook_point_head_index),)
                ],
            return_type="loss",
            ).item()
    # print("reconstr", reconst_loss)

    zero_abl_loss = model.run_with_hooks(
            batch_tokens,
            return_type="loss",
            fwd_hooks=[(utils.get_act_name("z", 1), 
                        partial(zero_abl_hook, hook_point_head_index=hook_point_head_index))],
            ).item(),
    # print("Zero", zero_abl_loss)
    return sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss, l0, orig_loss, reconst_loss, zero_abl_loss

def analyze_sae(sae):
    if sae.cfg.hook_point_head_index is None:
        return analyze_concatenated_sae(sae)
    else:
        return analyze_single_head_sae(sae)

In [98]:
utils.get_act_name("z", 1)

'blocks.1.attn.hook_z'