<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/LLaMA.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# LLAVA in TransformerLens

## Setup (skip)

In [19]:
# NBVAL_IGNORE_OUTPUT
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = False
IN_VSCODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")
    
# %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8
# %pip install sentencepiece # Llama tokenizer requires sentencepiece

if IN_COLAB or IN_GITHUB:
    %pip install torch
    %pip install transformer_lens
    %pip install circuitsvis
    
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using renderer: colab








In [20]:
# Import stuff
import torch
import tqdm.auto as tqdm
import plotly.express as px

from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)
# from transformers import ChameleonModel, AutoTokenizer
from tqdm import tqdm
from jaxtyping import Float



import sys
sys.path.append('/aifs4su/yaodong/changye/TransformerLens')
from transformer_lens import HookedTransformer
from transformer_lens.HookedLlava import HookedLlava
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Loading LLAVA

Trying to load local chameleon model...

In [21]:
# MODEL_PATH = "/aifs4su/yaodong/projects/hantao/dev_cham/align-anything/outputs/0830_4k_sft_flux"
MODEL_PATH = "llava-hf/llava-v1.6-mistral-7b-hf"
# MODEL_PATH = "/aifs4su/yaodong/models/chameleon-7b-hf"
# MODEL_PATH="/aifs4su/yaodong/projects/hantao/anole/facilitating_image_generation/model/chameleon_hf_0830_4k"

processor = LlavaNextProcessor.from_pretrained(MODEL_PATH)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_PATH, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True
)

hf_model=vision_model.language_model

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [22]:
model = HookedLlava.from_pretrained(
    MODEL_PATH, 
    hf_model=hf_model,
    torch_dtype=torch.float32, 
    low_cpu_mem_usage=True,
    device="cuda:2",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=None,
    )

Loaded pretrained model llava-hf/llava-v1.6-mistral-7b-hf into HookedTransformer


In [23]:
blocks_and_idxs = list(zip(range(model.cfg.n_layers), model.blocks))
for i, block in blocks_and_idxs:
    print(f"Block {i} is: {block}")

