In [1]:
import torch
from torch import nn
from torch.nn import functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


# KV Cache Attention

In [2]:
C = 2
B = 3

Wk = nn.Linear(C, C).float()
Wq = nn.Linear(C, C).float()
Wv = nn.Linear(C, C).float()

cached_K = torch.empty((B,0,C), dtype=torch.int64)
cached_V = torch.empty((B,0,C), dtype=torch.int64)
print(cached_K)

# a = torch.randn(B, 1, C)

# print(a.size())
# cached_K = torch.cat((cached_K,a), dim=1) # add a tensor to the cache
# cached_V = torch.cat((cached_V,a), dim=1)
# print(cached_K.size())

tensor([], size=(3, 0, 2), dtype=torch.int64)


In [3]:
# simulate the prefill phase, three requests arrive...

T = 3
X = torch.randn(B, T, C)

Q = Wq(X)
K = Wk(X)
V = Wv(X)

# store the K and V in cache
cached_K = torch.cat((cached_K,K), dim=1) # add a tensor to the cache
cached_V = torch.cat((cached_V,V), dim=1)
print(cached_K.size())

# calculate the attention score as always...
KQ = (
    Q @ K.transpose(-2, -1) * C**-0.5
)  # (B, T, C) @ (B, C, T) -> (B, T, T)
KQ = KQ.masked_fill(
    torch.tril(
        torch.ones(
            T, T, dtype=float
        )
    )
    == 0,
    float("-inf"),
)  # (B, T, T)
KQ = F.softmax(KQ, dim=-1)  # (B, T, T)

# perform the weighted aggregation of the values

O = KQ @ V  # (B, T, T) @ (B, T, C) -> (B, T, C)
print(O.size())

torch.Size([3, 3, 2])
torch.Size([3, 3, 2])


In [4]:
# Now we simulate one autoregressive step...

x_new = torch.randn(B, 1, C)
X_new = torch.cat((X, x_new), dim=1)
print(X_new.size())
x_new = X_new[:, -1:, :]
print(x_new.shape)

q_new = Wq(x_new)
k_new = Wk(x_new)
v_new = Wv(x_new)

q_new.shape

torch.Size([3, 4, 2])
torch.Size([3, 1, 2])


torch.Size([3, 1, 2])

In [5]:
cached_K = torch.cat((cached_K,k_new), dim=1)
cached_V = torch.cat((cached_V,v_new), dim=1)

print(cached_K.size())

torch.Size([3, 4, 2])


In [6]:
KQ = (
    q_new @ cached_K.transpose(-2, -1) * C**-0.5
)
KQ = F.softmax(KQ, dim=-1)
print(KQ.size())

O = KQ @ cached_V
O.size()

torch.Size([3, 1, 4])


torch.Size([3, 1, 2])

In [7]:
# Keep simulating the decoding phase

# cached_K = torch.empty((0,C), dtype=torch.int64)
# cached_V = torch.empty((0,C), dtype=torch.int64)

n_iters = 3
X = torch.randn(B, n_iters, C)

for i in range(n_iters):
    
    x_new = X[:, -1:, :]
    print(f"\n\nIter {i}: {x_new.size()}")

    q_new = Wq(x_new)
    k_new = Wk(x_new)
    v_new = Wv(x_new)

    cached_K = torch.cat((cached_K,k_new), dim=1)
    cached_V = torch.cat((cached_V,v_new), dim=1)

    print(f"cache size: 2 x {cached_K.size()}")

    kq = (
        q_new @ cached_K.transpose(-2, -1) * C**-0.5
    ) # [T+i+1, T+i+1]

    kq = kq.masked_fill(
        torch.tril(
            torch.ones(
                cached_K.size()[1], cached_K.size()[1], dtype=float
            )
        )
        == 0,
        float("-inf"),
    )
    kq = F.softmax(kq, dim=-1)

    out = kq @ cached_V
    print(out.size())



Iter 0: torch.Size([3, 1, 2])
cache size: 2 x torch.Size([3, 5, 2])
torch.Size([3, 5, 2])


