In [2]:
%load_ext autoreload
%autoreload 2

In [31]:
import torch
from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import (
    apply_rotary_pos_emb,
    repeat_kv,
)

In [4]:
# model_name = "openai-community/gpt2"
model_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

In [5]:
from hooked_transformer.auto_hooked_model import AutoHookedModelForCausalLM

hooked_model = AutoHookedModelForCausalLM.from_pretrained(model_name)

Added alias pre_attn_norm for model.layers[*].pre_attn_norm -> input_layernorm in LlamaDecoderLayer
Added alias eps for model.layers[*].pre_attn_norm.eps -> variance_epsilon in LlamaRMSNorm
Added alias pre_mlp_norm for model.layers[*].pre_mlp_norm -> post_attention_layernorm in LlamaDecoderLayer


In [6]:
hooked_model.model.config

LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "float32",
  "transformers_version": "4.53.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [7]:
hooked_model.model.io_keys

{'embed_tokens': {'output': ['hidden_states']},
 'layers': {'self': {'output': ['hidden_states']},
  'pre_attn_norm': {'output': ['hidden_states']},
  'self_attn': {'output': ['hidden_states', 'attn_weights']},
  'q_proj': {'output': ['hidden_states']},
  'k_proj': {'output': ['hidden_states']},
  'v_proj': {'output': ['hidden_states']},
  'pre_mlp_norm': {'output': ['hidden_states']},
  'mlp': {'output': ['hidden_states']}},
 'norm': {'output': ['hidden_states']},
 'lm_head': {'output': ['logits']}}

In [8]:
hook_config = {
    "model.embed_tokens": ["out_hidden_states", "in_input"],
    "lm_head": ["out_logits", "in_input"],
    "model.norm": ["out_hidden_states", "in_hidden_states"],
    "model.model.layers[0]": ["out_hidden_states", "in_hidden_states"],
    "model.model.layers[0].pre_attn_norm": ["out_hidden_states"],
    "model.model.layers[0].self_attn": ["out_hidden_states", "in_attention_mask"],
    "model.model.layers[0].self_attn.q_proj": ["out_hidden_states"],
    "model.model.layers[0].self_attn.k_proj": ["out_hidden_states"],
    "model.model.layers[0].self_attn.v_proj": ["out_hidden_states"],
    "model.model.layers[0].pre_mlp_norm": ["out_hidden_states"],
    "model.model.layers[0].mlp": ["out_hidden_states"],
}

if hooked_model.model.config.pos_type == "ape":
    hook_config["model.embed_positions"] = ["out_hidden_states", "in_input"]
if hooked_model.model.config.pos_type == "rope":
    hook_config["model.model.layers[0].self_attn"].append("in_position_embeddings")

In [9]:
hooked_model.register_hooks(hook_config)

[{'self': ['out_hidden_states', 'in_hidden_states'], 'pre_attn_norm': ['out_hidden_states'], 'self_attn': ['out_hidden_states', 'in_attention_mask', 'in_position_embeddings'], 'self_attn.q_proj': ['out_hidden_states'], 'self_attn.k_proj': ['out_hidden_states'], 'self_attn.v_proj': ['out_hidden_states'], 'pre_mlp_norm': ['out_hidden_states'], 'mlp': ['out_hidden_states']}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]


In [10]:
prompts = ["Tokyo is the capital of the", "Hello"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

generation_args = {
    "pad_token_id": tokenizer.eos_token_id,
    "do_sample": False,
}

outputs = hooked_model(**inputs, **generation_args)

In [11]:
hooked_model.hooks.layers[0]

_LayerHookResult
	self: Hook
	result: LayerHookResult
		in_hidden_states: torch.Size([2, 8, 2048])
		out_hidden_states: torch.Size([2, 8, 2048])

	pre_attn_norm: Hook
	result: NormHookResult
		out_hidden_states: torch.Size([2, 8, 2048])

	self_attn: Hook
	result: AttnHookResult
		in_position_embeddings: (torch.Size([1, 8, 64]), torch.Size([1, 8, 64]), )
		in_attention_mask: torch.Size([2, 1, 8, 8])
		out_hidden_states: torch.Size([2, 8, 2048])

	pre_mlp_norm: Hook
	result: NormHookResult
		out_hidden_states: torch.Size([2, 8, 2048])

	mlp: Hook
	result: MLPHookResult
		out_hidden_states: torch.Size([2, 8, 2048])

	self_attn_q_proj: Hook
	result: AttnQProjHookResult
		out_hidden_states: torch.Size([2, 8, 2048])

	self_attn_k_proj: Hook
	result: AttnKProjHookResult
		out_hidden_states: torch.Size([2, 8, 512])

	self_attn_v_proj: Hook
	result: AttnVProjHookResult
		out_hidden_states: torch.Size([2, 8, 512])


In [12]:
result = hooked_model.hook_results()
result

CausalLMObservationBatchResult
	embed_tokens: EmbedTokensObservationBatchResult
	embed_positions: None
	layers: LayerObservationBatchResult x 16
	norm: NormObservationBatchResult
	lm_head: LMHeadObservationBatchResult

In [13]:
result.embed_tokens

EmbedTokensObservationBatchResult
	input_tokens: torch.Size([2, 8])
	out_hidden_states: torch.Size([2, 8, 2048])

In [14]:
result.embed_positions

In [15]:
result.layers[0]

LayerObservationBatchResult
	pre_attn_norm_out: torch.Size([2, 8, 2048])
	attention_mask: torch.Size([2, 1, 8, 8])
	position_embeddings: (torch.Size([1, 8, 64]), torch.Size([1, 8, 64]), )
	attn_out_hidden_states: torch.Size([2, 8, 2048])
	q_proj_output: torch.Size([2, 8, 2048])
	k_proj_output: torch.Size([2, 8, 512])
	v_proj_output: torch.Size([2, 8, 512])
	pre_mlp_norm_out: torch.Size([2, 8, 2048])
	mlp_out_hidden_states: torch.Size([2, 8, 2048])
	out_hidden_states: torch.Size([2, 8, 2048])

In [16]:
result.layers[0].attention_mask

tensor([[[[ True, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True,  True,  True]]],


        [[[False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False

In [17]:
hidden_state_reconstructed = (
    result.embed_tokens.out_hidden_states
    + result.layers[0].attn_out_hidden_states
    + result.layers[0].mlp_out_hidden_states
)

if hooked_model.model.config.pos_type == "ape":
    hidden_state_reconstructed += result.embed_positions.out_hidden_states

In [18]:
torch.allclose(
    hidden_state_reconstructed, result.layers[0].out_hidden_states, atol=1e-5, rtol=1e-5
)

True

In [19]:
result.layers[0].position_embeddings[1].shape

torch.Size([1, 8, 64])

In [None]:
def attention(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    head_dim: int,
    attention_mask: torch.Tensor,
    o_proj: torch.nn.Module,
    precompute_ov: bool = False,
    rope: bool = False,
    position_embeddings: torch.Tensor = None,
):
    if rope:
        assert position_embeddings is not None, (
            "Position embeddings must be provided for RoPE"
        )
    shape_q = (*query_states.shape[:-1], -1, head_dim)
    shape_kv = (*key_states.shape[:-1], -1, head_dim)

    query_states = query_states.view(shape_q).transpose(1, 2)
    key_states = key_states.view(shape_kv).transpose(1, 2)

    if rope:
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(
            q=query_states,
            k=key_states,
            sin=sin,
            cos=cos,
        )

    value_states = value_states.view(shape_kv).transpose(1, 2)

    # prepare for gqa.
    # can be done inside spda, but do it here for precompute_ov=True
    batch_size, num_kheads, seq_length, _ = key_states.shape
    _, num_qheads, _, head_dim = query_states.shape
    key_states = repeat_kv(
        key_states,
        n_rep=num_qheads // num_kheads,
    )
    value_states = repeat_kv(
        value_states,
        n_rep=num_qheads // num_kheads,
    )
    if precompute_ov:
        if hasattr(o_proj, "li_weight"):
            weight = o_proj.li_weight
        else:
            weight = o_proj.weight.T
        o_proj_by_head = weight.view(num_qheads, head_dim, -1)
        value_states = torch.einsum("bhsi,hid->bhsd", value_states, o_proj_by_head)

    attn_weighted = torch.nn.functional.scaled_dot_product_attention(
        query=query_states,
        key=key_states,
        value=value_states,
        attn_mask=attention_mask,
    )

    if precompute_ov:
        if hasattr(o_proj, "bias") and o_proj.bias is not None:
            return attn_weighted.sum(dim=1) + o_proj.bias
        return attn_weighted.sum(dim=1)
    attn_weighted = attn_weighted.transpose(1, 2).reshape(batch_size, seq_length, -1)
    attn_weighted = o_proj(attn_weighted)
    return attn_weighted


torch.allclose(
    attention(
        query_states=result.layers[0].q_proj_output,
        key_states=result.layers[0].k_proj_output,
        value_states=result.layers[0].v_proj_output,
        head_dim=hooked_model.model.config.hidden_size
        // hooked_model.model.config.num_attention_heads,
        attention_mask=result.layers[0].attention_mask,
        o_proj=hooked_model.model.model.layers[0].self_attn.o_proj,
        precompute_ov=True,
        rope=hooked_model.model.config.pos_type == "rope",
        position_embeddings=result.layers[0].position_embeddings
        if hooked_model.model.config.pos_type == "rope"
        else None,
    )[:, -1, :],
    result.layers[0].attn_out_hidden_states[:, -1, :],
    atol=1e-5,
    rtol=1e-5,
)

True

In [43]:
(
    attention(
        query_states=result.layers[0].q_proj_output,
        key_states=result.layers[0].k_proj_output,
        value_states=result.layers[0].v_proj_output,
        head_dim=hooked_model.model.config.hidden_size
        // hooked_model.model.config.num_attention_heads,
        attention_mask=result.layers[0].attention_mask,
        o_proj=hooked_model.model.model.layers[0].self_attn.o_proj,
        precompute_ov=True,
        rope=hooked_model.model.config.pos_type == "rope",
        position_embeddings=result.layers[0].position_embeddings
        if hooked_model.model.config.pos_type == "rope"
        else None,
    )[:, -1, :],
)

(tensor([[ 1.7236e-03,  3.3289e-04, -5.5137e-02,  ...,  2.7408e-03,
           2.7539e-03,  1.0591e-02],
         [-1.3816e-05,  7.0120e-04,  9.2162e-03,  ...,  1.8067e-03,
           8.1766e-03,  2.4078e-04]], grad_fn=<SliceBackward0>),)

In [39]:
(result.layers[0].attn_out_hidden_states[:, -1, :],)

(tensor([[ 1.7236e-03,  3.3289e-04, -5.5137e-02,  ...,  2.7408e-03,
           2.7539e-03,  1.0591e-02],
         [-1.3808e-05,  7.0120e-04,  9.2163e-03,  ...,  1.8067e-03,
           8.1766e-03,  2.4078e-04]]),)

In [23]:
def attention_precompute_ov(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    head_dim: int,
    o_proj,
):
    shape_q = (*query_states.shape[:-1], -1, head_dim)
    shape_kv = (*key_states.shape[:-1], -1, head_dim)

    query_states = query_states.view(shape_q).transpose(1, 2)
    key_states = key_states.view(shape_kv).transpose(1, 2)
    value_states = value_states.view(shape_kv).transpose(1, 2)
    batch_size, num_heads, seq_length, _ = query_states.shape

    if hasattr(o_proj, "li_weight"):
        pass
    else:
        pass
    o_proj_by_head = o_proj.weight.view(num_heads, head_dim, -1)
    value_states = torch.einsum("bhsi,hid->bhsd", value_states, o_proj_by_head)

    attn_weighted = torch.nn.functional.scaled_dot_product_attention(
        query=query_states,
        key=key_states,
        value=value_states,
        attn_mask=result.layers[0].attention_mask,
    )
    return attn_weighted.sum(dim=1) + o_proj.bias


torch.allclose(
    attention_precompute_ov(
        query_states=result.layers[0].q_proj_output,
        key_states=result.layers[0].k_proj_output,
        value_states=result.layers[0].v_proj_output,
        head_dim=hooked_model.model.config.n_embd // hooked_model.model.config.n_head,
        o_proj=hooked_model.model.model.layers[0].self_attn.c_proj,
    )[:, -1, :],
    result.layers[0].attn_out_hidden_states[:, -1, :],
    rtol=1e-04,
    atol=1e-04,
)

AttributeError: 'LlamaConfig' object has no attribute 'n_embd'

In [None]:
attention_precompute_ov(
    query_states=result.layers[0].q_proj_output,
    key_states=result.layers[0].k_proj_output,
    value_states=result.layers[0].v_proj_output,
    head_dim=hooked_model.model.config.n_embd // hooked_model.model.config.n_head,
    o_proj=hooked_model.model.model.layers[0].self_attn.c_proj,
)[:, -1, :]

tensor([[-1.2031,  0.0857,  0.1764,  ..., -0.0763,  0.0583, -0.1197],
        [ 0.2686, -0.4273, -0.7308,  ...,  0.0217, -0.0387,  0.1454]],
       grad_fn=<SliceBackward0>)

In [None]:
(result.layers[0].attn_out_hidden_states[:, -1, :],)

(tensor([[-1.2031,  0.0857,  0.1764,  ..., -0.0763,  0.0583, -0.1197],
         [ 0.2686, -0.4273, -0.7308,  ...,  0.0217, -0.0387,  0.1454]]),)

In [None]:
attn_weghted = torch.nn.functional.scaled_dot_product_attention(
    query=result.layers[0].q_proj_output,
    key=result.layers[0].k_proj_output,
    value=result.layers[0].v_proj_output,
)
attn_weghted.shape

torch.Size([2, 7, 768])