In [None]:
## Attention:


In [134]:
import pytest
import torch
from jaxtyping import Float
from torch.testing import assert_close
import torch.nn as nn
from transformer_lens.components import Attention
from transformer_lens.components import LayerNorm
from esm.layers.attention import MultiHeadAttention
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
import functools
import einops
import torch.nn.functional as F

In [135]:
ATOL = 1e-4
def create_multi_head_attention_params(d_model, n_heads, bias=False, qk_layernorm=False):
    params = {
        "layernorm_qkv_weight": torch.rand(d_model),  # Weight of LayerNorm
        "layernorm_qkv_bias": torch.rand(d_model) if bias else None,    # Bias of LayerNorm
        "W_qkv_weight": torch.rand(d_model * 3, d_model),  # Weight of Linear layer
        "W_qkv_bias": torch.rand(d_model * 3) if bias else None,  # Bias of Linear layer
        "out_proj_weight": torch.rand(d_model, d_model),  # Output projection weight
        "out_proj_bias": torch.rand(d_model) if bias else None,  # Output projection bias
    }
    
    if qk_layernorm:
        params.update({
            "q_ln_weight": torch.rand(d_model),
            "q_ln_bias": torch.rand(d_model) if bias else None,
            "k_ln_weight": torch.rand(d_model),
            "k_ln_bias": torch.rand(d_model) if bias else None,
        })
    return params
    

def assign_params_to__esm_attention_layer(layer, params, bias=True):
    with torch.no_grad():
        # Assign LayerNorm for QKV
        layer.layernorm_qkv[0].weight.copy_(params["layernorm_qkv_weight"])
        if bias:
            layer.layernorm_qkv[0].bias.copy_(params["layernorm_qkv_bias"])
        
        # Assign Weights and Bias for QKV Projection
        layer.layernorm_qkv[1].weight.copy_(params["W_qkv_weight"])
        if bias:
            layer.layernorm_qkv[1].bias.copy_(params["W_qkv_bias"])
        
        # Assign Output Projection
        layer.out_proj.weight.copy_(params["out_proj_weight"])
        if bias:
            layer.out_proj.bias.copy_(params["out_proj_bias"])
        
        # Assign LayerNorm for Q
        if isinstance(layer.q_ln, nn.LayerNorm):
            layer.q_ln.weight.copy_(params["q_ln_weight"])
            if bias:
                layer.q_ln.bias.copy_(params["q_ln_bias"])
        
        # Assign LayerNorm for K
        if isinstance(layer.k_ln, nn.LayerNorm):
            layer.k_ln.weight.copy_(params["k_ln_weight"])
            if bias:
                layer.k_ln.bias.copy_(params["k_ln_bias"])

def assign_params_to_transformer_lens_attention_layer(attention_layer, pre_layer_norm, params, cfg, bias=True):
    with torch.no_grad():
        # Assign LayerNorm QKV
        pre_layer_norm.w.copy_(params["layernorm_qkv_weight"])
        if bias and "layernorm_qkv_bias" in params:
            pre_layer_norm.b.copy_(params["layernorm_qkv_bias"])

        # Extract and split QKV weights
        qkv_matrix = params["W_qkv_weight"].clone()  # Shape: (d_model * 3, d_model)
        assert qkv_matrix.shape == (cfg.d_model * 3, cfg.d_model), "QKV weight shape mismatch."

        qkv_reshaped = qkv_matrix.T  # Shape: (d_model, d_model * 3)
        q, k, v = torch.chunk(qkv_reshaped, 3, dim=-1)  # Split into Q, K, V
        
        reshaper = functools.partial(
            einops.rearrange, pattern="d_model (n_head d_head) -> n_head d_model d_head", n_head=cfg.n_heads
        )
        q, k, v = map(reshaper, (q, k, v))
        
        # Copy Q, K, V weights
        attention_layer.W_Q.copy_(q)
        attention_layer.W_K.copy_(k)
        attention_layer.W_V.copy_(v)

        # Handle QKV bias
        if bias and "W_qkv_bias" in params:
            qkv_bias = params["W_qkv_bias"].clone()  # Shape: (d_model * 3)
            b_q, b_k, b_v = torch.chunk(qkv_bias, 3, dim=-1)
            reshaper_bias = functools.partial(
                einops.rearrange, pattern="(n_head d_head) -> n_head d_head", n_head=cfg.n_heads
            )
            attention_layer.b_Q.copy_(reshaper_bias(b_q))
            attention_layer.b_K.copy_(reshaper_bias(b_k))
            attention_layer.b_V.copy_(reshaper_bias(b_v))

        # Assign Output Projection
        out_proj = params["out_proj_weight"].clone()  # Shape: (d_model, d_model)
        assert out_proj.shape == (cfg.d_model, cfg.d_model), "Output projection weight shape mismatch."
        out_proj_reshaped = einops.rearrange(out_proj.T, "(n_head d_head) d_model -> n_head d_head d_model", n_head=cfg.n_heads)
        attention_layer.W_O.copy_(out_proj_reshaped)

        # Assign Output Bias
        if bias and "out_proj_bias" in params:
            attention_layer.b_O.copy_(params["out_proj_bias"])

        # Assign LayerNorms for Q and K if qk_layernorm is enabled
        if cfg.qk_layernorm:
            attention_layer.q_ln.w.copy_(params["q_ln_weight"])
            attention_layer.k_ln.w.copy_(params["k_ln_weight"])
            if bias:
                attention_layer.q_ln.b.copy_(params.get("q_ln_bias", torch.zeros(cfg.d_model)))
                attention_layer.k_ln.b.copy_(params.get("k_ln_bias", torch.zeros(cfg.d_model)))

