Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Unable to train Mamba SAE #311

Open
1 task done
joelburget opened this issue Oct 1, 2024 · 2 comments
Open
1 task done

[Bug Report] Unable to train Mamba SAE #311

joelburget opened this issue Oct 1, 2024 · 2 comments

Comments

@joelburget
Copy link

joelburget commented Oct 1, 2024

Describe the bug

Error running with Mamba: 'HookedMamba' object has no attribute 'W_E'.

Code example

cfg = LanguageModelSAERunnerConfig(
    model_name="state-spaces/mamba-2.8b",
    model_class_name="HookedMamba",
    ...
)
sae = SAETrainingRunner(cfg).run()

Full code: https://github.com/joelburget/mamba-sae/blob/2f87fb99660516c47121aa7a0f65d8944c42778b/hyperparam_sweep.py

Traceback (most recent call last):
  File "/workspace/mamba-sae/hyperparam_sweep.py", line 70, in <module>
    sae = SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__
    self._init_sae_group_b_decs()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs
    layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer
    self._storage_buffer = self.get_buffer(self.half_buffer_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer
    refill_batch_tokens = self.get_batch_tokens(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens
    return torch.stack(sequences, dim=0).to(self.model.W_E.device)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?
Full output
root@39393c1a11a5:/workspace/mamba-sae# wandb agent PEAR-ML/mamba-sae-sweep/i0mhj40q
wandb: Starting wandb agent 🕵️
2024-10-01 15:28:00,877 - wandb.wandb_agent - INFO - Running runs: []
2024-10-01 15:28:01,254 - wandb.wandb_agent - INFO - Agent received command: run
2024-10-01 15:28:01,254 - wandb.wandb_agent - INFO - Agent starting run with config:
	learning_rate: 0.0012748220954754614
	sparsity_penalty: 0.0922662545654348
2024-10-01 15:28:01,259 - wandb.wandb_agent - INFO - About to run command: /usr/bin/env python hyperparam_sweep.py --learning_rate=0.0012748220954754614 --sparsity_penalty=0.0922662545654348
Resolving data files:   0%|                                                          | 0/37 [00:00<?, ?it/s]
2024-10-01 15:28:06,273 - wandb.wandb_agent - INFO - Running runs: ['ffh3gcsk']
Resolving data files: 100%|█████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.95it/s]
wandb: Currently logged in as: joelb (PEAR-ML). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.18.2
wandb: Run data is saved locally in /workspace/mamba-sae/wandb/run-20241001_152811-ffh3gcsk
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run dauntless-sweep-4
wandb: ⭐️ View project at https://wandb.ai/PEAR-ML/mamba-sae-sweep
wandb: 🧹 View sweep at https://wandb.ai/PEAR-ML/mamba-sae-sweep/sweeps/i0mhj40q
wandb: 🚀 View run at https://wandb.ai/PEAR-ML/mamba-sae-sweep/runs/ffh3gcsk
Run name: 12288-L1-0.0922662545654348-LR-0.0012748220954754614-Tokens-3.000e+07
n_tokens_per_buffer (millions): 4.194304
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 7324
Total wandb updates: 73
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 20971.52
We will reset the sparsity calculation 7 times.
Number tokens in sparsity calculation window: 4.10e+06
Using Ghost Grads.
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
Moving model to device:  cuda
Resolving data files: 100%|████████████████████████████████████████████████| 37/37 [00:00<00:00, 114.64it/s]
Traceback (most recent call last):
  File "/workspace/mamba-sae/hyperparam_sweep.py", line 70, in <module>
    sae = SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__
    self._init_sae_group_b_decs()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs
    layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer
    self._storage_buffer = self.get_buffer(self.half_buffer_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer
    refill_batch_tokens = self.get_batch_tokens(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens
    return torch.stack(sequences, dim=0).to(self.model.W_E.device)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?

System Info

root@39393c1a11a5:/workspace/mamba-sae# cat requirements.txt
torch>=2.2.0
transformers[sentencepiece]>=4.39.2
accelerate>=0.27.2
datasets>=2.15.0
wandb
sae-lens[mamba]
root@39393c1a11a5:/workspace/mamba-sae# uname -a
Linux 39393c1a11a5 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux
root@39393c1a11a5:/workspace/mamba-sae# python --version
Python 3.10.12

Checklist

  • I have checked that there is no similar issue in the repo (required)
@joelburget
Copy link
Author

I was also able to repro with tutorials/mamba_train_example.py after a fresh clone:

Fixing a couple other errors first

Error 1

root@39393c1a11a5:/workspace/SAELens# python3 tutorials/mamba_train_example.py
Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06
Using Ghost Grads.
Traceback (most recent call last):
  File "/workspace/SAELens/tutorials/mamba_train_example.py", line 57, in <module>
    SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 57, in __init__
    self.model = load_model(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/load_model.py", line 39, in load_model
    HookedMamba.from_pretrained(
TypeError: HookedMamba.from_pretrained() got an unexpected keyword argument 'center_writing_weights'

Diff

diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py
index d76a673..a055581 100644
--- a/tutorials/mamba_train_example.py
+++ b/tutorials/mamba_train_example.py
@@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig(
         "fast_ssm": True,
         "fast_conv": True,
     },
+    model_from_pretrained_kwargs={}
 )

Error 2

Traceback (most recent call last):
  File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in <module>
    SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 66, in __init__
    self.activations_store = ActivationsStore.from_config(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 69, in from_config
    return cls(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 218, in __init__
    raise ValueError(
ValueError: pretokenized dataset has context_size 1024, but the provided context_size is 128.

Diff

--- a/tutorials/mamba_train_example.py
+++ b/tutorials/mamba_train_example.py
@@ -27,7 +27,7 @@ cfg = LanguageModelSAERunnerConfig(
     l1_coefficient=0.00006 * 0.2,
     lr_scheduler_name="cosineannealingwarmrestarts",
     train_batch_size_tokens=4096,
-    context_size=128,
+    context_size=1024,
     lr_warm_up_steps=5000,
     # Activation Store Parameters
     n_batches_in_buffer=128,
@@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig(
         "fast_ssm": True,
         "fast_conv": True,
     },
+    model_from_pretrained_kwargs={}
 )

 SAETrainingRunner(cfg).run()
root@39393c1a11a5:/workspace/SAELens# python3 tutorials/mamba_train_example.py
Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 4.194304
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 20971.52
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06
Using Ghost Grads.
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
Moving model to device:  cuda
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:00<00:00, 254.39it/s]
Traceback (most recent call last):
  File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in <module>
    SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__
    self._init_sae_group_b_decs()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs
    layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer
    self._storage_buffer = self.get_buffer(self.half_buffer_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer
    refill_batch_tokens = self.get_batch_tokens(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens
    return torch.stack(sequences, dim=0).to(self.model.W_E.device)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?

@jbloomAus
Copy link
Owner

jbloomAus commented Oct 1, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants