In [3]:
import os
import wandb
import torch
from main import SparseAutoencoder, input_dim, hidden_dim, hook_point, model
# from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig
# from sae_vis.data_storing_fns import SaeVisData
# from sae_vis.data_config_classes import SaeVisConfig
from sae_lens import SAE, SAEConfig
from sae_dashboard import sae_vis_runner
from sae_dashboard.feature_data_generator import FeatureDataGenerator
from sae_dashboard.data_writing_fns import save_feature_centric_vis
from sae_dashboard.sae_vis_runner import SaeVisConfig, SaeVisRunner
# from sae_dashboard.data_parsing_fns import 
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
from transformer_lens import utils
from sae_lens import ActivationsStore, SAE, run_evals
from sae_lens.evals import EvalConfig
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model gpt2 into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [4]:
sweep_id = 'itifyaiz'
project_name = 'sae-expected-l0-sweep-norm'
entity = 'PEAR-ML' 
hook_point = "blocks.6.hook_resid_post"
dataset_path = "apollo-research/Skylion007-openwebtext-tokenizer-gpt2"

In [5]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset(dataset_path, split="train[:2048]")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
# tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
# tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = torch.tensor(data["input_ids"])
# assert isinstance(all_tokens, torch.Tensor)

In [6]:
api = wandb.Api()

In [7]:
sweep = api.sweep(f'{entity}/{project_name}/sweeps/{sweep_id}')

In [8]:
save_dir = f'{sweep_id}-files'
os.makedirs(save_dir, exist_ok=True)

for run in sweep.runs:
    file_path = os.path.join(save_dir, f"{run.name}_sae.pth")

    try:
        file = run.file('sae.pth')
        file.download(root=save_dir, replace=True)
        downloaded_file_path = os.path.join(save_dir, 'sae.pth')
        os.rename(downloaded_file_path, file_path)
        print(f"Downloaded {file_path} from run {run.name}")
    except Exception as e:
        print(f"Failed to download sae.pth from run {run.name}: {str(e)}")

Failed to download sae.pth from run lunar-sweep-8: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 404: Not Found)
Downloaded itifyaiz-files/desert-sweep-7_sae.pth from run desert-sweep-7
Downloaded itifyaiz-files/worldly-sweep-6_sae.pth from run worldly-sweep-6
Downloaded itifyaiz-files/flowing-sweep-5_sae.pth from run flowing-sweep-5
Downloaded itifyaiz-files/laced-sweep-4_sae.pth from run laced-sweep-4
Downloaded itifyaiz-files/brisk-sweep-3_sae.pth from run brisk-sweep-3
Downloaded itifyaiz-files/vague-sweep-2_sae.pth from run vague-sweep-2
Failed to download sae.pth from run stellar-sweep-1: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues 

In [9]:
def SAEofSparseAutoencoder(sae: SparseAutoencoder) -> SAE:
    d_hidden, d_in = sae.encoder.weight.shape
    conf = SAEConfig(
        architecture="standard",
        d_in=d_in,
        d_sae=d_hidden,
        activation_fn_str="relu",
        apply_b_dec_to_input=False,
        finetuning_scaling_factor=False,
        context_size=1024,  # TODO: what is this? does it matter?
        model_name="gpt2",
        hook_name=hook_point,
        hook_layer=6,
        hook_head_index=None,
        prepend_bos=False,
        dataset_path=dataset_path,
        dataset_trust_remote_code=False,
        normalize_activations=False,
        dtype="bfloat16",
        device="cpu",
        sae_lens_training_version=None,
    )
    result = SAE(conf)
    result.W_enc.data = sae.encoder.weight.T
    result.b_enc = sae.encoder.bias
    result.W_dec.data = sae.decoder.weight.T
    result.b_dec = sae.decoder.bias
    return result

In [14]:
run = sweep.runs[1]

In [15]:
file_path = os.path.join(save_dir, f"{run.name}_sae.pth")
state_dict = torch.load(file_path, map_location=torch.device('cpu'))

In [16]:
sparse_ae = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim, stddev_prior=run.config['stddev_prior'])
sparse_ae.load_state_dict(state_dict)
sae = SAEofSparseAutoencoder(sparse_ae)
filename = os.path.join(save_dir, f"{run.name}_vis.html")

In [17]:
activations_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=8,
    device="cpu",
)

