# **Setup** (No need to read. Run in High-RAM mode if you can)

## Dependencies

In [None]:
# !python3 -V

In [None]:
# !pip install transformer_lens
# !pip install gradio
# !pip install datasets

In [None]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch as t
import numpy as np
import gradio as gr

import torch as t
#from google.colab import drive

# This will prompt for authorization.
#drive.mount('/content/drive')

import einops
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import tqdm
from functools import partial
from datasets import load_dataset
from IPython.display import display

## load model and tokens

In [None]:
# load gpt2-small
model = HookedTransformer.from_pretrained("gpt2-small").to('cpu')

In [None]:
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(22)

In [None]:
def get_feature_acts(point, layer, dic, num_batches = 1000, minibatch_size = 50):
  try:
    del feature_acts
    del random_feature_acts
  except NameError:
    pass

  # get however many tokens we need
  toks = tokenized_data["tokens"][:num_batches]

  # get activations on test tokens at point of interest. Run model on batches of tokens with size [batch_size, 128]. Be careful with RAM.

  random_W_enc = t.randn( dic["W_enc"].size() )

  for i in tqdm.tqdm(range(toks.size(0)//minibatch_size)):
    # split toks into minibatch and run model with cache on minibatch
    toks_batch = toks[minibatch_size*i : minibatch_size*(i+1), :]
    logits, cache = model.run_with_cache(toks_batch, stop_at_layer=layer+1, names_filter=utils.get_act_name(point, layer))
    del logits

    act_batch = cache[point, layer]
    del cache

    # get feature acts and random feature acts on this minibatch (fewer random ones to save RAM)
    feature_act_batch = t.relu(einops.einsum(act_batch - dic["b_dec"], dic["W_enc"], "batch seq resid , resid mlp -> batch seq mlp")  + dic["b_enc"])

    random_feature_act_batch = t.relu(einops.einsum(act_batch[:10] - dic["b_dec"], random_W_enc, "batch seq resid , resid mlp -> batch seq mlp")  + dic["b_enc"])
    random_feature_act_batch = random_feature_act_batch / random_feature_act_batch.norm(dim=-1, keepdim=True) * feature_act_batch[:10].norm(dim=-1, keepdim=True)  #fix normalisation
    del act_batch

    # append minibatch feature acts to storage variable
    if i == 0:  # on first iteration, create feature_acts
      feature_acts = feature_act_batch
      random_feature_acts = random_feature_act_batch
    else:  # then add to it
      feature_acts = t.cat([feature_acts, feature_act_batch], dim=0)
      random_feature_acts = t.cat([random_feature_acts, random_feature_act_batch], dim=0)

    del feature_act_batch
    del random_feature_act_batch

  # set BOS acts to zero
  feature_acts[:, 0, :] = 0
  random_feature_acts[:, 0, :] = 0

  # flatten [batch n_seq] dimensions
  feature_acts = feature_acts.reshape(-1, feature_acts.size(2))
  random_feature_acts = random_feature_acts.reshape(-1, random_feature_acts.size(2))

  print("feature_acts has size:", feature_acts.size())

  return toks, feature_acts, random_feature_acts

## Neel utils (get_recons_loss, get_freqs)

In [None]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@t.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[t.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = t.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss


: 

### Get Frequencies

In [None]:
# Frequency
@t.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = t.zeros(local_encoder.d_hidden, dtype=t.float32).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = all_tokens[t.randperm(len(all_tokens))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
        mlp_acts = cache[utils.get_act_name("post", 0)]
        mlp_acts = mlp_acts.reshape(-1, d_mlp)

        hidden = local_encoder(mlp_acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

## Neel's visualisation Utils

In [None]:
from html import escape
import colorsys


from IPython.display import display

SPACE = "·"
NEWLINE="↩"
TAB = "→"

def create_html(strings, values, max_value=None, saturation=0.5, allow_different_length=False, return_string=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", f"{NEWLINE}<br/>").replace("\t", f"{TAB}&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, t.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    if max_value is None:
        max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'
    if return_string:
        return html
    else:
        display(HTML(html))

def basic_feature_vis(text, feature_index, max_val=0):
    feature_in = encoder.W_enc[:, feature_index]
    feature_bias = encoder.b_enc[feature_index]
    _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)][0]
    feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
    if max_val==0:
        max_val = max(1e-7, feature_acts.max().item())
        # print(max_val)
    # if min_val==0:
    #     min_val = min(-1e-7, feature_acts.min().item())
    return basic_token_vis_make_str(text, feature_acts, max_val)
def basic_token_vis_make_str(strings, values, max_val=None):
    if not isinstance(strings, list):
        strings = model.to_str_tokens(strings)
    values = utils.to_numpy(values)
    if max_val is None:
        max_val = values.max()
    # if min_val is None:
    #     min_val = values.min()
    header_string = f"<h4>Max Range <b>{values.max():.4f}</b> Min Range: <b>{values.min():.4f}</b></h4>"
    header_string += f"<h4>Set Max Range <b>{max_val:.4f}</b></h4>"
    # values[values>0] = values[values>0]/ma|x_val
    # values[values<0] = values[values<0]/abs(min_val)
    body_string = create_html(strings, values, max_value=max_val, return_string=True)
    return header_string + body_string
# display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
# # %%
# The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
import gradio as gr
try:
    demos[0].close()
except:
    pass
demos = [None]
def make_feature_vis_gradio(feature_id, starting_text=None, batch=None, pos=None):
    if starting_text is None:
        starting_text = model.to_string(all_tokens[batch, 1:pos+1])
    try:
        demos[0].close()
    except:
        pass
    with gr.Blocks() as demo:
        gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
        # The input elements
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label="Text", value=starting_text)
                # Precision=0 makes it an int, otherwise it's a float
                # Value sets the initial default value
                feature_index = gr.Number(
                    label="Feature Index", value=feature_id, precision=0
                )
                # # If empty, these two map to None
                max_val = gr.Number(label="Max Value", value=None)
                # min_val = gr.Number(label="Min Value", value=None)
                inputs = [text, feature_index, max_val]
        with gr.Row():
            with gr.Column():
                # The output element
                out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(starting_text, feature_id))
        for inp in inputs:
            inp.change(basic_feature_vis, inputs, out)
    demo.launch(share=True)
    demos[0] = demo

### Inspecting Top Logits

In [None]:
SPACE = "·"
NEWLINE="↩"
TAB = "→"
def process_token(s):
    if isinstance(s, t.Tensor):
        s = s.item()
    if isinstance(s, np.int64):
        s = s.item()
    if isinstance(s, int):
        s = model.to_string(s)
    s = s.replace(" ", SPACE)
    s = s.replace("\n", NEWLINE+"\n")
    s = s.replace("\t", TAB)
    return s

def process_tokens(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, t.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [process_token(s) for s in l]

def process_tokens_index(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, t.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [f"{process_token(s)}/{i}" for i,s in enumerate(l)]

def create_vocab_df(logit_vec, make_probs=False, full_vocab=None):
    if full_vocab is None:
        full_vocab = process_tokens(model.to_str_tokens(t.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "logit": utils.to_numpy(logit_vec)})
    if make_probs:
        vocab_df["log_prob"] = utils.to_numpy(logit_vec.log_softmax(dim=-1))
        vocab_df["prob"] = utils.to_numpy(logit_vec.softmax(dim=-1))
    return vocab_df.sort_values("logit", ascending=False)

### Make Token DataFrame

In [None]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]
def make_token_df(tokens, len_prefix=5, len_suffix=1):
    str_tokens = [process_tokens(model.to_str_tokens(t)) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    batch = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        # context.append([])
        # batch.append([])
        # pos.append([])
        # label.append([])
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            batch.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        batch=batch,
        pos=pos,
        label=label,
    ))

In [None]:
def feature_analysis(feature_id, toks, feature_acts, len_prefix=30, len_suffix=2, num_rows = 30, seq_length_to_check=128):
  feature_acts = feature_acts.reshape(toks.size(0), toks.size(1), feature_acts.size(-1))
  feature_acts = feature_acts[:, :seq_length_to_check, :]
  feature_acts = feature_acts.reshape(-1, feature_acts.size(2))
  print("avg feature activation: ", feature_acts[:,feature_id].mean())

  token_df = make_token_df(toks[:, :seq_length_to_check], len_prefix = len_prefix, len_suffix=len_suffix)
  token_df["feature"] =  utils.to_numpy(feature_acts[:, feature_id])
  styled_df = token_df.sort_values("feature", ascending=False).head(num_rows).style.background_gradient("coolwarm")
  display(styled_df)

# **Intro**

The following is a quick demo of some SAEs I trained on the activations of gpt2-small, using Neel Nanda's [open source code](https://github.com/neelnanda-io/1L-Sparse-Autoencoder) and [visualisation tools](https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn?usp=sharing#scrollTo=GbM9UZsJN0Uy).
I've written it in the spirit of messing around and having fun, with little by way of technical commentary: see e.g. [here](https://transformer-circuits.pub/2023/monosemantic-features), [here](https://www.lesswrong.com/posts/fKuugaxt2XLTkASkk/open-source-replication-and-commentary-on-anthropic-s) and [here](https://arxiv.org/abs/2309.08600) for nice technical explanations of SAEs.

The SAE weights are available at TODO. There is a separate SAE for each TODO naming TODO.

Each SAE has hidden dimension $8 d_\text{model} = 6144$. Most aspects of training were copied directly from Nanda. The only noteworthy differences are:
- Each SAE was trained on (the same) 1B tokens from OpenWebText.
- Batch size was reduced from 4096 to 1024 in order to fit on a V100.
- The coefficient of the $l_1$ norm term is the loss was set as follows:

      1) Before training the SAE, sample some activations at the point of interest.
      2) Throw away activations on BOS tokens.
      3) Calculate the mean residual vector in this sample, and subtract it from each of the individual residual activations to get a list of deviation vectors from the mean.
      4) Calculate the average norm of the sample of deviation vectors. Call it avg_norm.
      5) Set l1_coeff = avg_norm * 1e-5.

