In [1]:
import sys
sys.path.append("/root/specialised-SAEs")
from datasets import load_dataset
from transformer_lens import utils, HookedTransformer
import gc
import torch
from sae_lens.jacob.load_sae_from_hf import load_sae_from_hf
from config import DTYPE_MAP
from tqdm import tqdm

In [2]:
torch.set_grad_enabled(False)
DTYPE = "bfloat16"
ctx_length = 128
model = HookedTransformer.from_pretrained_no_processing("gemma-2b-it", device="cuda", dtype=DTYPE)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model gemma-2b-it into HookedTransformer


In [3]:
gsae = load_sae_from_hf("jacobcd52/gemma2-gsae", 
                        "sae_weights.safetensors", 
                        "cfg.json",
                        device="cuda",
                        dtype=DTYPE)

ssae_list = [load_sae_from_hf("jacobcd52/gemma2-ssae-phys", 
                        f"l1_coeff={l1_coeff}_tokens=40960000_lr=0.001.safetensors", 
                        f"l1_coeff={l1_coeff}_tokens=40960000_lr=0.001_cfg.json",
                        device="cuda",
                        dtype=DTYPE)
        for l1_coeff in [5, 10, 20]
]

Downloading weights from Hugging Face Hub


sae_weights.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

GSAE weights file saved as temp_sae/sae_weights.safetensors
Downloading cfg from Hugging Face Hub


cfg.json:   0%|          | 0.00/620 [00:00<?, ?B/s]

GSAE cfg file saved as temp_sae/cfg.json
Loading weights into GSAE from temp_sae/sae_weights.safetensors
temp_sae/cfg.json temp_sae/sae_weights.safetensors
Downloading weights from Hugging Face Hub


(…)f=5_tokens=40960000_lr=0.001.safetensors:   0%|          | 0.00/302M [00:00<?, ?B/s]

GSAE weights file saved as temp_sae/sae_weights.safetensors
Downloading cfg from Hugging Face Hub


(…)oeff=5_tokens=40960000_lr=0.001_cfg.json:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

GSAE cfg file saved as temp_sae/cfg.json
Loading weights into GSAE from temp_sae/sae_weights.safetensors
temp_sae/cfg.json temp_sae/sae_weights.safetensors
Downloading weights from Hugging Face Hub


(…)=10_tokens=40960000_lr=0.001.safetensors:   0%|          | 0.00/302M [00:00<?, ?B/s]

GSAE weights file saved as temp_sae/sae_weights.safetensors
Downloading cfg from Hugging Face Hub


(…)eff=10_tokens=40960000_lr=0.001_cfg.json:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

GSAE cfg file saved as temp_sae/cfg.json
Loading weights into GSAE from temp_sae/sae_weights.safetensors
temp_sae/cfg.json temp_sae/sae_weights.safetensors
Downloading weights from Hugging Face Hub


(…)=20_tokens=40960000_lr=0.001.safetensors:   0%|          | 0.00/302M [00:00<?, ?B/s]

GSAE weights file saved as temp_sae/sae_weights.safetensors
Downloading cfg from Hugging Face Hub


(…)eff=20_tokens=40960000_lr=0.001_cfg.json:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

GSAE cfg file saved as temp_sae/cfg.json
Loading weights into GSAE from temp_sae/sae_weights.safetensors
temp_sae/cfg.json temp_sae/sae_weights.safetensors


