<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/main/Multiheadlatentattention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Multiheadlatentattention(nn.Module):
  def __init__(self,d_model,n_heads,kv_latent_dim):
    super().__init__()
    self.d_model = d_model #input embedding dimension
    self.n_heads = n_heads #no of attention heads
    self.dh = d_model // n_heads #dimension per head

    #Projection layers
    self.W_q = nn.Linear(d_model, d_model, bias=False) #Query Projection
    self.W_dkv = nn.Linear(d_model, kv_latent_dim , bias = False) #Compress into latent KV space
    self.W_uk = nn.Linear(kv_latent_dim, d_model , bias = False) # Decompress K
    self.W_uv = nn.Linear(kv_latent_dim, d_model , bias = False) # Decompress V
    self.W_o = nn.Linear(d_model, d_model, bias = False) #Final Output projection

    self.ln = nn.LayerNorm(kv_latent_dim)
    self.register_buffer('absorbed_k', None) #Holds W_q @ W_uk

  def forward(self,x,kv_cache=None,past_length=0):
      B,S,D = x.size() # batch size , #no of tokens and size of embedding of the token
      #compute absorbed_k once: W_q @ W_uk , shape:(D,latent_dim)
      if self.absorbed_k is None:
        absorbed = torch.matmul(self.W_q.weight, self.W_uk.weight) #(D , latend_dim)
        self.absorbed_k = absorbed.view(self.n_heads,self.dh,-1) #(n_heads,dh,latent_dim)

      #compress x into latent KV space
      new_c_kv = self.ln(self.W_dkv(x)) #(B,S,latent_dim). does two things multiplying and doing layer normalization
      if kv_cache is None:
        c_kv = new_c_kv
      else:
        c_kv = torch.cat([kv_cache,new_c_kv],dim=1) #(B,S_total,latent_dim)

      S_full = c_kv.size(1)
      #decompress V to full d_model and split into heads
      v_full = self.W_uv(c_kv) #(B,S_full,D)
      v = v_full.view(B,S_full,self.n_heads,self.dh).transpose(1,2) #(B,S_full,n_heads,dh)

      #Use input X directly (since W_q is absorbed)
      q = x.view(B,S, self.n_heads, self.dh) #(B,n_heads,S,dh)

      #Compute attention scores
      attn_scores = torch.zeros(B,self.n_heads,S,S_full,device=x.device) #(B,S,n_heads,dh)
      for h in range(self.n_heads): # attention heads
        tmp = torch.matmul(q[:,:,h],self.absorbed_k[h])
        attn_scores[:,h] = torch.bmm(tmp,c_kv.transpose(1,2)) #batch matrix multiplication

      #Scale and apply causal mask
      attn_scores = attn_scores / (self.dh**0.5)
      mask = torch.tril(torch.ones((S,S_full),device = x.device), diagonal = past_length)
      attn_scores = attn_scores.masked_fill(mask.view(1,1,S,S_full) == 0, float('-inf'))

      #Softmax to get attention weights
      attn_weights = F.softmax(attn_scores,dim=-1) #(B, n_heads ,S,S_full)

      #Apply attention weights to each head's V separately
      out_heads = []
      for h in range(self.n_heads):
        context_h = torch.matmul(attn_weights[:,h],v[:,h]) #(B,S,dh)
        out_heads.append(context_h)

      #concatenate all head outputs along the feature dimension
      out = torch.cat(out_heads,dim=-1) #(B,S,D)
      return self.W_o(out), c_kv #Final output projection + updated latent cache

In [2]:
#Step 2:Memory testing
def demo():
  model = Multiheadlatentattention(d_model=512, n_heads=8, kv_latent_dim=256)
  x = torch.randn(1,5,512) #Batch = 2, Sequence  = 10 , d_model = 512
  out , cache = model(x)
  print(f"Output shape: {out.shape}")
  print(f"Cache shape: {cache.shape}")
  #memory comparison
  std_size = 2*2*10*512*4/1024 #KB
  latent_size = 1*2*10*256*4/1024 #KB
  print(f"Memory: Standard = {std_size:.1f}KB , Latent = {latent_size:.1f}KB , Reduction = {std_size/latent_size:.1f}x")

if __name__ == "__main__":
  demo()


Output shape: torch.Size([1, 5, 512])
Cache shape: torch.Size([1, 5, 256])
Memory: Standard = 80.0KB , Latent = 20.0KB , Reduction = 4.0x
