In [1]:
import re
from collections import Counter
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from typing import Optional, Tuple, List

# Set seed for reproducibility
torch.manual_seed(1337)

@dataclass
class LlamaConfig:
    vocab_size: int = 1000
    dim: int = 512           # Embedding dimension
    n_layers: int = 4        # Number of transformer blocks
    n_heads: int = 8         # Number of query heads
    n_kv_heads: int = 4      # Number of key/value heads (Grouped Query Attention)
    multiple_of: int = 32    # MLP hidden layer multiple
    norm_eps: float = 1e-5
    max_seq_len: int = 128   # Max context window
    head_dim: int = dim // n_heads
    
    # RoPE Config
    rope_theta: float = 10000.0

config = LlamaConfig()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu


In [2]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)
    
    def extra_repr(self):
        """Make parameters visible in model summary"""
        return f"dim={self.weight.shape[0]}, eps={self.eps}"

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with polar form.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    # Reshape xq and xk to match complex representation
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # Apply rotation
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [3]:
class LlamaAttention(nn.Module):
    def __init__(self, cfg: LlamaConfig):
        super().__init__()
        self.n_heads = cfg.n_heads
        self.n_kv_heads = cfg.n_kv_heads
        self.head_dim = cfg.head_dim
        
        # Projections
        self.wq = nn.Linear(cfg.dim, self.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(cfg.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(cfg.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_heads * self.head_dim, cfg.dim, bias=False)
        
        self.cache_k = None
        self.cache_v = None

    def forward(
        self, 
        x: torch.Tensor, 
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        use_cache: bool = False
    ):
        bsz, seqlen, _ = x.shape
        print(f"Attention forward: bsz={bsz}, seqlen={seqlen}")
        
        # 1. QKV Projections
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        print(f"QKV projections: xq.shape={xq.shape}, xk.shape={xk.shape}, xv.shape={xv.shape}")

        # 2. Reshape for heads
        xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        print(f"After reshape: xq.shape={xq.shape}, xk.shape={xk.shape}, xv.shape={xv.shape}")
        
        # 3. Apply RoPE (Only to Query and Key)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
        print("RoPE applied")
        
        # 4. KV Cache Management
        if use_cache:
            if self.cache_k is None or self.cache_k.shape[1] == 0:
                self.cache_k = xk
                self.cache_v = xv
                print("KV cache initialized")
            else:
                self.cache_k = torch.cat((self.cache_k, xk), dim=1)
                self.cache_v = torch.cat((self.cache_v, xv), dim=1)
            keys, values = self.cache_k, self.cache_v
            print(f"KV cache updated: cache_k.shape={self.cache_k.shape}")
        else:
            self.cache_k, self.cache_v = None, None
            keys, values = xk, xv
            print("No KV cache used")

        # 5. Grouped Query Attention (GQA) Logic
        # We repeat keys/values to match n_heads
        # keys shape: (B, Seq, n_kv_heads, D) -> (B, Seq, n_heads, D)
        keys = torch.repeat_interleave(keys, repeats=self.n_heads // self.n_kv_heads, dim=2)
        values = torch.repeat_interleave(values, repeats=self.n_heads // self.n_kv_heads, dim=2)
        print(f"After GQA repeat: keys.shape={keys.shape}, values.shape={values.shape}")

        # 6. Transpose for Attention: (B, H, Seq, D)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        print(f"After transpose: xq.shape={xq.shape}, keys.shape={keys.shape}")

        # 7. Attention Calculation
        # scores: (B, H, Seq_Q, Seq_K)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        print(f"Scores computed: scores.shape={scores.shape}")
        
        if mask is not None:
            scores = scores + mask
            print("Mask applied")
        
        probs = F.softmax(scores.float(), dim=-1).type_as(xq)
        print(f"Probs shape: {probs.shape}")
        
        # output: (B, H, Seq_Q, D)
        output = torch.matmul(probs, values)
        print(f"Attention output shape: {output.shape}")

        # 8. Restore shape
        # Flatten heads: (B, Seq, H * D)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)        
        return self.wo(output)

In [4]:
class LlamaMLP(nn.Module):
    def __init__(self, cfg: LlamaConfig):
        super().__init__()
        hidden_dim = 4 * cfg.dim
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = cfg.multiple_of * ((hidden_dim + cfg.multiple_of - 1) // cfg.multiple_of)

        self.w1 = nn.Linear(cfg.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, cfg.dim, bias=False)
        self.w3 = nn.Linear(cfg.dim, hidden_dim, bias=False)

    def forward(self, x):
        # SwiGLU: w2(F.silu(w1(x)) * w3(x))
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class LlamaBlock(nn.Module):
    def __init__(self, cfg: LlamaConfig):
        super().__init__()
        self.attention = LlamaAttention(cfg)
        self.feed_forward = LlamaMLP(cfg)
        self.attention_norm = RMSNorm(cfg.dim, eps=cfg.norm_eps)
        self.ffn_norm = RMSNorm(cfg.dim, eps=cfg.norm_eps)

    def forward(self, x, freqs_cis, mask, use_cache):
        h = x + self.attention(self.attention_norm(x), freqs_cis, mask, use_cache)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

In [5]:
class Llama2(nn.Module):
    def __init__(self, cfg: LlamaConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_embeddings = nn.Embedding(cfg.vocab_size, cfg.dim)
        self.layers = nn.ModuleList([LlamaBlock(cfg) for _ in range(cfg.n_layers)])
        self.norm = RMSNorm(cfg.dim, eps=cfg.norm_eps)
        self.output = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
        
        # Precompute RoPE frequencies
        self.freqs_cis = precompute_freqs_cis(
            cfg.dim // cfg.n_heads, cfg.max_seq_len * 2, cfg.rope_theta
        ).to(device)
        print(f"Llama2 initialized: vocab_size={cfg.vocab_size}, dim={cfg.dim}, n_layers={cfg.n_layers}")

    def forward(
        self, 
        tokens: torch.Tensor, 
        start_pos: int = 0, 
        mask: Optional[torch.Tensor] = None,
        use_cache: bool = False
    ):
        bsz, seqlen = tokens.shape
        print(f"Forward pass start: batch_size={bsz}, seq_len={seqlen}, start_pos={start_pos}, use_cache={use_cache}")
        h = self.tok_embeddings(tokens)
        print(f"After embeddings: h.shape={h.shape}")
                
        # Fetch appropriate RoPE frequencies
        # During inference (use_cache=True), we only need freqs for the current position
        # During training (use_cache=False), we need freqs for 0..seqlen
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        print(f"RoPE freqs_cis shape: {freqs_cis.shape}")

        for i, layer in enumerate(self.layers):
            print(f"Layer {i} forward start")
            h = layer(h, freqs_cis, mask, use_cache)
            print(f"Layer {i} output shape: {h.shape}")
        h = self.norm(h)
        print(f"After norm: h.shape={h.shape}")
        logits = self.output(h)
        print(f"Logits shape: {logits.shape}")
        return logits

    def clear_kv_cache(self):
        for layer in self.layers:
            layer.attention.cache_k = None
            layer.attention.cache_v = None
        print("KV cache cleared")
        
    # Add this method to the Llama2 class

    def get_kv_cache_stats(self):
        """Print statistics about KV cache in all layers"""
        print("\n--- KV Cache Statistics ---")
        for i, layer in enumerate(self.layers):
            cache_k = layer.attention.cache_k
            cache_v = layer.attention.cache_v
            
            if cache_k is not None:
                print(f"Layer {i}:")
                print(f"  cache_k shape: {cache_k.shape}")
                print(f"  cache_v shape: {cache_v.shape}")
                print(f"  cache_k memory: {cache_k.element_size() * cache_k.nelement() / 1024:.2f} KB")
            else:
                print(f"Layer {i}: No cache")
        print("--- End KV Cache Stats ---\n")

In [6]:
# ============================================================================
# STEP 1: Simple BPE-like Tokenizer (Character-level + Word-level)
# ============================================================================

class SimpleTokenizer:
    """
    Simple tokenizer for training - splits on whitespace + punctuation,
    builds vocabulary from training corpus.
    """
    def __init__(self, texts: List[str | int], min_freq: int = 0):        
        self.unk_token = "<UNK>"
        self.pad_token = "<PAD>"
        self.bos_token = "<BOS>" # Beginning of Sequence
        self.eos_token = "<EOS>" # End of Sequence
        
        # Pre-assign indicers for special tokens
        self.specials = [self.unk_token, self.pad_token, self.bos_token, self.eos_token]
        
        # Add special tokens to vocab
        self.vocab = {token: i for i, token in enumerate(self.specials)}
        self.inverse_vocab = {i: token for i, token in enumerate(self.specials)}
        if texts is not None:
            self._build_vocab(texts=texts, min_freq=min_freq)
        
    @property
    def pad_id(self):
        return self.vocab[self.pad_token]
    
    @property
    def bos_id(self):
        return self.vocab[self.bos_token]
    
    @property
    def eos_id(self):
        return self.vocab[self.eos_token]
    
    def _tokenize(self, text: str) -> List[str]:
        """
        Splits text into words and punctuation. 
        Does NOT remove punctuation.
        """
        tokens = re.findall(r"\w+|[^\w\s]", text.lower(), re.UNICODE)
        return tokens

    def _build_vocab(self, texts: List[str], min_freq: int = 0):
        print("Building vocabulary...")
        all_tokens = []
        for text in texts:
            # Filter out punctuation when building vocab
            tokens = self._tokenize(text=text)
            all_tokens.extend(tokens)
        
        counter = Counter(all_tokens)

        # Add non-special tokens satisfying min_freq
        idx = len(self.vocab)
        for token, freq in counter.items():
            if freq >= min_freq:
                self.vocab[token] = idx
                self.inverse_vocab[idx] = token
                idx += 1
        print(f"Vocab size: {len(self.vocab)}")
    
    def encode_raw(self, text: str) -> List[int]:
        """
        Returns raw list of IDs with BOS/EOS but NO padding/truncation.
        Used for packing.
        """
        tokens = self._tokenize(text=text)
        ids = [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
        return [self.bos_id] + ids + [self.eos_id]
    
    def encode_inference(self, text: str) -> torch.Tensor:
        """
        Encodes text for INFERENCE (No padding, BOS only).
        Expects batch_size=1 logic usually.
        """
        tokens = self._tokenize(text=text)
        ids = [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
        
        # Inference Prompt: Add BOS, do NOT add EOS (model must generate it), do NOT Pad.
        ids = [self.bos_id] + ids
        return torch.tensor(ids, dtype=torch.long)        
    
    def decode(self, ids: list) -> str:
        """
        Convert token IDs to text
        """
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
        tokens = [self.inverse_vocab.get(id, self.unk_token) for id in ids]
        
        # Specialized decoding: filter out special tokens
        filtered = [token for token in tokens if token not in self.specials]
        
        # Simple heuristic to join punctuation nicely
        # (For a real BPE tokenizer, this is handled by the subword merge logic)
        out_str = " ".join(filtered)
        # Cleanup spaces before punctuation (simple hack for readability)
        out_str = re.sub(r'\s+([?.!,:;])', r'\1', out_str)       
        return out_str

In [7]:
class PackedDataset(Dataset):
    
    def __init__(
        self,
        texts: List[str],
        tokenizer: SimpleTokenizer,
        max_seq_len: int = 128
    ):
        super().__init__()
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        
        # Pack sequences once during initialization
        self.input_ids, self.position_ids, self.seq_ids = self.pack_all_sequences()
    
    def __len__(self):
        # Return number of packed sequences
        return len(self.input_ids)
    
    def __getitem__(self, index):
        # Direct lookup from pre-packed tensors
        input_ids = self.input_ids[index]
        position_ids = self.position_ids[index]
        seq_ids = self.seq_ids[index]
        
        # Labels: shift input_ids by 1 for next-token prediction
        # The last token's label is pad_id (will be ignored in loss)
        labels = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.pad_id])])
        
        # Generate combined mask: causal + sequence packing
        # Causal: prevent attending to future tokens
        # Sequence: prevent attending across different sequences
        combined_mask = self._generate_combined_mask(seq_ids)
        
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "seq_ids": seq_ids,
            "labels": labels,
            "combined_mask": combined_mask
        }
    
    def pack_all_sequences(self):
        sequences = self.encode_text(texts=self.texts)
        stream_input_ids = []
        stream_position_ids = []
        stream_seq_ids = []
        
        current_global_seq_id = 0
        
        # ...existing code...
        for seq in sequences:
            full_seq = seq + [self.tokenizer.eos_id]
            stream_input_ids.extend(full_seq)
            
            pos_ids = list(range(len(full_seq)))
            stream_position_ids.extend(pos_ids)
            
            stream_seq_ids.extend([current_global_seq_id] * len(full_seq))
            current_global_seq_id += 1    
        
        # ...existing code...
        packed_input_batches = []
        packed_pos_batches = []
        packed_seq_batches = []

        total_tokens = len(stream_input_ids)

        for i in range(0, total_tokens, self.max_seq_len):
            end_idx = i + self.max_seq_len
            chunk_input = stream_input_ids[i:end_idx]
            chunk_pos = stream_position_ids[i:end_idx]
            chunk_seq = stream_seq_ids[i:end_idx]
            
            if len(chunk_input) < self.max_seq_len:
                pad_len = self.max_seq_len - len(chunk_input)
                chunk_input.extend([self.tokenizer.pad_id] * pad_len)
                chunk_pos.extend([self.tokenizer.pad_id] * pad_len)
                chunk_seq.extend([self.tokenizer.pad_id] * pad_len)
                
            packed_input_batches.append(chunk_input)
            packed_pos_batches.append(chunk_pos)
            packed_seq_batches.append(chunk_seq)
        
        input_ids = torch.tensor(packed_input_batches, dtype=torch.long)
        position_ids = torch.tensor(packed_pos_batches, dtype=torch.long)
        seq_ids = torch.tensor(packed_seq_batches, dtype=torch.long)
        
        return input_ids, position_ids, seq_ids

    def encode_text(self, texts):
        all_token_ids = []
        for text in texts:
            ids = self.tokenizer.encode_raw(text=text)
            all_token_ids.append(ids)
        return all_token_ids
    
    def _generate_combined_mask(self, seq_ids: torch.Tensor, dtype=torch.float32):
        seq_len = seq_ids.shape[0]
        seq_ids_row = seq_ids.unsqueeze(-1)
        seq_ids_col = seq_ids.unsqueeze(-2)
        
        same_seq_mask = (seq_ids_row == seq_ids_col) & (seq_ids_col != self.tokenizer.pad_id)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=seq_ids.device)).bool()
        
        combined_mask = same_seq_mask & causal_mask
        
        final_mask = torch.zeros((seq_len, seq_len), dtype=dtype, device=seq_ids.device)
        final_mask = final_mask.masked_fill(~combined_mask, value=-1e9)
        return final_mask

In [8]:
raw_texts = [
        # Shakespeare-style
        """To be, or not to be: that is the question. Whether 'tis nobler in the mind to suffer
        The slings and arrows of outrageous fortune, Or to take arms against a sea of troubles,
        And by opposing end them? To die: to sleep; No more;""",
        
        # Technical / Math
        """Machine learning models like transformers have revolutionized natural language processing.
        The attention mechanism computes a weighted sum of value vectors:
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V.""",
        
        """Self-supervised learning on large corpora enables emergent capabilities.""",
        
        """This is an advanced optimization technique. Training Large Language Models (LLMs)
        like Llama 2 is expensive; Sequence Packing (also known as Sample Packing or Multipacking)
        is crucial because it eliminates the wasted computation caused by padding tokens.""",
        
        # Code
        """def train_step(model, batch, optimizer):
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                loss = model(batch)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            return loss.item()""",
            
        # General Prose
        """The quick brown fox jumps over the lazy dog. 
        Artificial intelligence is rapidly evolving, impacting sectors from healthcare to finance.
        Robust data pipelines are essential for stable model training.""",
        
        """I can outline an end-to-end approach and provide a cohesive code scaffold for training 
        Llama2-like KV-cache aware PyTorch model with optional KV cache during inference and sequence
        packing during training."""
    ]

# Simple 90/10 split
split_idx = int(len(raw_texts) * 0.8)
train_texts = raw_texts[:split_idx]
val_texts = raw_texts[split_idx:]

In [9]:
tokenizer = SimpleTokenizer(texts=train_texts)
max_seq_len = 15
batch_size =6

train_dataset = PackedDataset(texts=train_texts, tokenizer=tokenizer, max_seq_len=max_seq_len)
val_dataset = PackedDataset(texts=val_texts, tokenizer=tokenizer, max_seq_len=max_seq_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Packed Batches - Train: {len(train_loader)} | Val: {len(val_loader)}")

Building vocabulary...
Vocab size: 127
Packed Batches - Train: 3 | Val: 1


In [10]:
config = LlamaConfig(
    vocab_size=len(tokenizer.vocab),
    dim=256,
    n_layers=4,
    n_heads=4,
    n_kv_heads=2,
    multiple_of=2,
    max_seq_len=max_seq_len,
)

model = Llama2(cfg=config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_id)

Llama2 initialized: vocab_size=127, dim=256, n_layers=4


In [11]:
print(model)

Llama2(
  (tok_embeddings): Embedding(127, 256)
  (layers): ModuleList(
    (0-3): 4 x LlamaBlock(
      (attention): LlamaAttention(
        (wq): Linear(in_features=256, out_features=256, bias=False)
        (wk): Linear(in_features=256, out_features=128, bias=False)
        (wv): Linear(in_features=256, out_features=128, bias=False)
        (wo): Linear(in_features=256, out_features=256, bias=False)
      )
      (feed_forward): LlamaMLP(
        (w1): Linear(in_features=256, out_features=682, bias=False)
        (w2): Linear(in_features=682, out_features=256, bias=False)
        (w3): Linear(in_features=256, out_features=682, bias=False)
      )
      (attention_norm): RMSNorm(dim=256, eps=1e-05)
      (ffn_norm): RMSNorm(dim=256, eps=1e-05)
    )
  )
  (norm): RMSNorm(dim=256, eps=1e-05)
  (output): Linear(in_features=256, out_features=127, bias=False)
)


In [12]:
# --- Training Loop ---
print("\n--- Starting Training (with Packed Sequences) ---")
model.train()

for epoch in range(2):
    print(f"Epoch {epoch} start")
    total_loss = 0
    for i, input_data in enumerate(train_loader):
        input_ids = input_data["input_ids"].to(device)
        labels = input_data["labels"].to(device)
        combined_mask = input_data["combined_mask"].unsqueeze(1).to(device)
        print(f"Batch {i}: input_ids.shape={input_ids.shape}, labels.shape={labels.shape}, mask.shape={combined_mask.shape}")
        
        # Forward pass with combined mask
        logits = model(input_ids, mask=combined_mask, use_cache=False)
        print(f"Logits from model: shape={logits.shape}")
        
        # Reshape for loss
        loss = loss_fn(logits.view(-1, config.vocab_size), labels.view(-1))
        print(f"Loss computed: {loss.item():.4f}")

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if i % 5 == 0:
            print(f"Epoch {epoch} | Batch {i} | Loss: {loss.item():.4f}")

print("Training complete.")


--- Starting Training (with Packed Sequences) ---
Epoch 0 start
Batch 0: input_ids.shape=torch.Size([6, 15]), labels.shape=torch.Size([6, 15]), mask.shape=torch.Size([6, 1, 15, 15])
Forward pass start: batch_size=6, seq_len=15, start_pos=0, use_cache=False
After embeddings: h.shape=torch.Size([6, 15, 256])
RoPE freqs_cis shape: torch.Size([15, 32])
Layer 0 forward start
Attention forward: bsz=6, seqlen=15
QKV projections: xq.shape=torch.Size([6, 15, 256]), xk.shape=torch.Size([6, 15, 128]), xv.shape=torch.Size([6, 15, 128])
After reshape: xq.shape=torch.Size([6, 15, 4, 64]), xk.shape=torch.Size([6, 15, 2, 64]), xv.shape=torch.Size([6, 15, 2, 64])
RoPE applied
No KV cache used
After GQA repeat: keys.shape=torch.Size([6, 15, 4, 64]), values.shape=torch.Size([6, 15, 4, 64])
After transpose: xq.shape=torch.Size([6, 4, 15, 64]), keys.shape=torch.Size([6, 4, 15, 64])
Scores computed: scores.shape=torch.Size([6, 4, 15, 15])
Mask applied
Probs shape: torch.Size([6, 4, 15, 15])
Attention outpu

In [13]:
@torch.inference_mode()
def generate(model, prompt, max_new_tokens):
    model.eval()
    model.clear_kv_cache() # Reset cache fro new generation
    print("Generation start: clearing KV cache")
    
    # Encode the prompt text using the tokenizer
    tokens = tokenizer.encode_inference(prompt).unsqueeze(0).to(device=device)  # Shape: (1, seq_len)
    print(f"Encoded prompt tokens: {tokens}")
    
    # 1. Prefill Phase
    # We pass the whole prompt to fill the KV cache
    # Causal mask for the prompt: lower triangle (including diagonal) = 0, upper = -inf
    seq_len = tokens.shape[1]
    mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
    mask = torch.triu(mask, diagonal=1)
    print(f"Prefill mask shape: {mask.shape}")
    
    logits = model(tokens, start_pos=0, mask=mask, use_cache=True)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    
    generated = [next_token.item()]
    print(f"Initial generated sequence: {generated}")
    
    # 2. Generation Phase (Token by Token)
    cur_pos = seq_len
    input_token = next_token
    
    for i in range(max_new_tokens-1):
        print(f"Generating token {i+1}: cur_pos={cur_pos}")
        # Pass only the single new token, no mask needed for autoregressive generation
        logits = model(input_token, start_pos=cur_pos, mask=None, use_cache=True)
        next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        
        input_token = next_token
        cur_pos += 1
        generated.append(next_token.item())
        print(f"Token {i+1}: {next_token.item()}")
        print(f"generated: {generated}")
        
    # Decode the generated tokens to text
    decoded_text = tokenizer.decode(generated)
    print(f"Generated text: {decoded_text}")
    return generated, decoded_text

In [14]:
tokens = tokenizer.encode_inference("an end-to-end approach").unsqueeze(0).to(device=device)  # Shape: (1, seq_len)
print(f"Encoded prompt tokens: {tokens}")

Encoded prompt tokens: tensor([[ 2, 84, 36, 75,  4, 75, 36,  0]])


In [15]:
# ============================================================================
# TESTING THE GENERATE FUNCTION
# ============================================================================

# Example prompt for testing (a string input)
test_prompt = "Hello, how are you?"

# Call the generate function (after training the model)
generated_tokens, generated_text = generate(model, test_prompt, max_new_tokens=10)

# Print the results
print(f"\nFinal generated tokens: {generated_tokens}")
print(f"Final generated text: {generated_text}")

KV cache cleared
Generation start: clearing KV cache
Encoded prompt tokens: tensor([[ 2,  0,  6,  0,  0,  0, 38]])
Prefill mask shape: torch.Size([7, 7])
Forward pass start: batch_size=1, seq_len=7, start_pos=0, use_cache=True
After embeddings: h.shape=torch.Size([1, 7, 256])
RoPE freqs_cis shape: torch.Size([7, 32])
Layer 0 forward start
Attention forward: bsz=1, seqlen=7
QKV projections: xq.shape=torch.Size([1, 7, 256]), xk.shape=torch.Size([1, 7, 128]), xv.shape=torch.Size([1, 7, 128])
After reshape: xq.shape=torch.Size([1, 7, 4, 64]), xk.shape=torch.Size([1, 7, 2, 64]), xv.shape=torch.Size([1, 7, 2, 64])
RoPE applied
KV cache initialized
KV cache updated: cache_k.shape=torch.Size([1, 7, 2, 64])
After GQA repeat: keys.shape=torch.Size([1, 7, 4, 64]), values.shape=torch.Size([1, 7, 4, 64])
After transpose: xq.shape=torch.Size([1, 4, 7, 64]), keys.shape=torch.Size([1, 4, 7, 64])
Scores computed: scores.shape=torch.Size([1, 4, 7, 7])
Mask applied
Probs shape: torch.Size([1, 4, 7, 7])
A

In [16]:
model.get_kv_cache_stats()


--- KV Cache Statistics ---
Layer 0:
  cache_k shape: torch.Size([1, 16, 2, 64])
  cache_v shape: torch.Size([1, 16, 2, 64])
  cache_k memory: 8.00 KB
Layer 1:
  cache_k shape: torch.Size([1, 16, 2, 64])
  cache_v shape: torch.Size([1, 16, 2, 64])
  cache_k memory: 8.00 KB
Layer 2:
  cache_k shape: torch.Size([1, 16, 2, 64])
  cache_v shape: torch.Size([1, 16, 2, 64])
  cache_k memory: 8.00 KB
Layer 3:
  cache_k shape: torch.Size([1, 16, 2, 64])
  cache_v shape: torch.Size([1, 16, 2, 64])
  cache_k memory: 8.00 KB
--- End KV Cache Stats ---

