# SAE Quality of Life Improvements

Add config to wandb run. 
Start saving the SAE weights mid-run. 


In [189]:
# Autoreload
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [190]:
from sparse_autoencoder import TensorActivationStore, SparseAutoencoder, pipeline
from sparse_autoencoder.source_data.pile_uncopyrighted import PileUncopyrightedDataset
from sparse_autoencoder.train.sweep_config import SweepParametersRuntime
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device, test_prompt
from transformers import PreTrainedTokenizerBase
import torch

device = get_device()

## Get Model

In [191]:
src_model = HookedTransformer.from_pretrained(
    "tiny-stories-instruct-1M", dtype="float32"
)

# test the model
example_prompt = """
Once upon a time, there lived a black cat. The cat belonged to a little girl called Katie. Every day, Katie
would take her cat for a walk in the park.
One day, as Katie and her cat were walking around, they saw a mean looking man. He said he wanted to
take the cat, to which she replied ”This cat belongs to
"""
example_answer = " me"


test_prompt(example_prompt, example_answer, src_model, prepend_bos=True)

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-instruct-1M into HookedTransformer
Tokenized prompt: ['<|endoftext|>', '\n', 'Once', ' upon', ' a', ' time', ',', ' there', ' lived', ' a', ' black', ' cat', '.', ' The', ' cat', ' belonged', ' to', ' a', ' little', ' girl', ' called', ' Katie', '.', ' Every', ' day', ',', ' Katie', '\n', 'would', ' take', ' her', ' cat', ' for', ' a', ' walk', ' in', ' the', ' park', '.', '\n', 'One', ' day', ',', ' as', ' Katie', ' and', ' her', ' cat', ' were', ' walking', ' around', ',', ' they', ' saw', ' a', ' mean', ' looking', ' man', '.', ' He', ' said', ' he', ' wanted', ' to', '\n', 'take', ' the', ' cat', ',', ' to', ' which', ' she', ' replied', ' �', '�', 'This', ' cat', ' belongs', ' to', '\n']
Tokenized answer: [' me']


Top 0th token. Logit: 20.20 Prob: 72.30% Token: |
|
Top 1th token. Logit: 18.09 Prob:  8.78% Token: |Kat|
Top 2th token. Logit: 17.37 Prob:  4.28% Token: |Summary|
Top 3th token. Logit: 17.29 Prob:  3.94% Token: |<|endoftext|>|
Top 4th token. Logit: 16.76 Prob:  2.33% Token: |The|
Top 5th token. Logit: 15.99 Prob:  1.08% Token: |John|
Top 6th token. Logit: 15.49 Prob:  0.65% Token: |"|
Top 7th token. Logit: 15.31 Prob:  0.55% Token: |She|
Top 8th token. Logit: 14.87 Prob:  0.35% Token: |Story|
Top 9th token. Logit: 14.57 Prob:  0.26% Token: |G|


# To train on Tiny Stories, we're going to need the tiny stories dataset

In [192]:
from typing import TypedDict, final
from sparse_autoencoder.source_data.abstract_dataset import (
    SourceDataset,
    TokenizedPrompts,
)


class TinyStoriesSourceDataBatch(TypedDict):
    """Pile Uncopyrighted Source Data.

    https://huggingface.co/datasets/roneneldan/TinyStories
    """

    text: list[str]
    meta: list[dict[str, dict[str, str]]]


