The main goal is to produce feature dashboards. 

But perhaps what needs to be done along the way is other kinds of analysis: for example, norms of decoder columns (as these SAEs are being trained with Anthropic's April update recipe).

In [39]:
import os
import torch
import torch.nn.functional as F
import plotly.express as px
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from transformer_lens import utils
from functools import partial
import plotly.graph_objects as go


# Set to use 20% of GPU memory, i.e. 8GB on an A100 as ~24 GB is being used to train SAEs
# Can remove this if not training on the GPU at the same time
# if device == "cuda":
#     torch.cuda.set_per_process_memory_fraction(0.2, 0)

torch.set_grad_enabled(False)

# I don't fully understand this but it seems important to avoid some warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# wandb: https://wandb.ai/shehper/gelu-2l-attn-1-sae/reports/gelu-2l-layer-1-attn-heads--Vmlldzo4MDA1NzE4/edit
# expcted losses with SAEs are written in front of the labels below
ckpt_subfolders = { 
    0: "rovi1lwe", #3.785
    1: "p7113j0v", #3.807
    2: "rjc53kjg", #3.768
    3: "hibm6x1l", #3.738
    4: "4xima76s", #3.746
    5: "jq26bfpa", #3.729
    6: "b8e2a9w5", #3.75
    7: "smfws6mc" # 3.748
}

model_name = "gelu-2l"
hook_point_layer=1
hook_point=f"blocks.{hook_point_layer}.attn.hook_z"

d_in= 64
expansion_factor = 32
sae_name = f"{model_name}_{hook_point}_{d_in * expansion_factor}_"

In [2]:
# TODO: Can perhaps probably give load_pretrained_sae a list of paths to load a few different SAEs.

### Code to replace a single head with an SAE

In [4]:
hook_point_head_index = 0 # specify the head index
ckpt_dir = os.path.join("checkpoints", 
                        ckpt_subfolders[hook_point_head_index], 
                        "983044096", # TODO: pick the last ckpt subdir by sorting in
                        sae_name)

model, saes, activations_loader = LMSparseAutoencoderSessionloader.load_pretrained_sae(path=ckpt_dir,
                                                                                        device=device)

# print(saes.autoencoders.keys())
# saes.autoencoders['gelu-2l_blocks.1.attn.hook_z_2048_'].W_dec

sparse_autoencoder = saes.autoencoders[sae_name]
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

SparseAutoencoder(
  (activation_fn): ReLU()
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

In [5]:
# activations_loader.store_batch_size_prompts = 16 but that's too large for my macbook
# if we don't specify batch_size in activations_loader.get_batch_tokens() it defaults to this value
# but here I specify it
batch_size = activations_loader.store_batch_size_prompts if device == "cuda" else 1

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = activations_loader.get_batch_tokens(batch_size=batch_size) 
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # # Use the SAE
    # sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss
    # sae_out, _, _, _, _, _ = sparse_autoencoder(
    #     cache[sparse_autoencoder.cfg.hook_point]
    # ) 
    


    # # save some room
    # del cache

    # # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    # l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    # print("average l0", l0.mean().item())
    # # px.histogram(l0.flatten().cpu().numpy()).show()

In [5]:
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out, hook_point_head_index):
    activation[:, :, hook_point_head_index, :] = sae_out[:, :, hook_point_head_index, :]
    return activation


def zero_abl_hook(activation, hook, hook_point_head_index):
    activation[:, :, hook_point_head_index, :] = torch.zeros(activation.shape[0], activation.shape[1], activation.shape[3])
    # return torch.zeros_like(activation)
    return activation


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name("z", 1), # TODO: is this correct? how do we specify head number?
                partial(reconstr_hook, sae_out=sae_out, hook_point_head_index=hook_point_head_index),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("z", 1), 
                    partial(zero_abl_hook, hook_point_head_index=hook_point_head_index))],
    ).item(),
)

Orig 3.699457883834839
reconstr 3.7783544063568115
Zero 3.8725802898406982


In [6]:
# sparse_autoencoder.W_dec.shape = (2048, 64)
dictionary_norms = torch.linalg.vector_norm(sparse_autoencoder.W_dec, dim=1).cpu().numpy()
print(f"average norm: {dictionary_norms.mean():.4f}")