In [4]:
# get OWT tokens
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=ctx_length)
tokenized_data = tokenized_data.shuffle(42)
owt_tokens = tokenized_data["tokens"][:20_000].cuda()
print("owt_tokens has shape", owt_tokens.shape)
print("total number of tokens:", int(owt_tokens.numel()//1e6), "million")
print()

# get physics-papers tokens
data = load_dataset("jacobcd52/physics-papers", split="train[:10%]")
# Define a filter function to remove null entries
def remove_null_entries(example):
    return all(value is not None and value != '' for value in example.values())
# Apply the filter to remove null entries
data = data.filter(remove_null_entries)
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=ctx_length)
tokenized_data = tokenized_data.shuffle(42)
phys_tokens = tokenized_data["tokens"][:20_000].cuda()
print("phys_tokens has shape", phys_tokens.shape)
print("total number of tokens:", int(phys_tokens.numel()//1e6), "million")

# clean up
del tokenized_data, data
gc.collect()

owt_tokens has shape torch.Size([20000, 128])
total number of tokens: 2 million

phys_tokens has shape torch.Size([20000, 128])
total number of tokens: 2 million


0

In [20]:
from einops import rearrange
def get_freqs_and_mean_acts(sae_list, tokens, num_tokens=10_000, batch_size=16):
    hook_pt = sae_list[0].cfg.hook_name
    layer = sae_list[0].cfg.hook_layer
    num_batches = num_tokens // (ctx_length*batch_size)
    num_tokens = num_batches * (ctx_length-1) * batch_size

    counts = {i : torch.zeros(sae.cfg.d_sae, dtype=torch.int32).cuda() 
              for i, sae in enumerate(sae_list)}
    total_acts = {i : torch.zeros(sae.cfg.d_sae, dtype=DTYPE_MAP[DTYPE]).cuda() 
                  for i, sae in enumerate(sae_list)}

    for b in tqdm(range(num_batches)):
        batch = tokens[b*batch_size:(b+1)*batch_size]
        loss, cache = model.run_with_cache(
            batch,
            return_type="loss",
            names_filter=[hook_pt],
            prepend_bos=False
        )
        flat_acts = rearrange(cache[hook_pt][:, 1:, :], "b s d -> (b s) d")
        for i, sae in enumerate(sae_list):
            feature_acts = torch.relu(sae.encode_fn(flat_acts))
            is_active = (feature_acts > 0).to(torch.int32)
            counts[i] += is_active.sum(dim=0)
            total_acts[i] += feature_acts.sum(dim=0)

        del cache
        gc.collect()
    freqs = [count.cpu().detach() / num_tokens for count in counts.values()]
    mean_acts = [total_act.cpu().detach() / num_tokens for total_act in total_acts.values()]

    return freqs, mean_acts

In [21]:
all_sae_list = [gsae] + ssae_list
freqs, mean_acts = get_freqs_and_mean_acts(all_sae_list, owt_tokens, num_tokens=40_000, batch_size=32)

  0%|          | 0/9 [00:00<?, ?it/s]

100%|██████████| 9/9 [00:03<00:00,  2.94it/s]


In [16]:
weighted_freqs = [mean_act.cuda() * all_sae_list[i].W_dec.norm(dim=-1)**2
                  for i, (freq, mean_act) in enumerate(zip(freqs, mean_acts))]

In [22]:
for freq in freqs:
    print(freq.sum())

tensor(21.1346)
tensor(22.1796)
tensor(11.9037)
tensor(1.7342)


In [31]:
import numpy as np
import plotly.express as px
for i in [0, 1]:
    px.histogram(freqs[i].to(torch.float32).cpu().log10()).show()

In [14]:
utils.test_prompt("What is 2+2?", "nlah",  model)

Tokenized prompt: ['<bos>', 'What', ' is', ' ', '2', '+', '2', '?']
Tokenized answer: [' n', 'lah']


Top 0th token. Logit:  9.94 Prob: 89.84% Token: |

|
Top 1th token. Logit:  7.59 Prob:  8.59% Token: | |
Top 2th token. Logit:  5.75 Prob:  1.36% Token: |
|
Top 3th token. Logit:  3.14 Prob:  0.10% Token: |


|
Top 4th token. Logit:  2.56 Prob:  0.06% Token: |  |
Top 5th token. Logit:  2.53 Prob:  0.05% Token: | Is|
Top 6th token. Logit:  1.93 Prob:  0.03% Token: | How|
Top 7th token. Logit:  1.80 Prob:  0.03% Token: | It|
Top 8th token. Logit:  0.76 Prob:  0.01% Token: | The|
Top 9th token. Logit:  0.14 Prob:  0.01% Token: | I|


Top 0th token. Logit:  4.47 Prob: 37.50% Token: |=|
Top 1th token. Logit:  3.45 Prob: 13.57% Token: |+|
Top 2th token. Logit:  3.36 Prob: 12.40% Token: |?|
Top 3th token. Logit:  2.97 Prob:  8.40% Token: |

|
Top 4th token. Logit:  2.09 Prob:  3.49% Token: | |
Top 5th token. Logit:  1.95 Prob:  3.02% Token: | the|
Top 6th token. Logit:  1.84 Prob:  2.72% Token: |-|
Top 7th token. Logit:  1.52 Prob:  1.98% Token: | =|
Top 8th token. Logit:  1.38 Prob:  1.71% Token: | is|
Top 9th token. Logit:  1.28 Prob:  1.55% Token: |apping|