In [136]:
d_model = 512
n_heads = 8
d_head = d_model // n_heads
bias = False
batch_size = 10
seq_len = 10
qk_layernorm= True

In [137]:
fake_params = create_multi_head_attention_params(d_model, n_heads, bias=bias, qk_layernorm=qk_layernorm)

# Create ESM original attention component
esm_original_component = MultiHeadAttention(d_model, n_heads, bias, qk_layernorm).to(torch.float32)

# Assign the explicit parameters to the model
assign_params_to__esm_attention_layer(esm_original_component, fake_params, bias)

#Now we want to create attention of transformer lens for comparing...

cfg = HookedTransformerConfig(
n_layers=1,           
d_model=d_model,           
n_ctx=20,            
d_head=d_head,                     
n_heads=n_heads,
attention_dir="bidirectional",
init_weights=False,
positional_embedding_type="rotary",
rotary_dim=d_head,
default_prepend_bos=False,
qk_layernorm=qk_layernorm,
dtype=torch.float32,
attn_only=True,
use_attn_result=False
)

#create transformer lens attention and initialize: 
tested_attention_layer = Attention(cfg)
pre_layer_norm = LayerNorm(cfg, d_model)
assign_params_to_transformer_lens_attention_layer(tested_attention_layer, pre_layer_norm,fake_params, cfg, bias)

In [148]:
x= torch.rand((batch_size, seq_len, d_model))
with torch.no_grad():
    layer_norm1= pre_layer_norm(x.clone())
    q1, k1, v1 = tested_attention_layer.calculate_qkv_matrices(layer_norm1, layer_norm1, layer_norm1)
    q1_flattened = einops.rearrange(q1, "batch pos head_index d_head -> batch pos (head_index d_head)")
    k1_flattened = einops.rearrange(k1, "batch pos head_index d_head -> batch pos (head_index d_head)")
    v1_flattened = einops.rearrange(v1, "batch pos head_index d_head -> batch pos (head_index d_head)")
    if cfg.qk_layernorm:
        q1 = tested_attention_layer.q_ln(q1_flattened)
        k1 = tested_attention_layer.k_ln(k1_flattened)
        q1 = einops.rearrange(q1, "batch pos (head_index d_head) -> batch pos head_index d_head", 
                                    head_index=cfg.n_heads, d_head=cfg.d_head)
        k1 = einops.rearrange(k1, "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", 
                        head_index=cfg.n_heads, d_head=cfg.d_head)
    if cfg.positional_embedding_type == "rotary":
        print("enter")
        kv_cache_pos_offset = 0
        q1 = tested_attention_layer.apply_rotary(q1, kv_cache_pos_offset, None)
        k1 = tested_attention_layer.apply_rotary(k1, 0, None)
    q1 = einops.rearrange(q1, "batch pos head_index d_head -> batch head_index pos d_head")
    k1 = einops.rearrange(k1, "batch pos head_index d_head -> batch head_index pos d_head")
    v1 = einops.rearrange(v1, "batch pos head_index d_head -> batch head_index pos d_head")
    # attn_scores = tested_attention_layer.calculate_attention_scores(q1, k1) 
    # pattern = F.softmax(attn_scores, dim=-1)
    # pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
    # pattern = pattern.to(tested_attention_layer.cfg.dtype)
    # pattern = pattern.to(v1.device)
    # z = tested_attention_layer.calculate_z_scores(v1, pattern)  # [batch, pos, head_index, d_head]
    z= F.scaled_dot_product_attention(
                    q1, k1, v1
                )
    if not tested_attention_layer.cfg.use_attn_result:
        w = einops.rearrange(
            tested_attention_layer.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
        )
        print(z.shape)
        # ctx1=z.reshape(z.shape[0], z.shape[1], tested_attention_layer.cfg.d_head * tested_attention_layer.cfg.n_heads)
        ctx1 = einops.rearrange(z, "b h s d -> b s (h d)")
        out1 = F.linear(
           ctx1,
            w,
            tested_attention_layer.b_O,
        )
        print(out1)

