In [1]:
import json
from dataclasses import dataclass, fields

import torch as t
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device

device(type='cuda')

In [3]:
if t.cuda.is_available():
    gpu_id = 0  # Set to your target GPU ID
    total_memory = t.cuda.get_device_properties(gpu_id).total_memory
    allocated_memory = t.cuda.memory_allocated(gpu_id)
    cached_memory = t.cuda.memory_reserved(gpu_id)

    print(f"Total GPU Memory: {total_memory / 1024**2:.2f} MB")
    print(f"Allocated GPU Memory: {allocated_memory / 1024**2:.2f} MB")
    print(f"Cached GPU Memory: {cached_memory / 1024**2:.2f} MB")
else:
    print("CUDA is not available.")

Total GPU Memory: 81037.75 MB
Allocated GPU Memory: 0.00 MB
Cached GPU Memory: 0.00 MB


export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [4]:
# # Create a large random tensor on the GPU
# tensor = t.randn(10000, 10000, device=device)
# del tensor

In [5]:
sae_gemma_2b_it, cfg_gemma_2b_it, sparsity_gemma_2b_it  = SAE.from_pretrained(
  "gemma-2b-it-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  "blocks.12.hook_resid_post" # change this to another specific SAE ID in the release if desired. 
)

del sae_gemma_2b_it, sparsity_gemma_2b_it

In [6]:
sae_gemma_scope, cfg_gemma_scope, sparsity_gemma_scope  = SAE.from_pretrained(
  release = "gemma-scope-2b-pt-res-canonical", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  sae_id = "layer_5/width_16k/canonical" # change this to another specific SAE ID in the release if desired. 
)

del sae_gemma_scope, sparsity_gemma_scope

t.cuda.empty_cache()


In [7]:
# Function to load the JSON configuration into the dataclass

def load_config(config_path: str) -> LanguageModelSAERunnerConfig:
    with open(config_path, 'r') as file:
        data = json.load(file)

    # Ensure no extra fields are present
    valid_fields = {f.name for f in fields(LanguageModelSAERunnerConfig)}

    # Filter the data to only include valid fields
    filtered_data = {k: v for k, v in data.items() if k in valid_fields}

    if "expansion_factor" in filtered_data.keys() and "d_sae" in filtered_data.keys():
        filtered_data = {k: v for k, v in filtered_data.items() if k != 'expansion_factor'}
    
    return LanguageModelSAERunnerConfig(**filtered_data)  # Unpack the JSON dictionary into the dataclass

# Usage example
cfg = load_config('gemma_2b_it_blocks.12.hook_resid_post_16384_cfg.json')

Run name: 16384-L1-2-LR-5e-05-Tokens-1.229e+09
n_tokens_per_buffer (millions): 0.131072
Lower bound: n_contexts_per_buffer (millions): 0.000128
Total training steps: 300000
Total wandb updates: 6000
n_tokens_per_feature_sampling_window (millions): 20971.52
n_tokens_per_dead_feature_window (millions): 20971.52
We will reset the sparsity calculation 60 times.
Number tokens in sparsity calculation window: 2.05e+07


some things to change from jbloom/Gemma-2b-IT-Residual-Stream-SAEs:
- change model to gemma-2-2b-it
- change layer
- change architecture from standard --> gated or jumprelu
- apply_b_dec_to_input --> depending on the SAE method
- maybe: dataset
- make training longer

parameters to play around:
- l1 coefficient

In [8]:
cfg

