### Pip install

In [1]:
!pip install datasets transformer_lens wandb plotly line_profiler

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting transformer_lens
  Downloading transformer_lens-2.2.2-py3-none-any.whl.metadata (12 kB)
Collecting wandb
  Downloading wandb-0.17.5-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting plotly
  Downloading plotly-5.23.0-py3-none-any.whl.metadata (7.3 kB)
Collecting line_profiler
  Downloading line_profiler-4.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylin

### Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from ActivationStoreParallel import ActivationsStore
from sparse_transcoder import SparseTranscoder
from transcoder_training_parallel import train_transcoder_on_language_model_parallel
from transcoder_runner_parallel import language_model_transcoder_runner_parallel
from dataclasses import dataclass
import transformer_lens
import torch
import wandb
from typing import Optional
from test_config import test_cfg
from sparsify_feature_map import sparsity_transcoder

In [3]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbiggs[0m ([33mbiggs-University College London (UCL)[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

### Query/Key transcoders

In [4]:
@dataclass
class UnifiedConfig():
    # Common settings
    model_name: str = "gpt2-small"
    hook_point: str = "blocks.10.hook_resid_pre"
    ln: str = 'blocks.10.ln1.hook_scale'
    hook_point_layer: int = 10
    layer: int = 10
    d_in: int = 768
    d_out: int = 768
    n_head: int = 12
    d_head: int = 64
    dataset_path: str = "Skylion007/openwebtext"
    is_dataset_tokenized: bool = False
    training: bool = True
    attn_scores_normed = True
    
    # SAE Parameters
    as_sae: bool = True
    expansion_factor: int = 12   # TODO: NOT being used??
    d_hidden: int = 2400
    b_dec_init_method: str = "mean"
    norming_decoder_during_training = True
    
    # Training Parameters
    lr: float = 1e-5
    reg_coefficient: float = 4e-6
    lr_scheduler_name: Optional[str] = None
    train_batch_size: int = 2048
    context_size: int = 256
    lr_warm_up_steps: int = 5000
    
    # Activation Store Parameters
    n_batches_in_buffer: int = 128
    total_training_tokens: int = 10_000 * 1_000
    store_batch_size: int = 32
    use_cached_activations: bool = False
    
    # Resampling protocol
    feature_sampling_method: str = 'none'
    feature_sampling_window: int = 1000
    feature_reinit_scale: float = 0.2
    resample_batches: int = 1028
    dead_feature_window: int = 50000
    dead_feature_threshold: float = 1e-6
    
    # WANDB
    log_to_wandb: bool = True
    log_final_model_to_wandb: bool = False
    wandb_project: str = "sparsification"
    wandb_entity: Optional[str] = None
    wandb_log_frequency: int = 50
    entity: str = "biggs-University College London (UCL)"
    
    # Misc
    device: str = "cuda"
    eps: float = 1e-7
    seed: int = 42
    reshape_from_heads: bool = True
    n_checkpoints: int = 10
    checkpoint_path: str = "checkpoints"
    dtype: torch.dtype = torch.float32
    run_name: str = "qk_parallel"
    
    # Query-specific settings
    hook_transcoder_in_q: str = "blocks.10.hook_resid_pre"
    hook_transcoder_out_q: str = "blocks.10.attn.hook_q"
    target_q: str = "blocks.10.attn.hook_q"
    type_q: str = "resid_to_queries"
    
    # Key-specific settings
    hook_transcoder_in_k: str = "blocks.10.hook_resid_pre"
    hook_transcoder_out_k: str = "blocks.10.attn.hook_k"
    target_k: str = "blocks.10.attn.hook_k"
    type_k: str = "resid_to_keys"

cfg = UnifiedConfig()
cfg.run_name = f"{cfg.d_hidden}_{cfg.reg_coefficient}_{cfg.lr}"

In [None]:
sparse_transcoder_Q, sparse_transcoder_K = language_model_transcoder_runner_parallel(cfg)

Running...
Loaded pretrained model gpt2-small into HookedTransformer
Dataset is not tokenized! Updating config.
Reinitializing b_dec with mean of activations
Previous distances: 27.712806701660156
New distances: 19.674325942993164
Reinitializing b_dec_out with mean of activations
Previous distances: 36.17184066772461
New distances: 28.02600860595703
Reinitializing b_dec with mean of activations
Previous distances: 27.712806701660156
New distances: 19.674325942993164
Reinitializing b_dec_out with mean of activations
Previous distances: 41.437225341796875
New distances: 31.178068161010742


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/frac_acc,▁▅▅▅▅▇▄▇▇▆▄▃▅▆▆▇▂▆▆▆▆▅▆▇▇▆▆▇▆▇▇█▅▅▇▆█▃▆▂
details/lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
details/patt_max_diff,█▅▅▄▄▃▄▄▃▃▃▃▂▃▃▂▃▂▂▂▁▂▂▂▂▂▂▁▂▁▂▂▁▁▂▁▁▁▂▁
details/pred_key_mean,▁▃▄▄▆▆▆▆▇▇▇▇▇▇█▇▇▇▆▇▇██▇▇▇▇▇█▇▇▇█▆▇▇▇█▆▇
details/pred_query_mean,█▆█▆▆▄▄▄▄▃▃▃▂▃▃▃▃▃▁▃▂▄▄▂▃▃▃▂▃▃▂▂▄▃▃▃▃▄▃▄
losses/mse_lossK,█▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁
losses/mse_lossQ,█▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/patt_lossK,█▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▃▂▂▂▁▂▂▁▂▁▂▁▂▁▁▁▁▁▁▁▁▂▁▂
losses/patt_lossQ,█▄▄▃▃▃▃▃▂▂▃▂▂▂▂▂▃▂▂▂▁▂▂▁▂▁▂▁▂▁▁▁▁▁▁▁▁▂▁▂

0,1
details/frac_acc,0.91162
details/lr,1e-05
details/n_training_tokens,74442752.0
details/patt_max_diff,0.06217
details/pred_key_mean,-0.00502
details/pred_query_mean,0.02367
losses/mse_lossK,1048.0647
losses/mse_lossQ,738.23828
losses/patt_lossK,0.10476
losses/patt_lossQ,0.10892


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112962099852868, max=1.0…

Initialising...


Training Transcoders:  10%|█         | 10020864/100000000 [04:00<1:33:58, 15958.83it/s, Loss=4.795]

Saved model to checkpoints/10000384_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/10000384_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt


Training Transcoders:  20%|██        | 20023296/100000000 [08:00<35:47, 37238.46it/s, Loss=1.298]  

Saved model to checkpoints/20000768_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/20000768_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt


Training Transcoders:  30%|███       | 30023680/100000000 [12:00<18:06, 64383.98it/s, Loss=1.113]  

Saved model to checkpoints/30001152_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/30001152_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt


Training Transcoders:  40%|████      | 40022016/100000000 [15:58<11:10, 89413.00it/s, Loss=1.035]  

Saved model to checkpoints/40001536_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/40001536_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt


Training Transcoders:  50%|█████     | 50024448/100000000 [19:55<09:39, 86237.97it/s, Loss=0.961]  

Saved model to checkpoints/50001920_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt
Saved model to checkpoints/50001920_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt


Training Transcoders:  51%|█████     | 50855936/100000000 [20:10<07:53, 103768.02it/s, Loss=0.922]

In [5]:
# train_transcoder_on_language_model_parallel(cfg, model, query_sae, key_sae, activations_store)

In [17]:
!cp checkpoints/final_sparse_transcoder_gpt2-small_resid_to_keys_2400.pt checkpoints/keys_sae_normed_final.pt
!cp checkpoints/final_sparse_transcoder_gpt2-small_resid_to_queries_2400.pt checkpoints/queries_sae_normed_final.pt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [5]:
query_sae = SparseTranscoder.load_from_pretrained("checkpoints/queries_sae_normed_final.pt", True)
key_sae = SparseTranscoder.load_from_pretrained("checkpoints/keys_sae_normed_final.pt", False)
cfg = query_sae.cfg
model = transformer_lens.HookedTransformer.from_pretrained(cfg.model_name, fold_ln=True)
activations_store = ActivationsStore(cfg, model)

Loaded pretrained model gpt2-small into HookedTransformer
Dataset is not tokenized! Updating config.


In [6]:
cfg.log_to_wandb = True
cfg.mask_reg_coeff = 4e-6
cfg.lr = 1e-3
mask = sparsity_transcoder(cfg, model, query_sae, key_sae, activations_store)

Training Mask:  16%|█▌        | 15728640/100000000 [09:36<51:29, 27280.36it/s, Patt Loss=0.125, Sparsity=32552.631, Fraction Zeros=1.000]    


KeyboardInterrupt: 

In [9]:
%config InteractiveShell

UsageError: Invalid config statement: 'InteractiveShell', should be `Class.trait = value`.