@final
class TinyStoriesDataset(SourceDataset[TinyStoriesSourceDataBatch]):
    """Tiny Stories Dataset.

    https://huggingface.co/datasets/roneneldan/TinyStories
    """

    tokenizer: PreTrainedTokenizerBase

    def preprocess(
        self,
        source_batch: TinyStoriesSourceDataBatch,
        *,
        context_size: int,
    ) -> TokenizedPrompts:
        """Preprocess a batch of prompts.

        For each prompt's `text`, tokenize it and chunk into a list of tokenized prompts of length
        `context_size`. For the last item in the chunk, throw it away if the length is less than
        `context_size` (i.e. if it would otherwise require padding). Then finally flatten all
        batches to a single list of tokenized prompts.

        Args:
            source_batch: A batch of source data. For example, with The Pile dataset this would be a
                dict including the key "text" with a value of a list of strings (not yet tokenized).
            context_size: The context size to use when returning a list of tokenized prompts.
        """
        prompts: list[str] = source_batch["text"]

        tokenized_prompts = self.tokenizer(prompts)

        # Chunk each tokenized prompt into blocks of context_size, discarding the last block if too
        # small.
        context_size_prompts = []
        for encoding in list(tokenized_prompts["input_ids"]):  # type: ignore
            chunks = [
                encoding[i : i + context_size]
                for i in range(0, len(encoding), context_size)
                if len(encoding[i : i + context_size]) == context_size
            ]
            context_size_prompts.extend(chunks)

        return {"input_ids": context_size_prompts}

    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        context_size: int = 250,
        buffer_size: int = 1000,
        preprocess_batch_size: int = 1000,
        dataset_path: str = "roneneldan/TinyStories",
        dataset_split: str = "train",
    ):
        self.tokenizer = tokenizer

        super().__init__(
            dataset_path=dataset_path,
            dataset_split=dataset_split,
            context_size=context_size,
            buffer_size=buffer_size,
            preprocess_batch_size=preprocess_batch_size,
        )

# Training an AutoEncoder for Tiny Stories

In [193]:
import wandb


src_model = HookedTransformer.from_pretrained(
    "tiny-stories-instruct-1M", dtype="float32"
)
src_d_model: int = src_model.cfg.d_model  # type: ignore

tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
max_items = 2_000_000
store = TensorActivationStore(max_items, src_d_model, device)

# Make Autoencoder
src_model_activation_hook_point = "blocks.0.hook_resid_pre"
autoencoder = SparseAutoencoder(src_d_model, src_d_model * 8, torch.zeros(src_d_model))
autoencoder.to(device)

# Make Source Data
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
source_data = TinyStoriesDataset(tokenizer=tokenizer)

# hyper parameter
max_activations = 30 * max_items


sweep_config = SweepParametersRuntime(
    lr=1e-3,
    batch_size=2048,
    l1_coefficient=1e-3,
)

wandb.init(
    project="sparse-autoencoder", dir=".cache/wandb", config=sweep_config.__dict__
)

pipeline(
    src_model=src_model,
    src_model_activation_hook_point=src_model_activation_hook_point,
    src_model_activation_layer=0,  # why do we need to specify this as well?
    source_dataset=source_data,
    activation_store=store,
    num_activations_before_training=max_items,
    sweep_parameters=sweep_config,
    log_artifacts=True,
    autoencoder=autoencoder,
    device=device,
    max_activations=max_activations,
)

wandb.finish()

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-instruct-1M into HookedTransformer