This was necessary because late-layer activations are larger than early-layer activations, which increases the size of the $l_2$ loss term relative to the $l_1$ loss term (since the former is quadratic order in activations whereas the latter is linear order). This was therefore compensate for by increasing l1_coeff.

For some early-layer SAEs, l1_coeff was reduced further, since this was found to help the reconstruction score stay high during training (i.e. overwriting the original activation by patching in the SAE's reconstruction of it did not make the next-token loss much worse). The layer 1 MLP was particularly hard to train, and I was not able to find hyperparameters that led to a good reconstruction score; hence the features found at this point should probably be ignored.









# **Analysis at resid_pre 9**

# Feature interpretation

Let's look for interpretable features at the resid_pre, layer 9 in gpt2-small. First we load the weights of our SAE trained at that point:

In [None]:
# load encoder weights
point, layer = "resid_pre", 9
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)
W_dec , b_dec, W_enc, b_enc = dic["W_dec"], dic["b_dec"], dic["W_enc"], dic["b_enc"]

In [None]:
# get the reconstruction loss



Now get the feature activations (i.e. hidden layer activations of the SAE) on this dataset. This will take a few minutes. You might need to decrease num_batches if you have limited RAM: 1000 batches require about 25GB RAM, so scale accordingly, but note that having fewer test tokens will make feature interpretation harder.

