In [10]:
## Attention:


In [1]:
import pytest
import torch
from jaxtyping import Float
from torch.testing import assert_close
import torch.nn as nn
from transformer_lens.components import Attention
from transformer_lens.components import LayerNorm
from esm.layers.attention import MultiHeadAttention
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.components import ESM3_Hooked_MLP, swiglu_correction_fn
from esm.layers.blocks import swiglu_ln_ffn
import functools
import einops
import torch.nn.functional as F
from transformer_lens.components import ESM3_Hooked_MLP

In [2]:
ATOL = 1e-4
def create_multi_head_attention_params(d_model, n_heads, bias=False, qk_layernorm=False):
    params = {
        "layernorm_qkv_weight": torch.rand(d_model),  # Weight of LayerNorm
        "layernorm_qkv_bias": torch.rand(d_model),    # Bias of LayerNorm
        "W_qkv_weight": torch.rand(d_model * 3, d_model),  # Weight of Linear layer
        "W_qkv_bias": torch.rand(d_model * 3) if bias else None,  # Bias of Linear layer
        "out_proj_weight": torch.rand(d_model, d_model),  # Output projection weight
        "out_proj_bias": torch.rand(d_model) if bias else None,  # Output projection bias
    }
    
    if qk_layernorm:
        params.update({
            "q_ln_weight": torch.rand(d_model),
            "q_ln_bias": torch.rand(d_model) if bias else None,
            "k_ln_weight": torch.rand(d_model),
            "k_ln_bias": torch.rand(d_model) if bias else None,
        })
    return params
def assign_params_to__esm_attention_layer(layer, params, bias=True):
    with torch.no_grad():
        # Assign LayerNorm for QKV
        layer.layernorm_qkv[0].weight.copy_(params["layernorm_qkv_weight"])
        layer.layernorm_qkv[0].bias.copy_(params["layernorm_qkv_bias"])
        
        # Assign Weights and Bias for QKV Projection
        layer.layernorm_qkv[1].weight.copy_(params["W_qkv_weight"])
        if bias:
            layer.layernorm_qkv[1].bias.copy_(params["W_qkv_bias"])
        
        # Assign Output Projection
        layer.out_proj.weight.copy_(params["out_proj_weight"])
        if bias:
            layer.out_proj.bias.copy_(params["out_proj_bias"])
        
        # Assign LayerNorm for Q
        if isinstance(layer.q_ln, nn.LayerNorm):
            layer.q_ln.weight.copy_(params["q_ln_weight"])
            if bias:
                layer.q_ln.bias.copy_(params["q_ln_bias"])
        
        # Assign LayerNorm for K
        if isinstance(layer.k_ln, nn.LayerNorm):
            layer.k_ln.weight.copy_(params["k_ln_weight"])
            if bias:
                layer.k_ln.bias.copy_(params["k_ln_bias"])

