In [1]:
import os
import wandb
import torch
from main import SparseAutoencoder, input_dim, hidden_dim, hook_point, model
from sae_lens import SAE, SAEConfig
from sae_dashboard import sae_vis_runner
from sae_dashboard.feature_data_generator import FeatureDataGenerator
from sae_dashboard.data_writing_fns import save_feature_centric_vis
from sae_dashboard.sae_vis_runner import SaeVisConfig, SaeVisRunner
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
from transformer_lens import utils
from sae_lens import ActivationsStore, SAE, run_evals
from sae_lens.evals import EvalConfig
from pathlib import Path



Loaded pretrained model gpt2 into HookedTransformer


Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/75 [00:00<?, ?it/s]

In [2]:
sweep_id = '1ds8ouf5'
project_name = 'sae-expected-l0-sweep-norm'
entity = 'PEAR-ML' 
hook_point = "blocks.6.hook_resid_post"
dataset_path = "apollo-research/Skylion007-openwebtext-tokenizer-gpt2"

In [3]:
SEQ_LEN = 128

# Load in the data (it's a Dataset object)
data = load_dataset(dataset_path, split="train[:2048]")
assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
# tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
# tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = torch.tensor(data["input_ids"])
# assert isinstance(all_tokens, torch.Tensor)

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

In [4]:
api = wandb.Api()

In [5]:
sweep = api.sweep(f'{entity}/{project_name}/sweeps/{sweep_id}')

In [10]:
save_dir = f'{sweep_id}-files'
os.makedirs(save_dir, exist_ok=True)

for run in sweep.runs:
    file_path = os.path.join(save_dir, f"{run.name}_sae.pth")

    try:
        file = run.file('sae.pth')
        file.download(root=save_dir, replace=True)
        downloaded_file_path = os.path.join(save_dir, 'sae.pth')
        os.rename(downloaded_file_path, file_path)
        print(f"Downloaded {file_path} from run {run.name}")
    except Exception as e:
        print(f"Failed to download sae.pth from run {run.name}: {str(e)}")

In [11]:
def SAEofSparseAutoencoder(sae: SparseAutoencoder) -> SAE:
    d_hidden, d_in = sae.encoder.weight.shape
    conf = SAEConfig(
        architecture="standard",
        d_in=d_in,
        d_sae=d_hidden,
        activation_fn_str="relu",
        apply_b_dec_to_input=False,
        finetuning_scaling_factor=False,
        context_size=1024,  # TODO: what is this? does it matter?
        model_name="gpt2",
        hook_name=hook_point,
        hook_layer=6,
        hook_head_index=None,
        prepend_bos=False,
        dataset_path=dataset_path,
        dataset_trust_remote_code=False,
        normalize_activations=False,
        dtype="bfloat16",
        device="cpu",
        sae_lens_training_version=None,
    )
    result = SAE(conf)
    result.W_enc.data = sae.encoder.weight.T
    result.b_enc = sae.encoder.bias
    result.W_dec.data = sae.decoder.weight.T
    result.b_dec = sae.decoder.bias
    return result

In [None]:
for run in sweep.runs:
    try:
        file_path = os.path.join(save_dir, f"{run.name}_sae.pth")
        state_dict = torch.load(file_path, map_location=torch.device('cpu'))
    except:
        continue
    sparse_ae = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim, stddev_prior=run.config['stddev_prior'])
    sparse_ae.load_state_dict(state_dict)
    sae = SAEofSparseAutoencoder(sparse_ae)
    filename = os.path.join(save_dir, f"{run.name}_vis.html")
    if os.path.exists(filename):
        continue

    activations_store = ActivationsStore.from_sae(
        model=model,
        sae=sae,
        streaming=True,
        store_batch_size_prompts=8,
        n_batches_in_buffer=8,
        device="cpu",
    )

    eval_metrics = run_evals(
        sae=sae,
        activation_store=activations_store,
        model=model,
        eval_config=EvalConfig(
            compute_kl=True,
            compute_ce_loss=True,
            compute_l2_norms=True,
            compute_sparsity_metrics=True,
            compute_variance_metrics=True
        ),
    )
    
    # CE Loss score should be high for residual stream SAEs
    # ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly
    # ce loss with SAE shouldn't be massively higher
    print(run.name, eval_metrics)

    feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(
        hook_point=hook_point,
        features=list(range(25)),
        minibatch_size_features=2,
        minibatch_size_tokens=64,  # this is really prompt with the number of tokens determined by the sequence length
        verbose=False,
        device="cpu",
        cache_dir=Path(
            "demo_activations_cache"
        ),  # TODO: this will enable us to skip running the model for subsequent features.
        dtype="bfloat16",
    )
    
    runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)
    
    data = runner.run(
        encoder=sae,
        model=model,
        tokens=all_tokens,
    )

    save_feature_centric_vis(sae_vis_data=data, filename=filename)

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

devout-sweep-18 {'metrics/kl_div_with_sae': 0.09504439681768417, 'metrics/kl_div_with_ablation': 10.571810722351074, 'metrics/ce_loss_with_sae': 3.1370737552642822, 'metrics/ce_loss_without_sae': 3.0437216758728027, 'metrics/ce_loss_with_ablation': 13.629889488220215, 'metrics/kl_div_score': 0.9910096388108104, 'metrics/ce_loss_score': 0.9911816928423716, 'metrics/l2_norm_in': 90.38226318359375, 'metrics/l2_norm_out': 87.69502258300781, 'metrics/l2_ratio': 0.9695269465446472, 'metrics/l0': 1194.3394775390625, 'metrics/l1': 3965.449951171875, 'metrics/explained_variance': 0.9269226789474487, 'metrics/mse': 372.55401611328125, 'metrics/total_tokens_evaluated': 81920}


stoic-sweep-17 {'metrics/kl_div_with_sae': 0.09539174288511276, 'metrics/kl_div_with_ablation': 10.571810722351074, 'metrics/ce_loss_with_sae': 3.136850118637085, 'metrics/ce_loss_without_sae': 3.0437216758728027, 'metrics/ce_loss_with_ablation': 13.629889488220215, 'metrics/kl_div_score': 0.9909767829380983, 'metrics/ce_loss_score': 0.9912028182043685, 'metrics/l2_norm_in': 90.38226318359375, 'metrics/l2_norm_out': 87.60328674316406, 'metrics/l2_ratio': 0.9684170484542847, 'metrics/l0': 1235.131591796875, 'metrics/l1': 4039.60107421875, 'metrics/explained_variance': 0.929223895072937, 'metrics/mse': 360.93505859375, 'metrics/total_tokens_evaluated': 81920}


lilac-sweep-16 {'metrics/kl_div_with_sae': 0.0849115327000618, 'metrics/kl_div_with_ablation': 10.571810722351074, 'metrics/ce_loss_with_sae': 3.12406325340271, 'metrics/ce_loss_without_sae': 3.0437216758728027, 'metrics/ce_loss_with_ablation': 13.629889488220215, 'metrics/kl_div_score': 0.9919681183356281, 'metrics/ce_loss_score': 0.9924107024417089, 'metrics/l2_norm_in': 90.38226318359375, 'metrics/l2_norm_out': 88.2266845703125, 'metrics/l2_ratio': 0.9754922389984131, 'metrics/l0': 1132.5970458984375, 'metrics/l1': 3689.04296875, 'metrics/explained_variance': 0.9344365000724792, 'metrics/mse': 333.1988525390625, 'metrics/total_tokens_evaluated': 81920}


revived-sweep-15 {'metrics/kl_div_with_sae': 0.09201941639184952, 'metrics/kl_div_with_ablation': 10.571810722351074, 'metrics/ce_loss_with_sae': 3.1274807453155518, 'metrics/ce_loss_without_sae': 3.0437216758728027, 'metrics/ce_loss_with_ablation': 13.629889488220215, 'metrics/kl_div_score': 0.9912957752641843, 'metrics/ce_loss_score': 0.9920878762808715, 'metrics/l2_norm_in': 90.38226318359375, 'metrics/l2_norm_out': 87.69331359863281, 'metrics/l2_ratio': 0.969398021697998, 'metrics/l0': 1129.5069580078125, 'metrics/l1': 3590.392578125, 'metrics/explained_variance': 0.9258410930633545, 'metrics/mse': 375.0377197265625, 'metrics/total_tokens_evaluated': 81920}


absurd-sweep-14 {'metrics/kl_div_with_sae': 0.08990734070539474, 'metrics/kl_div_with_ablation': 10.571810722351074, 'metrics/ce_loss_with_sae': 3.127901554107666, 'metrics/ce_loss_without_sae': 3.0437216758728027, 'metrics/ce_loss_with_ablation': 13.629889488220215, 'metrics/kl_div_score': 0.9914955589854336, 'metrics/ce_loss_score': 0.9920481254664527, 'metrics/l2_norm_in': 90.38226318359375, 'metrics/l2_norm_out': 87.85115051269531, 'metrics/l2_ratio': 0.9712569117546082, 'metrics/l0': 1238.2529296875, 'metrics/l1': 4042.947021484375, 'metrics/explained_variance': 0.9323223829269409, 'metrics/mse': 343.72265625, 'metrics/total_tokens_evaluated': 81920}


summer-sweep-13 {'metrics/kl_div_with_sae': 0.08768874406814575, 'metrics/kl_div_with_ablation': 10.571810722351074, 'metrics/ce_loss_with_sae': 3.1241612434387207, 'metrics/ce_loss_without_sae': 3.0437216758728027, 'metrics/ce_loss_with_ablation': 13.629889488220215, 'metrics/kl_div_score': 0.9917054186485998, 'metrics/ce_loss_score': 0.9924014460197679, 'metrics/l2_norm_in': 90.38226318359375, 'metrics/l2_norm_out': 88.38250732421875, 'metrics/l2_ratio': 0.9772909879684448, 'metrics/l0': 1197.3472900390625, 'metrics/l1': 3912.5673828125, 'metrics/explained_variance': 0.9319438934326172, 'metrics/mse': 345.9559020996094, 'metrics/total_tokens_evaluated': 81920}
