-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug Report] Unable to train Mamba SAE #311
Comments
I was also able to repro with Fixing a couple other errors firstError 1
Diff
Error 2
Diff
|
Thanks for flagging this. This might be an issue with MambaLens or we
weren't testing something we needed to during changes. Will follow up
shortly.
…On Tue, Oct 1, 2024, 9:56 PM Joel Burget ***@***.***> wrote:
I was also able to repro with tutorials/mamba_train_example.py after a
fresh clone:
Fixing a couple other errors first
Error 1
***@***.***:/workspace/SAELens# python3 tutorials/mamba_train_example.py
Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06
Using Ghost Grads.
Traceback (most recent call last):
File "/workspace/SAELens/tutorials/mamba_train_example.py", line 57, in <module>
SAETrainingRunner(cfg).run()
File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 57, in __init__
self.model = load_model(
File "/usr/local/lib/python3.10/dist-packages/sae_lens/load_model.py", line 39, in load_model
HookedMamba.from_pretrained(
TypeError: HookedMamba.from_pretrained() got an unexpected keyword argument 'center_writing_weights'
Diff
diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py
index d76a673..a055581 100644
--- a/tutorials/mamba_train_example.py
+++ b/tutorials/mamba_train_example.py
@@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig(
"fast_ssm": True,
"fast_conv": True,
},
+ model_from_pretrained_kwargs={}
)
Error 2
Traceback (most recent call last):
File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in <module>
SAETrainingRunner(cfg).run()
File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 66, in __init__
self.activations_store = ActivationsStore.from_config(
File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 69, in from_config
return cls(
File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 218, in __init__
raise ValueError(
ValueError: pretokenized dataset has context_size 1024, but the provided context_size is 128.
Diff
--- a/tutorials/mamba_train_example.py
+++ b/tutorials/mamba_train_example.py
@@ -27,7 +27,7 @@ cfg = LanguageModelSAERunnerConfig(
l1_coefficient=0.00006 * 0.2,
lr_scheduler_name="cosineannealingwarmrestarts",
train_batch_size_tokens=4096,
- context_size=128,
+ context_size=1024,
lr_warm_up_steps=5000,
# Activation Store Parameters
n_batches_in_buffer=128,
@@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig(
"fast_ssm": True,
"fast_conv": True,
},
+ model_from_pretrained_kwargs={}
)
SAETrainingRunner(cfg).run()
***@***.***:/workspace/SAELens# python3 tutorials/mamba_train_example.py
Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 4.194304
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 20971.52
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06
Using Ghost Grads.
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: huggingface/transformers#31884
warnings.warn(
Moving model to device: cuda
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:00<00:00, 254.39it/s]
Traceback (most recent call last):
File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in <module>
SAETrainingRunner(cfg).run()
File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__
self._init_sae_group_b_decs()
File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer
self._storage_buffer = self.get_buffer(self.half_buffer_size)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer
refill_batch_tokens = self.get_batch_tokens(
File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens
return torch.stack(sequences, dim=0).to(self.model.W_E.device)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?
—
Reply to this email directly, view it on GitHub
<#311 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQPMYZYCAPWVQTOFQAZTMU3ZZMEABAVCNFSM6AAAAABPF3A52SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBXGA2TSNJRGM>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
11 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
Error running with Mamba:
'HookedMamba' object has no attribute 'W_E'
.Code example
Full code: https://github.com/joelburget/mamba-sae/blob/2f87fb99660516c47121aa7a0f65d8944c42778b/hyperparam_sweep.py
Full output
System Info
Checklist
The text was updated successfully, but these errors were encountered: