In [54]:
import os
import wandb
import torch
from main import SparseAutoencoder, input_dim, expansion_factor, tokenizer, hook_point, model
from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
from transformer_lens import utils

In [3]:
sweep_id = 'uo0dzeuj'
project_name = 'sae-expected-l0-sweep-norm'
entity = 'PEAR-ML' 

In [4]:
api = wandb.Api()

In [6]:
sweep = api.sweep(f'{entity}/{project_name}/sweeps/{sweep_id}')

In [8]:
sweep.display()



True

In [10]:
list(run.name for run in sweep.runs)

['fallen-sweep-10',
 'logical-sweep-9',
 'brisk-sweep-8',
 'pleasant-sweep-7',
 'fancy-sweep-6',
 'robust-sweep-5',
 'solar-sweep-4',
 'skilled-sweep-3',
 'confused-sweep-2',
 'fancy-sweep-1']

In [11]:
save_dir = f'{sweep_id}-files'
os.makedirs(save_dir, exist_ok=True)

for run in sweep.runs:
    file_path = os.path.join(save_dir, f"{run.name}_sae.pth")

    try:
        file = run.file('sae.pth')
        file.download(root=save_dir, replace=True)
        downloaded_file_path = os.path.join(save_dir, 'sae.pth')
        os.rename(downloaded_file_path, file_path)
        print(f"Downloaded {file_path} from run {run.name}")
    except Exception as e:
        print(f"Failed to download sae.pth from run {run.name}: {str(e)}")

Failed to download sae.pth from run fallen-sweep-10: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 404: Not Found)
Downloaded uo0dzeuj-files/logical-sweep-9_sae.pth from run logical-sweep-9
Downloaded uo0dzeuj-files/brisk-sweep-8_sae.pth from run brisk-sweep-8
Downloaded uo0dzeuj-files/pleasant-sweep-7_sae.pth from run pleasant-sweep-7
Downloaded uo0dzeuj-files/fancy-sweep-6_sae.pth from run fancy-sweep-6
Downloaded uo0dzeuj-files/robust-sweep-5_sae.pth from run robust-sweep-5
Downloaded uo0dzeuj-files/solar-sweep-4_sae.pth from run solar-sweep-4
Downloaded uo0dzeuj-files/skilled-sweep-3_sae.pth from run skilled-sweep-3
Downloaded uo0dzeuj-files/confused-sweep-2_sae.pth from run confused-sweep-2
Downloaded uo0dzeuj-files/fancy-sweep-1_sae.pth from run fancy-sweep-1


In [53]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset("roneneldan/TinyStories", split="train")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

In [55]:
for run in sweep.runs:
    try:
        file_path = os.path.join(save_dir, f"{run.name}_sae.pth")
        state_dict = torch.load(file_path, map_location=torch.device('cpu'))
    except:
        continue
    sae = SparseAutoencoder(input_dim=input_dim, hidden_dim=input_dim*expansion_factor, stddev_prior=run.config['stddev_prior'])
    sae.load_state_dict(state_dict)

    new_state_dict = {
        "W_enc": state_dict["encoder.weight"].T,
        "b_enc": state_dict["encoder.bias"],
        "W_dec": state_dict["decoder.weight"].T,
        "b_dec": state_dict["decoder.bias"],
    }

    d_hidden, d_in = state_dict["encoder.weight"].shape
    cfg = AutoEncoderConfig(d_in=d_in, d_hidden=d_hidden)
    encoder = AutoEncoder(cfg)
    encoder.load_state_dict(new_state_dict)

    sae_vis_config = SaeVisConfig(
        hook_point = "blocks.4.hook_resid_post",
        features = range(64),
        verbose = False,
    )
    
    sae_vis_data = SaeVisData.create(
        encoder = encoder,
        model = model,
        tokens = all_tokens[:2048],
        cfg = sae_vis_config,
    )
    
    filename = os.path.join(save_dir, f"{run.name}_vis.html")
    sae_vis_data.save_feature_centric_vis(filename)