In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    #%pip install sae-lens==1.3.0 transformer-lens==1.17.0
    #%pip install --upgrade sae-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

In [3]:
import plotly.express as px
from tqdm import tqdm
import torch
import os
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from transformer_lens import utils
from functools import partial

# Set to use 20% of GPU memory, i.e. 8GB on an A100 as ~24 GB is being used to train SAEs
# Can remove this if not training on the GPU at the same time
# if device == "cuda":
#     torch.cuda.set_per_process_memory_fraction(0.2, 0)

torch.set_grad_enabled(False)

# I don't fully understand this but it seems important to avoid some warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# wandb: https://wandb.ai/shehper/gelu-2l-attn-1-sae/reports/gelu-2l-layer-1-attn-heads--Vmlldzo4MDA1NzE4/edit
# expcted losses with SAEs are written in front of the labels below
ckpt_subfolders = { 
    0: "rovi1lwe", #3.785
    1: "p7113j0v", #3.807
    2: "rjc53kjg", #3.768
    3: "hibm6x1l", #3.738
    4: "4xima76s", #3.746
    5: "jq26bfpa", #3.729
    6: "b8e2a9w5", #3.75
    7: "smfws6mc" # 3.748
}

model_name = "gelu-2l"
hook_point_layer=1
hook_point=f"blocks.{hook_point_layer}.attn.hook_z"

d_in= 64
expansion_factor = 32
sae_name = f"{model_name}_{hook_point}_{d_in * expansion_factor}_"

In [13]:
hook_point_head_index = 1 # specify the head index
ckpt_dir = os.path.join("checkpoints", 
                        ckpt_subfolders[hook_point_head_index], 
                        "983044096", # TODO: pick the last ckpt subdir by sorting in
                        sae_name)

model, saes, activations_loader = LMSparseAutoencoderSessionloader.load_pretrained_sae(path=ckpt_dir,
                                                                                        device=device)

# print(saes.autoencoders.keys())
# saes.autoencoders['gelu-2l_blocks.1.attn.hook_z_2048_'].W_dec

sparse_autoencoder = saes.autoencoders[sae_name]
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cpu


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

SparseAutoencoder(
  (activation_fn): ReLU()
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

In [15]:
# TODO: make things consistent with April updates --- normalize activations, normalize W_dec, etc.

In [16]:
from sae_lens.training.activations_store import ActivationsStore

n_batches = 2**5 if device == "cpu" else 2**10
n_prompts = 4096 * 2 if device == "cpu" else 4096*6

def get_tokens(
    activation_store: ActivationsStore,
    n_batches_to_sample_from: int = 2**10,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]


all_tokens = get_tokens(activations_loader, 
                        n_batches_to_sample_from=n_batches,
                        n_prompts_to_select=n_prompts)  # TODO: keeping it small for cpu




[A[A[A


[A[A[A


100%|██████████| 32/32 [00:00<00:00, 123.16it/s]


In [17]:
# TODO: should I just concatenate all my SAEs?
# perhaps the simplest way would be precisely this.
# I concatenate all my SAEs, and just implement concatenation of 

In [18]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = list(range(10))

batch_size = 2 if device == "cpu" else 2048
minibatch_size_tokens = 16 if device == "cpu" else 128

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=batch_size,
    minibatch_size_tokens=minibatch_size_tokens,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sparse_autoencoder,
    model=model, # type: ignore
    tokens=all_tokens,  # type: ignore
    cfg=feature_vis_config_gpt,
)

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/10 [00:00<?, ?it/s]

In [20]:
os.makedirs(f"./head{hook_point_head_index}_features", exist_ok=True)
for feature in test_feature_idx_gpt:
    filename = f"./head{hook_point_head_index}_features/{feature}_feature_vis_demo_gpt.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display_vis_inline(filename)

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/10 [00:00<?, ?it/s]