def assign_params_to_transformer_lens_attention_layer(attention_layer, pre_layer_norm, params, cfg, bias=True):
    with torch.no_grad():
        # Assign LayerNorm QKV
        pre_layer_norm.w.copy_(params["layernorm_qkv_weight"])
        pre_layer_norm.b.copy_(params["layernorm_qkv_bias"])

        # Extract and split QKV weights
        qkv_matrix = params["W_qkv_weight"].clone()  # Shape: (d_model * 3, d_model)
        assert qkv_matrix.shape == (cfg.d_model * 3, cfg.d_model), "QKV weight shape mismatch."

        qkv_reshaped = qkv_matrix.T  # Shape: (d_model, d_model * 3)
        q, k, v = torch.chunk(qkv_reshaped, 3, dim=-1)  # Split into Q, K, V
        
        reshaper = functools.partial(
            einops.rearrange, pattern="d_model (n_head d_head) -> n_head d_model d_head", n_head=cfg.n_heads
        )
        q, k, v = map(reshaper, (q, k, v))
        
        # Copy Q, K, V weights
        attention_layer.W_Q.copy_(q)
        attention_layer.W_K.copy_(k)
        attention_layer.W_V.copy_(v)

        # Handle QKV bias
        if bias and "W_qkv_bias" in params:
            qkv_bias = params["W_qkv_bias"].clone()  # Shape: (d_model * 3)
            b_q, b_k, b_v = torch.chunk(qkv_bias, 3, dim=-1)
            reshaper_bias = functools.partial(
                einops.rearrange, pattern="(n_head d_head) -> n_head d_head", n_head=cfg.n_heads
            )
            attention_layer.b_Q.copy_(reshaper_bias(b_q))
            attention_layer.b_K.copy_(reshaper_bias(b_k))
            attention_layer.b_V.copy_(reshaper_bias(b_v))

        # Assign Output Projection
        out_proj = params["out_proj_weight"].clone()  # Shape: (d_model, d_model)
        assert out_proj.shape == (cfg.d_model, cfg.d_model), "Output projection weight shape mismatch."
        out_proj_reshaped = einops.rearrange(out_proj.T, "(n_head d_head) d_model -> n_head d_head d_model", n_head=cfg.n_heads)
        attention_layer.W_O.copy_(out_proj_reshaped)

        # Assign Output Bias
        if bias and "out_proj_bias" in params:
            attention_layer.b_O.copy_(params["out_proj_bias"])

        # Assign LayerNorms for Q and K if qk_layernorm is enabled
        if cfg.qk_layernorm:
            attention_layer.q_ln.w.copy_(params["q_ln_weight"])
            attention_layer.k_ln.w.copy_(params["k_ln_weight"])
            if bias:
                attention_layer.q_ln.b.copy_(params.get("q_ln_bias", torch.zeros(cfg.d_model)))
                attention_layer.k_ln.b.copy_(params.get("k_ln_bias", torch.zeros(cfg.d_model)))

In [3]:
d_model = 512
n_heads = 8
d_head = d_model // n_heads
bias = False
batch_size = 1
seq_len = 10
qk_layernorm= True

In [4]:
fake_params = create_multi_head_attention_params(d_model, n_heads, bias=bias, qk_layernorm=qk_layernorm)

# Create ESM original attention component
esm_original_component = MultiHeadAttention(d_model, n_heads, bias, qk_layernorm).to(torch.float32)

# Assign the explicit parameters to the model
assign_params_to__esm_attention_layer(esm_original_component, fake_params, bias)

#Now we want to create attention of transformer lens for comparing...

cfg = HookedTransformerConfig(
n_layers=1,           
d_model=d_model,           
n_ctx=20,            
d_head=d_head,                     
n_heads=n_heads,
attention_dir="bidirectional",
init_weights=False,
positional_embedding_type="rotary",
rotary_dim=d_head,
default_prepend_bos=False,
qk_layernorm=qk_layernorm,
dtype=torch.float32,
attn_only=True,
use_attn_result=False
)

#create transformer lens attention and initialize: 
tested_attention_layer = Attention(cfg)
pre_layer_norm = LayerNorm(cfg, d_model)
assign_params_to_transformer_lens_attention_layer(tested_attention_layer, pre_layer_norm,fake_params, cfg, bias)

