<a href="https://colab.research.google.com/github/davidcrispell/LLM-PC/blob/main/TL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Setup

In [None]:
!huggingface-cli login

In [None]:
    %pip install transformer_lens
    %pip install einops
    %pip install jaxtyping

In [2]:
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM
device = t.device("cuda" if t.cuda.is_available() else "cpu")

#### Load model using TransformerLens

In [None]:
LLAMA_PATH = "LLM-PBE/Llama3.1-8b-instruct-LLMPC-Red-Team"
SKELETON_PATH = "meta-llama/meta-Llama-3-8b-instruct"

tokenizer = AutoTokenizer.from_pretrained(LLAMA_PATH)

# We have to seperately load the model through HF first so that we can set the hf_model parameter
# when setting up TransformerLens, and load Llama3.1-8b-instruct-LLMPC-Red-Team instead of meta-Llama-3-8b-instruct
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_PATH, low_cpu_mem_usage=True)

model = HookedTransformer.from_pretrained(
    SKELETON_PATH,
    hf_model=hf_model,
    device="cpu",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False
    )

if t.cuda.is_available():
  print("Using cuda")
  hf_model = hf_model.to("cuda")
  model = model.to("cuda")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

#### Check that logits are identical for Hugging Face and TL

In [None]:
prompts = [
    "The capital of Germany is",
    "2 * 42 = ",
    "My favorite",
    "aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe usaho uaotsnhuaosntuhaosntu haouaoshat u saotheu saonuh aoesntuhaosut aosu thaosu thaoustaho usaothusaothuao sutao sutaotduaoetudet uaosthuao uaostuaoeu aostouhsaonh aosnthuaoscnuhaoshkbaoesnit haosuhaoe uasotehusntaosn.p.uo ksoentudhao ustahoeuaso usant.hsa otuhaotsi aostuhs",
]

model.eval()
hf_model.eval()
prompt_ids = [tokenizer.encode(prompt, return_tensors="pt") for prompt in prompts]
tl_logits = [model(prompt_ids).detach().cpu() for prompt_ids in tqdm(prompt_ids)]

logits = [hf_model(prompt_ids).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)):
    assert t.allclose(logits[i], tl_logits[i], atol=1e-4, rtol=1e-2)