In [None]:
toks, feature_acts, random_feature_acts = get_feature_acts(point, layer, dic, num_batches=1000)

Let's first look at some features with high average activation:

In [None]:
_ , sorted_feature_inds = feature_acts.mean(dim=0).sort(descending=True)
print("some common features are: \n ", sorted_feature_inds[:20])

The feature with highest average activation is number 5337. The following is a list of examples of text where the feature activates the most. The specific token on which the feature activation is high is shown between vertical bars| like this|. But remember that  the model moves information forward through the text using attention, so it may be the case that a high feature activation at some token position says less about the token at that position and more about the text preceding it.

In [None]:
#####   For interactive use! Set feature_id to be an integer from 0 to 6143 #####
feature_id = 5217
feature_analysis(feature_id, toks, feature_acts, len_prefix = 30, len_suffix = 2, num_rows = 30) # reduce len_prefix to fit better on screen if needed

Feature 5337 seems to just activate on the token directly after a BOS token. Many of the very common features are similarly boring. E.g. 1224 activates on the ↩ symbol, and 130 activates on lists of things (although, as per usual, there's some junk mixed in).

If you change the feature_id above and start randomly looking at features, you'll find that there are a lot of boring ones that activate on one specific token, and also a lot of annoyingly polysemantic features. E.g. feature activates a lot on capitalised proper nouns, but it also activates really strongly on | sho|ebox and | adjud|ication. Could there be some subtle link between all these words that I didn't spot? Maybe. But it seems more likely that the small dictionary size is to blame. If the SAE had a larger hidden layer, I predict that the features it would find would be more monosemantic. (Another very plausible explanation is that I messed up something in training).

But there are also some gems. In my completely non-scientific search for fun features, I found it useful to sort features by their variance in activation across all the tokens. More precisely, for each feature we calculate its mean activation $\mu$ and its variance $\sigma^2$, and we sort the features by $\sigma/\mu$. For example, here are the 50 features with largest $\sigma/\mu$:

In [None]:
normed_feature_acts = feature_acts / feature_acts.mean(dim=0, keepdim=True)
_ , sorted_feature_inds = normed_feature_acts.std(0).sort(descending=True)
del normed_feature_acts

print(sorted_feature_inds[:50])

They're all super boring! In fact, anecdotally, the top 1000 or so highest variance features are all boring, activating on one specific token. This is no surprise: roughly speaking, high-variance means a feature doesn't activate much for the most part, but activates really strongly on something very specific.  Here are some features from around the 2000 mark, where things start getting more interesting:

In [None]:
print(sorted_feature_inds[2000:2050])

A large proportion of these features are super nice. For example:

- 2550: Places where things are located
- 319: (Mostly) things you can *spend*, e.g. time and money -- even in some cases where the word "spend" is not itself present, e.g. "risked their| lives|" and "take their| time|".
- 4138: Climbing/high places
- 2895: American football terms
- 6117: Instances of the word "the" which are *followed by* a location. This is more interesting than it seems, since the model can only look backwards! So this feature is really capturing some abstract property of the preceding tokens that implies a location is coming next. This example is also a good reminder to look at the full context of the token when trying to interpret a feature: if you just scanned your eyes down the leftmost column of "the"s (as I did at first), you'd have missed the point!

Here are some indices from around the 3000 mark to play around with

In [None]:
print(sorted_feature_inds[3000:3050])

Some nice ones are:
- 634: Implications (e.g. "A led to B").
- 5109: American football *and* English football (not sure if sports in general or just these two; would need to look at more varied test data to find out).
- 898: Greatest
- 4730: "Blah blah blah, *with* noun *verb*ing", and similar grammatical constructions, usually involving a clause starting with "with".
- 4498: Parts of a castle or dungeon. I'm going out on a limb here, but this seems to be the common theme linking the words "cell", "crypt", "wall", and "battery" (I had to look the last one up - it's part of a castle). "Mob", "crowd", "hex", "victor" and "bomb" all seem like things that might show up a lot in the medieval/fantasy genre, but I might be clutching at straws here. This example is a good case study about how hard it is to interpret abstract features.