fig = px.histogram(dictionary_norms, nbins=50, title='Histogram of Dictionary Vector Lengths')
fig.update_layout(
    xaxis_title='Length',
    yaxis_title='Frequency',
    bargap=0.2
)

# Show the plot
fig.show()


average norm: 2.0519


In [7]:
(sparse_autoencoder.W_dec / dictionary_norms[:, None]).shape

torch.Size([2048, 64])

In [8]:
# TODO: does SAELens have code that normalizes weights, biases and activations as in April update?

In [9]:
normalized_dictionary = sparse_autoencoder.W_dec / dictionary_norms[:, None]
normalized_dictionary_norms = torch.linalg.vector_norm(normalized_dictionary, dim=1).cpu().numpy()
print(f"average norm: {normalized_dictionary_norms.mean():.4f}")

fig = px.histogram(normalized_dictionary_norms, nbins=50, title='Histogram of Normalized Dictionary Vector Lengths')
fig.update_layout(
    xaxis_title='Length',
    yaxis_title='Frequency',
    bargap=0.2
)

# Show the plot
fig.show()


average norm: 1.0000


### Let's play around with TransformerLens cache

In [6]:
## The goal is to reproduce "attn_out = W_O @ z".

# batch_tokens = activations_loader.get_batch_tokens(batch_size=batch_size) 
# _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
hook_z = cache['blocks.1.attn.hook_z'] # (B, T, nh, dh)
B, T, nh, dh = hook_z.shape
hook_attn_out = cache['blocks.1.hook_attn_out'] # (B, T, C)

In [7]:
# model.W_O.shape = (n_layers, nh, dh, n_embd)
# model.b_O.shape = (n_layers, n_embd)
layer1_WO = model.W_O[1] # (nh, dh, n_embd)
layer1_bO = model.b_O[1]
manual_attn_out = torch.einsum("BTNH,NHC->BTC", hook_z, layer1_WO) + layer1_bO
torch.allclose(manual_attn_out, hook_attn_out)

# effectively the following operations
# layer1_WO = model.W_O[1] # (nh, dh, n_embd)
# total = torch.zeros(B, T, nh*dh)
# for i in range(nh):
#     total += hook_z[:, :, i, :] @ model.W_O[1][i]
# total += model.b_O[1]
# torch.allclose(total, hook_attn_out, atol=1e-5)

True

In [8]:
cat_hook_z = hook_z.flatten(start_dim=-2, end_dim=-1) # (B, T, C)
cat_WO = layer1_WO.flatten(start_dim=0, end_dim=1) # (C, C)
cat_attn_out = torch.einsum("BTC,CD->BTD", cat_hook_z, cat_WO) + layer1_bO # (B, T, C)
torch.allclose(cat_attn_out, hook_attn_out, atol=1e-5)

True

In [24]:
# I want to get something like (B, T, nh, C)
individual_head_out = torch.einsum("BTNH,NHC->BTNC", hook_z, layer1_WO)  # (B, T, nh, C)
# can check that summing over all heads (and adding bias) gives the same answer; the following line returns True
# torch.allclose(hook_attn_out, torch.einsum("BTNC->BTC", individual_head_out) + layer1_bO, atol=1e-5)

torch.Size([1, 1024, 8, 512])

In [33]:
# focusing on the first batch dimension here but can generalize # TODO
head1, head2 = 0, 1
head1_out = individual_head_out[0, :, head1, :]
head2_out = individual_head_out[0, :, head2, :]

In [35]:
head1_out_norm = torch.linalg.vector_norm(head1_out, dim=-1) # (T, )
head2_out_norm = torch.linalg.vector_norm(head2_out, dim=-1) # (T, )
dot_product = torch.einsum("TC,TC->T", head1_out, head2_out)
dot_product /= (head1_out_norm * head2_out_norm) 
# This computes the same answer as F.cosine_similarity
#torch.allclose(dot_product, F.cosine_similarity(head1_out, head2_out))

In [50]:
from plotly.subplots import make_subplots
y = dot_product.numpy()

