<a href="https://colab.research.google.com/github/favalosdev/feature-analysis/blob/main/feature_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports & Installs

In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-vis==0.2.14
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import torch
from datasets import load_dataset
import webbrowser
import os
from transformer_lens import utils, HookedTransformer
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
import time
import numpy as np

# New import
%pip install einops
from einops import einsum

# Library imports
from sae_vis.utils_fns import get_device
from sae_vis.model_fns import AutoEncoder
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig
# from sae_lens.training.sparse_autoencoder import SparseAutoencoder

import random
import gc
import zipfile

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

# Visualisation & chart-making
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns

# Import collection
from collections import Counter, defaultdict
from datetime import datetime

device = get_device()
torch.set_grad_enabled(False)

Collecting sae-vis==0.2.14
  Downloading sae_vis-0.2.14-py3-none-any.whl.metadata (4.1 kB)
Collecting einops (from sae-vis==0.2.14)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting eindex-callum (from sae-vis==0.2.14)
  Downloading eindex_callum-0.1.2-py3-none-any.whl.metadata (377 bytes)
Collecting transformer-lens (from sae-vis==0.2.14)
  Downloading transformer_lens-2.3.0-py3-none-any.whl.metadata (12 kB)
Collecting datasets (from sae-vis==0.2.14)
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting dataclasses-json (from sae-vis==0.2.14)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting jaxtyping (from sae-vis==0.2.14)
  Downloading jaxtyping-0.2.33-py3-none-any.whl.metadata (6.4 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json->sae-vis==0.2.14)
  Downloading marshmallow-3.21.3-py3-none-any.whl.metadata (7.1 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json->sae-vis==0.2.14)
  

<torch.autograd.grad_mode.set_grad_enabled at 0x7fd523b8a5f0>

## Visualisation interface

In [None]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

# Setup

## Sparse Autoencoders

In [None]:
encoder = AutoEncoder.load_from_hf(version="run1").to(device)
# encoder_B = AutoEncoder.load_from_hf(version="run2").to(device)

