- Plot OWT freq vs rank
- Plot freq hists for:
    - GSAE-phys
    - SSAE-phys-widegsae
    - SSAE-phys-narrowgsae

## Setup

In [None]:
import sys
sys.path.append("/root/specialised-SAEs")
from datasets import load_dataset
from transformer_lens import utils, HookedTransformer
import gc
import torch
from sae_lens.jacob.load_sae_from_hf import load_sae_from_hf
from config import DTYPE_MAP
from tqdm import tqdm

In [None]:
torch.set_grad_enabled(False)
DTYPE = "float32"
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda", dtype=DTYPE)

In [None]:
gsae_64 = load_sae_from_hf("jacobcd52/gpt2-gsae", 
                        "expansion=64.safetensors", 
                        "expansion=64_cfg.json",
                        device="cuda",
                        dtype=DTYPE)

gsae_16 = load_sae_from_hf("jacobcd52/gpt2-gsae", 
                        "expansion=16.safetensors", 
                        "expansion=16_cfg.json",
                        device="cuda",
                        dtype=DTYPE)

ssae_64_list = [
    load_sae_from_hf("jacobcd52/gpt2-ssae-phys-widegsae",
                    f"l1_coeff={l1_coeff}_expansion=2_control=0.0.safetensors",
                    f"l1_coeff={l1_coeff}_expansion=2_control=0.0_cfg.json",
                    device="cuda",
                    dtype=DTYPE)
    for l1_coeff in [2, 3, 4, 5, 6]
]

ssae_64_list = [
    load_sae_from_hf("jacobcd52/gpt2-ssae-phys-narrowgsae",
                    f"l1_coeff={l1_coeff}_expansion=2_control=0.0.safetensors",
                    f"l1_coeff={l1_coeff}_expansion=2_control=0.0_cfg.json",
                    device="cuda",
                    dtype=DTYPE)
    for l1_coeff in [2, 3, 4, 5, 6]
]

ssae_0_list = [
    load_sae_from_hf("jacobcd52/gpt2-gsae-phys",
                    f"l1_coeff={l1_coeff}_expansion=2.safetensors",
                    f"l1_coeff={l1_coeff}_expansion=2_cfg.json",
                    device="cuda",
                    dtype=DTYPE)
    for l1_coeff in [2, 3, 4, 5, 6]
]

In [None]:
# get OWT tokens
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=256)
tokenized_data = tokenized_data.shuffle(42)
owt_tokens = tokenized_data["tokens"][:20_000].cuda()
print("owt_tokens has shape", owt_tokens.shape)
print("total number of tokens:", int(owt_tokens.numel()//1e6), "million")
print()

# get physics-papers tokens
data = load_dataset("jacobcd52/physics-papers", split="train[:10%]")
# Define a filter function to remove null entries
def remove_null_entries(example):
    return all(value is not None and value != '' for value in example.values())
# Apply the filter to remove null entries
data = data.filter(remove_null_entries)
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=256)
tokenized_data = tokenized_data.shuffle(42)
phys_tokens = tokenized_data["tokens"][:20_000].cuda()
print("phys_tokens has shape", phys_tokens.shape)
print("total number of tokens:", int(phys_tokens.numel()//1e6), "million")

# clean up
del tokenized_data, data
gc.collect()

## Experiments

In [None]:
def get_freqs()