Repo card metadata block was not found. Setting CardData to empty.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjbloom[0m. Use [1m`wandb login --relogin`[0m to force relogin


Total activations trained on:   0%|          | 0/60000000 [00:00<?, ?it/s, Current mode=initializing]

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

In [195]:
# write the autoencoder to a .pt file
# torch.save(autoencoder.state_dict(), "autoencoder.pt")

# Later, load this autoencoder
autoencoder = SparseAutoencoder(src_d_model, src_d_model * 8, torch.zeros(src_d_model))
autoencoder.load_state_dict(torch.load("autoencoder.pt"))

<All keys matched successfully>

# Some analysis

In [196]:
import plotly.express as px

encoder_weights = autoencoder.encoder.Linear.weight.T.detach().cpu()
print(encoder_weights.shape)
centred_encoder_weights = encoder_weights - encoder_weights.mean(dim=0)
px.bar(encoder_weights.norm(dim=0))

torch.Size([64, 512])


In [258]:
# # get a cosine similarity matrix
import torch.nn.functional as F
from scipy.cluster import hierarchy
import numpy as np
import pandas as pd


def get_cosine_sim_heatmap(centred_weights):
    data_array = F.cosine_similarity(
        centred_weights.T.unsqueeze(1), centred_weights.T.unsqueeze(0), dim=2
    )
    df = pd.DataFrame(data_array.numpy())

    linkage = hierarchy.linkage(data_array)
    dendrogram = hierarchy.dendrogram(linkage, no_plot=True, color_threshold=-np.inf)
    reordered_ind = dendrogram["leaves"]
    # reorder df by ind
    df = df.iloc[reordered_ind, reordered_ind]
    data_array = df.to_numpy()
    fig = px.imshow(
        data_array, color_continuous_scale="RdBu", color_continuous_midpoint=0
    )
    return fig


# fig = get_cosine_sim_heatmap(centred_weights)
# fig.show()
# fig.write_html("cosine_sim.html")

In [197]:
decoder_weights = autoencoder.decoder[0].weight.detach().cpu()
decoder_weights.shape

centred_decoder_weights = decoder_weights - decoder_weights.mean(dim=0)
# px.bar(decoder_weights.norm(dim=0))

In [None]:
# fig = get_cosine_sim_heatmap(centred_weights)
# fig.show()

# Look at token intersection with encoder weights

In [53]:
centred_embedding = src_model.W_E.cpu() - src_model.W_E.cpu().mean(dim=0)
centred_embedding = centred_embedding / centred_embedding.norm(dim=1).unsqueeze(1)
token_alignment = centred_embedding @ encoder_weights.cpu()
token_alignment = token_alignment.T.detach()
token_alignment.shape

torch.Size([512, 50257])

In [52]:
src_model.W_E.norm(dim=1)

tensor([0.6388, 0.7038, 0.5022,  ..., 0.3607, 0.6348, 0.6243], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)

In [76]:
# token

tokenizer.vocab["cat"]

9246

In [None]:
# px.bar(token_alignment.norm(dim=1))


token_strings = tokenizer.convert_ids_to_tokens(list(tokenizer.vocab.values()))
# sort the strings by the keys in tokenizer vocab
token_strings = [x for _, x in sorted(zip(tokenizer.vocab.keys(), token_strings))]
[i for i in token_strings if "cat" in i]

In [86]:
import pandas as pd 
df = pd.DataFrame(
    {
        "token": token_strings,
        "projection": token_alignment[196],
    }
)
df.sort_values("projection", ascending=False).head(10)

Unnamed: 0,token,projection
32435,Ġbrains,0.242673
28163,ĠSites,0.242108
30252,ĠZe,0.233666
22025,ĠForbes,0.232709
11335,izing,0.229393
8108,compatible,0.228056
6234,age,0.220957
27623,ĠSV,0.219029
17579,Ġ269,0.217539
855,409,0.213913


## Analysis: Try to interpret a feature



In [199]:
src_model.reset_hooks()

In [248]:
# let's run a forward pass and look at the autoencoder

# example_prompt = """
# Once upon a time, there lived a black cat. The cat belonged to a little girl called Katie. Every day, Katie
# would take her cat for a walk in the park.
# One day, as Katie and her cat were walking around, they saw a mean looking man. He said he wanted to
# take the cat, to which she replied ”This cat belongs to
# """

example_prompts = [
    " cat", " kitten", " cats", " dog",  " dogs",  " puppy", " man", " boy", 
    " girl", " woman", " grandma", " grandpa", " aunt", " uncle",
    " old", " young", " male", " female"
    # "bear", "cloud", "dragon", "frog", "grape", "house", "island", "jungle",
    # "kangaroo", "leaf", "mushroom", "night", "ocean", "panda", "quilt", "rainbow",
    # "star", "turtle", "unicorn", "volcano", "wolf", "fox", "yarn", "zeppelin"
]


# Prepare a list to collect the data
data = []

for prompt in example_prompts:
    tokens = src_model.to_tokens(prompt)

    logits, cache = src_model.run_with_cache(tokens)
    activations = cache[src_model_activation_hook_point]

    autoencoder.eval()
    learned_activations, decoded_activations = autoencoder(activations.cpu())

    topk = 10
    topk_activations, topk_indices = torch.topk(learned_activations[0,-1], topk)

    # Collect data in the list
    for activation, index in zip(topk_activations, topk_indices):
        data.append({
            "Prompt": prompt,
            "Activation": activation.item(),
            "Index": index.item()
        })

# Create DataFrame from the collected data
results_df = pd.DataFrame(data)

# Display the DataFrame
print(results_df[results_df.Prompt == " cat"])

# ok let's see if we can work out what 311 is doing. 

# df = pd.DataFrame(
#     {
#         "token": token_strings,
#         "projection": token_alignment[311],
#     }
# )
# df.sort_values("projection", ascending=False).head(40)

  Prompt  Activation  Index
0    cat    0.350422    499
1    cat    0.278190    464
2    cat    0.212369    153
3    cat    0.128653    185
4    cat    0.127985    212
5    cat    0.101748    335
6    cat    0.093093    381
7    cat    0.067556    415
8    cat    0.052384     53
9    cat    0.022608    306


In [249]:
results_df[results_df.Index == 196].sort_values("Activation", ascending=False).head(10)

Unnamed: 0,Prompt,Activation,Index


In [250]:
wide_df = results_df.pivot_table(index='Prompt', columns='Index', values='Activation', fill_value=0).T
# wide_df.reset_index(inplace=True)

# fig = px.scatter(wide_df, x=" dog", y=" puppy", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

# fig = px.scatter(wide_df, x=" cat", y=" kitten", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

# fig = px.scatter(wide_df, x=" man", y=" boy", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

In [251]:
fig = px.scatter(wide_df, x=" girl", y=" boy", hover_name=wide_df.index, text=wide_df.index)
fig.show()

fig = px.scatter(wide_df, x=" woman", y=" man", hover_name=wide_df.index, text=wide_df.index)
fig.show()

fig = px.scatter(wide_df, x=" aunt", y=" uncle", hover_name=wide_df.index, text=wide_df.index)
fig.show()

In [252]:
fig = px.scatter(wide_df, x=" grandma", y=" grandpa", hover_name=wide_df.index, text=wide_df.index)
fig.show()

In [253]:
fig = px.scatter(wide_df, x=" old", y=" grandpa", hover_name=wide_df.index, text=wide_df.index)
fig.show()

In [254]:
fig = px.scatter(wide_df, x=" old", y=" young", hover_name=wide_df.index, text=wide_df.index)
fig.show()

In [264]:
# wide_df.T.sort_values(464)

get_cosine_sim_heatmap(torch.tensor(wide_df.T.to_numpy()))

In [233]:
# fig = px.scatter(wide_df, x=" cat", y=" cats", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

# fig = px.scatter(wide_df, x=" dog", y=" dogs", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

# fig = px.scatter(wide_df, x=" cat", y=" cats", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

# fig = px.scatter(wide_df, x=" dog", y=" dogs", hover_name=wide_df.index, text=wide_df.index)
# fig.show()

In [234]:
reference_column = " cat"
df = wide_df
# Calculate the difference
diff_df = df.sub(df[reference_column], axis=0).reset_index()
diff_df.head()

Prompt,Index,apple,baby,ball,bear,boy,car,cat,cats,cub,...,nest,owl,puppy,queen,rabbit,snake,tree,umbrella,violin,whale
0,5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,26,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.074598,0.0,0.0,0.0
2,30,0.0,0.0,0.069887,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,36,0.0,0.097327,0.0,0.066217,0.0,0.0,0.0,0.0,0.126138,...,0.097607,0.0,0.0,0.202236,0.0,0.0,0.0,0.0,0.0,0.0
4,39,0.0,0.05495,0.0,0.0,0.092238,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [235]:
long_diff_df = diff_df.melt(id_vars = "Index", value_vars = example_prompts)
long_diff_df["value"] = long_diff_df["value"]*-1
long_diff_df = long_diff_df[long_diff_df.value > 0 ]
long_diff_df.head()
#

Unnamed: 0,Index,Prompt,value
70,53,kitten,0.052384
78,153,kitten,0.154617
84,185,kitten,0.128653
102,306,kitten,0.022608
115,415,kitten,0.067556


In [218]:
px.strip(long_diff_df.sort_values("Index"), x = "Prompt", y="value", color="Index")

In [209]:
px.strip(long_diff_df.sort_values("Prompt"), x = "Index", y="value", color="Prompt")