LanguageModelSAERunnerConfig(model_name='gemma-2b-it', model_class_name='HookedTransformer', hook_name='blocks.12.hook_resid_post', hook_eval='NOT_IN_USE', hook_layer=12, hook_head_index=None, dataset_path='chanind/openwebtext-gemma', dataset_trust_remote_code=True, streaming=False, is_dataset_tokenized=True, context_size=1024, use_cached_activations=False, cached_activations_path=None, architecture='standard', d_in=2048, d_sae=16384, b_dec_init_method='zeros', expansion_factor=None, activation_fn='relu', activation_fn_kwargs={}, normalize_sae_decoder=False, noise_scale=0.0, from_pretrained_path=None, apply_b_dec_to_input=False, decoder_orthogonal_init=False, decoder_heuristic_init=True, init_encoder_as_decoder_transpose=True, n_batches_in_buffer=16, training_tokens=1228800000, finetuning_tokens=0, store_batch_size_prompts=8, train_batch_size_tokens=4096, normalize_activations='none', seqpos_slice=(None,), device='cuda', act_store_device='cuda', seed=42, dtype='float32', prepend_bos=Tr

In [9]:
cfg_gemma_scope

{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 16384,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.5.hook_resid_post',
 'hook_layer': 5,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': 'cpu'}

In [10]:
cfg.model_name = 'gemma-2-2b-it'
hook_layer = 5
cfg.hook_layer = hook_layer
cfg.hook_name = f"blocks.{hook_layer}.hook_resid_post"

cfg.d_in = cfg_gemma_scope["d_in"]
cfg.d_sae = 16384

cfg.dataset_path= cfg_gemma_scope["dataset_path"]
# cfg.dataset_path='lmsys/lmsys-chat-1m'
cfg.streaming = True

## for architecture
cfg.architecture = cfg_gemma_scope["architecture"]
cfg.apply_b_dec_to_input=cfg_gemma_scope["apply_b_dec_to_input"] ## let's check if this is true for gated models

# Logging / evals
cfg.log_to_wandb=True  # always use wandb unless you are just testing code.
cfg.wandb_project=f"{cfg.model_name}-SAEs-trial"
cfg.wandb_log_frequency=30
cfg.eval_every_n_wandb_logs=20

In [11]:
divide_batch_size = 1

total_training_steps = 40_000  # probably we should do more
batch_size = int(4096 / divide_batch_size)
total_training_tokens = total_training_steps * batch_size * divide_batch_size

lr_warm_up_steps = l1_warm_up_steps = total_training_steps // 10  # 10% of training
lr_decay_steps = total_training_steps // 5  # 20% of training

print(total_training_tokens)

cfg.training_tokens = total_training_tokens
cfg.train_batch_size_tokens = batch_size
cfg.lr_warm_up_steps = lr_warm_up_steps
cfg.lr_decay_steps = lr_decay_steps
cfg.store_batch_size_prompts = 6

163840000


In [12]:
cfg

LanguageModelSAERunnerConfig(model_name='gemma-2-2b-it', model_class_name='HookedTransformer', hook_name='blocks.5.hook_resid_post', hook_eval='NOT_IN_USE', hook_layer=5, hook_head_index=None, dataset_path='monology/pile-uncopyrighted', dataset_trust_remote_code=True, streaming=True, is_dataset_tokenized=True, context_size=1024, use_cached_activations=False, cached_activations_path=None, architecture='jumprelu', d_in=2304, d_sae=16384, b_dec_init_method='zeros', expansion_factor=None, activation_fn='relu', activation_fn_kwargs={}, normalize_sae_decoder=False, noise_scale=0.0, from_pretrained_path=None, apply_b_dec_to_input=False, decoder_orthogonal_init=False, decoder_heuristic_init=True, init_encoder_as_decoder_transpose=True, n_batches_in_buffer=16, training_tokens=163840000, finetuning_tokens=0, store_batch_size_prompts=6, train_batch_size_tokens=4096, normalize_activations='none', seqpos_slice=(None,), device='cuda', act_store_device='cuda', seed=42, dtype='float32', prepend_bos=Tr

In [None]:
# print("Comment this code out to train! Otherwise, it will load in the already trained model.")
t.set_grad_enabled(True)
runner = SAETrainingRunner(cfg)
sae = runner.run()

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.87s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdajale423[0m ([33mboston[0m). Use [1m`wandb login --relogin`[0m to force relogin


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)
3500| MSE Loss 331.244 | L1 704.431:   9%|████████▏                                                                                     | 14336000/163840000 [28:30<4:22:34, 9489.32it/s]