MLA computation using the steps as mentioned in figure:
![mla](./mla.png)

In [1]:
import torch
import torch.nn.functional as F

# Dimensions
seq_len = 6       # number of tokens
embed_dim = 6     # token embedding dim
d_kv = 4          # intermediate latent dim
d_model = 8       # final Q, K, V dim

# 1. Sample input (embedding matrix for 6 tokens)
X = torch.randn(seq_len, embed_dim)  # shape: (6, 6)

# 2. Define projection weights
Wq   = torch.randn(embed_dim, d_model)    # (6, 8)
Wdkv = torch.randn(embed_dim, d_kv)       # (6, 4)
Wuk  = torch.randn(d_kv, d_model)         # (4, 8)
Wuv  = torch.randn(d_kv, d_model)         # (4, 8)

# 3. KV Caching: latent matrix (Ckv)
Ckv = X @ Wdkv  # shape: (6, 4)

# 4. Projections
Q = X @ Wq         # (6, 8)
K = Ckv @ Wuk      # (6, 8)
V = Ckv @ Wuv      # (6, 8)

# 5. Attention score computation
attn_scores = (Q @ K.T) / (d_model ** 0.5)  # shape: (6, 6)

# Optional: Apply causal mask (prevent attending to future)
mask = torch.tril(torch.ones(seq_len, seq_len))  # lower triangular
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

# 6. Softmax to get attention weights
attn_weights = F.softmax(attn_scores, dim=-1)  # shape: (6, 6)

# 7. Context matrix (output of attention)
context = attn_weights @ V  # shape: (6, 8)

# Print shapes
print("Embedding X:", X.shape)
print("Q:", Q.shape)
print("K:", K.shape)
print("V:", V.shape)
print("Attention Weights:", attn_weights.shape)
print("Context:", context.shape)


Embedding X: torch.Size([6, 6])
Q: torch.Size([6, 8])
K: torch.Size([6, 8])
V: torch.Size([6, 8])
Attention Weights: torch.Size([6, 6])
Context: torch.Size([6, 8])


In [3]:
import torch
import torch.nn.functional as F

# Dimensions
seq_len = 6
embed_dim = 6
d_kv = 4
d_model = 8

# Seed for reproducibility
torch.manual_seed(42)

# Sample input embeddings (X)
X = torch.randn(seq_len, embed_dim)  # (6, 6)

# Projection matrices
Wq   = torch.randn(embed_dim, d_model)   # (6, 8)
Wdkv = torch.randn(embed_dim, d_kv)      # (6, 4)
Wuk  = torch.randn(d_kv, d_model)        # (4, 8)
Wuv  = torch.randn(d_kv, d_model)        # (4, 8)

# ========= Causal Mask =========
mask = torch.tril(torch.ones(seq_len, seq_len))  # (6, 6)

# ========== Original Computation ==========

Q = X @ Wq               # (6, 8)
Ckv = X @ Wdkv           # (6, 4)
K = Ckv @ Wuk            # (6, 8)
V = Ckv @ Wuv            # (6, 8)

attn_scores_orig = (Q @ K.T) / (d_model ** 0.5)
attn_scores_orig = attn_scores_orig.masked_fill(mask == 0, float('-inf'))
attn_weights_orig = F.softmax(attn_scores_orig, dim=-1)
context_orig = attn_weights_orig @ V

# ========== Optimized Computation ==========

# Precompute Wq @ Wuk^T
fused_qk_proj = Wq @ Wuk.T  # (6, 4)

attn_scores_opt = (X @ fused_qk_proj) @ Ckv.T
attn_scores_opt = attn_scores_opt / (d_model ** 0.5)
attn_scores_opt = attn_scores_opt.masked_fill(mask == 0, float('-inf'))
attn_weights_opt = F.softmax(attn_scores_opt, dim=-1)
context_opt = attn_weights_opt @ V

# ========== Comparison ==========

print("Masked Attention score diff (orig vs opt):", torch.norm(attn_scores_orig - attn_scores_opt).item())
print("Masked Context diff (orig vs opt):", torch.norm(context_orig - context_opt).item())


Masked Attention score diff (orig vs opt): nan
Masked Context diff (orig vs opt): 1.1026859283447266e-06


In [13]:
context_orig

tensor([[-0.9969, -9.4262, -2.8623, -2.8189, 13.2914, -7.2141, 11.3604, -1.5934],
        [-3.3808, -1.0989,  0.3626,  1.5451, -1.6427, -6.1897, -1.5837,  2.0744],
        [-3.3803, -1.0985,  0.3625,  1.5446, -1.6424, -6.1883, -1.5834,  2.0739],
        [-3.2030, -1.3739,  0.2528,  1.3568, -1.0646, -6.1085, -1.0487,  1.9079],
        [-0.9969, -9.4262, -2.8623, -2.8189, 13.2914, -7.2141, 11.3604, -1.5934],
        [-0.9964, -9.4248, -2.8617, -2.8185, 13.2895, -7.2130, 11.3591, -1.5931]])

In [14]:
context_opt

tensor([[-0.9969, -9.4262, -2.8623, -2.8189, 13.2914, -7.2141, 11.3604, -1.5934],
        [-3.3808, -1.0989,  0.3626,  1.5451, -1.6427, -6.1897, -1.5837,  2.0744],
        [-3.3803, -1.0985,  0.3625,  1.5446, -1.6424, -6.1883, -1.5834,  2.0739],
        [-3.2030, -1.3739,  0.2528,  1.3568, -1.0646, -6.1085, -1.0487,  1.9079],
        [-0.9969, -9.4262, -2.8623, -2.8189, 13.2914, -7.2141, 11.3604, -1.5934],
        [-0.9964, -9.4248, -2.8617, -2.8185, 13.2895, -7.2130, 11.3591, -1.5931]])

In [None]:
X[:5] 

tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345],
        [-0.0431, -1.6047, -0.7521,  1.6487, -0.3925, -1.4036],
        [-0.7279, -0.5594, -0.7688,  0.7624,  1.6423, -0.1596],
        [-0.4974,  0.4396,  0.3189, -0.4245,  0.3057, -0.7746],
        [ 0.0349,  0.3211,  1.5736, -0.8455, -1.2742,  2.1228]])

Compute the first 5 tokens first, then compute the 6th token

![mla 2](./mla_2.png)

In [30]:
seq_len_old = 5
X_old = X[:seq_len_old]  # Use first 5 tokens for old computation

Q = X_old @ Wq
Ckv = X_old @ Wdkv
K = Ckv @ Wuk
V = Ckv @ Wuv

mask = torch.tril(torch.ones(seq_len_old, seq_len_old)) 

attn_scores_orig = (Q @ K.T) / (d_model ** 0.5)
attn_scores_orig = attn_scores_orig.masked_fill(mask == 0, float('-inf'))
attn_weights_orig = F.softmax(attn_scores_orig, dim=-1)
context_orig_old = attn_weights_orig @ V



In [39]:
X_new = X[seq_len_old:]  # Use remaining tokens for new computation
Ckv_new = X_new @ Wdkv  # (1, 4)

# append new Ckv to existing Ckv
Ckv_combined = torch.cat((Ckv, Ckv_new), dim=0)

attn_score_new = X_new @ fused_qk_proj @ Ckv_combined.T / (d_model ** 0.5)
attn_weights_new = F.softmax(attn_score_new , dim=-1)

V_combined = Ckv_combined @ Wuv  

context_new = attn_weights_new @ V_combined

context_new


tensor([[-0.9964, -9.4248, -2.8617, -2.8185, 13.2895, -7.2130, 11.3591, -1.5931]])