Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: add mlp transcoders #183

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

dtch1997
Copy link
Contributor

@dtch1997 dtch1997 commented Jun 15, 2024

Description

Add support for training, loading, and running inference on MLP transcoders.

  • Add a Transcoder subclass of SAE
  • Add a TrainingTranscoder subclass of TrainingSAE
  • Add a TranscoderTrainer subclass of SAETrainer
  • Add a LanguageModelTranscoderTrainingRunner subclass of LanguageModelSAETrainingRunner

Fixes #182

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Performance Check.

If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:

  • L0
  • CE Loss
  • MSE Loss
  • Feature Dashboard Interpretability

Please links to wandb dashboards with a control and test group.

Copy link

codecov bot commented Jun 15, 2024

Codecov Report

Attention: Patch coverage is 49.71751% with 89 lines in your changes missing coverage. Please review.

Project coverage is 58.48%. Comparing base (227d208) to head (7a22a75).
Report is 11 commits behind head on main.

Current head 7a22a75 differs from pull request most recent head 1e020d5

Please upload reports for the commit 1e020d5 to get more accurate results.

Files Patch % Lines
sae_lens/training/sae_trainer.py 11.32% 47 Missing ⚠️
sae_lens/training/training_sae.py 47.61% 22 Missing ⚠️
sae_lens/sae_training_runner.py 34.61% 17 Missing ⚠️
sae_lens/sae.py 95.83% 1 Missing and 1 partial ⚠️
sae_lens/config.py 87.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #183      +/-   ##
==========================================
- Coverage   59.59%   58.48%   -1.11%     
==========================================
  Files          25       25              
  Lines        2636     2780     +144     
  Branches      445      466      +21     
==========================================
+ Hits         1571     1626      +55     
- Misses        987     1075      +88     
- Partials       78       79       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@dtch1997
Copy link
Contributor Author

dtch1997 commented Jun 15, 2024

Initial Wandb run here: https://wandb.ai/dtch1997/benchmark/workspace

Benchmarked using the following command:

poetry run pytest tests/benchmark/test_language_model_transcoder_runner.py --profile-svg -s
poetry run pytest tests/benchmark/test_language_model_sae_runner.py --profile-svg -s

image

Green = MLP-out SAE, red = MLP transcoder

  • Red loss goes down = the implementation kind of works!

@dtch1997 dtch1997 changed the title [WIP] refactor: sae forward pass [WIP] feat: add mlp transcoders Jun 15, 2024
@dtch1997
Copy link
Contributor Author

Current status of the PR:

  • We can train transcoders from scratch

The next priority might be to support pre-trained MLP transcoders.
Going to work on this tomorrow-ish

@dtch1997
Copy link
Contributor Author

dtch1997 commented Jun 19, 2024

Some notes on architecture.

  • minimal coupling, maximal duplication: Implement Transcoder as a totally separate class with code copy-pasted from SAE. (suggested at some point by Joseph)
  • maximal coupling, minimal duplication: Implement both Transcoder and SAE in the same class via a config option. (suggested by Jacob).
  • somewhere in the middle: inheritance / polymorphism, which is what I've started doing.

the right balance depends on a few things:

  • do we expect that we'll need to change Transcoder independently of SAE (if so, we'd prefer less coupling)
  • can relevant functionality can be reused between the two modules (if we can refactor certain things into external functions, or if it turns out that we can use the same Trainer class for both SAE and Transcoder, then we'll have a lot less duplication).

Feedback on the last two points above would be very useful.

Edit:

  • I think inheritance / polymorphism is a good design choice for implementing Transcoder and TrainingTranscoder, provided we use dependency inversion to make sure changes in SAE won't break Transcoder.
  • Having discused this with Jacob Dunefsky, I think it's probably possible to merge the SAETrainer/ TranscoderTrainer classes (+ corresponding runners)

Comment on lines +483 to +487
# NOTE: Transcoders have an additional b_dec_out parameter.
# Reference: https://github.com/jacobdunefsky/transcoder_circuits/blob/7b44d870a5a301ef29eddfd77cb1f4dca854760a/sae_training/sparse_autoencoder.py#L93C1-L97C14
self.b_dec_out = nn.Parameter(
torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the extra bias is needed. I'm probably just confused and missing something, but it would make the implementation simpler if you don't need it.

I understand that in normal SAEs people sometimes subtract b_dec from the input. This isn't really necessary but has a nice interpretation of choosing a new "0 point" which you can consider as the origin in the feature basis.

For transcoders this makes less sense. Since you aren't reconstructing the same activations you probably don't want to tie the pre-encoder bias with the post-decoder bias.

Thus, in the current implementation we do:
$$z = ReLU(W_{enc}(x - b_{dec}) + b_{enc})$$
and
$$out = W_{dec} x +b_\text{dec out}$$
This isn't any more expressive, you can always fold the first two biases ($b_{dec}$ and $b_{enc}$ above) into a single bias term. I don't see a good reason why it would result in a more interpretable zero point for the encoder basis either.

Overall I'd recommend dropping the complexity here, which maybe means you can just eliminate the Transcoder class entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense! i'll try dropping the extra b_dec term when training. I was initially concerned about supporting the previously-trained checkpoints, but as you say weight folding should solve that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Proposal] Add MLP transcoders
2 participants