In [1]:
from ActivationStoreParallel import ActivationsStore
from sparse_transcoder import SparseTranscoder
from transcoder_training_parallel import train_transcoder_on_language_model_parallel
from transcoder_runner_parallel import language_model_transcoder_runner_parallel
from dataclasses import dataclass
import transformer_lens
import torch
import wandb
from typing import Optional

In [4]:
@dataclass
class UnifiedConfig():
    # Common settings
    model_name: str = "gpt2-small"
    hook_point: str = "blocks.10.hook_resid_pre"
    ln: str = 'blocks.10.ln1.hook_scale'
    hook_point_layer: int = 10
    layer: int = 10
    d_in: int = 768
    d_out: int = 768
    n_head: int = 12
    d_head: int = 64
    dataset_path: str = "Skylion007/openwebtext"
    is_dataset_tokenized: bool = False
    training: bool = True
    attn_scores_normed = False
    
    # SAE Parameters
    expansion_factor: int = 12   # TODO: NOT being used??
    d_hidden: int = 2400
    b_dec_init_method: str = "mean"
    
    # Training Parameters
    lr: float = 1e-5
    reg_coefficient: float = 4e-6
    lr_scheduler_name: Optional[str] = None
    train_batch_size: int = 2048
    context_size: int = 256
    lr_warm_up_steps: int = 5000
    
    # Activation Store Parameters
    n_batches_in_buffer: int = 128
    total_training_tokens: int = 20_000 * 10
    store_batch_size: int = 32
    use_cached_activations: bool = False
    
    # Resampling protocol
    feature_sampling_method: str = 'none'
    feature_sampling_window: int = 1000
    feature_reinit_scale: float = 0.2
    resample_batches: int = 1028
    dead_feature_window: int = 50000
    dead_feature_threshold: float = 1e-6
    
    # WANDB
    log_to_wandb: bool = False
    wandb_project: str = "sparsification"
    wandb_entity: Optional[str] = None
    wandb_log_frequency: int = 1000
    entity: str = "kwyn390"
    
    # Misc
    device: str = "cuda"
    eps: float = 1e-7
    seed: int = 42
    reshape_from_heads: bool = True
    n_checkpoints: int = 10
    checkpoint_path: str = "checkpoints"
    dtype: torch.dtype = torch.float32
    run_name: str = "qk_parallel"
    
    # Query-specific settings
    hook_transcoder_in_q: str = "blocks.10.hook_resid_pre"
    hook_transcoder_out_q: str = "blocks.10.attn.hook_q"
    target_q: str = "blocks.10.attn.hook_q"
    type_q: str = "resid_to_queries"
    
    # Key-specific settings
    hook_transcoder_in_k: str = "blocks.10.hook_resid_pre"
    hook_transcoder_out_k: str = "blocks.10.attn.hook_k"
    target_k: str = "blocks.10.attn.hook_k"
    type_k: str = "resid_to_keys"

cfg = UnifiedConfig()
cfg.run_name = f"{cfg.d_hidden}_{cfg.reg_coefficient}_{cfg.lr}"

In [5]:
sparse_transcoder_Q, sparse_transcoder_K = language_model_transcoder_runner_parallel(cfg)

285| MSE Loss 22.481:   0%|          | 585728/200000000 [00:50<3:02:51, 18175.88it/s]

Loaded pretrained model gpt2-small into HookedTransformer
Dataset is not tokenized! Updating config.
TRAIN STARTED
gonna schedule!
Reinitializing b_dec with mean of activations
Previous distances: 27.712806701660156
New distances: 19.674325942993164
Reinitializing b_dec_out with mean of activations
Previous distances: 36.17184066772461
New distances: 28.02600860595703
Reinitializing b_dec with mean of activations
Previous distances: 27.712806701660156
New distances: 19.674325942993164
Reinitializing b_dec_out with mean of activations
Previous distances: 41.437225341796875
New distances: 31.178068161010742
gonna progress bar!



Training SAE:   0%|          | 0/200000 [00:00<?, ?it/s][A
0| MSE Loss 132.460:   0%|          | 0/200000 [00:00<?, ?it/s][A
1| MSE Loss 129.409:   1%|          | 2048/200000 [00:00<00:11, 17822.29it/s][A
1| MSE Loss 129.409:   2%|▏         | 4096/200000 [00:00<00:05, 35352.00it/s][A
2| MSE Loss 122.982:   2%|▏         | 4096/200000 [00:00<00:05, 35352.00it/s][A
3| MSE Loss 116.592:   3%|▎         | 6144/200000 [00:00<00:05, 35352.00it/s][A
3| MSE Loss 116.592:   4%|▍         | 8192/200000 [00:00<00:08, 22629.59it/s][A
4| MSE Loss 116.331:   4%|▍         | 8192/200000 [00:00<00:08, 22629.59it/s][A
5| MSE Loss 109.755:   5%|▌         | 10240/200000 [00:00<00:08, 22629.59it/s][A
5| MSE Loss 109.755:   6%|▌         | 12288/200000 [00:00<00:09, 20646.22it/s][A
6| MSE Loss 109.612:   6%|▌         | 12288/200000 [00:00<00:09, 20646.22it/s][A
7| MSE Loss 101.517:   7%|▋         | 14336/200000 [00:00<00:08, 20646.22it/s][A
7| MSE Loss 101.517:   8%|▊         | 16384/200000 [00:00<

Saved model to checkpoints/20480_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/20480_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



12| MSE Loss 85.221:  12%|█▏        | 24576/200000 [00:01<00:08, 20860.55it/s][A
13| MSE Loss 82.023:  13%|█▎        | 26624/200000 [00:01<00:08, 20860.55it/s][A
13| MSE Loss 82.023:  14%|█▍        | 28672/200000 [00:01<00:07, 21527.46it/s][A
14| MSE Loss 78.346:  14%|█▍        | 28672/200000 [00:01<00:07, 21527.46it/s][A
15| MSE Loss 75.706:  15%|█▌        | 30720/200000 [00:01<00:07, 21527.46it/s][A
15| MSE Loss 75.706:  16%|█▋        | 32768/200000 [00:01<00:07, 21100.59it/s][A
16| MSE Loss 72.907:  16%|█▋        | 32768/200000 [00:01<00:07, 21100.59it/s][A
17| MSE Loss 69.024:  17%|█▋        | 34816/200000 [00:01<00:07, 21100.59it/s][A
17| MSE Loss 69.024:  18%|█▊        | 36864/200000 [00:01<00:06, 23353.52it/s][A
18| MSE Loss 67.870:  18%|█▊        | 36864/200000 [00:01<00:06, 23353.52it/s][A
19| MSE Loss 64.194:  19%|█▉        | 38912/200000 [00:01<00:06, 23353.52it/s][A
19| MSE Loss 64.194:  20%|██        | 40960/200000 [00:01<00:07, 22181.89it/s][A
20| MSE Loss 65

Saved model to checkpoints/40960_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/40960_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



21| MSE Loss 62.601:  22%|██▏       | 43008/200000 [00:02<00:07, 22181.89it/s][A
21| MSE Loss 62.601:  23%|██▎       | 45056/200000 [00:02<00:09, 16979.68it/s][A
22| MSE Loss 60.093:  23%|██▎       | 45056/200000 [00:02<00:09, 16979.68it/s][A
23| MSE Loss 57.891:  24%|██▎       | 47104/200000 [00:02<00:09, 16979.68it/s][A
24| MSE Loss 58.132:  25%|██▍       | 49152/200000 [00:02<00:08, 16979.68it/s][A
24| MSE Loss 58.132:  26%|██▌       | 51200/200000 [00:02<00:08, 18394.91it/s][A
25| MSE Loss 56.156:  26%|██▌       | 51200/200000 [00:02<00:08, 18394.91it/s][A
25| MSE Loss 56.156:  27%|██▋       | 53248/200000 [00:02<00:07, 18552.95it/s][A
26| MSE Loss 55.351:  27%|██▋       | 53248/200000 [00:02<00:07, 18552.95it/s][A
27| MSE Loss 54.460:  28%|██▊       | 55296/200000 [00:02<00:07, 18552.95it/s][A
27| MSE Loss 54.460:  29%|██▊       | 57344/200000 [00:02<00:08, 17800.71it/s][A
28| MSE Loss 51.159:  29%|██▊       | 57344/200000 [00:03<00:08, 17800.71it/s][A
28| MSE Loss 51

Saved model to checkpoints/61440_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/61440_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



32| MSE Loss 50.164:  33%|███▎      | 65536/200000 [00:03<00:08, 16288.04it/s][A
32| MSE Loss 50.164:  34%|███▍      | 67584/200000 [00:03<00:09, 14439.25it/s][A
33| MSE Loss 47.666:  34%|███▍      | 67584/200000 [00:03<00:09, 14439.25it/s][A
34| MSE Loss 49.246:  35%|███▍      | 69632/200000 [00:03<00:09, 14439.25it/s][A
34| MSE Loss 49.246:  36%|███▌      | 71680/200000 [00:03<00:06, 19637.60it/s][A
35| MSE Loss 48.970:  36%|███▌      | 71680/200000 [00:03<00:06, 19637.60it/s][A
36| MSE Loss 46.536:  37%|███▋      | 73728/200000 [00:04<00:06, 19637.60it/s][A
36| MSE Loss 46.536:  38%|███▊      | 75776/200000 [00:04<00:07, 15736.79it/s][A
37| MSE Loss 45.769:  38%|███▊      | 75776/200000 [00:04<00:07, 15736.79it/s][A
37| MSE Loss 45.769:  39%|███▉      | 77824/200000 [00:04<00:08, 14307.59it/s][A
38| MSE Loss 44.629:  39%|███▉      | 77824/200000 [00:04<00:08, 14307.59it/s][A
39| MSE Loss 45.258:  40%|███▉      | 79872/200000 [00:04<00:08, 14307.59it/s][A
39| MSE Loss 45

Saved model to checkpoints/81920_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/81920_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



41| MSE Loss 44.505:  42%|████▏     | 83968/200000 [00:04<00:09, 12718.85it/s][A
42| MSE Loss 43.803:  43%|████▎     | 86016/200000 [00:05<00:08, 12718.85it/s][A
42| MSE Loss 43.803:  44%|████▍     | 88064/200000 [00:05<00:08, 13091.67it/s][A
43| MSE Loss 42.733:  44%|████▍     | 88064/200000 [00:05<00:08, 13091.67it/s][A
43| MSE Loss 42.733:  45%|████▌     | 90112/200000 [00:05<00:09, 12167.24it/s][A
44| MSE Loss 42.957:  45%|████▌     | 90112/200000 [00:05<00:09, 12167.24it/s][A
45| MSE Loss 42.975:  46%|████▌     | 92160/200000 [00:05<00:08, 12167.24it/s][A
45| MSE Loss 42.975:  47%|████▋     | 94208/200000 [00:05<00:08, 12946.33it/s][A
46| MSE Loss 39.935:  47%|████▋     | 94208/200000 [00:05<00:08, 12946.33it/s][A
46| MSE Loss 39.935:  48%|████▊     | 96256/200000 [00:05<00:07, 13128.75it/s][A
47| MSE Loss 38.973:  48%|████▊     | 96256/200000 [00:05<00:07, 13128.75it/s][A
47| MSE Loss 38.973:  49%|████▉     | 98304/200000 [00:05<00:08, 12330.11it/s][A
48| MSE Loss 42

Saved model to checkpoints/100352_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/100352_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



50| MSE Loss 41.992:  51%|█████     | 102400/200000 [00:06<00:06, 14303.02it/s][A
50| MSE Loss 41.992:  52%|█████▏    | 104448/200000 [00:06<00:07, 12786.04it/s][A
51| MSE Loss 37.548:  52%|█████▏    | 104448/200000 [00:06<00:07, 12786.04it/s][A
52| MSE Loss 39.318:  53%|█████▎    | 106496/200000 [00:06<00:07, 12786.04it/s][A
52| MSE Loss 39.318:  54%|█████▍    | 108544/200000 [00:06<00:06, 14756.21it/s][A
53| MSE Loss 42.886:  54%|█████▍    | 108544/200000 [00:06<00:06, 14756.21it/s][A
53| MSE Loss 42.886:  55%|█████▌    | 110592/200000 [00:06<00:05, 15471.92it/s][A
54| MSE Loss 36.828:  55%|█████▌    | 110592/200000 [00:06<00:05, 15471.92it/s][A
54| MSE Loss 36.828:  56%|█████▋    | 112640/200000 [00:06<00:06, 14097.09it/s][A
55| MSE Loss 40.333:  56%|█████▋    | 112640/200000 [00:07<00:06, 14097.09it/s][A
55| MSE Loss 40.333:  57%|█████▋    | 114688/200000 [00:07<00:06, 13506.08it/s][A
56| MSE Loss 37.123:  57%|█████▋    | 114688/200000 [00:07<00:06, 13506.08it/s][A
56|

Saved model to checkpoints/120832_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/120832_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



62| MSE Loss 37.551:  63%|██████▎   | 126976/200000 [00:07<00:04, 16194.72it/s][A
62| MSE Loss 37.551:  65%|██████▍   | 129024/200000 [00:07<00:03, 20013.63it/s][A
63| MSE Loss 38.101:  65%|██████▍   | 129024/200000 [00:07<00:03, 20013.63it/s][A
64| MSE Loss 35.760:  66%|██████▌   | 131072/200000 [00:07<00:03, 20013.63it/s][A
64| MSE Loss 35.760:  67%|██████▋   | 133120/200000 [00:07<00:02, 22606.92it/s][A
65| MSE Loss 39.102:  67%|██████▋   | 133120/200000 [00:08<00:02, 22606.92it/s][A
66| MSE Loss 35.109:  68%|██████▊   | 135168/200000 [00:08<00:02, 22606.92it/s][A
66| MSE Loss 35.109:  69%|██████▊   | 137216/200000 [00:08<00:03, 18002.42it/s][A
67| MSE Loss 33.011:  69%|██████▊   | 137216/200000 [00:08<00:03, 18002.42it/s][A
68| MSE Loss 34.862:  70%|██████▉   | 139264/200000 [00:08<00:03, 18002.42it/s][A
69| MSE Loss 36.677:  71%|███████   | 141312/200000 [00:08<00:03, 18002.42it/s][A
69| MSE Loss 36.677:  72%|███████▏  | 143360/200000 [00:08<00:02, 22984.27it/s][A
70|

Saved model to checkpoints/141312_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/141312_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



72| MSE Loss 32.924:  74%|███████▎  | 147456/200000 [00:08<00:02, 22984.27it/s][A
72| MSE Loss 32.924:  75%|███████▍  | 149504/200000 [00:08<00:02, 23891.57it/s][A
73| MSE Loss 33.623:  75%|███████▍  | 149504/200000 [00:08<00:02, 23891.57it/s][A
74| MSE Loss 33.817:  76%|███████▌  | 151552/200000 [00:08<00:02, 23891.57it/s][A
74| MSE Loss 33.817:  77%|███████▋  | 153600/200000 [00:09<00:02, 18147.58it/s][A
75| MSE Loss 32.806:  77%|███████▋  | 153600/200000 [00:09<00:02, 18147.58it/s][A
76| MSE Loss 34.422:  78%|███████▊  | 155648/200000 [00:09<00:02, 18147.58it/s][A
76| MSE Loss 34.422:  79%|███████▉  | 157696/200000 [00:09<00:02, 18200.18it/s][A
77| MSE Loss 34.426:  79%|███████▉  | 157696/200000 [00:09<00:02, 18200.18it/s][A
77| MSE Loss 34.426:  80%|███████▉  | 159744/200000 [00:09<00:02, 16557.77it/s][A
78| MSE Loss 34.294:  80%|███████▉  | 159744/200000 [00:09<00:02, 16557.77it/s][A
78| MSE Loss 34.294:  81%|████████  | 161792/200000 [00:09<00:02, 15788.98it/s][A
79|

Saved model to checkpoints/161792_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/161792_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



82| MSE Loss 30.866:  84%|████████▍ | 167936/200000 [00:09<00:01, 16037.84it/s][A
82| MSE Loss 30.866:  85%|████████▍ | 169984/200000 [00:09<00:01, 19086.35it/s][A
83| MSE Loss 30.952:  85%|████████▍ | 169984/200000 [00:10<00:01, 19086.35it/s][A
83| MSE Loss 30.952:  86%|████████▌ | 172032/200000 [00:10<00:01, 15849.28it/s][A
84| MSE Loss 34.458:  86%|████████▌ | 172032/200000 [00:10<00:01, 15849.28it/s][A
85| MSE Loss 29.644:  87%|████████▋ | 174080/200000 [00:10<00:01, 15849.28it/s][A
86| MSE Loss 30.522:  88%|████████▊ | 176128/200000 [00:10<00:01, 15849.28it/s][A
86| MSE Loss 30.522:  89%|████████▉ | 178176/200000 [00:10<00:00, 21947.55it/s][A
87| MSE Loss 32.424:  89%|████████▉ | 178176/200000 [00:10<00:00, 21947.55it/s][A
88| MSE Loss 33.976:  90%|█████████ | 180224/200000 [00:10<00:00, 21947.55it/s][A
88| MSE Loss 33.976:  91%|█████████ | 182272/200000 [00:10<00:00, 20826.84it/s][A
89| MSE Loss 33.516:  91%|█████████ | 182272/200000 [00:10<00:00, 20826.84it/s][A
90|

Saved model to checkpoints/180224_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/180224_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt



92| MSE Loss 29.061:  94%|█████████▍| 188416/200000 [00:10<00:00, 24291.27it/s][A
93| MSE Loss 31.397:  95%|█████████▌| 190464/200000 [00:10<00:00, 24291.27it/s][A
93| MSE Loss 31.397:  96%|█████████▋| 192512/200000 [00:10<00:00, 27273.01it/s][A
94| MSE Loss 31.165:  96%|█████████▋| 192512/200000 [00:11<00:00, 27273.01it/s][A
95| MSE Loss 34.666:  97%|█████████▋| 194560/200000 [00:11<00:00, 27273.01it/s][A
95| MSE Loss 34.666:  98%|█████████▊| 196608/200000 [00:11<00:00, 21734.33it/s][A
96| MSE Loss 29.405:  98%|█████████▊| 196608/200000 [00:11<00:00, 21734.33it/s][A
97| MSE Loss 30.004:  99%|█████████▉| 198656/200000 [00:11<00:00, 21734.33it/s][A
97| MSE Loss 30.004: : 200704it [00:11, 17817.05it/s]                          [A


Saved model to checkpoints/final_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/final_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt


In [7]:
sparse_transcoder_K.get_name()

'sparse_transcoder_gpt2-small_resid_to_keys_2400'

In [2]:
kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target = True)


In [54]:
a = torch.rand(12, 256, 256)
b = torch.rand(12, 256, 256)


a_patt = a.softmax(-1)
b_patt = b.softmax(-1)

print(kl_loss(a_patt, b_patt).mean())

tensor(0.0819)


In [None]:
kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target = True)

kl_loss(a_patt, b_patt)