### What is stored in `k` and `v`?

**Do they store FFN results?**
**In this specific code: Yes.**
If you look at the `forward` method:
```python
x = self.ffn(self.norm2(x))
if use_cache:
    return x, (x.clone(), x.clone())
```
The code saves `x` (the output of the FFN/Block) into the cache.

**3. Is this how standard Transformers work?**
**No, this is a simplified demonstration.**
In a real Transformer (like GPT or Llama):
*   The KV Cache stores the **internal Key and Value projections** inside the Attention layer.
*   It does **not** store the output of the FFN.
*   It does **not** store the output of the block.

*Note: Implementing a "real" KV cache requires writing a custom Attention layer instead of using `nn.MultiheadAttention`, as the PyTorch module hides the internal K/V projections.*

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import string
from typing import List, Tuple

In [None]:
class SimpleDecoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super(SimpleDecoderBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor, kv_cache = None, use_cache: bool = True) -> torch.Tensor:
        # x: (batch_size, seq_len, embed_dim)
        
        if kv_cache is not None:
            # Use cached key and value tensors for efficient decoding
            k, v = kv_cache
            x_attn, _ = self.self_attn(self.norm1(x), k, v, need_weights=False)
        else:
            # Compute self-attention normally
            x_attn, _ = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)

        x = x + x_attn
        x = self.ffn(self.norm2(x))
        
        if use_cache:
            # Update kv_cache with new key and value tensors
            return x, (x.clone().detach(), x.clone().detach())

        return x, None

### How KV Cache Works in `forward`

The `if kv_cache is not None:` block is the core of the optimization.

1.  **First Call (Prefill)**: When the model sees the prompt for the first time, `kv_cache` is usually `None`. The model processes all prompt tokens in parallel (`else` block).
2.  **Subsequent Calls (Decoding)**: When generating new tokens (or processing a new turn in a conversation), we pass the **new tokens only** as `x`.
    *   Instead of re-calculating attention for the entire history, we provide the pre-calculated Key and Value matrices via `kv_cache`.
    *   The model attends to the new `x` (Query) against the history in `kv_cache` (Key/Value).

**Is it a different batch?**
Usually, no. In inference, "batch size" refers to how many independent sequences we are generating in parallel.
*   `x` represents the **new time steps** for the *current* batch.
*   `kv_cache` represents the **past time steps** for the *current* batch.

If you were processing a completely unrelated request (a different user), you would indeed start with an empty cache.

### Example: Chatbot Conversation

Let's visualize this with a conversation to clarify "Generating Tokens" vs "New Turn".

**1. Turn 1 (User says "Hi") -> Prefill Phase**
*   **Input (`x`):** "Hi" (The whole prompt)
*   **Cache:** Empty (`None`)
*   **Action:** Model processes "Hi" from scratch.
*   **Result:**
    *   Predicts first token: "Hel"
    *   Returns Cache: KV("Hi")

**2. Generating Response (AI continues "Hello") -> Decoding Phase**
*   **Input (`x`):** "Hel" (Just the **newly generated** token)
*   **Cache:** KV("Hi") (From Step 1)
*   **Action:** Model attends "Hel" against cached "Hi".
*   **Result:**
    *   Predicts next token: "lo"
    *   Returns Cache: KV("Hi", "Hel")

**3. Turn 2 (User says "How are you?") -> New Turn Prefill**
*   **Context:** The history is "Hi" (User) + "Hello" (AI).
*   **Input (`x`):** "How are you?" (New user input)
*   **Cache:** KV("Hi", "Hello") (Saved from previous turn)
*   **Action:**
    *   We **do not** re-process "Hi" and "Hello".
    *   We pass `x="How are you?"` and `kv_cache=KV("Hi", "Hello")`.
    *   The model computes attention for "How are you?" attending to the cached "Hi" and "Hello".

In this notebook's loop, `round_id=2` is exactly like **Turn 2**. We feed new tokens (`token_tensors`) but provide the history (`kv_cache`) so the model understands the context without re-computing it.

In [None]:
class KVCacheManager:
    def __init__(self, max_cache_size: int = 64):
        self.cache : List[Tuple[torch.Tensor, torch.Tensor]] = []
        self.token_labels : List[str] = [] # To store labels for each token in the cache
        self.max_cache_size = max_cache_size

    def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
        if not self.cache:
            return None
        
        k = torch.cat([item[0] for item in self.cache], dim=1)  # Concatenate along sequence length
        v = torch.cat([item[1] for item in self.cache], dim=1)  # Concatenate along sequence length
        return (k, v) # shape of k or v: (batch_size, total_sequence_length, embed_dim)

    def update_cache(self, new_kv: Tuple[torch.Tensor, torch.Tensor], tokens : List[str], current_round : int):
        self.cache.append(new_kv)
        self.token_labels += [f"Round{current_round}"] * new_kv[0].size(1)  # Assuming new_kv[0] shape is (batch_size, seq_len, embed_dim)

        if len(self.token_labels) > self.max_cache_size:
            # Keep only current round tokens if cache is full
            # Note: The original logic was trying to filter based on labels. 
            # Since we append new_kv (current round) at the end, and we want to keep "Round{current_round}",
            # we can simply keep the last element of the cache if we assume previous rounds are what we want to discard.
            
            # Simplified logic to avoid tensor unpacking errors from original code
            self.cache = [self.cache[-1]] 
            self.token_labels = [label for label in self.token_labels if label == f"Round{current_round}"]

In [None]:
def generate_tokens(prompt : str, vocab : List[str], num_tokens: int = 5) -> List[str]:
    return [random.choice(vocab) for _ in range(num_tokens)]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
decoder = SimpleDecoderBlock(embed_dim=64, num_heads=4).to(device)
kv_cache_manager = KVCacheManager(max_cache_size=30)
vocab = list(string.ascii_lowercase)  # Example vocabulary

for round_id in range(1, 6):
    prompt = f"[Round {round_id}] User Input: write an function"
    tokens = generate_tokens(prompt, vocab)
    print(f"Round {round_id} generated tokens: {' '.join(tokens)}")

    # Simulate token embeddings
    token_tensors = torch.stack([torch.randn(64) for _ in tokens]).unsqueeze(0).to(device)  # shape: (1, seq_len, embed_dim)

    # Retrieve kv_cache and decode
    kv_cache = kv_cache_manager.get_cache()
    output, new_kv = decoder(token_tensors, kv_cache=kv_cache, use_cache=True)

    if new_kv is not None:
        kv_cache_manager.update_cache(new_kv, tokens, current_round=round_id)

    summary = ''.join(random.choices(string.ascii_lowercase, k=10))
    print(f"Round {round_id} summary: {summary}")

print("\n=== Final KV Cache State ===")
print(f"Current token number in cache: {len(kv_cache_manager.token_labels)}")
print(f"Round labels in cache: {kv_cache_manager.token_labels}")