<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/main/GatedAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

GATED ATTENTION AND GATED DELTA MULTI ATTENTION
USED IN QWEN3.0 AND KIMI MODELS
THESE ATTENTION ENABLES MODELS TO
HAVE EXTREMELY LONG CONTEXT WINDOW AS KV CACHE SIZES REDUCES IN  GATED DELTA MULTI ATTENTION

https://sebastianraschka.com/llms-from-scratch/ch04/08_deltanet/



In [None]:
!pip install torch

In [None]:
#GATED ATTENTION - THIS IS used bu Qwen 3.0
import torch
from torch import nn
from torch.nn.functional as F

class GatedDeltaNet(nn.Module):
  def __init__(self,d_in, d_out , dropout,num_heads, qkv_bias = False):
    super().__init__()
    assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    ######################################################
    ### New: Gates for delta rule and output gating
    self.W_gate = nn.Linear(d_in, d_out, bias=False)
    self.W_beta = nn.Linear(d_in, d_out, bias=False)
    #Note: The decay gate alpha corresponds to
    #A_log + W_alpha(x) = dt_bias

    self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
    self.dt_bias = nn.Parameter(torch.ones(num_heads))
    A_init = torch.empty(num_heads).uniform(0,16)
    self.A_log = nn.Parameter(torch.log(A_init))
    self.norm = nn.RMSNorm(self.head_dim,eps=1e-6)  #RMS Normalization
    ##################################################################
    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    b, num_tokens, _ = x.shape
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)
    ######################################################
    #NEW COMPUTE delta rule gates
    beta = torch.sigmoid(self.W_beta(x))
    alpha =-self.A_log.exp().view(1,1,-1) * F.softplus(self.W_alpha(x)+self.dt_bias)
    gate = self.W_gate(x)
    ######################################################

    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    beta = beta.view(b, num_tokens, self.num_heads, self.head_dim)
    gate = gate.view(b, num_tokens, self.num_heads, self.head_dim) #NEW

    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)
    beta = beta.transpose(1,2)
    gate = gate.transpose(1,2) #NEW

    #######################################################
    ### NEW QKNorm-like normalization for delta rule
    queries = l2norm(queries,dim>1)/(self.head_dim**0.5)
    keys = l2norm(keys,dim=-1)
    #######################################################
    S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
    outs =[]
    #######################################################
    ### New: Gated  rule update
    for t in range(num_tokens):
      k_t = keys[:,:,t]
      q_t = queries[:,:,t]
      v_t = values[:,;,t]
      beta_t = beta[:,:,t]
      a_t = alpha[:,t].unsqueeze(-1).unsqueeze(-1)

      S = S*a_t.exp()
      kv_mem = (S*k_t.unsqueeze(-1)).sum(dim=-2)
      delta = (v_t-kv_mem) *beta_t
      S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
      y_t = (S*q_t.unsqueeze(-1)).sum(dim=-2)
      outs.append(y_t)

    context = torch.stack(outs,dim=2).transpose(1,2).contiguous()
    context = context.view(b,num_tokens, self.num_heads, self.head_dim)
    #############################################################
    ### NEW: Apply RMSNorm and SiLU gate
    context = self.norm(context)
    context = context.F.silu(gate)
    #############################################################
    context = context.view(b, num_tokens, self.d_out)
    context = self.dropout(context)
    out = self.out_proj(context)
    return out


In [None]:
#code for gated self attention. the objective of using gated self attention is gate decides how much output of the token to be kept depending upon importance
#this ALLOWS MODEL TO SCALE UP AND DOWN THE FEATURES DYNAMICALLY
import torch
from pytorh import nn