Iter 1: torch.Size([3, 1, 2])
cache size: 2 x torch.Size([3, 6, 2])
torch.Size([3, 6, 2])


Iter 2: torch.Size([3, 1, 2])
cache size: 2 x torch.Size([3, 7, 2])
torch.Size([3, 7, 2])


In [19]:
class KVCache():
    def __init__(self, batch_size, max_tokens, num_heads, head_dim):
        self.kv_cache = torch.empty(
            2, batch_size, max_tokens, num_heads, head_dim,
        )
        self.max_tokens = max_tokens
        self.cumulative_length = 0

    def update(self, k, v):

        start = int(self.cumulative_length)
        end = start + k.size(-3)
        if end > self.max_tokens:
            raise ValueError("KVCache overflow: increase max_tokens")
        
        self.kv_cache[0, :, start:end, :, :] = k
        self.kv_cache[1, :, start:end, :, :] = v

        self.cumulative_length += k.size(-3)

        return self.kv_cache[0, :, :self.cumulative_length, :, :], self.kv_cache[1, :, :self.cumulative_length, :, :]
    
    def reset(self):
        """Clear the cache (start a new sequence)."""
        self.cumulative_length = 0


class MultiHeadAttention(nn.Module):
    """
    Multi-head scaled dot-product for causal attention.
    """

    def __init__(self, embed_dim: int, num_heads: int, attn_pdrop: float = 0.1, use_cache=False):
        super().__init__()

        if embed_dim % num_heads != 0:
            raise ValueError("embed_dim must be divisible by num_heads")
        
        self.use_cache = use_cache
        self.kv_cache = None

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.Wq = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wk = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wv = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wo = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        if self.use_cache:
            if self.kv_cache is None:
                # Prefill: full sequence
                self.kv_cache = KVCache(batch_size, 2048, self.num_heads, self.head_dim)
            else:
                # Decoding: only last token
                x = x[:, -1:, :]           # (B, 1, C)

    
        q = self.Wq(x).view(batch_size, -1, self.num_heads, self.head_dim)
        k = self.Wk(x).view(batch_size, -1, self.num_heads, self.head_dim)
        v = self.Wv(x).view(batch_size, -1, self.num_heads, self.head_dim)

        if self.use_cache:
            k, v = self.kv_cache.update(k, v)   # (B, T_total, nH, H)


        q = q.permute(0, 2, 1, 3)  # (B, nH, T, H)
        k = k.permute(0, 2, 1, 3)  # (B, nH, T, H)
        v = v.permute(0, 2, 1, 3)  # (B, nH, T, H)

        kq = (q @ k.transpose(-2, -1)) / (self.head_dim**0.5)  # (B, nH, T, T)

        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=kq.device, dtype=torch.bool), diagonal=1
        ) # (T, T)
        kq = kq.masked_fill(mask, float("-inf"))
        att = F.softmax(kq, dim=-1)

        o = att @ v

        o = o.permute(0, 2, 1, 3).contiguous()  # (B, T, nH, H)
        o = o.view(batch_size, seq_len, embed_dim)  # concat heads
        o = self.Wo(o)

        return o

B = 3
C = 8
nH = 2
T = 5

x = torch.randn(B, T, C)
multihead = MultiHeadAttention(C, nH, use_cache=1)  # 2 heads with 4 dimensions
multihead(x).size()

torch.Size([3, 5, 8])

In [20]:
B = 3
C = 8
nH = 2
T = 5
M = 2048
H = 4

multihead = MultiHeadAttention(C, nH, use_cache=1)
X = torch.randn(3, 3, C)

out = multihead(X)
X = torch.cat((X, out[:, -1:, :]), dim=1)
print('Prefill completed!')
for i in range(3):
    print(f'\nDecoding token {i+1}...')
    out = multihead(X)
    X = torch.cat((X, out[:, -1:, :]), dim=1)
    


Prefill completed!

Decoding token 1...

Decoding token 2...

Decoding token 3...


In [24]:
# ---------------------------
# Benchmark
# ---------------------------
import time

B, T, C, nH = 3, 3, 4, 2
x = torch.randn(B, T, C)

att_no_cache = MultiHeadAttention(C, nH, use_cache=False)
att_cache    = MultiHeadAttention(C, nH, use_cache=True)

# Measure
def bench(model, x, name):
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    t0 = time.time()
    for _ in range(1000):
        out = model(x)
        x = torch.cat((x, out[:, -1:, :]), dim=1)
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    print(f"{name}: { (time.time() - t0)*1000:.2f} ms")

bench(att_no_cache, x, "No cache")
bench(att_cache, x, "With cache")

No cache: 1388.83 ms
With cache: 1046.34 ms


In [None]:
# Many implementation  of KV-cache do not reflect how code is writen in production
# https://github.com/rasbt/LLMs-from-scratch/blob/main/ch04/03_kv-cache/gpt_with_kv_cache.py

import torch
import time

def grow_kv_cat(num_steps=2000, hidden_dim=1024, device="cpu"):
    """
    Naive KV growth using torch.cat.
    Every step:
       kv = torch.cat([kv, new_token], dim=0)
    which forces:
       - allocation of a new (seq_len+1, hidden_dim) tensor
       - full copy of old kv into it
    => O(T^2) memory copying
    """
    kv = torch.zeros((1, hidden_dim), device=device)
    byteqs_copied = 0
    bytes_written = 0

    if device == "cuda": torch.cuda.synchronize()
    start = time.perf_counter()

    for step in range(1, num_steps + 1):
        # new token row (1, hidden_dim)
        new_row = torch.full((1, hidden_dim), float(step), device=device)

        # bytes copied = size of old_kv
        bytes_copied += kv.numel() * kv.element_size()

        kv = torch.cat([kv, new_row], dim=0)

        # writing the new token row
        bytes_written += new_row.numel() * new_row.element_size()

    if device == "cuda": torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    return elapsed, bytes_copied, bytes_written


def prealloc_kv(num_steps=2000, hidden_dim=1024, device="cpu"):
    """
    Preallocation:
       kv = empty(max_len, hidden_dim)
    Just write into kv[step].
    => O(T) work, no copies of old KV.
    """
    max_len = num_steps + 1
    kv = torch.empty((max_len, hidden_dim), device=device)

    bytes_copied = 0
    bytes_written = 0

    if device == "cuda": torch.cuda.synchronize()
    start = time.perf_counter()

    for step in range(1, num_steps + 1):
        kv[step] = step
        bytes_written += hidden_dim * kv.element_size()

    if device == "cuda": torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    return elapsed, bytes_copied, bytes_written


if __name__ == "__main__":
    num_steps = 4000
    hidden_dim = 2048
    device = "cuda"  # or "cpu"

    t_cat, bytes_cat_copy, bytes_cat_write = grow_kv_cat(num_steps, hidden_dim, device)
    t_pre, bytes_pre_copy, bytes_pre_write = prealloc_kv(num_steps, hidden_dim, device)

    print(f"Naive grow-KV using torch.cat:")
    print(f"  time               : {t_cat:.4f} s")
    print(f"  bytes COPIED       : {bytes_cat_copy / 1e9:.3f} GB  (O(T^2))")
    print(f"  bytes WRITTEN      : {bytes_cat_write / 1e9:.3f} GB")

    print(f"\nPreallocated KV:")
    print(f"  time               : {t_pre:.4f} s")
    print(f"  bytes COPIED       : {bytes_pre_copy / 1e9:.3f} GB")
    print(f"  bytes WRITTEN      : {bytes_pre_write / 1e9:.3f} GB")

    print(f"\nSpeedup (prealloc vs cat): {t_cat / t_pre:.2f}x")


Naive grow-KV using torch.cat:
  time               : 0.4722 s
  bytes COPIED       : 65.552 GB  (O(T^2))
  bytes WRITTEN      : 0.033 GB

Preallocated KV:
  time               : 0.0252 s
  bytes COPIED       : 0.000 GB
  bytes WRITTEN      : 0.033 GB

Speedup (prealloc vs cat): 18.73x