From around the 4000 mark (when sorted by $\sigma/\mu$), features seem to get a lot harder to interpret (reminder: this is all totally anecdotal - I've only looked at a few percent of the features!). These are features that tend to fire a small amount across a large amount of text, rather than being peaked on very specific things like certain grammatical constructions or word groups. We might expect that these features are capturing something about the broad context of the text (e.g. "this text is a news article from a website that contains a lot of ↩ tokens"). These features are hard to interpret. Part of the issue might be the relatively small number of test tokens we're using for this analysis: it's hard to tell if a feature is detecting news articles, say, or if there just happen to be a lot of news articles in the test set. Another part of the issue is the small dictionary size: increasing the size (and training on more tokens) would hopefully let the SAE learn sharper features, rather than muddled mixtures of vague properties of the text. It may also be the case that these low-variance features are just fundamentally uninterpretable.


Another way we can try to isolate features that are nice and interpretable is by deliberately eliciting a feature we think ought to exist in the model. For example, say we've just been looking at feature 4369: "developed countries", and we have a hunch that the model should also contain a feature representing "developing countries". We can just run gpt2-small on the text "Ghana, Bangladesh, Syria, Thailand, Nigeria, Vietnam, Yemen", then feed the resulting activations into our SAE and see what features are most activated.

In [None]:
def get_elicited_features(eliciting_text, num_to_show=5, pos=None):
  '''Takes text meant to elicit a certain type of feature.
     Returns list of top features and their corresponding activations.
     If pos is set to some integer, only the activations at that token position are calcucated.
     If pos=None (default), activations are averaged over position.
     '''
  # Get some tokens meant to elicit a certain feature
  eliciting_toks = model.to_tokens(eliciting_text)

  # Run gpt2-small on those tokens, and cache the activations at the point of interest
  _, cache = model.run_with_cache(eliciting_toks)
  eliciting_acts = cache[point, layer][:, 1:, :]  # the indexing here is to get rid of the BOS activation

  # Feed those activations into the SAE and return the most highly activated features (averaged over the batch dimension)
  eliciting_feature_acts = t.relu((eliciting_acts - b_dec) @ W_enc + b_enc)
  if pos==None:
    top_eliciting_feature_acts,  top_eliciting_feature_inds = eliciting_feature_acts.mean(0).mean(0).topk(num_to_show)
  else:
    top_eliciting_feature_acts,  top_eliciting_feature_inds = eliciting_feature_acts[:,pos,:].mean(0).topk(num_to_show)

  return top_eliciting_feature_acts,  top_eliciting_feature_inds

get_elicited_features("Ghana Thailand Nigeria Vietnam Yemen")

Now we go back and plug these feature_id's into the visualisation tool above. 4721 seems to be "lists of (mostly European) countries". Close, but not quite what we wanted. 6044 is the feature we're after: it fires on Somalia, Bangladesh, Syria, Gaza, Azerbaijan, Zimbabwe, Venezuela, and many more countries/territories from the developing world.

Before moving on, here's a list of some other miscellaneous features I've stumbled across:
- 5495: French text, especially proper nouns
- 1:  World leaders
- 4445: Death
- 1927: Feminine words and names
- 1071: Things being done to body parts

In [None]:
get_elicited_features("Apple Attic Ant After Antelope Alexander Article Absent Absynthe Articulate Aardvark Aardwolf Arabia")

# Statistical properties

As well as looking at each feature separately, we can look at some global, statistical properties of the feature dictionary.

Our learnt features are meant to be sparse: a given token should only activate a small number of features. To test, this let's first look at how many features tend to activate on a given token. Here and throughout, we'll also show the corresponding plot with random features, for comparison. First, let's see how many features have activation > 0.2:



In [None]:
# @title Plot number of features with activation > 0.2 on a given token

num_features = (feature_acts[:random_feature_acts.size(0)] > 0.2).type(t.float32).sum(dim=1)
random_num_features = (random_feature_acts > 0.2).type(t.float32).sum(dim=1)

trace1 = go.Histogram(x=num_features.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=6000, size=10) )
trace2 = go.Histogram(x=random_num_features.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=6000, size=10)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Number of features with activation > 0.2 on given token in the test dataset",
    xaxis_title='Number of  activated features',
    yaxis_title='Token count'
)