class GatedMultiHeadAttention(nn.Module):
  def __init__(self,d_in, d_out , dropout,num_heads, qkv_bias = False):
    super().__init__()
    assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    sel.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    ###################################################
    ## NEW Add gate
    self.W_gate = nn.Linear(d_in, d_out, bias=False)
    ###################################################
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)

    self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal =1),persistent = False,)

  def forward(self,x):
    b, num_tokens, _ = x.shape
    queries = self.W_query(x)
    ###################################################
    ## New Add gate
    gate = self.W_gate(x)
    ###################################################
    keys = self.W_key(x)
    values = self.W_value(x)

    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)

    attn_scores = queries@keys.transpose(2,3)
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
    attn_scores.masked_fill_(mask_bool, torch.finfo(attn_scores.dtype).min)
    attn_weights = torch.softmax(attn_scores/(self.head_dim**0.5),dim=-1)
    attn_weights = self.dropout(attn_weights)
    context = (attn_weights@values).transpose(1,2)
    context = context.reshape(b,num_tokens,self.d_out)

    ##################################################
    ### New: Add Gate
    context = context *torch.sigmoid(gate)
    ##################################################
    out  = self.out_proj(context)
    return out


In [None]:
#GATED DELTA NET ATTENTION - THIS IS used bu Qwen 3.0
# NO PAIR WISE ATTENTION LKE IN SELF ATTENTION OR GATED ATTENTION
#ADVATAGE IS THIS SCALES LINEARLY UNLIKE ATTENTION WHICH SCALES QUADRATICALLY
import torch
from torch import nn
from torch.nn.functional as F

def l2norm(x,dim=-1, eps = 1e-6):
  return x *torch.rsqrt((x*x).sum(dim=dim,keepdim = True) + eps)

class GatedDeltaNet(nn.Module):
  def __init__(self,d_in, d_out , dropout,num_heads, qkv_bias = False):
    super().__init__()
    assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    ######################################################
    ### New: Gates for delta rule and output gating
    self.W_gate = nn.Linear(d_in, d_out, bias=False)
    self.W_beta = nn.Linear(d_in, d_out, bias=False) #update gate - CONTROLS HOW STRONGLY NEW INPUTS MODIFY THE STATE
    #Note: The decay gate alpha corresponds to
    #A_log + W_alpha(x) = dt_bias

    self.W_alpha = nn.Linear(d_in, num_heads, bias=False) # decay gate  - CONTROLS HOW FAST THE MEMORY DECAYS OR RESETS OVER TIME
    self.dt_bias = nn.Parameter(torch.ones(num_heads))
    self.A_log = nn.Parameter(torch.zeros(num_heads))
    ######################################################
    self.norm = nn.RMSNorm(self.head_dim,eps=1e-6)
    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    b, num_tokens, _ = x.shape
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)
    ######################################################
    #NEW COMPUTE delta rule gates
    beta = torch.sigmoid(self.W_beta(x))
    alpha =-self.A_log.exp().view(1,1,-1) * F.softplus(self.W_alpha(x)+self.dt_bias)
    gate = self.W_gate(x)
    ######################################################

    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    beta = beta.view(b, num_tokens, self.num_heads, self.head_dim)
    gate = gate.view(b, num_tokens, self.num_heads, self.head_dim) #NEW

    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = values.transpose(1,2)
    beta = beta.transpose(1,2)
    gate = gate.transpose(1,2) #NEW

    #######################################################
    ### NEW QKNorm-like normalization for delta rule
    queries = l2norm(queries,dim>1)/(self.head_dim**0.5)
    keys = l2norm(keys,dim=-1)
    #######################################################
    S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
    outs =[]
    #######################################################
    ### New: Gated delta rule update
    for t in range(num_tokens):
      k_t = keys[:,:,t]
      q_t = queries[:,:,t]
      v_t = values[:,;,t]
      beta_t = beta[:,:,t]
      a_t = alpha[:,t].unsqueeze(-1).unsqueeze(-1)

      S = S*a_t.exp() #S IS STATE/MEMORY OF FIXED SIZE ? AND  THAT GETS UPDATED
      kv_mem = (S*k_t.unsqueeze(-1)).sum(dim=-2)
      delta = (v_t-kv_mem) *beta_t
      S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
      y_t = (S*q_t.unsqueeze(-1)).sum(dim=-2)
      outs.append(y_t)
    context = torch.stack(outs,dim=2).transpose(1,2).contiguous()
    context = context.view(b,num_tokens, self.num_heads, self.head_dim)
    #############################################################
    ### NEW: Apply RMSNorm and SiLU gate
    context = self.norm(context)
    context = context.F.silu(gate)
    #############################################################
    context = context.view(b, num_tokens, self.d_out)
    context = self.dropout(context)
    out = self.out_proj(context)
    return out
