# SAE Quality of Life Improvements

Add config to wandb run. 
Start saving the SAE weights mid-run. 


In [None]:
# Autoreload
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

In [None]:
from sparse_autoencoder import TensorActivationStore, SparseAutoencoder, pipeline
from sparse_autoencoder.source_data.pile_uncopyrighted import PileUncopyrightedDataset
from sparse_autoencoder.train.sweep_config import SweepParametersRuntime
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device, test_prompt
from transformers import PreTrainedTokenizerBase
import torch

device = get_device()

## Get Model

In [None]:
src_model = HookedTransformer.from_pretrained(
    "tiny-stories-instruct-1M", dtype="float32"
)

# test the model
example_prompt = """
Once upon a time, there lived a black cat. The cat belonged to a little girl called Katie. Every day, Katie
would take her cat for a walk in the park.
One day, as Katie and her cat were walking around, they saw a mean looking man. He said he wanted to
take the cat, to which she replied ”This cat belongs to
"""
example_answer = " me"


test_prompt(example_prompt, example_answer, src_model, prepend_bos=True)

# To train on Tiny Stories, we're going to need the tiny stories dataset

In [None]:
from typing import TypedDict, final
from sparse_autoencoder.source_data.abstract_dataset import (
    SourceDataset,
    TokenizedPrompts,
)


class TinyStoriesSourceDataBatch(TypedDict):
    """Pile Uncopyrighted Source Data.

    https://huggingface.co/datasets/roneneldan/TinyStories
    """

    text: list[str]
    meta: list[dict[str, dict[str, str]]]


@final
class TinyStoriesDataset(SourceDataset[TinyStoriesSourceDataBatch]):
    """Tiny Stories Dataset.

    https://huggingface.co/datasets/roneneldan/TinyStories
    """

    tokenizer: PreTrainedTokenizerBase

    def preprocess(
        self,
        source_batch: TinyStoriesSourceDataBatch,
        *,
        context_size: int,
    ) -> TokenizedPrompts:
        """Preprocess a batch of prompts.

        For each prompt's `text`, tokenize it and chunk into a list of tokenized prompts of length
        `context_size`. For the last item in the chunk, throw it away if the length is less than
        `context_size` (i.e. if it would otherwise require padding). Then finally flatten all
        batches to a single list of tokenized prompts.

        Args:
            source_batch: A batch of source data. For example, with The Pile dataset this would be a
                dict including the key "text" with a value of a list of strings (not yet tokenized).
            context_size: The context size to use when returning a list of tokenized prompts.
        """
        prompts: list[str] = source_batch["text"]

        tokenized_prompts = self.tokenizer(prompts)

        # Chunk each tokenized prompt into blocks of context_size, discarding the last block if too
        # small.
        context_size_prompts = []
        for encoding in list(tokenized_prompts["input_ids"]):  # type: ignore
            chunks = [
                encoding[i : i + context_size]
                for i in range(0, len(encoding), context_size)
                if len(encoding[i : i + context_size]) == context_size
            ]
            context_size_prompts.extend(chunks)

        return {"input_ids": context_size_prompts}

    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        context_size: int = 250,
        buffer_size: int = 1000,
        preprocess_batch_size: int = 1000,
        dataset_path: str = "roneneldan/TinyStories",
        dataset_split: str = "train",
    ):
        self.tokenizer = tokenizer

        super().__init__(
            dataset_path=dataset_path,
            dataset_split=dataset_split,
            context_size=context_size,
            buffer_size=buffer_size,
            preprocess_batch_size=preprocess_batch_size,
        )

# Training an AutoEncoder for Tiny Stories

In [None]:
import wandb


src_model = HookedTransformer.from_pretrained(
    "tiny-stories-instruct-1M", dtype="float32"
)
src_d_model: int = src_model.cfg.d_model  # type: ignore

tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
max_items = 2_000_000
store = TensorActivationStore(max_items, src_d_model, device)

# Make Autoencoder
src_model_activation_hook_point = "blocks.0.hook_resid_pre"
autoencoder = SparseAutoencoder(src_d_model, src_d_model * 8, torch.zeros(src_d_model))
autoencoder.to(device)

# Make Source Data
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
source_data = TinyStoriesDataset(tokenizer=tokenizer)

# hyper parameter
num_iterations = 3


sweep_config = SweepParametersRuntime(
    lr=1e-3,
    batch_size=2048,
    l1_coefficient=1e-3,
)

wandb.init(
    project="sparse-autoencoder", dir=".cache/wandb", config=sweep_config.__dict__
)

pipeline(
    src_model=src_model,
    src_model_activation_hook_point=src_model_activation_hook_point,
    src_model_activation_layer=0,  # why do we need to specify this as well?
    source_dataset=source_data,
    activation_store=store,
    num_activations_before_training=max_items,
    sweep_parameters=sweep_config,
    num_iterations=num_iterations,
    log_artifacts=True,
    autoencoder=autoencoder,
    device=device,
)

wandb.finish()

In [None]:
import wandb

if wandb.run is None:
    print("Weights & Biases (wandb) is not initialized.")
else:
    print("Weights & Biases (wandb) is initialized.")

# Some analysis

In [None]:
autoencoder.encoder.Linear.weight.shape

In [None]:
import plotly.express as px

In [None]:
encoder_weights = autoencoder.encoder.Linear.weight.T.detach().cpu()
print(encoder_weights.shape)
centred_weights = encoder_weights - encoder_weights.mean(dim=0)
px.bar(encoder_weights.norm(dim=0))

In [None]:
# get a cosine similarity matrix
import torch.nn.functional as F
from scipy.cluster import hierarchy
import numpy as np
import pandas as pd


def get_cosine_sim_heatmap(centred_weights):
    data_array = F.cosine_similarity(
        centred_weights.T.unsqueeze(1), centred_weights.T.unsqueeze(0), dim=2
    )
    df = pd.DataFrame(data_array.numpy())

    linkage = hierarchy.linkage(data_array)
    dendrogram = hierarchy.dendrogram(linkage, no_plot=True, color_threshold=-np.inf)
    reordered_ind = dendrogram["leaves"]
    # reorder df by ind
    df = df.iloc[reordered_ind, reordered_ind]
    data_array = df.to_numpy()
    fig = px.imshow(
        data_array, color_continuous_scale="RdBu", color_continuous_midpoint=0
    )
    return fig


fig = get_cosine_sim_heatmap(centred_weights)
fig.show()

In [None]:
autoencoder

In [None]:
decoder_weights = autoencoder.decoder[0].weight.detach().cpu()
decoder_weights.shape

centred_weights = decoder_weights - decoder_weights.mean(dim=0)
px.bar(decoder_weights.norm(dim=0))

In [None]:
fig = get_cosine_sim_heatmap(centred_weights)
fig.show()

# Look at token intersection with encoder weights

In [None]:
token_alignment = src_model.W_E.cpu() @ encoder_weights.cpu()
token_alignment = token_alignment.T.detach()
token_alignment.shape

In [None]:
px.bar(token_alignment.norm(dim=1))

In [None]:
df = pd.DataFrame(
    {
        "token": token_strings,
        "projection": token_alignment[442],
    }
)
df.sort_values("projection", ascending=True).head(10)

In [None]:
token_strings = tokenizer.convert_ids_to_tokens(list(tokenizer.vocab.values()))
# sort the strings by the keys in tokenizer vocab
token_strings = [x for _, x in sorted(zip(tokenizer.vocab.keys(), token_strings))]
token_strings[:10]