In [1]:
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader


with open("/kaggle/input/small-text-sample-txt/small-text-sample.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

tokenizer = tiktoken.get_encoding("gpt2")
encoded_text = tokenizer.encode(raw_text)

vocab_size = 50257
output_dim = 256
max_len = 1024
context_length = max_len


token_embedding_layer = nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

max_length = 4
dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=max_length)

In [2]:
for batch in dataloader:
    x, y = batch

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings

    break

# Simple Multi-Head Attention

In [3]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self,x):
        b,n_tokens ,d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_query(x)
        # attention scores = QK^T
        attn_scores =  queries @ keys.transpose(1,2)
        attn_scores.masked_fill(self.mask.bool()[:n_tokens,:n_tokens],-torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights@values
        return context_vec

class MultiHeadAttentionWrapper(nn.Module):
     def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
         super().__init__()
         self.heads = nn.ModuleList(SelfAttention(d_in,d_out,context_length,dropout) for _ in range(num_heads))
         self.out_proj = nn.Linear(d_out*num_heads,d_out*num_heads)

     def forward(self,x):
        context_vec = torch.cat([head(x) for head in self.heads],dim=-1)
        return context_vec
        
            

In [9]:
import time
torch.manual_seed(123)
context_length = max_length
d_in = output_dim
num_heads = 2
d_out = d_in//num_heads
mha = MultiHeadAttentionWrapper(d_in,d_out,context_length,0.0,num_heads)
batch=input_embeddings
context_vec = mha(batch)
print(context_vec.shape)

torch.Size([8, 4, 256])


# Multi-Head Attention with KV cache

In [6]:
import time
import tiktoken
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False,max_seq_len=None,window_size=None):
        super().__init__()
        self.d_out = d_out
        self.num_heads=num_heads
        self.head_dim = d_out//self.num_heads ## Reduce the projection dim to match desired output dim
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.out_proj = nn.Linear(d_out,d_out)  #linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)


        ##
        self.max_seq_len = max_seq_len or context_length
        self.window_size = window_size or self.max_seq_len
        self.register_buffer("cache_k",None,persistent=False)
        self.register_buffer("cache_v",None,persistent=False)
        ##

    def forward(self,x, use_cache=False):
        b,num_tokens,d_in = x.shape
        if use_cache:
            # to prevent self.ptr_cur became negative
            assert num_tokens <= self.window_size, (
                f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). "
            )
            
        keys_new = self.W_key(x)
        values_new = self.W_value(x)
        queries_new = self.W_query(x)

        keys_new = keys_new.view(b,num_tokens,self.num_heads,self.head_dim)
        values_new = values_new.view(b,num_tokens,self.num_heads,self.head_dim)
        queries_new = queries_new.view(b,num_tokens,self.num_heads,self.head_dim)

        keys_new = keys_new.transpose(1,2)
        values_new = values_new.transpose(1,2)
        queries_new = queries_new.transpose(1,2)

        #####
        #new
        ########
        if use_cache:
            if self.cache_k is None or self.cache_k.size(0)!=b:
                self.cache_k = torch.zeros(b,self.num_heads,self.window_size,self.head_dim,device = x.device)
                self.cache_v = torch.zeros_like(self.cache_k)
                self.ptr_cur = 0
            # if incoming chunk would overlfow, discard oldest token

            if self.ptr_cur +num_tokens > self.window_size:
                overflow = self.ptr_cur + num_tokens - self.window_size
                # shift everythign left by overflow
                self.cache_k[:,:,:-overflow,:] = self.cache_k[:,:,overflow:,:].clone()
                self.cache_v[:,:,:-overflow,:] = self.cache_v[:,:,overflow:,:].clone()
                self.ptr_cur -= overflow
            
            self.cache_k[:,:,self.ptr_cur:self.ptr_cur+num_tokens,:] = keys_new
            self.cache_v[:,:,self.ptr_cur:self.ptr_cur+num_tokens,:] = values_new
            self.ptr_cur+=num_tokens


            keys = self.cache_k[:,:,:self.ptr_cur,:]
            values = self.cache_v[:,:,:self.ptr_cur,:]
        else:
            keys, values = keys_new,values_new
            self.ptr_cur=0

        attn_scores = queries_new@keys.transpose(2,3)

        K = attn_scores.size(-1)

        if num_tokens==K:
            casual_mask = torch.triu(torch.ones(num_tokens,K,device=x.device,dtype=torch.bool),diagonal=1)
        else:
            offset = K-num_tokens
            row_idx = torch.arange(num_tokens,device=x.device).unsqueeze(1)
            col_idx = torch.arange(K,device=x.device).unsqueeze(0)
            casual_mask = row_idx+offset<col_idx


        # use the mask to fill the attention scores
        attn_scores.masked_fill(casual_mask.unsqueeze(0).unsqueeze(0),-torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights@values).transpose(1,2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

    def reset_cache(self):
        self.cache_k,self.cache_v = None,None
        

In [10]:
torch.manual_seed(123)
max_len = 2048
context_length = max_len
d_in = output_dim
num_heads = 2
d_out = d_in//num_heads
mha = MultiHeadAttention(d_in,d_out,context_length,0.0,num_heads)
batch=input_embeddings
context_vec = mha(batch,False)
print(context_vec.shape)

torch.Size([8, 4, 128])


# Quantization


In [11]:
import time
import tiktoken
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False,max_seq_len=None,window_size=None):
        super().__init__()
        self.d_out = d_out
        self.num_heads=num_heads
        self.head_dim = d_out//self.num_heads 
        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.out_proj = nn.Linear(d_out,d_out)
        self.dropout = nn.Dropout(dropout)

        # Cache buffers
        self.max_seq_len = max_seq_len or context_length
        self.window_size = window_size or self.max_seq_len
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)
        self.register_buffer("cache_kq", None, persistent=False)
        self.register_buffer("cache_vq", None, persistent=False)
        self.ptr_cur = 0
        
        # Quant
        self.quant_bits = 8
        self.register_buffer("k_scale", torch.tensor(1.0))
        self.register_buffer("v_scale", torch.tensor(1.0))
        
    def quantize(self, tensor):
        scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127.0
        scale = scale.clamp(min=1e-5)
        q = torch.round(tensor / scale).clamp(-128, 127).to(torch.int8)
        return q, scale
    
    def dequantize(self, q_int8, scale):
        return q_int8.to(torch.float16) * scale
        
    def reset_cache(self):
        self.cache_k = self.cache_v = self.cache_kq = self.cache_vq = None
        self.ptr_cur = 0
        
    def forward(self, x, use_cache=False, use_quantize=False):
        b, num_tokens, _ = x.shape
        
        # Compute QKV
        queries_new = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        keys_new    = self.W_key(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values_new  = self.W_value(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        if use_cache:
            assert num_tokens <= self.window_size, f"Chunk {num_tokens} > window {self.window_size}"
            
            if self.cache_k is None or self.cache_k.shape[0] != b:
                self.cache_k  = torch.zeros(b, self.num_heads, self.window_size, self.head_dim, device=x.device, dtype=torch.float16)
                self.cache_v  = torch.zeros_like(self.cache_k)
                self.cache_kq = torch.zeros(b, self.num_heads, self.window_size, self.head_dim, device=x.device, dtype=torch.int8)
                self.cache_vq = torch.zeros_like(self.cache_kq)
                self.ptr_cur = 0

            # Overflow handling
            if self.ptr_cur + num_tokens > self.window_size:
                overflow = self.ptr_cur + num_tokens - self.window_size
                self.cache_k [:,:,:-overflow,:] = self.cache_k [:,:,overflow:,:].clone()
                self.cache_v [:,:,:-overflow,:] = self.cache_v [:,:,overflow:,:].clone()
                self.cache_kq[:,:,:-overflow,:] = self.cache_kq[:,:,overflow:,:].clone()
                self.cache_vq[:,:,:-overflow,:] = self.cache_vq[:,:,overflow:,:].clone()
                self.ptr_cur -= overflow

            # Store new keys/values
            if use_quantize:
                kq_new, k_scale_new = self.quantize(keys_new.float())
                vq_new, v_scale_new = self.quantize(values_new.float())
                self.cache_kq[:,:,self.ptr_cur:self.ptr_cur+num_tokens,:] = kq_new
                self.cache_vq[:,:,self.ptr_cur:self.ptr_cur+num_tokens,:] = vq_new
                keys = self.dequantize(self.cache_kq[:,:,:self.ptr_cur+num_tokens,:], self.k_scale)
                values = self.dequantize(self.cache_vq[:,:,:self.ptr_cur+num_tokens,:], self.v_scale)
                self.k_scale = 0.9 * self.k_scale + 0.1 * k_scale_new.mean()
                self.v_scale = 0.9 * self.v_scale + 0.1 * v_scale_new.mean()
            else:
                keys_new_fp16 = keys_new.to(torch.float16)
                values_new_fp16 = values_new.to(torch.float16)
                self.cache_k[:,:,self.ptr_cur:self.ptr_cur+num_tokens,:] = keys_new_fp16
                self.cache_v[:,:,self.ptr_cur:self.ptr_cur+num_tokens,:] = values_new_fp16
                keys = self.cache_k[:,:,:self.ptr_cur+num_tokens,:]
                values = self.cache_v[:,:,:self.ptr_cur+num_tokens,:]

            self.ptr_cur += num_tokens
        else:
            keys, values = keys_new.to(torch.float16), values_new.to(torch.float16)
            self.ptr_cur = 0

        # Cast queries to FP16
        queries_new = queries_new.to(torch.float16)

        # Attention (all FP16)
        attn_scores = torch.matmul(queries_new, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Causal mask
        K = keys.size(-2)
        if num_tokens == K:
            causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
        else:
            offset = K - num_tokens
            row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1)
            col_idx = torch.arange(K, device=x.device).unsqueeze(0)
            causal_mask = row_idx + offset < col_idx

        attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = torch.matmul(attn_weights, values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        # *** FIX: Cast back to FP32 for out_proj ***
        context_vec = context_vec.to(torch.float32)
        context_vec = self.out_proj(context_vec)
        
        return context_vec

if __name__ == "__main__":
    torch.manual_seed(123)
    context_length = 1024
    d_in = 512
    num_heads = 8
    d_out = d_in
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    attn = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads).to(device)
    long_input = torch.randn(1, 1024, d_in, device=device)
    
    print("=== NON-QUANTIZED (FP16 Cache) ===")
    attn.reset_cache()
    out_nonquant = attn(long_input, use_cache=True, use_quantize=False)
    fp16_mem = attn.cache_k.numel() * attn.cache_k.element_size() / 1e6
    print(f"Cache K: {attn.cache_k.dtype} {fp16_mem:.1f}MB (K+V: {fp16_mem*2:.1f}MB)")
    print(f"Output: {out_nonquant.shape}")

    print("\n=== QUANTIZED (INT8 Cache) ===")
    attn.reset_cache()
    out_quant = attn(long_input, use_cache=True, use_quantize=True)
    int8_mem = attn.cache_kq.numel() * attn.cache_kq.element_size() / 1e6
    print(f"Cache Kq: {attn.cache_kq.dtype} {int8_mem:.1f}MB (Kq+Vq: {int8_mem*2:.1f}MB)")
    print(f"Output: {out_quant.shape}")
    
    match = torch.allclose(out_nonquant, out_quant, atol=1e-2)
    print(f"\nOutputs match: {match}")
    print(f"Memory savings: {100*(1-int8_mem/fp16_mem):.0f}%")


=== NON-QUANTIZED (FP16 Cache) ===
Cache K: torch.float16 1.0MB (K+V: 2.1MB)
Output: torch.Size([1, 1024, 512])

=== QUANTIZED (INT8 Cache) ===
Cache Kq: torch.int8 0.5MB (Kq+Vq: 1.0MB)
Output: torch.Size([1, 1024, 512])

Outputs match: False
Memory savings: 50%