# Show the plot
fig.show()

This suggests that the learnt features are much sparser than random features, if "sparse" is suitably interpreted. We also expect, though, that when a given feature activates, it should activate pretty strongly. So if we set the activation threshold a bit higher - say at 5.0, then our learnt features ought to activate above this threshold more often than random ones:

In [None]:
# @title Plot number of features with activation > 2.0 on a given token

num_features = (feature_acts[:random_feature_acts.size(0)] > 2).type(t.float32).sum(dim=1)
random_num_features = (random_feature_acts > 2).type(t.float32).sum(dim=1)

trace1 = go.Histogram(x=num_features.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=400, size=1)  )
trace2 = go.Histogram(x=random_num_features.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=400, size=1)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Number of features with activation > 2.0 on given token in the test dataset",
    xaxis_title='Number of activated features',
    yaxis_title='Token count'
)

# Show the plot
fig.show()

(Note: the spike at 0 is an artifact of my lazy way of dealing with BOS activations in this notebook, i.e. simply setting them to 0).

The transition between the two behaviours occurs at an activation of around 1.5, which is therefore a natural order-of-magnitude scale at which to start calling an activation "strong".


Next, we can look at how often our learnt features are nonzero, noting that we'd expect a random feature to activate about half the time:

In [None]:
# @title Plot activation frequencies of features

freqs = (feature_acts[:random_feature_acts.size(0)] > 0).type(t.float32).mean(dim=0)
random_freqs = (random_feature_acts > 0).type(t.float32).mean(dim=0)

trace1 = go.Histogram( x=freqs.detach() , opacity=0.98, name='Learnt features', xbins=dict(start=0, end=1, size=0.005)  )
trace2 = go.Histogram(x=random_freqs.detach() , opacity=0.2, name='Random features', xbins=dict(start=0, end=1, size=0.005)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Proportion of tokens on which feature activates",
    xaxis_title='Proportion of tokens on which feature activates',
    yaxis_title='Feature count'
)

# Show the plot
fig.show()

Along the same lines, we can look at the average feature activations:

In [None]:
# @title Plot average activations of features

avg_act = feature_acts.mean(dim=0)
random_avg_act = random_feature_acts.mean(dim=0)

trace1 = go.Histogram(x=avg_act.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=0.5, size=0.005) )
trace2 = go.Histogram(x=random_avg_act.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=0.5, size=0.005)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Average activation of feature across all tokens",
    xaxis_title='avg activation',
    yaxis_title='Feature count'
)

# Show the plot
fig.show()

The average activations and activation frequencies for our learnt features look pretty unimodal: there's no evidence here of the ultra low density cluster found by Anthropic.

