### Imports

In [1]:
import torch
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

### Load huggingface model

In [2]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
).cuda()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
hf_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_

### Load transformerlens model

In [5]:
hf_model_cpu = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)
tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model_cpu,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


### Check outputs 

In [16]:
PROMPT = "The capital of Germany is"
TOKENS = tokenizer(PROMPT, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    LOGITS_HF = hf_model(TOKENS).logits.clone()
    LOGITS_TL = tl_model(TOKENS)

print(LOGITS_HF.shape, LOGITS_TL.shape)
assert torch.allclose(LOGITS_HF, LOGITS_TL, atol=1e-4)

torch.Size([1, 6, 32000]) torch.Size([1, 6, 32000])


In [41]:
max_new_tokens = 300
with torch.no_grad():
    COMPL_HF = tl_model.to_string(
        hf_model.generate(
            TOKENS,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )[0]
    )
    print(COMPL_HF)

    COMPL_TL = tl_model.generate(
        PROMPT, max_new_tokens=max_new_tokens, temperature=0
    )

assert COMPL_HF[:4] == "<s> "
assert COMPL_HF[4:] == COMPL_TL



<s> The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions, and vibrant nightlife. The city is home to many famous landmarks, including the Berlin Wall, the Brandenburg Gate, and the Berlin Cathedral. Berlin is also a hub for art and culture, with numerous museums, galleries, and festivals throughout the year.
The city has a long and complex history, having been the capital of Prussia, the German Empire, and the Weimar Republic before becoming the capital of the Federal Republic of Germany in 1949. Berlin has been influenced by many different cultures throughout its history, including the Slavic peoples, the Franks, the Teutonic Knights, and the Soviet Union.
Today, Berlin is a thriving metropolis with a diverse population and a strong economy. The city is home to many international companies and organizations, and is a popular destination for tourists and expats alike. Berlin is also known for its vibrant nightli

  0%|          | 0/300 [00:00<?, ?it/s]