## Sliding window attention

In [1]:
import torch
import einops
import torch.nn as nn
import torch.nn.functional as F

In [2]:
batch_size=12
d_model=56
seq_len=32
n_heads=4
d_head=14

window_size=5

x=torch.rand(batch_size,seq_len,d_model)
W_q=nn.Linear(in_features=d_model,out_features=d_model)
W_k=nn.Linear(in_features=d_model,out_features=d_model)
W_v=nn.Linear(in_features=d_model,out_features=d_model)

q=W_q(x).view(batch_size,seq_len,n_heads,d_head).transpose(1,2)
k=W_k(x).view(batch_size,seq_len,n_heads,d_head).transpose(1,2)
v=W_v(x).view(batch_size,seq_len,n_heads,d_head).transpose(1,2)
assert q.shape==k.shape
#print(q.shape) # batch_size, n_heads, seq_len, d_head
half_window=window_size//2
pad_k=F.pad(k,(0,0,half_window,half_window))
pad_v=F.pad(v,(0,0,half_window,half_window))
#print(pad_k.shape) # batch_size, n_heads, seq_len+2*half_window, d_head

k_unf=pad_k.unfold(dimension=2,size=window_size,step=1).transpose(3,4) #(batch_size,n_heads,seq_len,window_size,d_head)
v_unf=pad_v.unfold(dimension=2,size=window_size,step=1).transpose(3,4)#(batch_size,n_heads,seq_len,window_size,d_head)
#print('k_unf',k_unf.shape)
#print('q',q.shape)
q=q.unsqueeze(-2)
#print('q',q.shape)

attn_scores=einops.einsum(q,k_unf,'b h s w d, b h s w d -> b h s w ')
#print(attn_scores.shape) # batch_size, n_heads, seq_len, window_size
mask=torch.tril(torch.ones(window_size,window_size))
mask=mask[-1]
mask=mask.view(1,1,1,window_size).expand(batch_size,n_heads,seq_len,window_size)
print(mask.shape)
print(attn_scores.shape)
attn_scores=attn_scores.masked_fill(mask==0,-float('inf'))
attn_scores=attn_scores/(d_head**0.5)

attn_weights=F.softmax(attn_scores,dim=-1)

attn_output=einops.einsum(attn_weights,v_unf,'b h s w, b h s w d ->b h s d')
print(attn_output.shape)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, n_heads * d_head)




  from .autonotebook import tqdm as notebook_tqdm


torch.Size([12, 4, 32, 5])
torch.Size([12, 4, 32, 5])
torch.Size([12, 4, 32, 14])


## Grouped Query Head

In [3]:
x=torch.rand(12,16,112)
n_heads=8
n_kv_heads=2
d_head=14
d_model=112
batch_size=12
seq_len=16
#each 8//2=4 query group shares 1 kv head

W_q=nn.Linear(d_model,n_heads*d_head)
W_k=nn.Linear(d_model,n_kv_heads*d_head)
W_v=nn.Linear(d_model,n_kv_heads*d_head)
q=W_q(x).view(batch_size,seq_len,n_heads,d_head)
k=W_k(x).view(batch_size,seq_len,n_kv_heads, d_head)
v=W_v(x).view(batch_size,seq_len,n_kv_heads,d_head)


n=n_heads//n_kv_heads
k=torch.repeat_interleave(k,dim=2,repeats=n) #  batch,seq_len,n_heads,d_head
v=torch.repeat_interleave(v,dim=2,repeats=n)#  batch,seq_len,n_heads,d_head

#transpose


## KV cache

In [4]:
self.cache_k=None
self.cache_v=None

q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

q=q.transpose(1,2)
k=k.transpose(1,2)
v=v.transpose(1,2)

if is_training:
    attn_scores=torch.matmul(q,k.transpose(-1,-2)) / (self.head_dim**0.5)
    attn_probs=nn.Softmax(dim=-1)(attn_scores)
    output = torch.matmul(attn_probs, v)  # [batch_size, num_heads, seq_len, head_dim]
else:
    if self.cache_k is not None and self.cache_v is not None:
        k=torch.cat((self.cache_k,k),dim=2)
        v=torch.cat((self.cache_v,v),dim=2)
    
    attn_scores=torch.matmul(q,k.transpose(-1,-2)) / (self.head_dim**0.5)
    attn_probs=nn.Softmax(dim=-1)(attn_scores)
    output = torch.matmul(attn_probs, v)  # [batch_size, num_heads, seq_len, head_dim]
    self.cache_k=k
    self.cache_v=v