# Create a subplot figure with 2 rows and 1 column
fig = make_subplots(rows=1, cols=2, shared_xaxes=False, vertical_spacing=0.1,
                    subplot_titles=('Line Plot', 'Box Plot'))

# Add the line plot to the first subplot
fig.add_trace(go.Scatter(y=y, mode='lines+markers', marker=dict(size=3), line=dict(width=1), name='Line Plot'), row=1, col=1)

# Add the boxplot to the second subplot
fig.add_trace(go.Box(y=y, name='Box Plot'), row=1, col=2)

# Update the layout
fig.update_layout(
    title=f'Plot of {y.shape[0]} Tensor Elements',
    xaxis_title='Index',
    yaxis_title='Value',
    template='plotly_white'
)

# Show the plot
fig.show()

We note that cosine similarities range from -0.05 to 0.38 with a median value of 0.21. It follows that feature directions must also interfere. 

Let's think about what this means. If the outputs of different heads were orthogonal, we would say that different heads write to *orthogonal subspaces* of the residual stream space. But that is definitely *not* true. 

One of the things about interference has to be that it must affect the model performance. We should be able to see that somehow. Perhaps by training models of two different sizes. 

Another thing is that we must be able to interpret *which* features interfere with each other?

In [None]:
# TODO: It would be nice to see what the rank of n_features vectors from a single head is after the action of W_O. 
# Before the action of W_O, it is d_head. After the action, is it n_embd?

In [None]:
# There is always interference between the outputs. That is, there is not a single case where the lift to the
# n_embd-dimensional space is not 

In [10]:
head_0_hook_z_lift = torch.cat([hook_z[:, :, 0, :], torch.zeros(B, T, (nh-1)*dh)], dim=-1)
# one can check that head_0_hook_z_lift lies in the dh-dimensional subspace of the concatenated space
# but the action of W_O mixes it with our heads: 
# (head_0_hook_z_lift @ cat_WO) lives in an n_embd-dimensional space and has a projection onto each head's subspace
# print((head_0_hook_z_lift @ cat_WO)[0, 0, :]) 
# can we intepret the projection of one head's output onto another head? This might be easier with feature vectors
# where we can perhaps decompose a feature vector of one had as a linear combination of feature vectors of all heads

In [11]:
# TODO: can we check whether dictionary vectors of each head span the entire head space?
# apparently you can check the rank of each matrix. 
# Use the following as an example. This matrix has rank 3:
# A = torch.tensor([[1, 1, 1],
#                  [3, 2, 1],
#                  [1, 1, 0],
#                  [1, 0, 0]
#                  ], dtype=torch.float32)
# print(torch.linalg.matrix_rank(A))
# TODO: can we check that the dictionary of each head has rank dh?
# If yes, we can just decompose Kissane et al's feature vectors into head feature directions.
# TODO: Is it possible to interpret the decomposition?
print(torch.linalg.matrix_rank(sparse_autoencoder.W_dec).item())

64


In [107]:
# let's take a feature vector of head 0, lift it to the concatenated space and apply W_O.
lifted_feat_vec = torch.cat([sparse_autoencoder.W_dec[0], torch.zeros((nh-1)*dh,)])
(lifted_feat_vec @ cat_WO).shape

# What is the projection of this vector on head 1 space?
(lifted_feat_vec @ cat_WO)[64:128]

# how hard is it to write it as a linear combination of head 1 feature directions?
# in general, it's not going to be unique. But maybe we need to find a sparse decomposition of it?

tensor([-0.0446, -0.0459, -0.1217, -0.1061, -0.0331,  0.1629,  0.1702, -0.1390,
         0.1699, -0.0371, -0.0685,  0.0325,  0.1299,  0.1173,  0.1376, -0.0299,
         0.0787,  0.3107, -0.0155,  0.0416,  0.0492,  0.0835, -0.0504, -0.1963,
        -0.2972,  0.0910,  0.0516,  0.1121, -0.0101, -0.1786, -0.0916, -0.1545,
        -0.0197, -0.1067,  0.0210, -0.2395,  0.0058, -0.0413,  0.0809, -0.0477,
         0.1629, -0.0127,  0.0745, -0.2009, -0.2790, -0.0052, -0.0664, -0.0006,
        -0.0401, -0.1222, -0.0414, -0.0921, -0.0655,  0.0230,  0.1810, -0.0552,
        -0.0809, -0.1069, -0.0894,  0.1107,  0.1168,  0.1726,  0.0180, -0.0350])