Whilst some of the features we've found are reasonably abstract, we might expect that some features are present in the model because they give a strong, direct prediction of the next token. Let's check this by directly unembedding the learnt features and softmaxing to get a probability distribution. If the probabilities are concentrated on a small number of tokens, the entropy of that distribution will be close to 0, and the maximum probability appearing in the distribution will be close to 1.

In [None]:
# @title Plots

# get average hidden-layer norm of the SAE across many activations.
avg_hidden_norm = feature_acts.norm(dim=-1).mean()

# take a single feature, with the norm found above, and directly unembed it to get probs
probs = model.unembed( model.ln_final (  (avg_hidden_norm * W_dec+b_dec).unsqueeze(1)  )).squeeze().softmax(dim=-1)

# do the same with random features for comparison
random_features = t.randn(W_dec.size())
random_features = avg_hidden_norm * random_features / random_features.norm(dim=-1, keepdim=True) + b_dec
random_probs = model.unembed( model.ln_final (  (random_features+b_dec).unsqueeze(1)  )).squeeze().softmax(dim=-1)

# calculate entropies of the learnt and random features
entropies = - einops.einsum(probs, t.log2(probs), "f tok, f tok -> f")
random_entropies = - einops.einsum(random_probs, t.log2(random_probs), "f tok, f tok -> f")

# calculate maxprobs
maxprobs = probs.max(dim=-1)[0]
random_maxprobs = random_probs.max(dim=-1)[0]

del probs
del random_probs


# PLOT ENTROPIES

trace1 = go.Histogram(x=entropies.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=15, size=0.1)  )
trace2 = go.Histogram(x=random_entropies.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=15, size=0.1)  )

# Create the figure and add the histograms
fig_ent = go.Figure()
fig_ent.add_trace(trace1)
fig_ent.add_trace(trace2)

# Update layout
fig_ent.update_layout(
    barmode='overlay',  # Overlay the histograms
    title=None,
    xaxis_title='Entropy',
    yaxis_title='Feature count',
    legend_title=None
)

# Show the plot
fig_ent.show()



# PLOT MAXPROB

trace1 = go.Histogram(x=maxprobs.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=1, size=0.01)  )
trace2 = go.Histogram(x=random_maxprobs.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=1, size=0.01)  )

# Create the figure and add the histograms
fig_max = go.Figure()
fig_max.add_trace(trace1)
fig_max.add_trace(trace2)

# Update layout
fig_max.update_layout(
    barmode='overlay',  # Overlay the histograms
    title=None,
    xaxis_title='max prob',
    yaxis_title='Feature count',
)

# Show the plot
fig_max.show()

We see that some of our learnt features are directly responsible for next-token predictions, with entropy near 0 and max prob near 1. But a significant proportion of the features are more abstract than that.

# **Analysis at mlp_out 6**

# Feature Interpretation

Now let's do the same thing for mlp_out at layer 6. As before, decrease num_batches if your RAM is limited.

In [None]:
# load encoder weights
point, layer = "mlp_out", 6
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)
W_dec , b_dec, W_enc, b_enc = dic["W_dec"], dic["b_dec"], dic["W_enc"], dic["b_enc"]

# get feature activations
toks, feature_acts, random_feature_acts = get_feature_acts(point, layer, dic)

Let's first look at some features with high average activation:

In [None]:
_ , sorted_feature_inds = feature_acts.mean(dim=0).sort(descending=True)
print("some common features are: \n ", sorted_feature_inds[:50])

As before, just edit the feature_id variable below to explore different features.

In [None]:
#####   For interactive use! Set feature_id to be an integer from 0 to 6143 #####
feature_id = 5217
feature_analysis(feature_id, toks, feature_acts, len_prefix = 30, len_suffix = 2, num_rows = 30) # reduce len_prefix to fit better on screen if needed

Compared to resid_pre, there seem to be fewer "obvious" features, like "words that begin with a capital letter" or other trivial grammatical features. To find interesting features, we can again sort by the variance of a feature's activation (divided by its mean activation).

Here are features 1-50 when sorted from high to low variance:

In [None]:
normed_feature_acts = feature_acts / feature_acts.mean(dim=0, keepdim=True)
_ , sorted_feature_inds = normed_feature_acts.std(0).sort(descending=True)
del normed_feature_acts

print(sorted_feature_inds[:50])

As before, they're boring. Here are features 1000-1050:

In [None]:
print(sorted_feature_inds[1000:1050])

