## imports & definitions

In [1]:
!pip install git+https://github.com/jacobcd52/sae_vis.git --quiet

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
DTYPE = 'bfloat16'
ctx_length = 128

In [3]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append("/root/specialised-SAEs/")
from huggingface_hub import hf_hub_download
from sae_lens.sae import SAE
from sae_lens.training.training_sae import TrainingSAE
import sae_vis
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae_training_runner import SAETrainingRunner
from sae_lens.jacob.load_sae_from_hf import load_sae_from_hf
from tqdm import tqdm
import gc
import torch
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f7109411de0>

In [4]:
# callum imports 
from IPython import get_ipython # type: ignore
ipython = get_ipython(); assert ipython is not None

# Standard imports
import torch
from datasets import load_dataset
import webbrowser
import os
from transformer_lens import utils, HookedTransformer
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
import time

# Library imports
from sae_vis.utils_fns import get_device
from sae_vis.model_fns import AutoEncoder
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig

## Load stuff

In [5]:
model = HookedTransformer.from_pretrained_no_processing("gemma-2b-it", device="cuda", dtype=DTYPE)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b-it into HookedTransformer


In [6]:
# get OWT tokens
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=ctx_length)
tokenized_data = tokenized_data.shuffle(42)
owt_tokens = tokenized_data["tokens"][:20_000].cuda()
print("owt_tokens has shape", owt_tokens.shape)
print("total number of tokens:", int(owt_tokens.numel()//1e6), "million")
print()

# get physics-papers tokens
data = load_dataset("jacobcd52/physics-papers", split="train[:10%]")
# Define a filter function to remove null entries
def remove_null_entries(example):
    return all(value is not None and value != '' for value in example.values())
# Apply the filter to remove null entries
data = data.filter(remove_null_entries)
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=ctx_length)
tokenized_data = tokenized_data.shuffle(42)
phys_tokens = tokenized_data["tokens"][:20_000].cuda()
print("phys_tokens has shape", phys_tokens.shape)
print("total number of tokens:", int(phys_tokens.numel()//1e6), "million")

# clean up
del tokenized_data, data
gc.collect()

owt_tokens has shape torch.Size([20000, 128])
total number of tokens: 2 million

phys_tokens has shape torch.Size([20000, 128])
total number of tokens: 2 million


0

In [7]:
sae = load_sae_from_hf("jacobcd52/gemma2-ssae-phys",
                    f"l1_coeff=30_tokens=40960000_lr=0.001.safetensors",
                    f"l1_coeff=30_tokens=40960000_lr=0.001_cfg.json",
                    device="cuda",
                    dtype=DTYPE)                    

Downloading weights from Hugging Face Hub


(…)=30_tokens=40960000_lr=0.001.safetensors:   0%|          | 0.00/302M [00:00<?, ?B/s]

GSAE weights file saved as temp_sae/sae_weights.safetensors
Downloading cfg from Hugging Face Hub


(…)eff=30_tokens=40960000_lr=0.001_cfg.json:   0%|          | 0.00/2.57k [00:00<?, ?B/s]

GSAE cfg file saved as temp_sae/cfg.json
Loading weights into GSAE from temp_sae/sae_weights.safetensors
temp_sae/cfg.json temp_sae/sae_weights.safetensors


# Get dashboards

In [1]:
# features with freq > 1e-5 on phys data, and freq_phys/freq_owt > 100

finetune_phys_ids = [749, 1274, 2346, 2353, 4585, 6447, 8114, 9646, 11407, 11540, 12451, 14503, 14684, 15171, 15246, 15314, 15344, 15536, 16238, 16261, 16545, 16570, 20752, 20916, 22401, 23048, 26239, 29236, 30120, 31045, 31198, 31205, 31530, 31834, 32261]
print(len(finetune_phys_ids))

gsae_phys_ids = [1274, 1436, 2346, 3212, 4384, 4585, 8114, 10329, 11407, 11540, 12360, 12451, 14503, 14684, 15144, 15246, 15314, 15344, 15536, 16238, 18282, 20752, 21016, 21162, 23048, 23979, 26239, 27357, 27514, 28035, 29218, 29236, 31045, 31198, 31205, 31729, 32084, 32261]
print(len(gsae_phys_ids))

new_ids = [id for id in finetune_phys_ids[:20] if id not in gsae_phys_ids]
print(len(new_ids))

35
38
6


In [8]:
torch.cuda.empty_cache()
import gc
gc.collect()
torch.cuda.empty_cache()

feature_vis_config_gpt = SaeVisConfig(
    hook_point = sae.cfg.hook_name,
    features =  [i for i in range(50)],
    batch_size = 4096,
    verbose = True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder = sae,
    model = model,
    tokens = phys_tokens, # type: ignore
    cfg = feature_vis_config_gpt,
)

filename = "phys_features.html"
sae_vis_data_gpt.save_feature_centric_vis(filename)


Forward passes to cache data for vis:   0%|          | 0/64 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/50 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/50 [00:00<?, ?it/s]