In [10]:
import jax
import jax.numpy as jnp
import math

import flax.linen as nn

import sys
sys.path.append('/home/wuhao/md4')


from md4.networks.transformer import Attention, apply_rotary_emb, Dropout1d, precompute_freqs_cis

In [17]:
class FusedHeadAttention(nn.Module):
    """
    Drop‑in replacement for Attention that is *mathematically identical*
    but packs 4 heads → 1 MXU‑friendly 256‑wide head (v6e TPUs).

    Handles the grouped‑/multi‑query case exactly like Llama:
        n_rep   = n_heads // n_kv_heads
        K,V are projected for n_kv_heads and repeated n_rep times.
    
    """
    # ------------------------------ public arguments --------------------------
    dim:          int
    n_heads:      int
    n_kv_heads:   int | None = None      # None ≡ n_heads (full MHA)
    dropout_rate: float = 0.0
    causal:       bool  = False
    qkv_bias:     bool  = False
    group:        int   = 4              # 4×64 = 256

    # ------------------------------ setup -------------------------------------
    def setup(self):
        self._n_kv = self.n_heads if self.n_kv_heads is None else self.n_kv_heads
        assert self.n_heads % self._n_kv == 0, "n_heads must be multiple of n_kv_heads"

        self.n_rep     = self.n_heads // self._n_kv      # replication factor
        self.head_dim  = self.dim // self.n_heads        # 64  (must divide)
        self.mega_dim  = self.group * self.head_dim      # 256

        # Projections: note the shapes match the original implementation
        self.wq = nn.Dense(self.n_heads * self.head_dim,  use_bias=self.qkv_bias)
        self.wk = nn.Dense(self._n_kv   * self.head_dim,  use_bias=self.qkv_bias)
        self.wv = nn.Dense(self._n_kv   * self.head_dim,  use_bias=self.qkv_bias)
        self.wo = nn.Dense(self.dim, use_bias=False)

        if self.dropout_rate:
            self.attn_dropout  = nn.Dropout(self.dropout_rate)
            self.resid_dropout = Dropout1d(self.dropout_rate)

    # ------------------------------ helpers -----------------------------------
    @staticmethod
    def _pack(x, group=4):
        """(B,T,H,64) → (B,T,G,256), where H = G*group."""
        B, T, H, D = x.shape
        G = H // group
        x = x.reshape(B, T, G, group, D)                    # (B,T,G,4,64)

        # gather indices  [[0..63],[64..127],[128..191],[192..255]]
        idx = (jnp.arange(group)[:, None] * D) + jnp.arange(D)   # (4,64)
        buf = jnp.zeros((B, T, G, group * D), dtype=x.dtype)
        buf = buf.at[..., idx].set(x)                       # scatter into zeros
        return buf                                          # (B,T,G,256)

    @staticmethod
    def _unpack(y, group=4):
        """(B,T,G,256) → (B,T,H,64), H = G*group."""
        B, T, G, _ = y.shape
        D = 64
        idx = (jnp.arange(group)[:, None] * D) + jnp.arange(D)
        y = y[..., idx]                                    # gather back
        return y.reshape(B, T, G * group, D)

    # ------------------------------ call --------------------------------------
    def __call__(self, x, freqs_cos, freqs_sin, *, train=False):
        B, T, _ = x.shape

        # 1) Linear projections -------------------------------------------------
        q = self.wq(x).reshape(B, T, self.n_heads, self.head_dim)      # (B,T,H,64)
        k = self.wk(x).reshape(B, T, self._n_kv,  self.head_dim)       # (B,T,H_kv,64)
        v = self.wv(x).reshape(B, T, self._n_kv,  self.head_dim)

        # 2) RoPE ---------------------------------------------------------------
        q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)

        # 3) Repeat K,V across heads if GQA/MQA ---------------------------------
        if self.n_rep > 1:
            k = repeat_kv(k, self.n_rep)   # (B,T,H,64)
            v = repeat_kv(v, self.n_rep)

        # 4) Pack four heads → one 256‑wide mega‑head ---------------------------
        q = self._pack(q, self.group)          # (B,T,G,256)
        k = self._pack(k, self.group)
        v = self._pack(v, self.group)

        # 5) MXU‑friendly dot‑product ------------------------------------------
        q = q.transpose(0, 2, 1, 3)            # (B,G,T,256)
        k = k.transpose(0, 2, 3, 1)            # (B,G,256,T)

        scale = 1.0 / math.sqrt(self.head_dim) # still √64
        attn  = jnp.matmul(q, k) * scale       # (B,G,T,T)

        if self.causal:
            mask = jnp.triu(jnp.full((T, T), -jnp.inf, attn.dtype), k=1)
            attn = attn + mask                 # broadcast over (B,G)

        attn = nn.softmax(attn, axis=-1)
        if self.dropout_rate:
            attn = self.attn_dropout(attn, deterministic=not train)

        # 6) context ------------------------------------------------------------
        v = v.transpose(0, 2, 1, 3)            # (B,G,T,256)
        ctx = jnp.matmul(attn, v)              # (B,G,T,256)

        # 7) Un‑pack and merge heads -------------------------------------------
        ctx = ctx.transpose(0, 2, 1, 3)        # (B,T,G,256)
        ctx = self._unpack(ctx, self.group)    # (B,T,H,64)
        ctx = ctx.reshape(B, T, self.dim)      # (B,T,768)

        # 8) Output projection --------------------------------------------------
        out = self.wo(ctx)
        if self.dropout_rate:
            out = self.resid_dropout(out, deterministic=not train)
        return out


In [19]:
def test_attention_equivalence():
    """Test if Attention and FusedHeadAttention are mathematically equivalent."""
    print("=== Testing Attention Equivalence ===\n")
    
    # Test configuration
    batch_size = 4
    seq_len = 16
    dim = 12 * 64  # 768
    n_heads = 12
    dropout_rate = 0.0  # No dropout for deterministic comparison
    
    # Create models
    attention_ref = Attention(
        dim=dim, 
        n_heads=n_heads, 
        n_kv_heads=n_heads,
        dropout_rate=dropout_rate,
        qkv_bias=False,
        causal=False
    )
    
    attention_fused = FusedHeadAttention(
        dim=dim,
        n_heads=n_heads,
        dropout_rate=dropout_rate,
        causal=False,
        qkv_bias=False,
        group=4
    )
    
    # Create test data
    key = jax.random.PRNGKey(42)
    x = jax.random.normal(key, (batch_size, seq_len, dim))
    
    # Precompute RoPE frequencies
    head_dim = dim // n_heads
    freqs_cos, freqs_sin = precompute_freqs_cis(head_dim, seq_len)
    
    # Initialize parameters
    key1, key2 = jax.random.split(key)
    params_ref = attention_ref.init(key1, x, freqs_cos, freqs_sin, train=False)
    params_fused = attention_fused.init(key2, x, freqs_cos, freqs_sin, train=False)
    
    # Copy weights from reference to fused to ensure they're identical
    params_fused_corrected = {
        'params': {
            'wq': params_ref['params']['wq'],
            'wk': params_ref['params']['wk'], 
            'wv': params_ref['params']['wv'],
            'wo': params_ref['params']['wo']
        }
    }
    
    print("Model configurations:")
    print(f"  Batch size: {batch_size}")
    print(f"  Sequence length: {seq_len}")
    print(f"  Dimension: {dim}")
    print(f"  Number of heads: {n_heads}")
    print(f"  Head dimension: {head_dim}")
    print(f"  Group size (fused): {attention_fused.group}")
    print()
    
    # Test 1: Without RoPE effect (identity rotation)
    print("Test 1: Without RoPE effect (freqs_cos=1, freqs_sin=0)")
    freqs_cos_identity = jnp.ones_like(freqs_cos)
    freqs_sin_identity = jnp.zeros_like(freqs_sin)
    
    # Forward pass
    output_ref = attention_ref.apply(params_ref, x, freqs_cos_identity, freqs_sin_identity, train=False)
    output_fused = attention_fused.apply(
        params_fused_corrected, x, freqs_cos_identity, freqs_sin_identity, train=False
    )
    
    # Compare outputs
    max_diff = jnp.max(jnp.abs(output_ref - output_fused))
    mean_diff = jnp.mean(jnp.abs(output_ref - output_fused))
    relative_error = mean_diff / (jnp.mean(jnp.abs(output_ref)) + 1e-8)
    
    print(f"  Max absolute difference: {max_diff:.2e}")
    print(f"  Mean absolute difference: {mean_diff:.2e}")
    print(f"  Relative error: {relative_error:.2e}")
    print(f"  Outputs are equivalent: {max_diff < 1e-5}")
    print()
    
    # Test 2: With actual RoPE
    print("Test 2: With RoPE applied")
    output_ref_rope = attention_ref.apply(params_ref, x, freqs_cos, freqs_sin, train=False)
    output_fused_rope = attention_fused.apply(
        params_fused_corrected, x, freqs_cos, freqs_sin, train=False
    )
    
    # Compare outputs with RoPE
    max_diff_rope = jnp.max(jnp.abs(output_ref_rope - output_fused_rope))
    mean_diff_rope = jnp.mean(jnp.abs(output_ref_rope - output_fused_rope))
    relative_error_rope = mean_diff_rope / (jnp.mean(jnp.abs(output_ref_rope)) + 1e-8)
    
    print(f"  Max absolute difference: {max_diff_rope:.2e}")
    print(f"  Mean absolute difference: {mean_diff_rope:.2e}")
    print(f"  Relative error: {relative_error_rope:.2e}")
    print(f"  Outputs are equivalent with RoPE: {max_diff_rope < 1e-5}")
    print()
    
    # Test 4: Detailed shape verification
    print("\nTest 4: Internal shape verification")
    
    # Manually test the packing/unpacking functions
    test_tensor = jax.random.normal(key, (2, 8, 12, 64))  # (B, T, H, D)
    packed = _pack_heads(test_tensor, group=4)
    unpacked = _unpack_heads(packed, group=4)
    
    pack_unpack_diff = jnp.max(jnp.abs(test_tensor - unpacked))
    print(f"  Pack/unpack round-trip error: {pack_unpack_diff:.2e}")
    print(f"  Pack/unpack preserves data: {pack_unpack_diff < 1e-10}")
    
    print(f"\n=== Summary ===")
    print(f"✓ Models are mathematically equivalent (no RoPE): {max_diff < 1e-5}")
    print(f"✓ Models are mathematically equivalent (with RoPE): {max_diff_rope < 1e-5}")
    print(f"✓ Pack/unpack functions preserve data: {pack_unpack_diff < 1e-10}")