Getting better. There seem to be a lot of synonyms, and common short phrases, which are a bit dull:
- 4979: "In [some location]"
- 6027: "Everything"/"everyone"/"all"
- 9: "At the time"/"as things stand"/"as of..."
- 6130: "Who"/"which"/"whom"/"that" - more generally, tokens that follow a proper noun that has just been introduced, and is about to be described.

But we also get the occasional fun feature:
- 4412: Fantasy/sci-fi/fiction.

Here are features 2000-2050:

In [None]:
print(sorted_feature_inds[2000:2050])

There are some nice ones here:

- 4432: Government/public policy
- 2705: Ranges of times or quantities, usually in the form "A to B", e.g. "three to four months".
- 2946: Temporal adverbials, often in the form "for X", e.g. "for several weeks" or "for the first time". This feature also activates on some other phrases like "worth a thought" or "worth every penny" - I can't tell if these are somehow grammatically related.
- 5306: Dates followed by a comma, e.g. "In the summer of 2001,".
- 1749: Secrets/classified information/lack of knowledge. Some other stuff mixed in, but mostly along these lines.
- 5598: Football (soccer)
- 5484: Phrases denoting uncertain belief, e.g. "rumored to", "alleged to", "appear to"

Here are some features from around the 3000 mark:


In [None]:
print(sorted_feature_inds[3000:3050])

Nice ones include:

- 569: War/military
- 1911: Titles, especially of books
- 3662: I'll leave this one as a puzzle. It's a great example of a feature that may seem uninterpretable at first, but whose meaning becomes clear after a bit of head-scratching.
- 5062: Repeated words/phrases, especially used informally/as a rhetorical device/got emphasis. E.g. "what works for you may not work for someone else", and "it may not be good for America, but it's damn good for CBS".

Let's quickly run a check on that last one. If I had to come up with a sentence that would activate 5062 as much as possible, I'd probably go with the famous JFK quote: "Ask not what your country can do for you, but what you can do for your country". And indeed, we see that 5062 has a very high activation of 2.29 on the fourth-from-last token (" do"):

In [None]:
get_elicited_features("Ask not what your country can do for you, but what you can do for your country",  pos=-4)

In fact, the only four features that activate more than 5062 on the " do" token are all just annoying features that activate near the start of the input (perhaps I could've handled BOS tokens differently to avoid such features). Feature 5062 seems to behave as expected.

# Statistical properties

Let's plot the same graphs as we did for resid_pre:



In [None]:
# @title Plot number of features with activation > 0.1 on a given token

num_features = (feature_acts[:random_feature_acts.size(0)] > 0.1).type(t.float32).sum(dim=1)
random_num_features = (random_feature_acts > 0.1).type(t.float32).sum(dim=1)

trace1 = go.Histogram(x=num_features.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=6000, size=1) )
trace2 = go.Histogram(x=random_num_features.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=6000, size=1)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Number of features with activation > 0.1 on given token in the test dataset",
    xaxis_title='Number of  activated features',
    yaxis_title='Token count'
)

# Show the plot
fig.show()

In [None]:
# @title Plot number of features with activation > 0.5 on a given token

num_features = (feature_acts[:random_feature_acts.size(0)] > 0.5).type(t.float32).sum(dim=1)
random_num_features = (random_feature_acts > 0.5).type(t.float32).sum(dim=1)

trace1 = go.Histogram(x=num_features.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=200, size=1)  )
trace2 = go.Histogram(x=random_num_features.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=200, size=1)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Number of features with activation > 0.5 on given token in the test dataset",
    xaxis_title='Number of activated features',
    yaxis_title='Token count'
)

# Show the plot
fig.show()

The above plots suggest a rough scale where we start calling an activation "large" is somewhere between 0.1 and 0.5.

In [None]:
# @title Plot activation frequencies of features

freqs = (feature_acts[:random_feature_acts.size(0)] > 0).type(t.float32).mean(dim=0)
random_freqs = (random_feature_acts > 0).type(t.float32).mean(dim=0)

trace1 = go.Histogram(x=freqs.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=1, size=0.005)  )
trace2 = go.Histogram(x=random_freqs.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=1, size=0.005)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Proportion of tokens on which feature activates",
    xaxis_title='Proportion of tokens on which feature activates',
    yaxis_title='Feature count'
)

# Show the plot
fig.show()

In [None]:
# @title Plot average activations of features

avg_act = feature_acts.mean(dim=0)
random_avg_act = random_feature_acts.mean(dim=0)