def reset_cache(self):
    """Clear cached keys and values after each sequence generation (for inference)."""
    self.cache_k = None
    self.cache_v = None

#model.reset_cache() at the beginning of inference   and under range(batch_size) 

NameError: name 'self' is not defined

In [None]:
class AttentionWithKVCache(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int = 2048):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.max_seq_len = max_seq_len

        # Projection layers
        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)
        self.W_o = nn.Linear(dim, dim)

        # Initialize KV cache
        self.register_buffer('cache_k', torch.zeros(
            (1, max_seq_len, num_heads, self.head_dim)  # batch=1 for simplicity
        ))
        self.register_buffer('cache_v', torch.zeros(
            (1, max_seq_len, num_heads, self.head_dim)
        ))

    def forward(self, x: torch.Tensor, start_pos: int = 0, is_training: bool = False):
        batch_size, seq_len, _ = x.shape
        
        # Project queries, keys, values
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        if is_training:
            # Training mode - full attention
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            attn_probs = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_probs, v)
        else:
            # Inference mode - use KV cache
            # Update cache
            self.cache_k[:batch_size, start_pos:start_pos+seq_len] = k
            self.cache_v[:batch_size, start_pos:start_pos+seq_len] = v
            
            # Get keys/values up to current position
            keys = self.cache_k[:batch_size, :start_pos+seq_len]
            values = self.cache_v[:batch_size, :start_pos+seq_len]
            
            # Compute attention
            attn_scores = torch.matmul(q, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
            
            # Apply causal mask for new tokens
            if seq_len > 1:  # Only needed when processing multiple new tokens
                mask = torch.ones((seq_len, start_pos+seq_len), dtype=torch.bool, device=x.device)
                mask = torch.tril(mask, diagonal=start_pos)
                attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
            
            attn_probs = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_probs, values)

        # Output projection
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.W_o(output)

    def reset_cache(self):
        """Reset the KV cache between sequences"""
        self.cache_k.zero_()
        self.cache_v.zero_()

## Sliding Window, GQA and  KV cache

In [None]:
import math
class AttentionWithKVCache(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_seq_len: int = 2048, num_kv_heads:int=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.num_kv_heads=num_kv_heads
        self.max_seq_len = max_seq_len
        self.repeats=num_heads//num_kv_heads
        
        
        # Projection layers
        self.W_q = nn.Linear(dim,self.num_heads*self.head_dim)
        self.W_k = nn.Linear(dim, self.num_kv_heads*self.head_dim)
        self.W_v = nn.Linear(dim, self.num_kv_heads*self.head_dim)
        self.W_o = nn.Linear(dim, dim)

        # Initialize KV cache
        # self.register_buffer('cache_k', torch.zeros(
        #     (1, max_seq_len, self.num_kv_heads, self.head_dim)  # batch=1 for simplicity
        # ))
        # self.register_buffer('cache_v', torch.zeros(
        #     (1, max_seq_len, self.num_kv_heads, self.head_dim)
        # ))
        
        self.cache_k=None
        self.cache_v=None

    def forward(self, x: torch.Tensor, start_pos: int = 0):
        batch_size, seq_len, _ = x.shape
        
        # Project queries, keys, values
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)


        if self.training:
            # Training mode - full attention
            q=q.transpose(1,2)
            k=k.transpose(1,2)
            v=v.transpose(1,2)
            k=torch.repeat_interleave(k,dim=1,repeats=self.repeats) #batch_size,num_heads,seq_len,head_dim
            v=torch.repeat_interleave(v,dim=1,repeats=self.repeats)#batch_size,num_heads,seq_len,head_dim
           
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) #batch_size,num_heads,seq_len,seq_len
            
            mask = torch.tril(torch.ones((seq_len, seq_len), device=x.device)).unsqueeze(0).unsqueeze(0)  # (1,1,seq_len,seq_len)
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
            attn_probs = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_probs, v)
        else:
            # Inference mode - use KV cache
            # Update cache
            if self.cache_k is None or self.cache_v is None:
                self.cache_k=torch.zeros((batch_size,self.max_seq_len,self.num_kv_heads,self.head_dim))
                self.cache_v=torch.zeros((batch_size,self.max_seq_len,self.num_kv_heads,self.head_dim))
            print(k.shape)
            print(self.cache_k.shape)
            self.cache_k[:batch_size, start_pos:start_pos+seq_len] = k
            self.cache_v[:batch_size, start_pos:start_pos+seq_len] = v
            
            # Get keys/values up to current position

            keys = self.cache_k[:batch_size, :start_pos+seq_len]
            values = self.cache_v[:batch_size, :start_pos+seq_len]
            
            q=q.transpose(1,2)
            keys=keys.transpose(1,2)
            values=values.transpose(1,2)            # Compute attention
            keys=keys.repeat_interleave(self.repeats,dim=1)
            values=values.repeat_interleave(self.repeats,dim=1)
            print(keys.shape)
            print(q.shape)
            attn_scores = torch.matmul(q, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
            
            # Apply causal mask for new tokens
            if seq_len > 1:  # Only needed when processing multiple new tokens
                mask = torch.ones((seq_len, start_pos+seq_len), dtype=torch.bool, device=x.device)
                mask = torch.tril(mask, diagonal=start_pos)
                attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
            
            attn_probs = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_probs, values)

        # Output projection
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.W_o(output)

    def reset_cache(self):
        """Reset the KV cache between sequences"""
        if self.cache_k is not None and self.cache_v is not None:
            self.cache_k.zero()
            self.cache_v.zero()