for k, v in encoder.named_parameters():
    print(f"{k}: {tuple(v.shape)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


25.pt:   0%|          | 0.00/269M [00:00<?, ?B/s]

W_enc: (2048, 16384)
W_dec: (16384, 2048)
b_enc: (16384,)
b_dec: (2048,)


## Model

<!-- This library supports non-transformerlens models, provided you apply a wrapper around your model with a few specific methods (e.g. a modified `forward` function which returns a tuple of `(logits, activations, resid)`). However, it's much easier to just use a TransformerLens model in most cases! -->

<!-- The code below loads in our GELU-1l transformer model. You can create your transformer model any way you like; all that matters is that:

* Your model has a `forward` method which takes `tokens` and returns a tuple of `(logits, residual, post_activations)`.
* This forward method has a parameter `return_logits`, which is by default `True`, and when `False` it only returns `(residual, post_activations)`.

Provided this is the case, all other code here (including calculating the effect of ablating certain features) doesn't rely on any specific implementation details of the model.

If you're trying to use a particular model, we recommend **creating a wrapper class around your model which has an altered `forward` method** to match the required behaviour. In the case of this notebook, to make it clear that a `HookedTransformer` model is not necessary, we're using a `DemoTransformer` model (code in this repository), which is a very minimal version of the `HookedTransformer` model lacking the features like hooks, caches, etc. -->

In [None]:
model = HookedTransformer.from_pretrained("gelu-1l")
model.to(device);

config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

model_final.pth:   0%|          | 0.00/213M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.04M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  cuda


## Data


In [None]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

DATASET_LEN = len(all_tokens)

print(all_tokens.shape)

Downloading readme:   0%|          | 0.00/754 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/20000 [00:00<?, ? examples/s]

torch.Size([215402, 128])


# Utils

In [None]:
def clean_cache():
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
def generate_prompt_centric_visualisation(prompt: str, filename_prefix: str, feature_idxs: list[int] | int):
    sae_vis_config = SaeVisConfig(
        hook_point=utils.get_act_name("post", 0),
        features=feature_idxs,
        batch_size=128,
        verbose=True,
    )

    # Gather the feature data
    sae_vis_data = SaeVisData.create(
        encoder=encoder,
        model=model,
        tokens=all_tokens,
        cfg=sae_vis_config,
    )

    # seq_pos = model.tokenizer.tokenize(prompt).index("Ġ('") # type: ignore
    metric = 'act-quantiles'
    filename = filename_prefix + '.html'

    sae_vis_data.save_prompt_centric_vis(
        prompt=prompt,
        filename=filename,
    )

    # Always clean after ourselves
    clean_cache()

In [None]:
def generate_prompt_centric_visualisations(prompt: str, run_prefix: str):
    n_ranges = 8
    range_size = encoder.W_dec.shape[0] // n_ranges
    feature_ranges = [range(i*range_size, (i+1)*range_size) for i in range(n_ranges)]

    generated_files = []

    for i, feature_range in enumerate(feature_ranges, start=1):
        filename = f"{run_prefix}_{i}.html"
        generate_prompt_centric_visualisation(prompt, filename, feature_idxs=list(feature_range))
        generated_files.append(filename)

    zip_filename = f"{run_prefix}.zip"

    with zipfile.ZipFile(zip_filename, 'w') as zipf:
        for f in generated_files:
            zipf.write(f)
            os.remove(f)

    print(f"All visualizations have been packed into {zip_filename}")
    clean_cache()

In [None]:
def generate_feature_centric_visualisation(filename_prefix: str, feature_idxs: int | list[int]):
    sae_vis_config = SaeVisConfig(
        hook_point=utils.get_act_name("post", 0),
        features=feature_idxs,
        verbose=True,
    )

    random_token_indices = random.sample(range(len(all_tokens)), 16384)
    selected_tokens = all_tokens[random_token_indices]

    sae_vis_data = SaeVisData.create(
        encoder=encoder,
        model=model,
        tokens=selected_tokens,
        cfg=sae_vis_config,
    )

    filename = filename_prefix + '.html'
    sae_vis_data.save_feature_centric_vis(filename)

    display_vis_inline(filename)
    clean_cache()

In [None]:
random_token_indices = random.sample(range(len(all_tokens)), 4096)
selected_tokens = [all_tokens[i] for i in random_token_indices]
print(selected_tokens)

[tensor([    1,    27,   985,   210,   347, 34652,   985,  2769,   543,  3185,
         1724, 20560,    16, 14052,    65,  4352,  1551,  4352,    65,    27,
          985,   210,   347,    42,  2625,   985,  2769,   543,  3185,   534,
         1724, 20560,    16,  3079,    65,  9783,  1551, 45031,    65, 11400,
           65,   252,    65,  7384,   985,  5489,    11,  1800,    54,  8337,
         1724, 20560,    16, 14052,    65,  4352,    10,   347,  4352,    65,
           18,   985,   342,   347,  6618,   985,  4216, 15205, 45031,  6131,
          328,  4244,  9431,  2312,  1724, 20560,    16, 14052,    65,  4352,
         1551,  4352,    65,    19,   985,   210,   347, 26338,    65, 25390,
          985,   287,  8379,     4, 38305,    65, 11400,    65,  1190,   307,
           65,   252,    65,  9783,     4,  9252,  2312,  1724, 20560,    16,
        14052,    65,  4352,  1551,  4352,    65,    20,   985,   210,   347,
        13318,  4620,   985,  2962,   543,   495,    24,    18]

# Detailed investigation of features

## Feature mining

Generate prompt-centric visualisations for identifying promising features. Prompts were written in a way they could elicit features related to code and grammar structures.

In [None]:
generate_prompt_centric_visualisations(
    prompt='''
    import random
    N = 10
    xs = [random.randint(1, 100) for _ in range(N)]

    for x in xs:
        print(x)''',
    run_prefix='code_feats'
)

In [None]:
generate_prompt_centric_visualisations(
    prompt='''
    "Hello, world!" she exclaimed. "How are you today? I'm feeling great; the sun is shining, birds are singing, and everything seems perfect. Isn't life wonderful? But wait - what's that noise? Could it be... rain?"
    ''',
    run_prefix='punctuation_feats')

After inspecting the visualisations, the following feature candidates were identified:

In [None]:
full_stop_feat_idx = 12896
for_feat_idx = 1738
else_feat_idx = 2366
conjunction_feat_idx = 5407

## The "full stop" feature

In [None]:
generate_feature_centric_visualisation(filename_prefix='full_stop', feature_idxs=full_stop_feat_idx)

Forward passes to cache data for vis:   0%|          | 0/256 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Serving files from /content on port 8004


## The "for" detector feature

In [None]:
generate_feature_centric_visualisation(filename_prefix='for', feature_idxs=for_feat_idx)

Forward passes to cache data for vis:   0%|          | 0/256 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Serving files from /content on port 8005


## The "else" detector feature

In [None]:
generate_feature_centric_visualisation(filename_prefix='else', feature_idxs=else_feat_idx)

Forward passes to cache data for vis:   0%|          | 0/256 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Serving files from /content on port 8006


# The conjunction feature

In [None]:
generate_feature_centric_visualisation(filename_prefix='conjunction', feature_idxs=conjunction_feat_idx)

Forward passes to cache data for vis:   0%|          | 0/256 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Serving files from /content on port 8007


# Cemetery (FAILURE)

## Dictionary vector weights

In [None]:
'''
PROCEDURE:
Most correlated features over a set of 40 million dataset examples
'''

def plot_dictionary_vector_weights(proxy):
    samples = find_top_k_samples(proxy)

    def hook(value, _):
        pass

    model.run_with_hooks(samples, return_type=None, fwd_hooks=[(MLP_ACTIVATIONS_LAYER, hook)], reset_hooks_end=True, clear_contexts=True)
    clean_cache()

## Correlation chart

### Feature activations

In [None]:
def plot_activations_correlation(acts1, acts2):
    pass

### Logit weights

In [None]:
def plot_logit_weights_correlation(weights1, weights2):
    pass

## Activation density (DEPRECATED)

In [None]:
'''
gen_batches = lambda batch_size: [batch['tokens'] for batch in tokenized_data.iter(batch_size=batch_size)]
BATCHES = gen_batches(500)
def calc_activation_levels(feature_idx):
    activation_levels = []

    def accumulate_activations(value, hook):
        f = calc_features(value)
        feature_acts = einsum(f[:,:, feature_idx], 'bc->b')
        activation_levels.extend(feature_acts.cpu().tolist())
        return value

    model.add_hook(MLP_ACTIVATIONS_LAYER, accumulate_activations)

    for batch in BATCHES:
        model.forward(batch, return_type=None)
        clean_cache()

    model.reset_hooks()
    clean_cache()

    return activation_levels
TEST_DIST = np.random.uniform(0, 10, size=all_tokens.shape[0])
def aggregate_activations(activation_levels, calc_proxy):
    annotated = [(level, calc_proxy(sample)) for level, sample in zip(activation_levels, all_tokens.tolist())]
    proxy_values = list(map(lambda x: x[1], annotated))
    min_value = min(proxy_values)
    max_value = max(proxy_values)
    bound = max(abs(min_value), abs(max_value))

    num_classes = 7
    bins = np.linspace(-bound, bound, num_classes+1)
    indices = np.digitize(list(map(lambda x: x[1], annotated)), bins)

    levels = {(level_idx-1): [] for level_idx in indices}

    for (value, _), level_idx in zip(annotated, indices):
        levels[level_idx-1].append(value)

    stacked_levels = list(reversed([level for level in levels.values()]))
    return stacked_levels, bins
def plot_activation_levels(levels, bins, title, legend_title, x_label, y_label):
    plt.close('all')  # Close all figures
    plt.clf()

    fig, ax = plt.subplots(figsize=(10, 6))
    colors = ['#800000', '#FF4040', '#FFC0C0', '#D5DAE0', '#C0C0FF', '#4040FF', '#000080']
    n, _, patches = ax.hist(levels, bins='auto', stacked=True, color=colors)

    legend_labels = [f'{bins[i]:.2f} to {bins[i+1]:.2f}' for i in range(len(levels))]
    legend_labels = legend_labels[::-1]  # Reverse order to match stacked_classes
    ax.legend(patches, legend_labels, title=legend_title, loc='center left', bbox_to_anchor=(1, 0.5))

    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)

    # Set x-ticks to whole numbers
    x_min, x_max = ax.get_xlim()
    x_ticks = np.arange(int(np.floor(x_min)), int(np.ceil(x_max)) + 1)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_ticks)

    plt.tight_layout()
    plt.show()
levels, bins = aggregate_activations(TEST_DIST, lambda x: random.uniform(-5, 5))
plot_activation_levels(
    levels,
    bins,
    title='Feature activation distribution',
    legend_title='Proxy',
    x_label='Feature activation level',
    y_label='Density'
)
'''

##  Definition of methods for experiments

In [None]:
MLP_ACTIVATIONS_LAYER = utils.get_act_name("post", 0)

In [None]:
encoder.eval()

In [None]:
def calc_features(mlp_activations):
    x_hat = mlp_activations - encoder.b_dec
    encoded = x_hat @ encoder.W_enc + encoder.b_enc
    f = torch.nn.functional.relu(encoded)
    return f

## Expected value of activations

## Logit weight distribution

In [None]:
def calc_logit_weight_distribution(feature_idx):
    direction = encoder.W_dec[feature_idx, :]
    W_down = model.W_out[0]
    x = direction @ W_down
    eps=1e-5
    normalised =  (x - torch.mean(x)) / (torch.std(x) + eps).item()
    contribution = normalised @ model.W_U
    dist = torch.nn.functional.softmax(contribution, dim=0)
    shifted_dist = dist - dist.mean(dim=0)
    return shifted_dist