In [5]:
x= torch.rand((batch_size, seq_len, d_model))
with torch.no_grad():
    layer_norm1= pre_layer_norm(x.clone())
    q1, k1, v1 = tested_attention_layer.calculate_qkv_matrices(layer_norm1, layer_norm1, layer_norm1)
    q1_flattened = einops.rearrange(q1, "batch pos head_index d_head -> batch pos (head_index d_head)")
    k1_flattened = einops.rearrange(k1, "batch pos head_index d_head -> batch pos (head_index d_head)")
    v1_flattened = einops.rearrange(v1, "batch pos head_index d_head -> batch pos (head_index d_head)")
    if cfg.qk_layernorm:
        q1 = tested_attention_layer.q_ln(q1_flattened)
        k1 = tested_attention_layer.k_ln(k1_flattened)
        q1 = einops.rearrange(q1, "batch pos (head_index d_head) -> batch pos head_index d_head", 
                                    head_index=cfg.n_heads, d_head=cfg.d_head)
        k1 = einops.rearrange(k1, "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", 
                        head_index=cfg.n_heads, d_head=cfg.d_head)
    if cfg.positional_embedding_type == "rotary":
        print("enter")
        kv_cache_pos_offset = 0
        q1 = tested_attention_layer.apply_rotary(q1, kv_cache_pos_offset, None)
        k1 = tested_attention_layer.apply_rotary(k1, 0, None)
    q1 = einops.rearrange(q1, "batch pos head_index d_head -> batch head_index pos d_head")
    k1 = einops.rearrange(k1, "batch pos head_index d_head -> batch head_index pos d_head")
    v1 = einops.rearrange(v1, "batch pos head_index d_head -> batch head_index pos d_head")
    # attn_scores = tested_attention_layer.calculate_attention_scores(q1, k1) 
    # pattern = F.softmax(attn_scores, dim=-1)
    # pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
    # pattern = pattern.to(tested_attention_layer.cfg.dtype)
    # pattern = pattern.to(v1.device)
    # z = tested_attention_layer.calculate_z_scores(v1, pattern)  # [batch, pos, head_index, d_head]
    z= F.scaled_dot_product_attention(
                    q1, k1, v1
                )
    if not tested_attention_layer.cfg.use_attn_result:
        w = einops.rearrange(
            tested_attention_layer.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
        )
        print(z.shape)
        # ctx1=z.reshape(z.shape[0], z.shape[1], tested_attention_layer.cfg.d_head * tested_attention_layer.cfg.n_heads)
        ctx1 = einops.rearrange(z, "b h s d -> b s (h d)")
        out1 = F.linear(
           ctx1,
            w,
            tested_attention_layer.b_O,
        )
        print(out1)

enter
torch.Size([1, 8, 10, 64])
tensor([[[31908.0078, 33322.4922, 34319.7070,  ..., 31437.3828,
          32462.5820, 33189.7734],
         [31918.7949, 33333.3008, 34328.0820,  ..., 31449.6289,
          32476.2598, 33202.3984],
         [31947.8242, 33366.5820, 34361.9141,  ..., 31477.5039,
          32507.8770, 33234.5000],
         ...,
         [31920.2891, 33335.3477, 34329.9609,  ..., 31447.3828,
          32478.0410, 33205.3125],
         [31963.3281, 33383.1719, 34379.3242,  ..., 31495.5664,
          32523.2949, 33248.3828],
         [31946.0703, 33367.8633, 34360.9961,  ..., 31476.6504,
          32503.8379, 33232.1367]]])


In [6]:
with torch.no_grad():
    qkv=esm_original_component.layernorm_qkv(x.clone())
    q2, k2, v2 = torch.chunk(qkv, 3, dim=-1)
    q2,k2 = (
                esm_original_component.q_ln(q2).to(q2.dtype),
                esm_original_component.k_ln(k2).to(q2.dtype),
            )
    q2, k2 = esm_original_component._apply_rotary(q2, k2)
    n_heads = esm_original_component.n_heads
    reshaper = functools.partial(
        einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
    )
    q2, k2, v2 = map(
        reshaper, (q2, k2, v2)
    )
    
    ctx2 = F.scaled_dot_product_attention(
        q2, k2, v2
    )
    ctx2 = einops.rearrange(ctx2, "b h s d -> b s (h d)")
    out2=esm_original_component.out_proj(ctx2)
    print(out2)

tensor([[[31908.0176, 33322.4883, 34319.7148,  ..., 31437.3789,
          32462.5781, 33189.7734],
         [31918.7969, 33333.2969, 34328.0859,  ..., 31449.6289,
          32476.2578, 33202.3984],
         [31947.8281, 33366.5820, 34361.9141,  ..., 31477.5098,
          32507.8750, 33234.4922],
         ...,
         [31920.2910, 33335.3438, 34329.9453,  ..., 31447.3867,
          32478.0410, 33205.3125],
         [31963.3281, 33383.1680, 34379.3359,  ..., 31495.5684,
          32523.2969, 33248.3828],
         [31946.0684, 33367.8789, 34361.0000,  ..., 31476.6484,
          32503.8320, 33232.1406]]])


In [114]:
#print(torch.allclose(q1, q2, atol=ATOL))
#print(torch.allclose(k1, k2, atol=ATOL))
#print(torch.allclose(v1_flattened, v2, atol=ATOL))
print(torch.allclose(ctx1,ctx2, atol=ATOL))
print(torch.allclose(w, esm_original_component.out_proj.weight, atol=ATOL))
print(torch.allclose(out1, out2, atol=1e-4))