In [59]:
def test_attention_with_kv_cache():
    torch.manual_seed(42)  # for reproducibility

    # Model params
    dim = 32
    num_heads = 8
    num_kv_heads = 2
    max_seq_len = 16

    # Make model
    model = AttentionWithKVCache(dim, num_heads, max_seq_len, num_kv_heads)

    # ---------------- Training mode ----------------
    model.train()
    x_train = torch.randn(2, 10, dim)  # (batch_size=2, seq_len=10, dim)

    out_train = model(x_train)
    print("Training output shape:", out_train.shape)

    assert out_train.shape == (2, 10, dim), "Training output has wrong shape!"

    # ---------------- Inference mode ----------------
    model.eval()
    model.reset_cache()

    x_infer = torch.randn(2, 5, dim)  # (batch_size=2, seq_len=5, dim)
    out_infer = model(x_infer, start_pos=0)
    print("Inference output shape:", out_infer.shape)

    assert out_infer.shape == (2, 5, dim), "Inference output has wrong shape!"

    print("\n✅ Test passed: Training and Inference outputs are correct!")

test_attention_with_kv_cache()


Training output shape: torch.Size([2, 10, 32])
torch.Size([2, 5, 2, 4])
torch.Size([2, 16, 2, 4])
torch.Size([2, 8, 5, 4])
torch.Size([2, 8, 5, 4])
Inference output shape: torch.Size([2, 5, 32])

✅ Test passed: Training and Inference outputs are correct!


## Adding sliding window to gqa,kv cache

In [5]:
device="mps"