Block 0 is: MistralBlock(
  (ln1): MistralRMSNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (ln2): MistralRMSNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (attn): MistralAttention(
    (hook_k): HookPoint()
    (hook_q): HookPoint()
    (hook_v): HookPoint()
    (hook_z): HookPoint()
    (hook_attn_scores): HookPoint()
    (hook_pattern): HookPoint()
    (hook_result): HookPoint()
    (hook_rot_k): HookPoint()
    (hook_rot_q): HookPoint()
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )
  (mlp): GatedMLP(
    (hook_pre): HookPoint()
    (hook_pre_linear): HookPoint()
    (hook_post): HookPoint()
  )
  (hook_attn_in): HookPoint()
  (hook_q_input): HookPoint()
  (hook_k_input): HookPoint()
  (hook_v_input): HookPoint()
  (hook_mlp_in): HookPoint()
  (hook_attn_out): HookPoint()
  (hook_mlp_out): HookPoint()
  (hook_resid_pre): HookPoint()
  (hook_resid_mid): HookPoint()
  (hook_resid_post): HookPoint()

In [24]:
hf_blocks_and_idxs = list(zip(range(hf_model.config.num_hidden_layers), hf_model.model.layers))

for i, block in hf_blocks_and_idxs:
    print(f"Block {i} is: {block}")
# print(hf_model)

Block 0 is: MistralDecoderLayer(
  (self_attn): MistralAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
  (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
)
Block 1 is: MistralDecoderLayer(
  (self_attn): MistralAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj)

In [25]:
block_params = model.state_dict()
hf_block_params = hf_model.state_dict()
print(block_params.keys())
print(hf_block_params.keys())


odict_keys(['embed.W_E', 'blocks.0.ln1.w', 'blocks.0.ln2.w', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.0.attn.rotary_sin', 'blocks.0.attn.rotary_cos', 'blocks.0.attn.o_proj.weight', 'blocks.0.mlp.W_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.W_gate', 'blocks.0.mlp.b_in', 'blocks.0.mlp.b_out', 'blocks.1.ln1.w', 'blocks.1.ln2.w', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_O', 'blocks.1.attn.W_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_V', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.1.attn.rotary_sin', 'blocks.1.attn.rotary_cos', 'blocks.1.attn.o_proj.weight', 'blocks.1.mlp.W_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.W_gate', 'blocks.1.mlp.b_in', 'blocks.1.mlp.b_out', 'blocks.2.ln1.w', 'blocks.2.ln2.w', 'blocks.2.attn.W_Q', 'blocks.2.attn.W_O', 'bloc

In [39]:
# for (i, block), (j, hf_block) in zip(blocks_and_idxs, hf_blocks_and_idxs):
#     assert block == hf_block, f"Block {i} does not match: {block} vs {hf_block}"
# print(model.blocks[0].attn.query.weight.shape)
import einops
for i in range(model.cfg.n_layers):
    W_Q=einops.rearrange(block_params[f"blocks.{i}.attn.W_Q"], "n m h -> (n h) m")
    W_K=einops.rearrange(block_params[f"blocks.{i}.attn.W_K"], "n m h -> (n h) m")
    W_V=einops.rearrange(block_params[f"blocks.{i}.attn.W_V"], "n m h -> (n h) m")
    W_O=einops.rearrange(block_params[f"blocks.{i}.attn.W_O"], "n h m -> m (n h)")
    
    device = "cuda:2"
    if not torch.equal(W_Q.to(device),hf_block_params[f"model.layers.{i}.self_attn.q_proj.weight"].to(device)):
        print(f"Block {i} W_Q does not match")
    if not torch.equal(W_K.to(device),hf_block_params[f"model.layers.{i}.self_attn.k_proj.weight"].to(device)):
        print(f"Block {i} W_K does not match")
    if not torch.equal(W_V.to(device),hf_block_params[f"model.layers.{i}.self_attn.v_proj.weight"].to(device)):
        print(f"Block {i} W_V does not match")
    if not torch.equal(W_O.to(device),hf_block_params[f"model.layers.{i}.self_attn.o_proj.weight"].to(device)):
        print(f"Block {i} W_O does not match")
    # print(torch.equal(W_Q.to(device),hf_block_params[f"model.layers.{i}.self_attn.q_proj.weight"].to(device)))
# W_Q=einops.rearrange(block_params[f"blocks.{0}.attn.W_Q"], "n m h -> (n h) m")
# print(W_Q.shape)
# device = "cuda:2"
# print(torch.equal(W_Q.to(device),hf_block_params[f"model.layers.{0}.self_attn.q_proj.weight"].to(device)))
    

In [27]:
# print(model.blocks[0].attn.norm_Q.weight)
# print("="*10)
# print(model.blocks[0].attn.norm_Q.bias)
# print("="*10)
# print(model.blocks[0].attn.norm_K.weight)
# print("="*10)
# print(model.blocks[0].attn.norm_K.bias)

In [28]:
prompt = "Where is the capital of Germany?"
input = processor(prompt, return_tensors="pt")
input_ids = input.input_ids
print(input_ids)
output = model.generate(input_ids, max_new_tokens=20, temperature=0)
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

tensor([[    1,  6926,   349,   272,  5565,   302,  7293, 28804]])


  0%|          | 0/20 [00:00<?, ?it/s]

Where is the capital of Germany? making a love a or a or a or a or or or or or or or or Jie


In [29]:
print(model.blocks[0])

MistralBlock(
  (ln1): MistralRMSNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (ln2): MistralRMSNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (attn): MistralAttention(
    (hook_k): HookPoint()
    (hook_q): HookPoint()
    (hook_v): HookPoint()
    (hook_z): HookPoint()
    (hook_attn_scores): HookPoint()
    (hook_pattern): HookPoint()
    (hook_result): HookPoint()
    (hook_rot_k): HookPoint()
    (hook_rot_q): HookPoint()
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )
  (mlp): GatedMLP(
    (hook_pre): HookPoint()
    (hook_pre_linear): HookPoint()
    (hook_post): HookPoint()
  )
  (hook_attn_in): HookPoint()
  (hook_q_input): HookPoint()
  (hook_k_input): HookPoint()
  (hook_v_input): HookPoint()
  (hook_mlp_in): HookPoint()
  (hook_attn_out): HookPoint()
  (hook_mlp_out): HookPoint()
  (hook_resid_pre): HookPoint()
  (hook_resid_mid): HookPoint()
  (hook_resid_post): HookPoint()
)


In [30]:
prompts = [
        "The capital of Germany is",
        "2 * 42 = ", 
        "My favorite", 
        "aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe usaho uaotsnhuaosntuhaosntu haouaoshat u saotheu saonuh aoesntuhaosut aosu thaosu thaoustaho usaothusaothuao sutao sutaotduaoetudet uaosthuao uaostuaoeu aostouhsaonh aosnthuaoscnuhaoshkbaoesnit haosuhaoe uasotehusntaosn.p.uo ksoentudhao ustahoeuaso usant.hsa otuhaotsi aostuhs",
    ]
    
    # 切换到评估模式
model.eval()
hf_model.eval()
tokenizer=AutoTokenizer.from_pretrained(MODEL_PATH)
# print(tokenizer)
# 将模型参数移动到 GPU 上
model_device ="cuda:0"
hf_model_device = "cuda:1"
model=model.to(model_device)
hf_model=hf_model.to(hf_model_device)
    
    # 分别处理每一个 prompt，避免一次性加载太多
# 分别处理每一个 prompt，避免一次性加载太多

Moving model to device:  cuda:0


In [31]:
# Step 1: 定义 prompt 和 tokenizer
prompt = "What is the capital of France?"

# 假设我们使用 Hugging Face 模型的 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# 分别对 HookedTransformer 和 Hugging Face 模型进行编码
prompt_id_tl = tokenizer.encode(prompt, return_tensors="pt").to(model_device)  # HookedTransformer 模型输入
prompt_id_hf = tokenizer.encode(prompt, return_tensors="pt").to(hf_model_device)  # Hugging Face 模型输入


In [32]:
# Step 2: 定义钩子函数和输出字典

# 定义一个钩子函数来捕获子模块的输出
def hook_fn(module_name, module, input, output):
    if isinstance(output, tuple):
        output = output[0]  # 如果输出是元组，取第一个输出
    return {module_name: output.detach().cpu()}

# 创建字典来存储 HookedTransformer 和 Hugging Face 模型的输出
tl_internal_outputs = {}
hf_internal_outputs = {}


In [33]:
# Step 3: 注册钩子

# 为 Hugging Face 模型的主要子模块（input_layernorm, self_attn, mlp）添加钩子
def register_hf_hooks(hf_model):
    hf_model.model.layers[0].input_layernorm.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("input_layernorm", m, i, o)))
    hf_model.model.layers[0].self_attn.q_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.q_proj", m, i, o)))
    hf_model.model.layers[0].self_attn.o_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.o_proj", m, i, o)))
    hf_model.model.layers[0].mlp.gate_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.gate_proj", m, i, o)))
    hf_model.model.layers[0].mlp.down_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.down_proj", m, i, o)))

# 为 HookedTransformer 模型的各个子模块添加钩子
def register_tl_hooks(model):
    model.blocks[0].hook_resid_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_resid_pre", m, i, o)))
    model.blocks[0].attn.hook_q.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_in", m, i, o)))
    model.blocks[0].attn.hook_z.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_out", m, i, o)))
    model.blocks[0].mlp.hook_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_in", m, i, o)))
    model.blocks[0].mlp.hook_post.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_out", m, i, o)))

# 注册 Hugging Face 模型钩子
register_hf_hooks(hf_model)

# 注册 HookedTransformer 模型钩子
register_tl_hooks(model)


In [34]:
print(hf_model.device)
print(prompt_id_hf.device)

cuda:1
cuda:1


In [35]:
# Step 4: 运行模型前向传播
tl_logits = model(prompt_id_tl).detach().cpu()
hf_logits = hf_model(prompt_id_hf).logits.detach().cpu()


In [36]:
# Step 5: 比较两个模型的中间输出
module_mapping = {
    "hook_attn_in": "self_attn.q_proj",
    "hook_attn_out": "self_attn.o_proj",
    "hook_mlp_in": "mlp.gate_proj",
    "hook_mlp_out": "mlp.down_proj"
}

for tl_key, hf_key in module_mapping.items():
    if tl_key in tl_internal_outputs and hf_key in hf_internal_outputs:
        tl_value = tl_internal_outputs[tl_key]
        hf_value = hf_internal_outputs[hf_key]
        print(tl_value.shape)
        
        print(hf_value.shape)
        if tl_key=="hook_attn_in" or tl_key=="hook_attn_out":
            tl_value=tl_value.reshape(1,8,4096)

        if not torch.allclose(tl_value, hf_value, atol=1e-4, rtol=1e-2):
            print(f"Difference found in {tl_key} (TL) vs {hf_key} (HF):")
            print(f"HookedTransformer output: {tl_value}")
            print(f"Hugging Face output: {hf_value}")
            print(f"Difference: {tl_value - hf_value}")


torch.Size([1, 8, 32, 128])
torch.Size([1, 8, 4096])
torch.Size([1, 32, 8, 128])
torch.Size([1, 8, 4096])
Difference found in hook_attn_out (TL) vs self_attn.o_proj (HF):
HookedTransformer output: tensor([[[ 7.2933e-03,  5.3986e-03,  1.2297e-03,  ..., -9.5093e-04,
          -7.8503e-04, -8.8677e-04],
         [ 3.9709e-03, -7.0594e-03,  4.5734e-04,  ..., -1.6104e-02,
          -2.2545e-02,  3.1311e-02],
         [ 7.2933e-03,  5.3986e-03,  1.2297e-03,  ..., -1.3243e-03,
           1.0687e-04, -8.0034e-04],
         ...,
         [ 3.9709e-03, -7.0594e-03,  4.5734e-04,  ...,  1.8842e-02,
          -3.0881e-03,  6.8312e-03],
         [ 7.2933e-03,  5.3986e-03,  1.2297e-03,  ..., -2.8347e-03,
           2.3358e-03, -7.3878e-05],
         [ 3.9709e-03, -7.0594e-03,  4.5734e-04,  ...,  1.1225e-02,
          -8.1580e-03, -3.4510e-04]]])
Hugging Face output: tensor([[[ 0.0019,  0.0014,  0.0015,  ...,  0.0006,  0.0006,  0.0010],
         [-0.0017, -0.0008,  0.0002,  ...,  0.0019,  0.0004,  0.0

RuntimeError: The size of tensor a (14336) must match the size of tensor b (4096) at non-singleton dimension 2

In [155]:
# 定义一个钩子函数来捕获子模块的输出
def hook_fn(module_name, module, input, output):
    if isinstance(output, tuple):
        output = output[0]  # 如果输出是元组，取第一个输出
    return {module_name: output.detach().cpu()}

# 创建字典来存储 HookedTransformer 和 Hugging Face 模型的输出
tl_internal_outputs = {}
hf_internal_outputs = {}

# 映射表，定义 HookedTransformer 模型和 Hugging Face 模型的模块映射
module_mapping = {
    "hook_resid_pre": "input_layernorm",
    "hook_attn_in": "self_attn.q_proj",
    "hook_attn_out": "self_attn.o_proj",
    "hook_mlp_in": "mlp.gate_proj",
    "hook_mlp_out": "mlp.down_proj"
}

# 为 Hugging Face 模型的主要子模块（input_layernorm, self_attn, mlp）添加钩子
def register_hf_hooks(hf_model):
    hf_model.model.layers[0].input_layernorm.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("input_layernorm", m, i, o)))
    hf_model.model.layers[0].self_attn.q_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.q_proj", m, i, o)))
    hf_model.model.layers[0].self_attn.o_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.o_proj", m, i, o)))
    hf_model.model.layers[0].mlp.gate_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.gate_proj", m, i, o)))
    hf_model.model.layers[0].mlp.down_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.down_proj", m, i, o)))

# 为 HookedTransformer 模型的各个子模块添加钩子
def register_tl_hooks(model):
    model.blocks[0].hook_resid_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_resid_pre", m, i, o)))
    model.blocks[0].attn.hook_q.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_in", m, i, o)))
    model.blocks[0].attn.hook_z.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_out", m, i, o)))
    model.blocks[0].mlp.hook_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_in", m, i, o)))
    model.blocks[0].mlp.hook_post.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_out", m, i, o)))

# 注册 Hugging Face 模型钩子
register_hf_hooks(hf_model)

# 注册 HookedTransformer 模型钩子
register_tl_hooks(model)
# 定义你的 prompt
prompt = "What is the capital of France?"

# 使用 tokenizer 对 prompt 进行编码
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# 对 HookedTransformer 模型和 Hugging Face 模型分别进行编码
prompt_id_tl = tokenizer.encode(prompt, return_tensors="pt").to(model_device)  # HookedTransformer 模型的输入
prompt_id_hf = tokenizer.encode(prompt, return_tensors="pt").to(hf_model_device)  # Hugging Face 模型的输入
# 运行模型前向传播
tl_logits = model(prompt_id_tl).detach().cpu()
hf_logits = hf_model(prompt_id_hf).logits.detach().cpu()

# 比较两个模型的中间输出
for tl_key, hf_key in module_mapping.items():
    if tl_key in tl_internal_outputs and hf_key in hf_internal_outputs:
        tl_value = tl_internal_outputs[tl_key]
        hf_value = hf_internal_outputs[hf_key]
        if not torch.allclose(tl_value, hf_value, atol=1e-4, rtol=1e-2):
            print(f"Difference found in {tl_key} (TL) vs {hf_key} (HF):")
            print(f"HookedTransformer output: {tl_value}")
            print(f"Hugging Face output: {hf_value}")
            print(f"Difference: {tl_value - hf_value}")


TypeError: hook_fn() missing 1 required positional argument: 'output'

In [151]:

# 分别处理每一个 prompt，避免一次性加载太多
for i, prompt in enumerate(prompts):
    print(f"Processing prompt {i+1}/{len(prompts)}")

    # 将输入移动到模型所在的设备
    prompt_id = tokenizer.encode(prompt, return_tensors="pt").to(model_device)
    prompt_id_hf = tokenizer.encode(prompt, return_tensors="pt").to(hf_model_device)

    # 获取第 0 层的输入
    tl_input = 
    hf_input = prompt_id_hf
    
    # 执行第 0 层的前向传播
    tl_layer_output = model.blocks[0](tl_input)
    hf_layer_output = hf_model.model.layers[0](hf_input)

    # 比较第 0 层的输出
    if not torch.allclose(hf_layer_output, tl_layer_output, atol=1e-4, rtol=1e-2):
        print(f"Difference found at layer 0 for prompt {i}:")
        print(f"hf_layer_output: {hf_layer_output}")
        print(f"tl_layer_output: {tl_layer_output}")
        print(f"Difference: {hf_layer_output - tl_layer_output}")
        
        # 打印最大绝对误差和相对误差
        abs_diff = torch.max(torch.abs(hf_layer_output - tl_layer_output))
        rel_diff = torch.max(torch.abs((hf_layer_output - tl_layer_output) / (tl_layer_output + 1e-8)))
        print(f"Max absolute difference at layer 0: {abs_diff.item()}")
        print(f"Max relative difference at layer 0: {rel_diff.item()}")

        # 放宽误差条件后再检查
        if not torch.allclose(hf_layer_output, tl_layer_output, atol=1e-3, rtol=1e-2):
            print(f"Larger difference persists at layer 0 for prompt {i}, investigate further.")

    # 断言条件，严格验证差异
    assert torch.allclose(hf_layer_output, tl_layer_output, atol=1e-4, rtol=1e-2)


Moving model to device:  cuda:1


TypeError: hook_fn() missing 1 required positional argument: 'output'

In [None]:
torch.cuda.empty_cache()

In [None]:
# prompt = "<image> Tell me what is in the image?"
# image_path = "/aifs4su/yaodong/datasets/aaa_dataset/T2I-preference/0812_t2i_preference_dataset/image/285ed1502f68ad9737af9b4c059d76b82984421692a22f08eff793a0cc6301e3/36ad46bb0f79b5eedaed2818b2744a19d5f89e52e410b43140a875daabebf088.png"
# torch.set_printoptions(threshold=torch.inf)
# from PIL import Image

# image = Image.open(image_path)
# input = processor(prompt, image, return_tensors="pt").to(
#             hf_model.device, dtype=hf_model.dtype
#         )
# input_ids = input.input_ids
# print(input_ids)
input_ids = torch.tensor([[    0,  8197, 
                          3643,   285,  5867,   453,  6540,  6488,   376,  7347,  6468,  1656,
         1057,  1910,  6482,  4008,  5611,   376,  3706,  2941,  7444,  7444,
          712,  5664,  4476,  3234,  6307,   798,  7444,  6345,  5589,  3158,
           40,  1323,   376,  3643,  3181,  1641,   448,  3743,  2895,  1971,
          376,  3642,  5439,  6097,  6820,   226,  3666,   750,  5867,   269,
         7444,  1343,  4214,  2229,  1093,  3000,  6009,  3214,  7755,  4599,
         6420,  5409,  7873,  3396,  6614,   827,  3872,  1741,  7536,   581,
         7349,  4187,  7444,  7444,  3706,   376,  4317,  1303,  7518,  1062,
         1846,  3214,  6488,  4746,  2419,  4462,  6035,  6498,  5104,  7634,
         4129,  2991,  3356,   182,  2872,  4187,  7903,  4237,  2095,  2873,
         4359,  2128,  4405,  1527,  2895,  3772,   798,  1385,  3643,   362,
         7065,  6305,  4102,   376,  5906,  1910,   805,  4502,  2868,  3048,
         8015,  2217,  1596,  2029,  7303,  6045,  7626,  1930,  4472,   251,
         4591,  6811,  7162,  6684,  3048,  2359,  2394,  3214,  4134,  2941,
         7783,  1846,   610,  2953,   863,   805,  6306,    37,  6113,  1238,
         6486,  7991,  7729,  1689,  2290,   651,  6891,  5312,  2644,  5291,
         4384,  4274,  3195,  3296,  6435,  4983,   589,  1846,   481,    10,
         4707,  3320,  6529,   587,  1578,  2659,  6248,  6482,  3214,  4503,
         2881,  5021,   350,  5054,  4234,   685,  6805,  1424,    41,  5274,
         3911,   453,  4084,  5711,    56,  7528,  7720,  5813,  5918,  2348,
          792,  1787,  3025,  6037,  3706,  6820,   805,  3412,  7444,  7893,
         6291,  2547,  1846,  3214,  1324,  6078,  5267,   249,  1240,  6435,
         6761,  5366,  4488,   777,   238,  4554,  4436,  1541,  4187,  1930,
         5949,  3983,  1028,  4661,  5054,  4503,  3911,  4659,  3144,   481,
          792,  3144,  5592,   249,  7395,    77,   363,  4309,  1067,  1590,
         2202,  3643,  2001,  2982,  7201,   604,  4666,  2325,  3110,    63,
         4400,    10,  5730,  6778,   219,  4149,  7905,  5432,  2658,  4642,
         8044,  4790,  4440,   312,  1414,   625,  2182,   363,  2225,  4649,
         6874,  3163,  4791,  5441,  7229,   362,  2466,   191,   481,  4498,
         6922,  1324,  3605,  1272,   863,  2805,  2561,  7189,  4593,  5192,
         3859,  4180,  3218,  3188,  4149,  3252,  3437,  2317,  2820,    10,
         7749,  7569,   675,  1847,  3772,  3427,  3584,  7437,  5447,  3083,
         3248,   604,  4129,  5488,  1245,  2001,  6906,  5936,  6718,  1351,
         1022,  6750,  3609,  3549,   177,  1526,  7893,  7380,  1306,  2557,
         1237,  1218,  1106,  4276,  3610,  6358,  2132,  4440,  5137,   742,
         4869,  5296,   363,  5137,  3204,  7036,  4077,  3819,  2269,  3197,
         6512,  1748,  1292,  4077,  4681,  1091,  7647,  3819,  7489,  2886,
         6380,  4659,  2001,  4284,  5680,  3020,  6622,   930,  5124,  4489,
         1876,  3000,  3819,  2454,  3248,   363,   363,  5519,  2992,  6775,
         7334,  2132,  6068,  1022,  3221,  4077,  2347,   363,  1754,  4422,
         7647,  2886,  2132,  3452,  6725,  4676,  7102,  1028,  4086,  2066,
         4128,  7031,  2501,  4642,  6291,  5866,  1858,  3963,  2132,  3437,
          302,  4818,   445,  1106,  4529,  1681,  1641,  5540,  2001,   363,
         7489,  3605,  7647,  2132,  4354,  1930,  7402,  5192,   953,  5946,
         7524,  3810,  6072,  3181,  3020,  5486,  3382,  2672,   407,  3810,
         4671,  6522,  6782,  7547,  6115,  2132,  5716,  5375,  7450,  3661,
         4211,  1106,  2590,  1861,  3601,  3356,  1263,  3540,  1208,  6360,
         3900,  5278,   862,  6607,   103,  4537,  4446,  3206,  5239,  2908,
         6584,  1858,  2284,  3319,  5503,  4135,  4671,  5307,  7631,   354,
         5293,  4374,   221,  2749,  6861,  4211,  3665,  4029,  5307,  1682,
         3188,  1245,    19,   502,   390,  5798,  4239,  3020,  4686,  4184,
         4184,  2243,  3831,  3831,  5706,  6260,  2240,  3135,  4259,  2064,
         3168,  8081,  3020,  1974,  1016,  4180,  4146,  6085,   637,  1331,
         1079,   741,  6725,  3862,  3190,  1745,  5779,  1453,  5402,  4821,
         2036,  5866,  1578,  1298,  1706,  3631,   390,  5861,  1708,  1938,
         5034,  1049,  6524,  1876,  1648,  4574,  5074,  8044,  4317,  1882,
          975,  2900,  4251,  2900,  1420,   133,  1453,  1077,  2784,  1332,
         4837,  2773,  3244,  7514,  4416,   894,   346,  5483,  6345,  7622,
         5486,  1642,  2787,  5378,   973,  7099,   614,  7748,  4953,   777,
         5127,  5757,  1558,  4729,  7292,  1331,  3420,  2401,  2731,  1298,
          943,  6629,  5838,  5905,  7935,  4163,  6199,   551,  4968,  4951,
          141,  3181,  3155,   480,  4512,  3544,  2318,  2145,  4213,  6852,
         8147,  4300,  3446,  6511,  3891,  7374,  7774,  5906,  1659,  5592,
         3096,  4192,   229,  3242,  4477,  3688,  4298,  5618,  2445,  1628,
         1557,  3841,   833,   381,  7065,  3762,  3172,  3020,   689,  4146,
         1838,  7729,  7287,   768,  2291,  3008,  3543,  7864,  6028,   959,
         7603,  2508,   142,  7528,  4254,  6896,  7075,  2941,  1264,  1582,
         1930,  2048,   543,  6139,  2961,  4405,  2289,  2004,  5865,  8030,
         6524,  3008,  6438,  5278,  4266,  2291,  4407,  1181,  3296,  3110,
         5652,  5320,  6390,  2659,  1018,  4192,  7592,  5723,  5564,  1237,
         2268,  3296,  4029,  1057,   520,  2463,  4883,    79,  2514,  6390,
          693,  1611,  3610,  3946,  1832,  6800,  4478,   177,  7990,  2941,
         7154,  2895,  4877,  3606,  1522,   854,  3172,  6290,  3559,   905,
         4405,  1630,  7347,  3778,  4029,  4484,  5453,  6191,  7183,  3769,
         7510,  4686,  6629,  4615,  3042,  5681,  1440,   630,  7011,  5342,
         2380,  3645,   467,  3155,  5774,  1166,  1459,  7347,  5103,   249,
          346,  5664,  6291,  8044,  8067,  4986,  4462,  2245,  5578,  1215,
         3108,  4331,  5271,  1645,  1202,  1288,  5554,  1745,  1331,  7032,
         1028,  1067,  6462,  1630,   947,  6155,  3016,  5278,   242,  5785,
         2869,  2744,  2744,  7626,  8187,  1503,  3636,  2716,  4485,  7047,
         2443,  2372,  3008,  1846,  1363,  2036,  3108,   312,  5428,  4976,
           46,  2466,  3255,  1930,  5892,  4503,  5320,  2113,  1584,  2338,
          318,   476,  2998,  3907,  1747,   742,  2492,  3808,  5525,   386,
         5362,  4849,   930,  3767,  1213,  1067,  3103,  5026,  8160,  4821,
         2126,  5313,  4239,  2287,  3252,  3214,  1414,  1526,  2645,  6111,
         1317,  5478,   382,  4780,  2012,  6664,  5682,  3769,  1832,  4530,
         4512,   275,   714,   474,  2021,  6706,  2291,   355,  4276,  5376,
         7195,  5901,   610,  2268,  2941,  1223,  4574,  2287,  4008,  1810,
         1420,  7078,  2765,   635,  3867,  5662,    79,  8140,  7820,  7510,
         3181,  2998,  3220,   318,  5255,  1306,  6922,  7380,  1531,   429,
         3221,   543,  6761,   476,  2500,  3245,  6571,  6199,  2387,  3096,
         6814,  1354,   348,  6162,  4596,  1440,  2036,    37,  2460,  4590,
         4686,  1363,  3447,   472,  6199,  3373,  7487,  6486,  3645,  1199,
         5066,  4970,  2287,  3412,  1240,  4955,  3131,  1201,  6271,  7198,
          269,  4951,  3373,  6953,  7351,  5871,  4321,  5711,   881,  6861,
          486,  7192,  5103,  2289,  1198,  7394,  3459,  6045,  4550,  3345,
          931,  3868,  5146,  4837,  4780,   703,  2965,  1157,  4676,  6203,
         1610,  6941,   182,  7287,  4819,   226,  2869,  3686,  6024,  4145,
         7489,  1223,  6486,  3614,   312,   550,  5259,  2174,  1198,  6723,
          735,  2557,  5946,  1068,  6177,   974,  3898,  2788,  1610,  4409,
         7834,   789,  7098,  6198,  8060,  6198,   295,   224,    67,  2561,
         5737,  4407,  2384,   651,  7278,  4135,  2153,  5408,  3129,  7705,
          476,  6248,  3033,  3706,  8196,  28862, 16848, 17016,
         16704, 16672, 16647, 19521, 16414,  8710]])

print(input_ids.shape)
output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
print(output[0])
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

## Loading LLAVA from transformers

Load a chameleon model from transformers, and compare the outputs, the logits, and the hidden states to ensure we did a good job integrating the model.

In [71]:
hf_model = hf_model.to("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
prompt = "Where is the capital of Germany?"
# image_path = "/aifs4su/yaodong/datasets/aaa_dataset/T2I-preference/0812_t2i_preference_dataset/image/285ed1502f68ad9737af9b4c059d76b82984421692a22f08eff793a0cc6301e3/36ad46bb0f79b5eedaed2818b2744a19d5f89e52e410b43140a875daabebf088.png"

# from PIL import Image

# image = Image.open(image_path)
input = processor(prompt, return_tensors="pt").to(
            hf_model.device, dtype=hf_model.dtype
        )
print(input.input_ids)
input_ids = input.input_ids

output = hf_model.generate(input_ids.to(hf_model.device), max_new_tokens=20, do_sample=False)
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

In [None]:
# get shape of the weights
# for name in hf_model.state_dict():
#     print(name, hf_model.state_dict()[name].shape)
    
# print(hf_model.state_dict()["model.layers.0.self_attn.q_norm.weight"])
# print(hf_model.state_dict()["model.layers.0.input_layernorm.weight"])
# print(hf_model.state_dict()["model.layers.0.post_attention_layernorm.weight"])
# print(hf_model.state_dict()["model.layers.0.self_attn.q_norm.weight"])
# print(hf_model.state_dict()["model.layers.0.self_attn.q_proj.weight"])
print(model.state_dict())

In [72]:
hf_blocks_and_idxs = list(zip(range(hf_model.config.num_hidden_layers), hf_model.named_modules()))
for i, block in hf_blocks_and_idxs:
    print(f"Block {i} is: {block}")

Block 0 is: ('', MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32064, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps

### Compare logits with HuggingFace model

In [73]:
prompts = [
    "Where is the capital of Germany?",
    "Calculate 2 * 42 = ", 
    "My favorite", 
    "My favorite place is",
]

model.eval()
hf_model.eval()
tokenizer = processor.tokenizer
prompt_ids = [tokenizer.encode(prompt, return_tensors="pt") for prompt in prompts]
tl_logits = [model(prompt_ids).detach().cpu() for prompt_ids in tqdm(prompt_ids)]

# hf logits are really slow as it's on CPU. If you have a big/multi-GPU machine, run `hf_model = hf_model.to("cuda")` to speed this up
logits = [hf_model(prompt_ids.to(hf_model.device)).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    if not torch.allclose(logits[i], tl_logits[i], atol=1e-2, rtol=1e-2):
        print(f"Logits for prompt {i} are not close")
        print(f"Logits from HuggingFace: shape {logits[i].shape}")
        print(f"Logits from TransformerLens: shape {tl_logits[i].shape}")
        diff = torch.abs(logits[i] - tl_logits[i]) > 1e-2
        indices = torch.nonzero(diff)
        for index in indices:
            row, col, loc = index[0], index[1], index[2]
            print(f"Diff at {index}: HuggingFace={logits[i][row, col, loc]}, TransformerLens={tl_logits[i][row, col, loc]}")

  0%|          | 0/4 [00:35<?, ?it/s]


UnboundLocalError: local variable 'extreme_negative_value' referenced before assignment

In [None]:
# compare hidden states

tl_hidden_states = [model(prompt_ids, return_type="hidden_states", stop_at_layer=1).detach().cpu() for prompt_ids in tqdm(prompt_ids)]
hf_hidden_states = [hf_model(prompt_ids.to(hf_model.device), output_hidden_states=True, output_attentions=True).hidden_states[1].detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    print(f"Shape of hf hidden states: {hf_hidden_states[i].shape}")
    print(f"Shape of tl hidden states: {tl_hidden_states[i].shape}")
    if not torch.allclose(hf_hidden_states[i], tl_hidden_states[i], atol=1e-4, rtol=1e-2):
        print(f"Hidden states for prompt {i} are not close")
    print(f"Hidden states from HuggingFace: {hf_hidden_states[i]}")
    print(f"Hidden states from TransformerLens: {tl_hidden_states[i]}")

In [None]:
# compare attentions

tl_attentions = [model(prompt_ids, return_type="attentions")[2].detach().cpu() for prompt_ids in tqdm(prompt_ids)]
hf_attentions = [hf_model(prompt_ids.to(hf_model.device), output_hidden_states=True, output_attentions=True).attentions[2].detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    print(f"Shape of hf attentions: {hf_attentions[i].shape}")
    print(f"Shape of tl attentions: {tl_attentions[i].shape}")
    if not torch.allclose(hf_attentions[i], tl_attentions[i], atol=1e-4, rtol=1e-2):
        print(f"Attentions for prompt {i} are not close")
        print(f"Attentions from HuggingFace: {hf_attentions[i]}")
        print(f"Attentions from TransformerLens: {tl_attentions[i]}")

# 

## TransformerLens Demo

### Reading from hooks

In [None]:
llama_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
llama_tokens = model.to_tokens(llama_text)
llama_logits, llama_cache = model.run_with_cache(llama_tokens, remove_batch_dim=True)

attention_pattern = llama_cache["pattern", 0, "attn"]
llama_str_tokens = model.to_str_tokens(llama_text)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern))

### Writing to hooks

In [None]:
layer_to_ablate = 0
head_index_to_ablate = 31

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(llama_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    llama_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")