# PagedAttention

In [26]:
import torch
from torch import nn
from torch.nn import functional as F

B = 3
C = 4
nH = 2
H = 2

num_blocks = 4
block_size = 5

kv_cache = torch.zeros((2, num_blocks, block_size, nH, H))  # 2 x (num_blocks, block_size, nH, H)
block_table = {i: [] for i in range(B)}
free_blocks = set(range(num_blocks))
block_table, free_blocks


({0: [], 1: [], 2: []}, {0, 1, 2, 3})

In [28]:
# simulate the prefill phase, three requests arrive...

T = 3
Wk = nn.Linear(C, C).float()
Wq = nn.Linear(C, C).float()
Wv = nn.Linear(C, C).float()

X = torch.randn(B, T, C)

K = Wk(X)  # (B,T,C)
V = Wv(X)  # (B,T,C)
Q = Wq(X)  # (B,T,C)

# compute attention as always
KQ = (
    Q @ K.transpose(-2, -1) * C**-0.5
)  # (B, T, C) @ (B, C, T) -> (B, T, T)
KQ = KQ.masked_fill(
    torch.tril(
        torch.ones(
            T, T, dtype=float
        )
    )
    == 0,
    float("-inf"),
)  # (B, T, T)
KQ = F.softmax(KQ, dim=-1)  # (B, T, T)
V = Wv(X)  # (B,T,C)
O = KQ @ V  # (B, T, T) @ (B, T, C) -> (B, T, C)

In [30]:
# DISCLAIMER!
# For didactive reason we'll set some values of the K matrix to be zero.
# This is meant to visualize padding in case of input sequences with 
# different legths. 
# In this way it is more intuitive seing how KV blocks are allocated.

# choose unique counts of 1s for each request
lengths = torch.tensor([1, 2, 5])

for i, L in enumerate(lengths):
    K[i, L:] = 0  # pad with zeros for vid√¨sualization 


print(K.size())
K

torch.Size([3, 3, 4])


tensor([[[-0.1408, -0.9432,  0.3591, -1.0831],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1292, -0.8815,  0.0480, -0.7154],
         [ 0.4501, -1.0432, -0.0692, -0.5334],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.7015, -1.1552,  0.1455, -0.8109],
         [ 0.1893,  0.7872, -0.1970,  0.2375],
         [-0.1401, -1.3495,  0.5314, -1.3789]]], grad_fn=<AsStridedBackward0>)

In [34]:
block_table = {i: [] for i in range(B)}
free_blocks = set(range(num_blocks))
req_lens = (K != 0).any(dim=-1).sum(dim=1)


# allocate in blocks
for b in range(B):
    t = 0
    while t < req_lens[b].item(): #K.size(1):
        # If there's a last block with free space, fill that first
        if block_table[b] and block_table[b][-1][1] < block_size:
            block_id, filled = block_table[b][-1]
        else:
            # Need a new block
            if not free_blocks:
                raise RuntimeError("No more free blocks. Implement eviction/preemption here.")
            block_id = free_blocks.pop()
            filled = 0
            block_table[b].append([block_id, 0])  # store mutable [block_id, filled]

        # take = min(block_size - filled, K.size(1) - t)
        take = min(block_size - filled, req_lens[b].item() - t)
        kv_cache[0, block_id, filled:filled+take, :, :] = K.view(B, T, nH, H)[b, t:t+take, :, :]
        kv_cache[1, block_id, filled:filled+take, :, :] = V.view(B, T, nH, H)[b, t:t+take, :, :]

        # Update filled count in block_table
        block_table[b][-1][1] = filled + take

        t += take

print(block_table, free_blocks)

kv_cache[0].view(4, 5, 4)

{0: [[0, 1]], 1: [[1, 2]], 2: [[2, 3]]} {3}


tensor([[[-0.1408, -0.9432,  0.3591, -1.0831],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1292, -0.8815,  0.0480, -0.7154],
         [ 0.4501, -1.0432, -0.0692, -0.5334],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.7015, -1.1552,  0.1455, -0.8109],
         [ 0.1893,  0.7872, -0.1970,  0.2375],
         [-0.1401, -1.3495,  0.5314, -1.3789],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<ViewBackward0>)

