We show that the transformer-lens implementation of Llama does not match that of Huggingface.

### Imports

In [1]:
import torch
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

### Load models

In [2]:
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32,
).cpu()
tl_model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    hf_model=AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float32,
    ),
    tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
    device="cpu",
    n_devices=1,
    move_to_device=True,
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    center_unembed=False,
    torch_dtype=torch.float32,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


In [7]:
def check_similarity_with_hf_model(
    tl_model: HookedTransformer,
    hf_model: AutoModelForCausalLM,
    atol: float,
    prompt: str = "Hello world!",
):
    tokens = tl_model.tokenizer.encode(prompt, return_tensors="pt")
    tl_logits = tl_model(tokens, prepend_bos=False)
    hf_logits = hf_model(tokens).logits

    print(tl_logits, hf_logits)
    assert torch.allclose(tl_logits, hf_logits, atol=atol)


with torch.no_grad():
    check_similarity_with_hf_model(tl_model, hf_model, atol=1)

tensor([[[ 0.2226,  0.0299,  0.2729,  ...,  1.4124,  1.9937,  0.7167],
         [-7.5956, -1.9652, -1.2686,  ..., -6.3465, -4.6871, -7.5035],
         [-3.2299,  2.0011,  6.4700,  ..., -1.7174, -2.2347, -2.4737],
         [-8.5657, -5.2156,  5.2663,  ..., -3.5372, -2.7532, -4.0162]]]) tensor([[[ 0.2707,  0.0165,  0.2806,  ...,  1.4403,  2.0234,  0.7647],
         [-7.5216, -2.1810, -1.1470,  ..., -6.3703, -4.6442, -7.4660],
         [-3.6662,  1.9481,  6.2992,  ..., -2.1683, -2.4888, -2.6284],
         [-8.7830, -5.1027,  5.4732,  ..., -3.7767, -2.9180, -4.4215]]])


AssertionError: 