Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing HookedEncoder #276

Merged
merged 27 commits into from May 19, 2023
Merged

Conversation

rusheb
Copy link
Collaborator

@rusheb rusheb commented May 18, 2023

Description

This feature was co-authored with @MatthewBaggins. Thanks also to @luciaquirke and @jbloomAus for helpful discussions throughout.

Closes issue #258.

Introducing HookedEncoder, a BERT-style encoder inheriting from HookedRootModule. Weights can be loaded from the huggingface bert-base-cased pretrained model.

Unlike HookedTransformer, it does not (yet) do any pre-processing of the weights (e.g. folding LayerNorm). Another difference is that the model can currently only be run with tokens, and not strings or list of strings. Currently, the supported task/architecture is masked language modelling. Next sentence prediction, causal language modelling, and other tasks are not supported. The HookedEncoder does not contain dropouts, which may lead to inconsistent results when pretraining.

This is an MVP implementation which serves as a starting point to iterate on. I've tried to keep the scope as small as possible for a number of reasons

  • Getting earlyish feedback on this PR
  • Getting earlier user feedback
  • Reducing the risk of getting distracted by other priorities. (Personally, I won't have much time to work on this over the next two months.)

Notes

  • In the end, based on the discussion in Introduce HookedEncoderConfig (Issue #258) #262, I decided to reuse HookedTransformerConfig rather than creating a new class HookedEncoderConfig. I believe the configs could still be separated if the need arises.
  • I chose the masked language modelling task based on feedback from people who want to do research using BERT.

Key uncertainties

  • I'm really not sure about the naming. Currently I have named the main transformer class HookedEncoder, but I've prefixed the new component names with Bert.
  • I'm hopeful that this will be easy to extend to support more variations on the architecture, like HookedTransformer, but I'd like to hear if anybody thinks otherwise.

Summary of Changes

  • Add new class HookedEncoder
  • Add new components
    • TokenTypeEmbed
    • BertEmbed
    • BertMLMHead
    • BertBlock
  • Add additive_attention_mask parameter to forward method of Attention component
  • Add BERT config and state dict to loading_from_pretrained
  • Extract methods from HookedTransformer for reuse:
    • devices.move_to_and_update_config
    • lodaing.fill_missing_keys
  • Add demo notebook demos/BERT.ipynb
  • Update Available Models list in Main Demo
  • Testing
    • Unit and acceptance tests for HookedEncoder and sub-components
    • New demo in demos/BERT.ipynb also acts as a test
    • I also added some tests for existing components e.g. HookedTransformerConfig

Future work

  • Add support for different tasks, e.g. Next Sentence Prediction, Causal Language Modelling
  • Add more models: bert-base-uncased, bert-large-cased, bert-large-uncased
  • Add preprocessing of weights including LayerNorm folding
  • Accept strings as input and add tokenization helpers from HookedTransformer
  • Add support for training/finetuning (most notably, dropouts)
  • Adding examples of research using BERT to the demo notebooks

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@jbloomAus
Copy link
Collaborator

Thanks @rusheb and @MatthewBaggins ! This is stellar work! I can't wait to see subsequent investigations!

@rusheb rusheb mentioned this pull request May 19, 2023
10 tasks
@rusheb rusheb merged commit c268a71 into TransformerLensOrg:main May 19, 2023
4 checks passed
@rusheb rusheb deleted the rusheb-bert-WIP branch May 19, 2023 09:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants