Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Llama 3 (and Llama-2-70b-hf) #549

Merged
merged 7 commits into from Apr 24, 2024

Conversation

joelburget
Copy link
Contributor

Description

This adds all four current Llama 3 models and enables Llama-2-70b-hf. No dependencies required for this change (but you must have been granted access on Hugging Face to download).

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

I didn't add tests but I did write a sanity check:

test.py:

import transformer_lens

model = transformer_lens.HookedTransformer.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct"
)

prompts = [
    "Hey how are you doing today?",
    "Two households, both alike in dignity\n(In fair Verona, where we lay our scene),",
    "The Times 03/Jan/2009",
]

for prompt in prompts:
    print(prompt)
    print(model.generate(prompt))
    tokens = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
    print(type(logits))
    print(type(cache))

output:

(transformer-lens-py3.11) root@db2b43c33b4c:/workspace/TransformerLens# python3 test.py
/workspace/TransformerLens/.venv/lib/python3.11/site-packages/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
Downloading shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 161.14it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:18<00:00, 49.55s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping
Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer
Hey how are you doing today?
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:14<00:00,  1.46s/it]
Hey how are you doing today? I am having a pretty good day so far.
<class 'torch.Tensor'>
<class 'transformer_lens.ActivationCache.ActivationCache'>
Two households, both alike in dignity
(In fair Verona, where we lay our scene),
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:17<00:00,  1.74s/it]
Two households, both alike in dignity
(In fair Verona, where we lay our scene), - a grand and beautiful phrase. These words are
<class 'torch.Tensor'>
<class 'transformer_lens.ActivationCache.ActivationCache'>
The Times 03/Jan/2009
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:15<00:00,  1.53s/it]
The Times 03/Jan/2009
Breast cancer trial hears of ‘miracle
<class 'torch.Tensor'>
<class 'transformer_lens.ActivationCache.ActivationCache'>

@joelburget joelburget marked this pull request as draft April 20, 2024 18:57
@joelburget
Copy link
Contributor Author

Looks like the docs fail to build because "Repo model meta-llama/Llama-2-7b-hf is gated. You must be authenticated to access it." Might be the same as #548.

@joelburget joelburget marked this pull request as ready for review April 20, 2024 19:32
@neelnanda-io
Copy link
Collaborator

neelnanda-io commented Apr 20, 2024 via email

@joelburget
Copy link
Contributor Author

joelburget commented Apr 21, 2024

Sure thing! Hopefully what I have now looks okay. Note that the docs are failing to build but I think it's because of #548:

Cannot access gated repo for url https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json.
Repo model mistralai/Mistral-7B-v0.1 is gated. You must be authenticated to access it.

@joelburget
Copy link
Contributor Author

joelburget commented Apr 21, 2024

Note in case someone needs to do this in the future:

Since the Llama 3 configs are exactly the same as the other Llama models (LlamaForCausalLM), copy their configs into a helper:

def mk_cfg(hf_config):
    return {
    "d_model": hf_config.hidden_size,
    "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
    "n_heads": hf_config.num_attention_heads,
    "d_mlp": hf_config.intermediate_size,
    "n_layers": hf_config.num_hidden_layers,
    "n_ctx": hf_config.max_position_embeddings,
    "eps": hf_config.rms_norm_eps,
    "d_vocab": hf_config.vocab_size,
    "act_fn": hf_config.hidden_act,
    "n_key_value_heads": (
        hf_config.num_key_value_heads
        if hf_config.num_key_value_heads != hf_config.num_attention_heads
        else None
    ),
    # This is done because the current implementation of GQA will use Grouped-Query Attention if
    # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
    # the same as hf_config.num_attention_heads, in which case GQA should not be used.
    "normalization_type": "RMS",
    "positional_embedding_type": "rotary",
    "rotary_adjacent_pairs": False,
    "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
    "final_rms": True,
    "gated_mlp": True,
}

Then, on a computer signed in to a HF account with access:

>>> mk_cfg(AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B"))
{'d_model': 4096, 'd_head': 128, 'n_heads': 32, 'd_mlp': 14336, 'n_layers': 32, 'n_ctx': 8192, 'eps': 1e-05, 'd_vocab': 128256, 'act_fn': 'silu', 'n_key_value_heads': 8, 'normalization_type': 'RMS', 'positional_embedding_type': 'rotary', 'rotary_adjacent_pairs': False, 'rotary_dim': 128, 'final_rms': True, 'gated_mlp': True}
>>> mk_cfg(AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-70B"))
{'d_model': 8192, 'd_head': 128, 'n_heads': 64, 'd_mlp': 28672, 'n_layers': 80, 'n_ctx': 8192, 'eps': 1e-05, 'd_vocab': 128256, 'act_fn': 'silu', 'n_key_value_heads': 8, 'normalization_type': 'RMS', 'positional_embedding_type': 'rotary', 'rotary_adjacent_pairs': False, 'rotary_dim': 128, 'final_rms': True, 'gated_mlp': True}

@bryce13950
Copy link
Collaborator

Yeah, that config is definitely a bit unruly. Revising it to find ways to eliminate shared code, or finding other ways to make it more manageable is a worthwhile undertaking. The docs issue is resolved. As long as everything is still passing with the recent changes, I should be able to get this merged shortly.

@bryce13950 bryce13950 merged commit 2092dc9 into TransformerLensOrg:main Apr 24, 2024
8 checks passed
@joelburget joelburget deleted the llama3 branch April 25, 2024 00:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants