In [9]:
import torch
from datasets import load_dataset
import webbrowser
import os
from transformer_lens import utils, HookedTransformer
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
import time

# Library imports
from sae_vis.utils_fns import get_device
from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig
# from sae_lens.training.sparse_autoencoder import SparseAutoencoder

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

device = get_device()
torch.set_grad_enabled(False);

In [42]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    webbrowser.open(filename)

In [5]:
state_dict = torch.load("/Users/joel/Downloads/sae.pth", map_location=torch.device('cpu'))

In [13]:
state_dict["encoder.weight"].shape

torch.Size([512, 64])

In [14]:
state_dict["encoder.bias"].shape

torch.Size([512])

In [15]:
state_dict["decoder.weight"].shape

torch.Size([64, 512])

In [16]:
state_dict["decoder.bias"].shape

torch.Size([64])

In [17]:
new_state_dict = {
    "W_enc": state_dict["encoder.weight"].T,
    "b_enc": state_dict["encoder.bias"],
    "W_dec": state_dict["decoder.weight"].T,
    "b_dec": state_dict["decoder.bias"],
}

In [18]:
d_hidden, d_in = state_dict["encoder.weight"].shape
cfg = AutoEncoderConfig(d_in=d_in, d_hidden=d_hidden)
encoder = AutoEncoder(cfg)
encoder.load_state_dict(new_state_dict)

for k, v in encoder.named_parameters():
    print(f"{k}: {tuple(v.shape)}")

W_enc: (64, 512)
W_dec: (512, 64)
b_enc: (512,)
b_dec: (64,)


In [19]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M")
model.to(device);



Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer
Moving model to device:  mps


In [21]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset("roneneldan/TinyStories", split="train")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

torch.Size([3714797, 128])


In [23]:
device

device(type='mps')

In [35]:
model.cfg.device

'mps'

In [40]:
encoder.W_enc.device

device(type='mps', index=0)

In [39]:
encoder.to("mps")

AutoEncoder(d_in=64, dict_mult=8)

In [None]:
# Specify the hook point you're using, and the features you're analyzing
sae_vis_config = SaeVisConfig(
    hook_point = "blocks.4.hook_resid_post",
    features = range(64),
    verbose = True,
)

# Gather the feature data
sae_vis_data = SaeVisData.create(
    encoder = encoder,
    # encoder_B = encoder_B,
    model = model,
    tokens = all_tokens[:2048],
    cfg = sae_vis_config,
)

# Save as HTML file & display vis
filename = "_feature_vis_demo.html"
sae_vis_data.save_feature_centric_vis(filename, feature_idx=8)

display_vis_inline(filename)