# Set Up

In [None]:
# !huggingface-cli login

import torch
import os
import sys
import yaml

device = "cuda" if torch.cuda.is_available() else "CPU"

# Sparse Autoencoder

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

In [None]:
# https://jbloomaus.github.io/SAELens/training_saes/
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "qwen1.5-0.5b",
    hook_point = "blocks.18.hook_resid_pre",
    hook_point_layer = 18,
    d_in = 1024, # https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,

    # SAE Parameters
    expansion_factor = 64,
    b_dec_init_method = "geometric_median",

    # Training Parameters
    lr = 0.0004,
    l1_coefficient = 0.00008,
    lr_scheduler_name="constant", # constantwithwarmup not supported?
    train_batch_size = 4096,
    context_size = 128,
    feature_sampling_window = 1000,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-6,

    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_qwen",
    wandb_entity = None,
    wandb_log_frequency=100,

    # Misc
    device = "cuda",
    seed = 42,
    n_checkpoints = 10,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


## Upload to HuggingFace


In [None]:
from huggingface_hub import HfApi

api = HfApi()

uuid_str = "pqs59n3e"
repo_id = "kcoopermiller/qwen1.5-0.5b-saes"
local_folder = f"checkpoints/{uuid_str}"
hf_folder = f"{uuid_str}"
api.upload_folder(
    folder_path=local_folder,
    path_in_repo=hf_folder,
    repo_id=repo_id,
    repo_type="model",
)

## Evaluating the SAEs

In [None]:
import json
import plotly.express as px
from transformer_lens import utils
from datasets import load_dataset
from typing import Dict
from pathlib import Path
from huggingface_hub import hf_hub_download
from functools import partial
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from sae_vis.data_fetching_fns import get_feature_data, FeatureData
from sae_vis.data_config_classes import SaeVisConfig
torch.set_grad_enabled(False)
sys.path.append("..")

In [None]:
REPO_ID = "kcoopermiller/qwen1.5-0.5b-saes"

layer = 2 # 2 or 18
checkpoint = "crujwafo" if layer == 2 else "pqs59n3e"
FILENAME = f"{checkpoint}/final_sae_group_qwen1.5-0.5b_blocks.{layer}.hook_resid_pre_65536.pt"

path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

model, sparse_autoencoders, activation_store = (
    LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)
)
sparse_autoencoders.eval()
sparse_autoencoder = list(sparse_autoencoders)[0]

### L0 Test and Reconstruction Test

In [None]:
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    batch_tokens = activation_store.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)

### Specific Capability Test

In [None]:
example_prompt = "What is a common dog name?"
example_answer = "A common dog name is Max."
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)

### Generating Feature Interfaces


In [None]:
vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)
px.bar(x=[str(i) for i in inds], y=vals).show()

vocab_dict = model.tokenizer.vocab
vocab_dict = {
    v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset(
    "NeelNanda/c4-code-20k", split="train"
)  # currently use this dataset to avoid deal with tokenization while streaming
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

total_batch_size = 4096 * 5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))


feature_vis_params = SaeVisConfig(
    hook_point=sparse_autoencoder.cfg.hook_point,
    minibatch_size_features=256,
    minibatch_size_tokens=64,
    features=feature_idx,
    verbose=True
)



tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder,
    model=model,
    tokens=tokens,
    cfg=feature_vis_params
)

feature_data.model = model

for test_idx in list(inds.flatten().cpu().numpy()):
    feature_data.save_feature_centric_vis(
        f"data_{test_idx:04}.html",
        feature_idx=test_idx
    )