In [18]:
eval_metrics = run_evals(
    sae=sae,
    activation_store=activations_store,
    model=model,
    eval_config=EvalConfig(
        compute_kl=True,
        compute_ce_loss=True,
        compute_l2_norms=True,
        compute_sparsity_metrics=True,
        compute_variance_metrics=True
    ),
)

In [19]:
eval_metrics

{'metrics/kl_div_with_sae': 0.4081176817417145,
 'metrics/kl_div_with_ablation': 10.571954727172852,
 'metrics/ce_loss_with_sae': 3.4435088634490967,
 'metrics/ce_loss_without_sae': 3.0437071323394775,
 'metrics/ce_loss_with_ablation': 13.629880905151367,
 'metrics/kl_div_score': 0.9613961947176392,
 'metrics/ce_loss_score': 0.962233594527192,
 'metrics/l2_norm_in': 90.38226318359375,
 'metrics/l2_norm_out': 85.30855560302734,
 'metrics/l2_ratio': 0.942299485206604,
 'metrics/l0': 592.2353515625,
 'metrics/l1': 1976.57470703125,
 'metrics/explained_variance': 0.8029875159263611,
 'metrics/mse': 1013.3504638671875,
 'metrics/total_tokens_evaluated': 81920}

In [None]:
feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(
    hook_point=hook_point,
    features=list(range(25)),
    minibatch_size_features=2,
    minibatch_size_tokens=1024,  # this is really prompt with the number of tokens determined by the sequence length
    verbose=False,
    device="cpu",
    cache_dir=Path(
        "demo_activations_cache"
    ),  # TODO: this will enable us to skip running the model for subsequent features.
    dtype="bfloat16",
)

runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)

data = runner.run(
    encoder=sae,
    model=model,
    tokens=all_tokens,
)

In [None]:
return
for run in sweep.runs:
    try:
        file_path = os.path.join(save_dir, f"{run.name}_sae.pth")
        state_dict = torch.load(file_path, map_location=torch.device('cpu'))
    except:
        continue
    sparse_ae = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim, stddev_prior=run.config['stddev_prior'])
    sparse_ae.load_state_dict(state_dict)
    sae = SAEofSparseAutoencoder(sparse_ae)
    filename = os.path.join(save_dir, f"{run.name}_vis.html")
    if os.path.exists(filename):
        continue

    activations_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        n_batches_in_buffer=8,
        device="cpu",
    )

    eval_metrics = run_evals(
        sae=sae,
        activation_store=activations_store,
        model=model,
        eval_config=EvalConfig(
            compute_kl=True,
            compute_ce_loss=True,
            compute_l2_norms=True,
            compute_sparsity_metrics=True,
            compute_variance_metrics=True
        ),
    )
    
    # CE Loss score should be high for residual stream SAEs
    # ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly
    # ce loss with SAE shouldn't be massively higher
    print(eval_metrics)

    feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(
        hook_point=hook_point,
        features=list(range(25)),
        minibatch_size_features=2,
        minibatch_size_tokens=1024,  # this is really prompt with the number of tokens determined by the sequence length
        verbose=False,
        device="cpu",
        cache_dir=Path(
            "demo_activations_cache"
        ),  # TODO: this will enable us to skip running the model for subsequent features.
        dtype="bfloat16",
    )
    
    runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)
    
    data = runner.run(
        encoder=sae,
        model=model,
        tokens=all_tokens,
    )

    save_feature_centric_vis(sae_vis_data=data, filename=filename)

    # sae_vis version:

    # new_state_dict = {
    #     "W_enc": state_dict["encoder.weight"].T,
    #     "b_enc": state_dict["encoder.bias"],
    #     "W_dec": state_dict["decoder.weight"].T,
    #     "b_dec": state_dict["decoder.bias"],
    # }

    # d_hidden, d_in = state_dict["encoder.weight"].shape
    # cfg = AutoEncoderConfig(d_in=d_in, d_hidden=d_hidden)
    # encoder = AutoEncoder(cfg)
    # encoder.load_state_dict(new_state_dict)

    # sae_vis_config = SaeVisConfig(
    #     hook_point = "blocks.6.hook_resid_post",
    #     features = range(64),
    #     verbose = False,
    # )
    # 
    # sae_vis_data = SaeVisData.create(
    #     encoder = encoder,
    #     model = model,
    #     tokens = all_tokens,
    #     cfg = sae_vis_config,
    # )
    # 
    # sae_vis_data.save_feature_centric_vis(filename)