enter
torch.Size([10, 8, 10, 64])
tensor([[[ 3.0243e+02,  2.9030e+02,  3.1657e+02,  ...,  3.1700e+02,
           3.0445e+02,  2.9163e+02],
         [ 3.3427e+02,  3.2994e+02,  3.4836e+02,  ...,  3.5302e+02,
           3.3805e+02,  3.2873e+02],
         [ 3.8418e+02,  3.8537e+02,  4.1591e+02,  ...,  4.0084e+02,
           4.0166e+02,  3.8421e+02],
         ...,
         [ 2.9701e+02,  2.8587e+02,  3.0819e+02,  ...,  3.1114e+02,
           3.0583e+02,  2.8690e+02],
         [ 3.1077e+02,  3.0824e+02,  3.3486e+02,  ...,  3.3666e+02,
           3.2267e+02,  3.0997e+02],
         [ 3.3301e+02,  3.2017e+02,  3.5319e+02,  ...,  3.5345e+02,
           3.3967e+02,  3.2283e+02]],

        [[-1.0949e+02, -1.1138e+02, -1.2704e+02,  ..., -1.1981e+02,
          -1.1537e+02, -1.1542e+02],
         [-1.0969e+02, -1.0966e+02, -1.2139e+02,  ..., -1.1902e+02,
          -1.1746e+02, -1.1306e+02],
         [-9.4633e+01, -9.6794e+01, -1.1104e+02,  ..., -1.0608e+02,
          -1.0997e+02, -1.0186e+02],
     

In [149]:
with torch.no_grad():
    qkv=esm_original_component.layernorm_qkv(x.clone())
    q2, k2, v2 = torch.chunk(qkv, 3, dim=-1)
    q2,k2 = (
                esm_original_component.q_ln(q2).to(q2.dtype),
                esm_original_component.k_ln(k2).to(q2.dtype),
            )
    q2, k2 = esm_original_component._apply_rotary(q2, k2)
    n_heads = esm_original_component.n_heads
    reshaper = functools.partial(
        einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
    )
    q2, k2, v2 = map(
        reshaper, (q2, k2, v2)
    )
    
    ctx2 = F.scaled_dot_product_attention(
        q2, k2, v2
    )
    ctx2 = einops.rearrange(ctx2, "b h s d -> b s (h d)")
    out2=esm_original_component.out_proj(ctx2)
    print(out2)

tensor([[[ 3.0243e+02,  2.9030e+02,  3.1657e+02,  ...,  3.1700e+02,
           3.0445e+02,  2.9163e+02],
         [ 3.3427e+02,  3.2994e+02,  3.4836e+02,  ...,  3.5302e+02,
           3.3804e+02,  3.2872e+02],
         [ 3.8418e+02,  3.8536e+02,  4.1591e+02,  ...,  4.0084e+02,
           4.0166e+02,  3.8421e+02],
         ...,
         [ 2.9701e+02,  2.8587e+02,  3.0819e+02,  ...,  3.1114e+02,
           3.0583e+02,  2.8690e+02],
         [ 3.1077e+02,  3.0823e+02,  3.3485e+02,  ...,  3.3666e+02,
           3.2267e+02,  3.0997e+02],
         [ 3.3301e+02,  3.2016e+02,  3.5319e+02,  ...,  3.5345e+02,
           3.3967e+02,  3.2283e+02]],

        [[-1.0948e+02, -1.1138e+02, -1.2704e+02,  ..., -1.1981e+02,
          -1.1537e+02, -1.1542e+02],
         [-1.0969e+02, -1.0966e+02, -1.2139e+02,  ..., -1.1902e+02,
          -1.1746e+02, -1.1306e+02],
         [-9.4632e+01, -9.6793e+01, -1.1104e+02,  ..., -1.0608e+02,
          -1.0997e+02, -1.0186e+02],
         ...,
         [-5.4704e+01, -6

In [153]:
#print(torch.allclose(q1, q2, atol=ATOL))
#print(torch.allclose(k1, k2, atol=ATOL))
#print(torch.allclose(v1_flattened, v2, atol=ATOL))
print(torch.allclose(ctx1,ctx2, atol=ATOL))
print(torch.allclose(w, esm_original_component.out_proj.weight, atol=ATOL))
print(torch.allclose(out1, out2, atol=1e-2))

True
True
True


In [141]:
print(torch.allclose(torch.zeros(10), torch.zeros(10), atol=ATOL))

True


In [142]:
x= torch.rand((batch_size, seq_len, d_model))

layer_norm1= pre_layer_norm(x.clone())
res1=tested_attention_layer.forward(layer_norm1,layer_norm1,layer_norm1)

In [143]:
#seq_id = torch.tensor([[1,1,1,1,0,0,0,0,0,0]])
seq_id = None
res2=esm_original_component.forward(x,seq_id )

In [144]:
res1

tensor([[[ 8.6625e+01,  1.0004e+02,  8.5801e+01,  ...,  8.7451e+01,
           9.1848e+01,  7.8041e+01],
         [ 8.8997e+01,  1.0819e+02,  9.4817e+01,  ...,  8.4767e+01,
           9.4909e+01,  7.5189e+01],
         [ 3.7773e+01,  5.9993e+01,  3.8074e+01,  ...,  3.0051e+01,
           4.2727e+01,  3.2113e+01],
         ...,
         [ 7.2632e+01,  8.7934e+01,  6.9827e+01,  ...,  6.4970e+01,
           7.3000e+01,  6.3227e+01],
         [ 1.4556e+01,  3.9058e+01,  2.0168e+01,  ...,  1.6914e-01,
           2.0829e+01,  1.2688e+01],
         [ 7.2108e+01,  8.9306e+01,  7.1060e+01,  ...,  6.4950e+01,
           9.0310e+01,  5.9327e+01]],

        [[ 6.7271e+01,  8.0738e+01,  9.4497e+01,  ...,  7.3725e+01,
           7.8386e+01,  9.5042e+01],
         [ 4.2411e+01,  5.2446e+01,  5.8948e+01,  ...,  3.9208e+01,
           4.4769e+01,  6.2315e+01],
         [ 2.5081e+01,  3.2241e+01,  3.6895e+01,  ...,  1.4308e+01,
           2.8353e+01,  4.3983e+01],
         ...,
         [ 2.0624e+01,  2

In [145]:
res2

tensor([[[ 8.6623e+01,  1.0004e+02,  8.5799e+01,  ...,  8.7449e+01,
           9.1846e+01,  7.8039e+01],
         [ 8.8996e+01,  1.0819e+02,  9.4816e+01,  ...,  8.4765e+01,
           9.4907e+01,  7.5187e+01],
         [ 3.7771e+01,  5.9991e+01,  3.8072e+01,  ...,  3.0049e+01,
           4.2725e+01,  3.2111e+01],
         ...,
         [ 7.2630e+01,  8.7932e+01,  6.9825e+01,  ...,  6.4968e+01,
           7.2998e+01,  6.3225e+01],
         [ 1.4554e+01,  3.9056e+01,  2.0167e+01,  ...,  1.6744e-01,
           2.0827e+01,  1.2686e+01],
         [ 7.2106e+01,  8.9304e+01,  7.1058e+01,  ...,  6.4948e+01,
           9.0308e+01,  5.9325e+01]],

        [[ 6.7270e+01,  8.0736e+01,  9.4496e+01,  ...,  7.3723e+01,
           7.8385e+01,  9.5040e+01],
         [ 4.2410e+01,  5.2445e+01,  5.8947e+01,  ...,  3.9207e+01,
           4.4767e+01,  6.2313e+01],
         [ 2.5079e+01,  3.2240e+01,  3.6893e+01,  ...,  1.4306e+01,
           2.8351e+01,  4.3981e+01],
         ...,
         [ 2.0622e+01,  2

In [147]:
print(torch.allclose(res1, res2, atol=1e-2))

True