In [None]:
# let's load all SAEs, look at the ranks of dictionaries and see if we can decompose the W_O-output of a feature
# vector into feature vectors from all heads. 

In [None]:
# TODO: does W_O cause interference between dictionary vectors from different heads?
# take feature vectors from head 0 and feature vectors from head 1. 
# apply W_O to both and calculate the cosine similarities between outputs. 

# TODO: but this has the problem that features don't often fire together. 

In [None]:
# TODO: how do the reconstructed losses compare if we splice in Kissane et al's SAE vs all of our SAEs?

In [None]:
# load Kissane et al's SAE, check the rank of its dictionary and decompose their dictionary vectors into our basis. 
# Extra points if we can do it in an interpretable way. 

### Loading all of our SAEs

In [2]:
# load all of our dictionaries
n_heads = 8 # gelu-2l has 8 heads
sparse_autoencoders = {}
for hook_point_head_index in range(n_heads):
    ckpt_dir = os.path.join("checkpoints", 
                            ckpt_subfolders[hook_point_head_index], 
                            "983044096", # TODO: pick the last ckpt subdir by sorting in
                            sae_name)

    _, saes, _ = LMSparseAutoencoderSessionloader.load_pretrained_sae(path=ckpt_dir,
                                                                     device=device)

    # print(saes.autoencoders.keys())
    # saes.autoencoders['gelu-2l_blocks.1.attn.hook_z_2048_'].W_dec

    sparse_autoencoder = saes.autoencoders[sae_name]
    sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads
    sparse_autoencoders[hook_point_head_index] = sparse_autoencoder

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


KeyboardInterrupt: 

In [129]:
# all heads have rank 64
# for head_id in range(n_heads):
#     print(torch.linalg.matrix_rank(sparse_autoencoders[0].W_dec))

features_per_head, dh = sparse_autoencoders[0].W_dec.shape

# dictionary in concatenated space
cat_dict = torch.zeros(nh*features_per_head, nh*dh)

for head_id in range(n_heads):
    cat_dict[head_id * features_per_head : (head_id + 1) * features_per_head, 
             head_id * dh : (head_id + 1) * dh] = sparse_autoencoders[head_id].W_dec
    
# check that cat_dict has the right rank = n_embd
print(torch.linalg.matrix_rank(cat_dict))

In [130]:
# manually investigate a few features
# maybe instead of generating feature pages for each head individually 
# one can just use the concatenated dict and concatenated feature activations
# this can also be used to perform forward pass through the 

tensor(512)

In [131]:
# TODO: W_O is a mixing matrix between different heads.
# I think that in the circuits paper, they interpret attention heads through W_O * W_V matrix.
# What if we only use W_V to get an interpretation for each head and W_O to get an interpretation 
# for the mixing between heads? 

torch.Size([16384, 512])

### Load Kissane et al's gelu-2l SAE

In [145]:
# Their SAE code is taken from the Colab Notebook
# https://colab.research.google.com/drive/10zBOdozYR2Aq2yV9xKs-csBH2olaFnsq?usp=sharing#scrollTo=feJOqPeoPjvX
# The SAE name is in "Loading models and data section"

In [155]:
from pathlib import Path 
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

from torch import nn
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")