True
True
True


In [115]:
print(torch.allclose(torch.zeros(10), torch.zeros(10), atol=ATOL))

True


In [147]:
x= torch.rand((batch_size, seq_len, d_model))

In [152]:
#layer_norm1= pre_layer_norm(x.clone())
layer_norm1= esm_original_component.layernorm_qkv[0](x.clone())
res1=tested_attention_layer.forward(layer_norm1, layer_norm1, layer_norm1)
seq_id = None
res2=esm_original_component.forward(x,seq_id )
print(torch.allclose(res1, res2, atol=1e-10, rtol=1e-6))

True


In [162]:
def create_mlp_params(d_model, expansion_ratio, bias):
    hidden_dim = swiglu_correction_fn(expansion_ratio, d_model)
    params = {
        "layernorm_weight": torch.rand(d_model),
        "layernorm_bias": torch.rand(d_model),
        "l1_weight": torch.rand(hidden_dim * 2, d_model),
        "l1_bias": torch.rand(hidden_dim * 2) if bias else None,
        "l2_weight": torch.rand(d_model, hidden_dim),
        "l2_bias": torch.rand(d_model) if bias else None,
    }
    return params


def assign_params_to_swiglu_mlp(mdl, params, bias):
    with torch.no_grad():
        # Assign LayerNorm parameters
        mdl[0].weight.copy_(params["layernorm_weight"])
        mdl[0].bias.copy_(params["layernorm_bias"])
        # Assign first Linear layer parameters
        mdl[1].weight.copy_(params["l1_weight"])
        if bias:
            mdl[1].bias.copy_(params["l1_bias"])
        # Assign second Linear layer parameters
        mdl[3].weight.copy_(params["l2_weight"])
        if bias:
            mdl[3].bias.copy_(params["l2_bias"])

def assign_params_to_esm_mlp(mdl, params, bias, pre_layer_norm):
    with torch.no_grad():
        # Assign LayerNorm 
        pre_layer_norm.w.copy_(params["layernorm_weight"])
        pre_layer_norm.b.copy_(params["layernorm_bias"])
        # Assign first Linear layer parameters
        mdl.l1.weight.copy_(params["l1_weight"])
        if bias:
            mdl.l1.bias.copy_(params["l1_bias"])
        # Assign second Linear layer parameters
        mdl.l2.weight.copy_(params["l2_weight"])
        if bias:
            mdl.l2.bias.copy_(params["l2_bias"])

def test_compare_esm_and_swiglu_mlp(bias=False, expansion_ratio=4.0):
    d_model = 512
    batch_size = 1
    seq_len = 10
    # Create fake parameters for testing
    fake_params = create_mlp_params(d_model, expansion_ratio, bias)

    # Create the SwiGLU-based MLP
    swiglu_mlp = swiglu_ln_ffn(d_model, expansion_ratio, bias)

    # Assign parameters to SwiGLU MLP
    assign_params_to_swiglu_mlp(swiglu_mlp, fake_params, bias)

    # Create ESM3_Hooked_MLP configuration
    cfg = HookedTransformerConfig(
        n_layers=1,
        d_head=64,
        d_model=d_model,           
        n_ctx=20,                               
        init_weights=False,
        dtype=torch.float32,
        esm3_mlp_expansion_ratio=expansion_ratio,
        act_fn = "swiglu"
    )

    # Create ESM3_Hooked_MLP
    esm_mlp = ESM3_Hooked_MLP(cfg)
    pre_layer_norm = LayerNorm(cfg, d_model)
    # Assign parameters to ESM3_Hooked_MLP
    assign_params_to_esm_mlp(esm_mlp, fake_params, bias, pre_layer_norm)

    # Generate input tensor
    x = torch.rand((batch_size, seq_len, d_model))

    # Forward pass through both MLPs
    with torch.no_grad():
        original_output = swiglu_mlp(x.clone())
        # gto do - add layer norm
        #layer_norm1= pre_layer_norm(x.clone())
        layer_norm1=swiglu_mlp[0](x.clone())
        hooked_output = esm_mlp(layer_norm1)
    # Compare outputs
    assert torch.allclose(original_output, hooked_output, atol=1e-5, rtol=1e-3), "Outputs do not match!"
    diff = torch.abs(original_output - hooked_output)
    print("Maximum absolute difference:", torch.max(diff))
    print("Mean absolute difference:", torch.mean(diff))