In [6]:
import math
import torch.nn.functional as F
class AttentionWithKVCache(nn.Module):
    def __init__(self, dim: int, num_heads: int, window_size: int, max_seq_len: int = 2048, num_kv_heads:int=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.num_kv_heads=num_kv_heads
        self.max_seq_len = max_seq_len
        self.repeats=num_heads//num_kv_heads
        self.window_size=window_size
        self.half_window=self.window_size//2
        
        # Projection layers
        self.W_q = nn.Linear(dim,self.num_heads*self.head_dim)
        self.W_k = nn.Linear(dim, self.num_kv_heads*self.head_dim)
        self.W_v = nn.Linear(dim, self.num_kv_heads*self.head_dim)
        self.W_o = nn.Linear(dim, dim)

        # Initialize KV cache
        # self.register_buffer('cache_k', torch.zeros(
        #     (1, max_seq_len, self.num_kv_heads, self.head_dim)  # batch=1 for simplicity
        # ))
        # self.register_buffer('cache_v', torch.zeros(
        #     (1, max_seq_len, self.num_kv_heads, self.head_dim)
        # ))
        
        self.cache_k=None
        self.cache_v=None

    def forward(self, x: torch.Tensor, start_pos: int = 0):
        batch_size, seq_len, _ = x.shape
        
        # Project queries, keys, values
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)


        if self.training:
            # Training mode - full attention
            q=q.transpose(1,2)
            k=k.transpose(1,2)
            v=v.transpose(1,2)
            k=torch.repeat_interleave(k,dim=1,repeats=self.repeats) #batch_size,num_heads,seq_len,head_dim
            v=torch.repeat_interleave(v,dim=1,repeats=self.repeats)#batch_size,num_heads,seq_len,head_dim
           
            pad_k=F.pad(k,(0,0,self.half_window,self.half_window))
            pad_v=F.pad(v,(0,0,self.half_window,self.half_window))

            k_unf=pad_k.unfold(dimension=2,size=self.window_size,step=1).transpose(3,4) #(batch_size,num_heads,seq_len,self.window_size,d_head)
            v_unf=pad_v.unfold(dimension=2,size=self.window_size,step=1).transpose(3,4) #(batch_size,num_heads,seq_len,self.window_size,d_head)
            q=q.unsqueeze(-2)
            
            attn_scores=einops.einsum(q,k_unf,'b h s w d, b h s w d -> b h s w ')
            
            mask=torch.tril(torch.ones(self.window_size,self.window_size,device=device))
            mask=mask[-1]
            mask=mask.view(1,1,1,self.window_size).expand(batch_size,self.num_heads,seq_len,self.window_size)
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
            
            attn_weights = F.softmax(attn_scores, dim=-1)
            output=einops.einsum(attn_weights,v_unf,'b h s w, b h s w d ->b h s d')
            
        else:
            # Inference mode - use KV cache
            # Update cache
            if self.cache_k is None or self.cache_v is None:
                self.cache_k=torch.zeros((batch_size,self.max_seq_len,self.num_kv_heads,self.head_dim))
                self.cache_v=torch.zeros((batch_size,self.max_seq_len,self.num_kv_heads,self.head_dim))
            print(k.shape)
            print(self.cache_k.shape)
            self.cache_k[:batch_size, start_pos:start_pos+seq_len] = k
            self.cache_v[:batch_size, start_pos:start_pos+seq_len] = v
            
            # Get keys/values up to current position

            keys = self.cache_k[:batch_size, :start_pos+seq_len]
            values = self.cache_v[:batch_size, :start_pos+seq_len]
            
            q=q.transpose(1,2)
            keys=keys.transpose(1,2)
            values=values.transpose(1,2)            # Compute attention
            keys=keys.repeat_interleave(self.repeats,dim=1)
            values=values.repeat_interleave(self.repeats,dim=1)
            print(keys.shape)
            print(q.shape)
            attn_scores = torch.matmul(q, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
            
            # Apply causal mask for new tokens
            if seq_len > 1:  # Only needed when processing multiple new tokens
                mask = torch.ones((seq_len, start_pos+seq_len), dtype=torch.bool, device=x.device)
                mask = torch.tril(mask, diagonal=start_pos)
                attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
            
            attn_probs = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_probs, values)

        # Output projection
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.W_o(output)

    def reset_cache(self):
        """Reset the KV cache between sequences"""
        if self.cache_k is not None and self.cache_v is not None:
            self.cache_k.zero()
            self.cache_v.zero()

In [7]:
import torch

batch_size = 2
seq_len = 10
dim = 32
num_heads = 4
num_kv_heads = 2
window_size = 5

model = AttentionWithKVCache(dim=dim, num_heads=num_heads, num_kv_heads=num_kv_heads, window_size=window_size)
model.train()
model=model.to(device)
x = torch.randn(batch_size, seq_len, dim)
x=x.to(device)
out = model(x)

print(out.shape)
# Should print: (2, 10, 32)

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([2, 10, 32])


## Adding Rolling buffer cache to gqa,sliding window,kv cache

In [12]:
class RollingBufferCache:
    def __init__(self,max_size:int,dim:int):
        self.max_size=max_size
        self.dim=dim
        self.cache=[]
        
    
    def insert(self,new_data):
        if len(self.cache)>=self.max_size:
            self.cache.pop(0)
        self.cache.append(new_data)
        #self.index=(self.index+1) % self.max_size
    
    def get_cache(self):
        return self.cache


max_size=5
dim=3
cache=RollingBufferCache(max_size,dim)

for i in range(7):
    data=torch.randn(dim)
    cache.insert(data)
    print(f"After inserting data {i+1}:")
    print(cache.get_cache())
        