In [35]:
# now one autoregressive step....

x_new = torch.randn(B, 1, C)
X_new = torch.cat((X, x_new), dim=1)
print(f'X_new.size(): {X_new.size()}')
x_new = X_new[:, -1:, :]
req_lens = (x_new != 0).any(dim=-1).sum(dim=1)
print(f'req_lens: {req_lens}')

q_new = Wq(x_new)
k_new = Wk(x_new)
v_new = Wv(x_new)

for b in range(B):
    t = 0
    while t < req_lens[b].item():
        # If there's a last block with free space, fill that first
        if block_table[b] and block_table[b][-1][1] < block_size:
            block_id, filled = block_table[b][-1]
        else:
            # Need a new block
            if not free_blocks:
                raise RuntimeError("No more free blocks. Implement eviction/preemption here.")
            block_id = free_blocks.pop()
            filled = 0
            block_table[b].append([block_id, 0])  # store mutable [block_id, filled]

        take = min(block_size - filled, req_lens[b].item() - t)
        kv_cache[0, block_id, filled:filled+take, :, :] = K.view(B, T, nH, H)[b, t:t+take, :, :]
        kv_cache[1, block_id, filled:filled+take, :, :] = V.view(B, T, nH, H)[b, t:t+take, :, :]

        # Update filled count in block_table
        block_table[b][-1][1] = filled + take

        t += take
    
print(block_table, free_blocks)
kv_cache[0].view(4, 5, 4)

X_new.size(): torch.Size([3, 4, 4])
req_lens: tensor([1, 1, 1])
{0: [[0, 2]], 1: [[1, 3]], 2: [[2, 4]]} {3}


tensor([[[-0.1408, -0.9432,  0.3591, -1.0831],
         [-0.1408, -0.9432,  0.3591, -1.0831],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1292, -0.8815,  0.0480, -0.7154],
         [ 0.4501, -1.0432, -0.0692, -0.5334],
         [ 0.1292, -0.8815,  0.0480, -0.7154],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.7015, -1.1552,  0.1455, -0.8109],
         [ 0.1893,  0.7872, -0.1970,  0.2375],
         [-0.1401, -1.3495,  0.5314, -1.3789],
         [-0.7015, -1.1552,  0.1455, -0.8109],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<ViewBackward0>)

In [47]:
# Fetch dense K/V for each sequence
# Find max length across batch (for padding to common T_total)
max_len = max([sum(f for _, f in x) for x in block_table.values()])

# these two matrices will allow to reuse our attention code
# k_full = K.new_zeros(B, nH, max_len, H)
# v_full = V.new_zeros(B, nH, max_len, H)
k_full = K.new_zeros(B, max_len, nH, H)
v_full = V.new_zeros(B, max_len, nH, H)

for b in range(B):
    cur = 0
    for block_id, filled in block_table[b]:
        assert filled
        
        # kv_cache: (2, num_blocks, block_size, nH, H)
        k_block = kv_cache[0, block_id, :filled]  # (filled, nH, H)
        v_block = kv_cache[1, block_id, :filled]  # (filled, nH, H)

        k_full[b,  cur:cur+filled, :, :] = k_block
        v_full[b,  cur:cur+filled, :, :] = v_block
        cur += filled

print(B, max_len, nH*H)
k_full.view(B, max_len, nH*H)


3 4 4


tensor([[[-0.1408, -0.9432,  0.3591, -1.0831],
         [-0.1408, -0.9432,  0.3591, -1.0831],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1292, -0.8815,  0.0480, -0.7154],
         [ 0.4501, -1.0432, -0.0692, -0.5334],
         [ 0.1292, -0.8815,  0.0480, -0.7154],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.7015, -1.1552,  0.1455, -0.8109],
         [ 0.1893,  0.7872, -0.1970,  0.2375],
         [-0.1401, -1.3495,  0.5314, -1.3789],
         [-0.7015, -1.1552,  0.1455, -0.8109]]], grad_fn=<ViewBackward0>)

