In [3]:
import torch
import time

class StandardKVCache:
    def __init__(self, n_heads, seq_len, head_dim, dtype=torch.float32, device="cpu"):
        self.cache_k = torch.zeros(n_heads, seq_len, head_dim, dtype=dtype, device=device)
        self.cache_v = torch.zeros(n_heads, seq_len, head_dim, dtype=dtype, device=device)
        self.n_heads = n_heads
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.cur_len = 0

    def append(self, k_new, v_new):
        n = k_new.size(1)
        self.cache_k[:, self.cur_len:self.cur_len+n] = k_new
        self.cache_v[:, self.cur_len:self.cur_len+n] = v_new
        self.cur_len += n

    def get(self):
        return self.cache_k[:, :self.cur_len], self.cache_v[:, :self.cur_len]

class QuantizedKVCache:
    def __init__(self, n_heads, seq_len, head_dim, device="cpu"):
        self.n_heads = n_heads
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.max_val = 127  # int8
        self.device = device
        # We'll use per-head scaling for better accuracy
        self.scale_k = torch.ones(n_heads, 1, 1, device=device)
        self.scale_v = torch.ones(n_heads, 1, 1, device=device)
        self.cache_k = torch.zeros(n_heads, seq_len, head_dim, dtype=torch.int8, device=device)
        self.cache_v = torch.zeros(n_heads, seq_len, head_dim, dtype=torch.int8, device=device)
        self.cur_len = 0

    def append(self, k_new, v_new):
        n = k_new.size(1)
        # Update scale (abs max per head) to minimize quantization error
        self.scale_k = k_new.abs().amax(dim=(1,2), keepdim=True).clamp_min(1e-6)
        self.scale_v = v_new.abs().amax(dim=(1,2), keepdim=True).clamp_min(1e-6)
        # Quantize
        qk = (k_new / self.scale_k * self.max_val).round().clamp(-self.max_val, self.max_val).to(torch.int8)
        qv = (v_new / self.scale_v * self.max_val).round().clamp(-self.max_val, self.max_val).to(torch.int8)
        self.cache_k[:, self.cur_len:self.cur_len+n] = qk
        self.cache_v[:, self.cur_len:self.cur_len+n] = qv
        self.cur_len += n

    def get(self):
        # Dequantize before returning
        k = (self.cache_k[:, :self.cur_len].float() * self.scale_k / self.max_val)
        v = (self.cache_v[:, :self.cur_len].float() * self.scale_v / self.max_val)
        return k, v

def print_memory(tensor, label):
    numel = tensor.numel()
    bytes_per_elem = tensor.element_size()
    total_mb = numel * bytes_per_elem / 1e6
    print(f"{label}: {numel} elements, {bytes_per_elem} bytes/elem, {total_mb:.2f} MB")

In [4]:
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(0)
    n_heads = 16
    seq_len = 1024
    head_dim = 64
    batch_size = 1

    # Simulate random keys and values for a batch
    k = torch.randn(n_heads, seq_len, head_dim, device=device)
    v = torch.randn(n_heads, seq_len, head_dim, device=device)

    print("==> Standard (float32) KV Cache")
    std_cache = StandardKVCache(n_heads, seq_len, head_dim, device=device)
    start = time.time()
    std_cache.append(k, v)
    k_retr, v_retr = std_cache.get()
    elapsed = time.time() - start
    print_memory(std_cache.cache_k, "Standard cache_k")
    print(f"Runtime: {elapsed:.4f}s")

    print("\n==> Quantized (int8) KV Cache")
    q_cache = QuantizedKVCache(n_heads, seq_len, head_dim, device=device)
    start = time.time()
    q_cache.append(k, v)
    kq_retr, vq_retr = q_cache.get()
    elapsed = time.time() - start
    print_memory(q_cache.cache_k, "Quantized cache_k")
    print(f"Runtime: {elapsed:.4f}s")

    # Optional: check approximation quality
    error_k = torch.mean((k - kq_retr).abs()).item()
    error_v = torch.mean((v - vq_retr).abs()).item()
    print(f"\nQuantization Error: mean(|orig - dequant|) for k: {error_k:.6f}, for v: {error_v:.6f}")

if __name__ == "__main__":
    main()


==> Standard (float32) KV Cache
Standard cache_k: 1048576 elements, 4 bytes/elem, 4.19 MB
Runtime: 0.0021s

==> Quantized (int8) KV Cache
Quantized cache_k: 1048576 elements, 1 bytes/elem, 1.05 MB
Runtime: 0.0181s

Quantization Error: mean(|orig - dequant|) for k: 0.008725, for v: 0.008838