After inserting data 1:
[tensor([ 1.5755, -2.1257, -1.3995])]
After inserting data 2:
[tensor([ 1.5755, -2.1257, -1.3995]), tensor([ 0.2451, -1.1397,  0.7424])]
After inserting data 3:
[tensor([ 1.5755, -2.1257, -1.3995]), tensor([ 0.2451, -1.1397,  0.7424]), tensor([-1.8523, -0.0405, -0.0153])]
After inserting data 4:
[tensor([ 1.5755, -2.1257, -1.3995]), tensor([ 0.2451, -1.1397,  0.7424]), tensor([-1.8523, -0.0405, -0.0153]), tensor([-0.4123, -1.3378, -1.3015])]
After inserting data 5:
[tensor([ 1.5755, -2.1257, -1.3995]), tensor([ 0.2451, -1.1397,  0.7424]), tensor([-1.8523, -0.0405, -0.0153]), tensor([-0.4123, -1.3378, -1.3015]), tensor([0.6482, 1.3867, 0.8134])]
After inserting data 6:
[tensor([ 0.2451, -1.1397,  0.7424]), tensor([-1.8523, -0.0405, -0.0153]), tensor([-0.4123, -1.3378, -1.3015]), tensor([0.6482, 1.3867, 0.8134]), tensor([ 1.0984, -0.2068, -0.2628])]
After inserting data 7:
[tensor([-1.8523, -0.0405, -0.0153]), tensor([-0.4123, -1.3378, -1.3015]), tensor([0.6482, 1

In [None]:
import math
import torch.nn.functional as F
class AttentionWithKVCache(nn.Module):
    def __init__(self, dim: int, num_heads: int, window_size: int, max_seq_len: int = 2048, num_kv_heads:int=2):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.num_kv_heads=num_kv_heads
        self.max_seq_len = max_seq_len
        self.repeats=num_heads//num_kv_heads
        self.window_size=window_size
        self.half_window=self.window_size//2
        
        # Projection layers
        self.W_q = nn.Linear(dim,self.num_heads*self.head_dim)
        self.W_k = nn.Linear(dim, self.num_kv_heads*self.head_dim)
        self.W_v = nn.Linear(dim, self.num_kv_heads*self.head_dim)
        self.W_o = nn.Linear(dim, dim)

        # Initialize KV cache
        # self.register_buffer('cache_k', torch.zeros(
        #     (1, max_seq_len, self.num_kv_heads, self.head_dim)  # batch=1 for simplicity
        # ))
        # self.register_buffer('cache_v', torch.zeros(
        #     (1, max_seq_len, self.num_kv_heads, self.head_dim)
        # ))
        self.register_buffer('cache_k', torch.zeros((max_seq_len, self.num_kv_heads, self.head_dim)))
        self.register_buffer('cache_v', torch.zeros((max_seq_len, self.num_kv_heads, self.head_dim)))
        
        # self.cache_k=None
        # self.cache_v=None
        self.cache_pos=0
    
    def update_cache(self,seq_len,k,v):
        seq_len=k.size(1)
        
        if self.cache_pos + seq_len > self.max_seq_len: #check if cache has enough space
            #roll the cache to make space
            roll_amount=seq_len
            self.cache_k=torch.roll(self.cache_k,shifts=-roll_amount,dim=0)
            self.cache_v=torch.roll(self.cache_v,shifts=-roll_amount,dim=0)
            self.cache_pos-=roll_amount
        
        self.cache_k[self.cache_pos:self.cache_pos+seq_len]=k.squeeze(0)
        self.cache_v[self.cache_pos:self.cache_pos+seq_len]=v.squeeze(0)
        self.cache_pos+=seq_len

    def forward(self, x: torch.Tensor, start_pos: int = 0):
        batch_size, seq_len, _ = x.shape
        
        # Project queries, keys, values
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)


        if self.training:
            # Training mode - full attention
            q=q.transpose(1,2)
            k=k.transpose(1,2)
            v=v.transpose(1,2)
            k=torch.repeat_interleave(k,dim=1,repeats=self.repeats) #batch_size,num_heads,seq_len,head_dim
            v=torch.repeat_interleave(v,dim=1,repeats=self.repeats)#batch_size,num_heads,seq_len,head_dim
           
            pad_k=F.pad(k,(0,0,self.half_window,self.half_window))
            pad_v=F.pad(v,(0,0,self.half_window,self.half_window))

            k_unf=pad_k.unfold(dimension=2,size=self.window_size,step=1).transpose(3,4) #(batch_size,num_heads,seq_len,self.window_size,d_head)
            v_unf=pad_v.unfold(dimension=2,size=self.window_size,step=1).transpose(3,4) #(batch_size,num_heads,seq_len,self.window_size,d_head)
            q=q.unsqueeze(-2)
            
            attn_scores=einops.einsum(q,k_unf,'b h s w d, b h s w d -> b h s w ')
            
            mask=torch.tril(torch.ones(self.window_size,self.window_size,device=device))
            mask=mask[-1]
            mask=mask.view(1,1,1,self.window_size).expand(batch_size,self.num_heads,seq_len,self.window_size)
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
            
            attn_weights = F.softmax(attn_scores, dim=-1)
            output=einops.einsum(attn_weights,v_unf,'b h s w, b h s w d ->b h s d')
            
        else:
            #batch_size must be 1 for inference
            assert batch_size==1, "batch size must be 1"
            # Inference mode - use KV cache
            # Update cache
            self.update_cache(seq_len,k,v)
            current_len=min(self.cache_pos,self.max_seq_len)
            valid_cache_len=min(current_len,self.window_size)
            start_window=max(0,current_len-seq_len-self.half_window)
                
            if self.cache_k is None or self.cache_v is None:
                self.cache_k=torch.zeros((self.max_seq_len,self.num_kv_heads,self.head_dim))
                self.cache_v=torch.zeros((self.max_seq_len,self.num_kv_heads,self.head_dim))
            
            cached_k=self.cache_k[start_window:current_len].unsqueeze(0)
            cached_v=self.cache_v[start_window:current_len].unsqueeze(0)
            
 
            
            print(k.shape)
            print(self.cache_k.shape)
            

            q=q.transpose(1,2)
            cached_k=cached_k.transpose(1,2)
            cached_v=cached_v.transpose(1,2)            # Compute attention
            cached_k=cached_k.repeat_interleave(self.repeats,dim=1)
            cached_v=cached_v.repeat_interleave(self.repeats,dim=1)
           
            attn_scores = torch.matmul(q, cached_k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            
            # Apply causal mask for new tokens
        
            mask = torch.ones((seq_len, valid_cache_len), dtype=torch.bool, device=x.device)
            mask = torch.tril(mask, diagonal=valid_cache_len - seq_len)

            # Expand mask to (batch_size, num_heads, seq_len, valid_cache_len)
            mask = mask.unsqueeze(0).unsqueeze(0)  # (1,1,seq_len,valid_cache_len)
            mask = mask.expand(batch_size, model.num_heads, seq_len, valid_cache_len)

            attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
            attn_probs = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_probs, cached_v)

        # Output projection
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.W_o(output)

    def reset_cache(self):
        """Reset the KV cache between sequences"""
        self.cache_k.zero_()
        self.cache_v.zero_()
        self.cache_pos = 0

In [33]:
import torch
import torch.nn as nn


def test_attention_inference_loop():
    # Setup parameters
    dim = 16
    num_heads = 4
    num_kv_heads = 2
    window_size = 6
    max_seq_len = 20

    # Create the model
    model = AttentionWithKVCache(
        dim=dim,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        window_size=window_size,
        max_seq_len=max_seq_len
    )

    model.eval()  # Important! Set to inference mode
    model.reset_cache()
    batch_size = 1
    seq_len = 3
    input_tensor = torch.randn(batch_size, seq_len, dim)

    # Save old cache state
    old_cache_pos = model.cache_pos

    # Forward pass
    output = model(input_tensor)

    # Check output shape
    assert output.shape == (batch_size, seq_len, dim), "Output shape mismatch"

    # Check cache update
    assert model.cache_pos == old_cache_pos + seq_len, "Cache position not updated correctly"
    assert torch.all(model.cache_k[old_cache_pos:old_cache_pos+seq_len] != 0), "Keys not updated in cache"
    assert torch.all(model.cache_v[old_cache_pos:old_cache_pos+seq_len] != 0), "Values not updated in cache"

    # Additional checks
    # Test that values outside updated range are still zero
    if old_cache_pos > 0:
        assert torch.all(model.cache_k[0:old_cache_pos] == 0) or torch.all(model.cache_v[0:old_cache_pos] == 0), "Previous cache values modified incorrectly"

    # Make another forward pass to check rolling
    input_tensor2 = torch.randn(batch_size, seq_len, dim)
    output2 = model(input_tensor2)

    # Output2 shape
    assert output2.shape == (batch_size, seq_len, dim), "Second output shape mismatch after second inference"

In [34]:
test_attention_inference_loop()

torch.Size([1, 3, 2, 4])
torch.Size([20, 2, 4])
torch.Size([1, 3, 2, 4])
torch.Size([20, 2, 4])


In [35]:
import torch

batch_size = 2
seq_len = 10
dim = 32
num_heads = 4
num_kv_heads = 2
window_size = 5

model = AttentionWithKVCache(dim=dim, num_heads=num_heads, num_kv_heads=num_kv_heads, window_size=window_size)
model.train()
model=model.to(device)
x = torch.randn(batch_size, seq_len, dim)
x=x.to(device)
out = model(x)

print(out.shape)
# Should print: (2, 10, 32)

torch.Size([2, 10, 32])


## Sparse mixture of experts

In [38]:
class SwiGLUFFN(nn.Module):
    # a single expert
    def __init__(self,input_dim,hidden_dim):
        super().__init__()
        self.w1=nn.Linear(input_dim,hidden_dim)
        self.w2=nn.Linear(input_dim,hidden_dim)
        self.out=nn.Linear(hidden_dim,input_dim)
    
    def forward(self,x):
        return self.out(self.w1(x) * F.silu(self.w2(x)))



In [45]:
x=torch.randn(2*10,32)

torch.topk(x,k=2,dim=-1)[0].shape


torch.Size([20, 2])

In [48]:
class SparseMOE(nn.Module):
    def __init__(self,d_model,d_hidden,num_experts=8,top_k=2):
        super().__init__()
        self.num_experts=num_experts
        self.top_k=top_k
        
        self.experts=nn.ModuleList(
            [
                SwiGLUFFN(input_dim=d_model,hidden_dim=d_hidden)
                for _ in range(num_experts)
            ]
        )
        
        self.router=nn.Linear(input_dim,num_experts)
    
    
    def forward(self,x):
        
        batch_size,seq_len,d_model=x.shape
        x_flat=x.view(-1,d_model) # (batch_size * seq_len, d_model)
        
        
        #Step 1: get router scores for each token
        router_logits=self.router(x_flat)
        router_probs=F.softmax(router_logits,dim=-1)
        
        #Step 2: get top-k experts
        topk_probs,topk_indices=torch.topk(router_probs,self.top_k,dim=-1) #(batch_size*seq_len, top_k)
        
        #Step 3: Compute outputs from selected experts
        expert_outputs=[]
        for i in range(self.top_k):
            expert_idx=topk_indices[:,i] 
            outputs=torch.zeros_like(x_flat)
            
            for expert_id in range(self.num_experts):
                mask=(expert_id==expert_idx)
                if mask.any():
                    selected_x=x_flat[mask]
                    expert_out=self.experts[expert_id](selected_x)
                    outputs[mask]=expert_out
            
            weighted_output = topk_probs[:, i].unsqueeze(-1) * outputs
            expert_outputs.append(weighted_output)

        
        # 4. Sum the expert outputs
        final_output = sum(expert_outputs)

        # 5. Reshape back to (batch_size, seq_len, d_model)
        final_output = final_output.view(batch_size, seq_len, d_model)
        
        router_probs_mean = router_probs.mean(dim=0)
        load_balancing_loss = (router_probs_mean * router_probs_mean).sum() * self.num_experts

        return final_output, load_balancing_loss
    
#final_loss = task_loss + router_loss_weight * router_loss

## ROPE

In [None]:
import torch
def precompute_theta_pos_frequencies(d_head,seq_len,theta=1000.0):
    assert d_head %2==0
    theta_nr=torch.arange(0,head_dim,2).float()
    theta=1.0 / ( theta**(theta_nr/d_head))
    m=torch.arange(seq_len)
    
    freqs=torch.outer(m,theta).float()
    
    freqs_complex=torch.polar(torch.ones_like(freqs),freqs)
    
    return freqs_complex


def apply_rotary_embeddings(x,freqs_complex):
    x_complex=torch.view()
    
    
    
    