In [44]:
q_new.shape, k_full.shape

(torch.Size([3, 1, 4]), torch.Size([3, 4, 2, 2]))

In [50]:
kq = (q_new  @ k_full.view(B, max_len, nH*H).transpose(-2, -1)) / (C**0.5)  # (B, nH, T, T)

mask = torch.triu(
    torch.ones(1, 1, device=kq.device, dtype=torch.bool), diagonal=1
) # (T, T)
kq = kq.masked_fill(mask, float("-inf"))
att = F.softmax(kq, dim=-1)

o = att @ v_full.view(B, max_len, nH*H)
o.shape

torch.Size([3, 1, 4])

In [51]:
class PagedKVCache():
    def __init__(self, num_blocks, block_size, num_heads, head_dim):
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.block_table = None #{i: [] for i in range(B)}
        self.free_blocks = set(range(num_blocks))
        self.kv_cache = torch.empty((2, num_blocks, block_size, num_heads, head_dim)) # (2, nB, B, nH, H)

    def update(self, k, v):
        # allocate in blocks
        # batch_size, _, seq_len, _ = k.size()
        batch_size, seq_len, _, _ = k.size()
        if self.block_table is None:
            self.block_table = {i: [] for i in range(batch_size)}

        for b in range(batch_size):
            t = 0
            # while t < k.size(2):
            while t < k.size(1):
                # If there's a last block with free space, fill that first
                if self.block_table[b] and self.block_table[b][-1][1] < self.block_size:
                    block_id, filled = self.block_table[b][-1]
                else:
                    # Need a new block
                    if not self.free_blocks:
                        raise RuntimeError("No more free blocks. Implement eviction/preemption here.")
                    block_id = self.free_blocks.pop()
                    filled = 0
                    self.block_table[b].append([block_id, 0])  # store mutable [block_id, filled]

                # take = min(self.block_size - filled, k.size(2) - t)
                take = min(self.block_size - filled, k.size(1) - t)
                # print(k.shape)
                # print(batch_size, seq_len, self.num_heads, self.head_dim)
                self.kv_cache[0, block_id, filled:filled+take, :, :] = k.view(batch_size, seq_len, self.num_heads, self.head_dim)[b, t:t+take, :, :]
                self.kv_cache[1, block_id, filled:filled+take, :, :] = v.contiguous().view(batch_size, seq_len, self.num_heads, self.head_dim)[b, t:t+take, :, :]
                # self.kv_cache[0, block_id, filled:filled+take, :, :] = k.contiguous().view(batch_size, seq_len, self.num_heads, self.head_dim)[b, t:t+take, :, :]
                # self.kv_cache[1, block_id, filled:filled+take, :, :] = v.contiguous().view(batch_size, seq_len, self.num_heads, self.head_dim)[b, t:t+take, :, :]

                # Update filled count in block_table
                self.block_table[b][-1][1] = filled + take

                t += take

        # Fetch dense K/V for each sequence
        # Find max length across batch (for padding to common T_total)
        max_len = max([sum(f for _, f in x) for x in self.block_table.values()])

        # k_full = k.new_zeros(batch_size, self.num_heads, max_len, self.head_dim)
        # v_full = v.new_zeros(batch_size, self.num_heads, max_len, self.head_dim)

        k_full = k.new_zeros(batch_size, max_len, self.num_heads, self.head_dim)
        v_full = v.new_zeros(batch_size, max_len, self.num_heads, self.head_dim)

        for b in range(batch_size):
            cur = 0
            for block_id, filled in self.block_table[b]:
                assert filled
                
                # kv_cache: (2, num_blocks, block_size, nH, H)
                # slice: (filled, nH, H) -> permute to (nH, filled, H)
                k_block = self.kv_cache[0, block_id, :filled]  # (filled, nH, H)
                v_block = self.kv_cache[1, block_id, :filled]  # (filled, nH, H)

                # k_full[b, :, cur:cur+filled, :] = k_block.permute(1, 0, 2)
                # v_full[b, :, cur:cur+filled, :] = v_block.permute(1, 0, 2)

                # print(k_full[b, cur:cur+filled,:, :].shape, k_block.shape)

                k_full[b, cur:cur+filled,:, :] = k_block #.permute(1, 0, 2)
                v_full[b,  cur:cur+filled, :, :] = v_block#.permute(1, 0, 2)
                cur += filled

        return k_full, v_full