In [163]:
test_compare_esm_and_swiglu_mlp()

Maximum absolute difference: tensor(0.)
Mean absolute difference: tensor(0.)


In [164]:
d_model = 512
batch_size = 1
seq_len = 10
expansion_ratio = 4
bias = False
# Create fake parameters for testing
fake_params = create_mlp_params(d_model, expansion_ratio, bias)

# Create the SwiGLU-based MLP
swiglu_mlp = swiglu_ln_ffn(d_model, expansion_ratio, bias)

# Assign parameters to SwiGLU MLP
assign_params_to_swiglu_mlp(swiglu_mlp, fake_params, bias)

# Create ESM3_Hooked_MLP configuration
cfg = HookedTransformerConfig(
    n_layers=1,
    d_head=64,
    d_model=d_model,           
    n_ctx=20,                               
    init_weights=False,
    dtype=torch.float32,
    esm3_mlp_expansion_ratio=expansion_ratio,
    act_fn = "swiglu"
)
esm_mlp = ESM3_Hooked_MLP(cfg)
pre_layer_norm = LayerNorm(cfg, d_model)
# Assign parameters to ESM3_Hooked_MLP
assign_params_to_esm_mlp(esm_mlp, fake_params, bias, pre_layer_norm)

In [165]:
print(torch.sum(esm_mlp.l1.weight != swiglu_mlp[1].weight))
if esm_mlp.l1.bias is not None:
    print(torch.sum(esm_mlp.l1.bias != swiglu_mlp[1].bias))
else:
    print(esm_mlp.l1.bias == swiglu_mlp[1].bias)
print(torch.sum(pre_layer_norm.w != swiglu_mlp[0].weight))
print(torch.sum(pre_layer_norm.b != swiglu_mlp[0].bias))
print(torch.sum(esm_mlp.l2.weight != swiglu_mlp[3].weight))
if esm_mlp.l1.bias is not None:
    print(torch.sum(esm_mlp.l2.bias != swiglu_mlp[3].bias))
else:
    print(esm_mlp.l2.bias == swiglu_mlp[3].bias)

tensor(0)
True
tensor(0)
tensor(0)
tensor(0)
True


In [178]:
x = torch.rand((batch_size, seq_len, d_model))
a= pre_layer_norm(x.clone())
a= esm_mlp(a)
b=swiglu_mlp(x.clone())
assert torch.allclose(a, b, atol=1e-5, rtol=1e-3), "Outputs do not match!"

In [179]:
a

tensor([[[16574672., 16639844., 16665702.,  ..., 16493135., 17035316.,
          16245276.],
         [17376684., 17435892., 17479734.,  ..., 17285440., 17864320.,
          17032906.],
         [18500416., 18574904., 18611504.,  ..., 18419540., 19019404.,
          18134496.],
         ...,
         [18008828., 18080424., 18109598.,  ..., 17913314., 18493892.,
          17644062.],
         [16654212., 16736838., 16763836.,  ..., 16571718., 17121532.,
          16326286.],
         [18564806., 18627442., 18658598.,  ..., 18467940., 19047638.,
          18182764.]]], grad_fn=<UnsafeViewBackward0>)

In [180]:
b

tensor([[[16574667., 16639840., 16665696.,  ..., 16493129., 17035308.,
          16245270.],
         [17376684., 17435892., 17479734.,  ..., 17285440., 17864320.,
          17032906.],
         [18500420., 18574904., 18611504.,  ..., 18419544., 19019404.,
          18134498.],
         ...,
         [18008832., 18080430., 18109604.,  ..., 17913320., 18493896.,
          17644068.],
         [16654214., 16736840., 16763838.,  ..., 16571720., 17121534.,
          16326287.],
         [18564810., 18627446., 18658602.,  ..., 18467942., 19047642.,
          18182768.]]], grad_fn=<UnsafeViewBackward0>)