# Run the equivalence test
test_attention_equivalence()

=== Testing Attention Equivalence ===

xq: (4, 12, 16, 64), xk: (4, 12, 16, 64), xv: (4, 12, 16, 64)
scores: (4, 12, 16, 16)
Model configurations:
  Batch size: 4
  Sequence length: 16
  Dimension: 768
  Number of heads: 12
  Head dimension: 64
  Group size (fused): 4

Test 1: Without RoPE effect (freqs_cos=1, freqs_sin=0)
xq: (4, 12, 16, 64), xk: (4, 12, 16, 64), xv: (4, 12, 16, 64)
scores: (4, 12, 16, 16)
  Max absolute difference: 2.78e+00
  Mean absolute difference: 3.74e-01
  Relative error: 1.31e+00
  Outputs are equivalent: False

Test 2: With RoPE applied
xq: (4, 12, 16, 64), xk: (4, 12, 16, 64), xv: (4, 12, 16, 64)
scores: (4, 12, 16, 16)
  Max absolute difference: 2.15e+00
  Mean absolute difference: 3.66e-01
  Relative error: 1.29e+00
  Outputs are equivalent with RoPE: False


Test 4: Internal shape verification
  Pack/unpack round-trip error: 0.00e+00
  Pack/unpack preserves data: True

=== Summary ===
✓ Models are mathematically equivalent (no RoPE): False
✓ Models are ma