num_blocks, block_size = 5, 10

class MultiHeadAttention(nn.Module):
    """
    Multi-head scaled dot-product for causal attention.
    """

    def __init__(self, embed_dim: int, num_heads: int, attn_pdrop: float = 0.1, use_cache=False):
        super().__init__()

        if embed_dim % num_heads != 0:
            raise ValueError("embed_dim must be divisible by num_heads")
        
        self.use_cache = use_cache
        self.kv_cache = None

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.Wq = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wk = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wv = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wo = nn.Linear(embed_dim, embed_dim, bias=False)

        self.dropout = nn.Dropout(attn_pdrop)

    def forward(self, x):
        if self.use_cache and self.kv_cache is not None:
            x = x[:, -1:, :]
        
        batch_size, seq_len, embed_dim = x.shape

        if self.use_cache:
            if self.kv_cache is None:
                # Prefill: full sequence
                self.kv_cache = PagedKVCache(num_blocks, block_size, self.num_heads, self.head_dim)
            else:
                # Decoding: only last token
                x = x[:, -1:, :]           # (B, 1, C)

    
        q = self.Wq(x).view(batch_size, -1, self.num_heads, self.head_dim)
        k = self.Wk(x).view(batch_size, -1, self.num_heads, self.head_dim)
        v = self.Wv(x).view(batch_size, -1, self.num_heads, self.head_dim)

        if self.use_cache:
            k, v = self.kv_cache.update(k, v)   # (B, T_total, nH, H)

        # print(k.shape)

        q = q.permute(0, 2, 1, 3)  # (B, nH, T, H)
        k = k.permute(0, 2, 1, 3)  # (B, nH, T, H)
        v = v.permute(0, 2, 1, 3)  # (B, nH, T, H)

        # q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        # k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        # v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        # q = q.permute(0, 2, 1, 3)  # (B, nH, T, H)
        # k = k.permute(0, 2, 1, 3)  # (B, nH, T, H)
        # v = v.permute(0, 2, 1, 3)  # (B, nH, T, H)

        # if self.use_cache:
        #     if self.kv_cache is None:
        #         # Prefill: initialize kv cache
        #         self.kv_cache = PagedKVCache(num_blocks, block_size, self.num_heads, self.head_dim)
        #         k, v = self.kv_cache.update(k, v)  # 2 x (B, nH, T, H)
        #         print(self.kv_cache.block_table)
        #     else:
        #         # Decoding: add last token
        #         k, v = self.kv_cache.update(k[:, :, -1:, :], v[:, :, -1:, :])

        kq = (q @ k.transpose(-2, -1)) / (self.head_dim**0.5)  # (B, nH, T, T)

        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=kq.device, dtype=torch.bool), diagonal=1
        ) # (T, T)
        kq = kq.masked_fill(mask, float("-inf"))
        att = F.softmax(kq, dim=-1)

        att = self.dropout(att)

        o = att @ v

        o = o.permute(0, 2, 1, 3).contiguous()  # (B, T, nH, H)
        o = o.view(batch_size, seq_len, embed_dim)  # concat heads
        o = self.Wo(o)
        o = self.dropout(o)
        return o



B = 3
C = 4
nH = 2
T = 5
H = 2
multihead = MultiHeadAttention(C, nH, use_cache=1)  # 2 heads with 5 dimensions
X = torch.randn(B, T, C)

out = multihead(X)
print('Prefill completed!')
for i in range(3):
    print(f'\nDecoding token {i+1}...')
    out = multihead(X)
    X = torch.cat((X,out), dim=1)

Prefill completed!

Decoding token 1...

Decoding token 2...

Decoding token 3...