class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.cfg = cfg
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("ckkissane/tinystories-1M-SAES", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        # pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("ckkissane/tinystories-1M-SAES", f"{version}.pt", force_is_torch=True))
        return self

In [158]:
auto_encoder_run = "concat-z-gelu-21-l1-lr-sweep-3/gelu-2l_L1_Hcat_z_lr1.00e-03_l12.00e+00_ds16384_bs4096_dc1.00e-07_rie50000_nr4_v78"

encoder = AutoEncoder.load_from_hf(auto_encoder_run)

# Many of their configurations are different from ours: See encoder.cfg
print(f"shape: {encoder.W_dec.shape}")
print(f"rank: {torch.linalg.matrix_rank(encoder.W_dec)}")

shape: torch.Size([16384, 512])
rank: 512


In [159]:
# feature # 99 is supposed to correspond to the Local Context Feature: In questions starting with Which
# https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs#Local_Context_Feature__In_questions_starting_with__Which_

# I have confirmed by generating feature card for this feature in their Colab notebook

In [175]:
which_feature = encoder.W_dec[99]


# let's try to decompose this vector in terms of concatenated basis
# let's look at the contribution of each head to the norm
print(f"norm of this feature: {torch.linalg.vector_norm(which_feature):.2f}")

sum = 0
for head_id in range(n_heads):
    print(f"for head {head_id}, norm-SQUARED contribution: {torch.linalg.vector_norm(which_feature[head_id * dh: (head_id + 1) * dh])**2:.2f}")
    sum += torch.linalg.vector_norm(which_feature[head_id * dh: (head_id + 1) * dh])**2

norm of this feature: 1.00
for head 0, norm-SQUARED contribution: 0.09
for head 1, norm-SQUARED contribution: 0.02
for head 2, norm-SQUARED contribution: 0.18
for head 3, norm-SQUARED contribution: 0.22
for head 4, norm-SQUARED contribution: 0.01
for head 5, norm-SQUARED contribution: 0.01
for head 6, norm-SQUARED contribution: 0.00
for head 7, norm-SQUARED contribution: 0.48


In [176]:
# heads in decreasing importance: 7, 3, 2, 0, 1, 4, 5, 6
# TODO: in their work, Kissane et al are wondering why the feature activations get a large contribution 
# from head 5 while decoder weights do not. Our analysis can likely shine a light on that. 

In [182]:
# let's decompose the projection of feature 99 on the last head space into the dictionary for this head
which_feature[7*dh:].shape, sparse_autoencoders[7].W_dec.shape

(torch.Size([64]), torch.Size([2048, 64]))

In [None]:
# I want to find a matrix of coefficients C such that which_feature[7*dh:] = C @ sparse_autoencoders[7].W_dec
# C should have shape (2048,)

In [186]:
# this has 2048 rows but 64 columns so it's a tall matrix
# this matrix likely has linearly-independent columns
# which means that it has a left inverse
# https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Definition
# so we CANNOT just invert sparse_autoencoders[7] to get C.
sparse_autoencoders[7].W_dec.shape

#torch.allclose(torch.linalg.pinv(sparse_autoencoders[7].W_dec) @ sparse_autoencoders[7].W_dec, torch.eye(dh, dh), atol=1e-6)

# the psuedoinv has shape (64, 2048)

torch.Size([2048, 64])

In [213]:
C = which_feature[7*dh:] @ torch.linalg.pinv(sparse_autoencoders[7].W_dec)

torch.Size([2048])

In [216]:
torch.allclose(C @ sparse_autoencoders[7].W_dec, which_feature[7*dh:], atol=1e-6)

True

In [223]:
C.max()

tensor(0.0079)

In [219]:
# C[i] is the amount ith feature of the head contributes to which_feature
# let's plot C[i]

# but the decomposition cannot be unique since feature directions form an overcomplete basis
# it would be nice if we could find a sparse decomposition instead?
# TODO: especially for interpretability, C would probably need to be sparse and non-negative.

tensor([-1.9657e-04,  1.4956e-04,  7.3280e-05,  ...,  1.0016e-03,
         3.3254e-06, -1.8092e-03])

In [224]:
# TODO: generalize the above to all heads

# This seems useful for decomposition
# https://mathoverflow.net/questions/145688/how-to-project-a-vector-onto-a-very-large-non-orthogonal-subspace

In [221]:
import torch
import numpy as np
import plotly.graph_objects as go

# Create a tensor with 2048 elements
tensor = C

# Convert the tensor to a NumPy array
array = tensor.numpy()

# Create a Plotly figure
fig = go.Figure()

# Add the data to the figure
fig.add_trace(go.Scatter(y=array, mode='lines+markers', marker=dict(size=3), line=dict(width=1)))

# Update the layout
fig.update_layout(
    title='Plot of 2048 Tensor Elements',
    xaxis_title='Index',
    yaxis_title='Value',
    template='plotly_white'
)

# Show the plot
fig.show()