trace1 = go.Histogram(x=avg_act.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=0.5, size=0.002) )
trace2 = go.Histogram(x=random_avg_act.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=0.5, size=0.002)  )

# Create the figure and add the histograms
fig = go.Figure()
fig.add_trace(trace1)
fig.add_trace(trace2)

# Update layout for better visualization
fig.update_layout(
    barmode='overlay',  # Overlay the histograms
    title="Average activation of feature across all tokens",
    xaxis_title='avg activation',
    yaxis_title='Feature count'
)

# Show the plot
fig.show()

These distributions look vaguely bimodal, I guess?

In [None]:
# @title Plots

# get average hidden-layer norm of the SAE across many activations.
avg_hidden_norm = feature_acts.norm(dim=-1).mean()

# take a single feature, with the norm found above, and directly unembed it to get probs
probs = model.unembed( model.ln_final (  (avg_hidden_norm * W_dec+b_dec).unsqueeze(1)  )).squeeze().softmax(dim=-1)

# do the same with random features for comparison
random_features = t.randn(W_dec.size())
random_features = avg_hidden_norm * random_features / random_features.norm(dim=-1, keepdim=True) + b_dec
random_probs = model.unembed( model.ln_final (  (random_features+b_dec).unsqueeze(1)  )).squeeze().softmax(dim=-1)

# calculate entropies of the learnt and random features
entropies = - einops.einsum(probs, t.log2(probs), "f tok, f tok -> f")
random_entropies = - einops.einsum(random_probs, t.log2(random_probs), "f tok, f tok -> f")

# calculate maxprobs
maxprobs = probs.max(dim=-1)[0]
random_maxprobs = random_probs.max(dim=-1)[0]

del probs
del random_probs


# PLOT ENTROPIES

trace1 = go.Histogram(x=entropies.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=15, size=0.1)  )
trace2 = go.Histogram(x=random_entropies.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=15, size=0.1)  )

# Create the figure and add the histograms
fig_ent = go.Figure()
fig_ent.add_trace(trace1)
fig_ent.add_trace(trace2)

# Update layout
fig_ent.update_layout(
    barmode='overlay',  # Overlay the histograms
    title=None,
    xaxis_title='Entropy',
    yaxis_title='Feature count',
    legend_title=None
)

# Show the plot
fig_ent.show()



# PLOT MAXPROB

trace1 = go.Histogram(x=maxprobs.detach(), opacity=0.98, name='Learnt features', xbins=dict(start=0, end=1, size=0.01)  )
trace2 = go.Histogram(x=random_maxprobs.detach(), opacity=0.2, name='Random features', xbins=dict(start=0, end=1, size=0.01)  )

# Create the figure and add the histograms
fig_max = go.Figure()
fig_max.add_trace(trace1)
fig_max.add_trace(trace2)

# Update layout
fig_max.update_layout(
    barmode='overlay',  # Overlay the histograms
    title=None,
    xaxis_title='max prob',
    yaxis_title='Feature count',
)

# Show the plot
fig_max.show()

# Outlook

Most of the features learnt by the SAE are not as clean and monosemantic as the cherry-picked examples above. There are two obvious improvements to be made: make the SAEs bigger, and train them on more tokens.

Hopefully I'll get around to doing this soon, with a beefier GPU than a V100. But in the meantime, I think there are some interesting avenues to explore with the feature dictionaries we already have. I'll conclude by just listing with a few ideas I find interesting:

1) My intuition is that MLPs look at old features and do computation with them to produce new ones. So I'd expect something like "{features at resid_pre, n+1}  = {features at resid_pre, n} $\cup$ {features at mlp_out, n}". Obviously, this isn't true as written (since each of those sets has the same size, to name just one reason), but can we check whether something along these lines holds?

2) Having feature dictionaries should allow us to be far more surgical with activation patching. Rather than patching in/corrupting an entire residual stream activation, we could just patch in the projection along certain feature directions.

3) Similarly, we can be more surgical with attribution patching, measuring the effect of varying an upstream feature component on some downstream feature activation.

4) I'm most excited about automating ideas 2) and 3), along the lines of Automated Circuit DisCovery ([Conmy et al](https://arxiv.org/abs/2304.14997),  [Syed-Rager-Conmy](https://arxiv.org/abs/2310.10348)). In its current form, ACDC tells us how information flows through the network when it does a given task, but not what information is flowing. On the other hand, feature dictionaries tell us what information exists in the network but not how it flows. It seems natural to combine the ideas. This strikes me as a promising direction towards full start-to-finish automation of mech interp research.